Skip to main content

gproxy_protocol/transform/
utils.rs

1use std::borrow::Cow;
2use std::collections::{BTreeMap, HashSet};
3use std::error::Error;
4use std::fmt::{Display, Formatter};
5
6use serde_json::Value;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct TransformError {
10    pub message: Cow<'static, str>,
11}
12
13impl TransformError {
14    /// Construct a `TransformError` with a static string message.
15    ///
16    /// Kept for backwards compatibility with `TryFrom` impls that use
17    /// compile-time string literals for "not yet supported" cases.
18    pub const fn not_implemented(message: &'static str) -> Self {
19        Self {
20            message: Cow::Borrowed(message),
21        }
22    }
23
24    /// Construct a `TransformError` with a dynamically-built message.
25    ///
26    /// Used by the runtime transform dispatcher in `crate::transform::dispatch`
27    /// which reports errors like "no stream aggregation for protocol: {protocol}".
28    pub fn new(message: impl Into<String>) -> Self {
29        Self {
30            message: Cow::Owned(message.into()),
31        }
32    }
33}
34
35impl Display for TransformError {
36    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37        f.write_str(&self.message)
38    }
39}
40
41impl Error for TransformError {}
42
43pub type TransformResult<T> = Result<T, TransformError>;
44
45// `push_message_block` lives next to the other Claude-side helpers in
46// `transform::claude::utils`. Re-exported here so that callers reach it via
47// the generic `transform::utils` path without a cross-module dependency on
48// the `claude` submodule.
49pub use crate::transform::claude::utils::{ORPHAN_TOOL_USE_PLACEHOLDER_NAME, push_message_block};
50
51/// Patch a JSON Schema in place so it satisfies Anthropic's strict-mode
52/// requirements: every `object` node gets `additionalProperties: false`, and
53/// every key in `properties` is added to `required`.
54///
55/// Anthropic rejects object schemas that omit `additionalProperties`, set it
56/// to `true`, or fail to list every property in `required`. OpenAI/Gemini
57/// schemas typically don't satisfy these constraints.
58pub fn enforce_anthropic_strict_schema(schema: &mut BTreeMap<String, Value>) {
59    let mut tmp: serde_json::Map<String, Value> = std::mem::take(schema).into_iter().collect();
60    enforce_anthropic_strict_value_map(&mut tmp);
61    *schema = tmp.into_iter().collect();
62}
63
64fn enforce_anthropic_strict_value(value: &mut Value) {
65    match value {
66        Value::Object(map) => enforce_anthropic_strict_value_map(map),
67        Value::Array(arr) => {
68            for v in arr.iter_mut() {
69                enforce_anthropic_strict_value(v);
70            }
71        }
72        _ => {}
73    }
74}
75
76fn enforce_anthropic_strict_value_map(map: &mut serde_json::Map<String, Value>) {
77    if let Some(Value::Object(props)) = map.get_mut("properties") {
78        for (_, v) in props.iter_mut() {
79            enforce_anthropic_strict_value(v);
80        }
81    }
82    if let Some(items) = map.get_mut("items") {
83        enforce_anthropic_strict_value(items);
84    }
85    for key in ["$defs", "definitions"] {
86        if let Some(Value::Object(defs)) = map.get_mut(key) {
87            for (_, v) in defs.iter_mut() {
88                enforce_anthropic_strict_value(v);
89            }
90        }
91    }
92    for key in ["allOf", "anyOf", "oneOf"] {
93        if let Some(Value::Array(arr)) = map.get_mut(key) {
94            for v in arr.iter_mut() {
95                enforce_anthropic_strict_value(v);
96            }
97        }
98    }
99
100    let is_object_schema = map.get("type").and_then(|v| v.as_str()) == Some("object")
101        || map.contains_key("properties");
102    if !is_object_schema {
103        return;
104    }
105
106    map.insert("additionalProperties".to_string(), Value::Bool(false));
107
108    let prop_keys: Vec<String> = map
109        .get("properties")
110        .and_then(|v| v.as_object())
111        .map(|props| props.keys().cloned().collect())
112        .unwrap_or_default();
113    if prop_keys.is_empty() {
114        return;
115    }
116
117    let required = map
118        .entry("required".to_string())
119        .or_insert_with(|| Value::Array(Vec::new()));
120    if let Value::Array(arr) = required {
121        let existing: HashSet<String> = arr
122            .iter()
123            .filter_map(|v| v.as_str().map(str::to_string))
124            .collect();
125        for key in prop_keys {
126            if !existing.contains(&key) {
127                arr.push(Value::String(key));
128            }
129        }
130    }
131}
132
133#[cfg(test)]
134mod enforce_anthropic_strict_schema_tests {
135    use super::*;
136    use serde_json::json;
137
138    fn run(input: serde_json::Value) -> serde_json::Value {
139        let mut schema: BTreeMap<String, Value> =
140            input.as_object().unwrap().clone().into_iter().collect();
141        enforce_anthropic_strict_schema(&mut schema);
142        Value::Object(schema.into_iter().collect())
143    }
144
145    #[test]
146    fn top_level_object_gets_additional_properties_and_required() {
147        let out = run(json!({
148            "type": "object",
149            "properties": {
150                "name": {"type": "string"},
151                "age": {"type": "integer"}
152            }
153        }));
154        assert_eq!(out["additionalProperties"], json!(false));
155        let required: HashSet<String> = out["required"]
156            .as_array()
157            .unwrap()
158            .iter()
159            .map(|v| v.as_str().unwrap().to_string())
160            .collect();
161        assert_eq!(
162            required,
163            ["name", "age"].iter().map(|s| s.to_string()).collect()
164        );
165    }
166
167    #[test]
168    fn nested_objects_in_properties_and_array_items_are_patched() {
169        let out = run(json!({
170            "type": "object",
171            "properties": {
172                "user": {
173                    "type": "object",
174                    "properties": {"name": {"type": "string"}}
175                },
176                "tags": {
177                    "type": "array",
178                    "items": {
179                        "type": "object",
180                        "properties": {"id": {"type": "string"}}
181                    }
182                }
183            }
184        }));
185        assert_eq!(
186            out["properties"]["user"]["additionalProperties"],
187            json!(false)
188        );
189        assert_eq!(out["properties"]["user"]["required"], json!(["name"]));
190        assert_eq!(
191            out["properties"]["tags"]["items"]["additionalProperties"],
192            json!(false)
193        );
194        assert_eq!(
195            out["properties"]["tags"]["items"]["required"],
196            json!(["id"])
197        );
198    }
199
200    #[test]
201    fn defs_and_anyof_branches_are_patched() {
202        let out = run(json!({
203            "type": "object",
204            "properties": {"x": {"$ref": "#/$defs/X"}},
205            "$defs": {
206                "X": {"type": "object", "properties": {"a": {"type": "string"}}}
207            },
208            "anyOf": [
209                {"type": "object", "properties": {"b": {"type": "integer"}}}
210            ]
211        }));
212        assert_eq!(out["$defs"]["X"]["additionalProperties"], json!(false));
213        assert_eq!(out["$defs"]["X"]["required"], json!(["a"]));
214        assert_eq!(out["anyOf"][0]["additionalProperties"], json!(false));
215        assert_eq!(out["anyOf"][0]["required"], json!(["b"]));
216    }
217
218    #[test]
219    fn existing_additional_properties_true_is_overwritten() {
220        let out = run(json!({
221            "type": "object",
222            "additionalProperties": true,
223            "properties": {"k": {"type": "string"}}
224        }));
225        assert_eq!(out["additionalProperties"], json!(false));
226    }
227
228    #[test]
229    fn existing_required_is_extended_not_replaced() {
230        let out = run(json!({
231            "type": "object",
232            "required": ["a"],
233            "properties": {"a": {"type": "string"}, "b": {"type": "string"}}
234        }));
235        let required: HashSet<String> = out["required"]
236            .as_array()
237            .unwrap()
238            .iter()
239            .map(|v| v.as_str().unwrap().to_string())
240            .collect();
241        assert_eq!(required, ["a", "b"].iter().map(|s| s.to_string()).collect());
242    }
243
244    #[test]
245    fn non_object_schemas_are_left_alone() {
246        let out = run(json!({"type": "string", "format": "uuid"}));
247        assert!(out.get("additionalProperties").is_none());
248        assert!(out.get("required").is_none());
249    }
250}