gproxy_protocol/transform/
utils.rs1use std::borrow::Cow;
2use std::collections::{BTreeMap, HashSet};
3use std::error::Error;
4use std::fmt::{Display, Formatter};
5
6use serde_json::Value;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct TransformError {
10 pub message: Cow<'static, str>,
11}
12
13impl TransformError {
14 pub const fn not_implemented(message: &'static str) -> Self {
19 Self {
20 message: Cow::Borrowed(message),
21 }
22 }
23
24 pub fn new(message: impl Into<String>) -> Self {
29 Self {
30 message: Cow::Owned(message.into()),
31 }
32 }
33}
34
35impl Display for TransformError {
36 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37 f.write_str(&self.message)
38 }
39}
40
41impl Error for TransformError {}
42
43pub type TransformResult<T> = Result<T, TransformError>;
44
45pub use crate::transform::claude::utils::{ORPHAN_TOOL_USE_PLACEHOLDER_NAME, push_message_block};
50
51pub fn enforce_anthropic_strict_schema(schema: &mut BTreeMap<String, Value>) {
59 let mut tmp: serde_json::Map<String, Value> = std::mem::take(schema).into_iter().collect();
60 enforce_anthropic_strict_value_map(&mut tmp);
61 *schema = tmp.into_iter().collect();
62}
63
64fn enforce_anthropic_strict_value(value: &mut Value) {
65 match value {
66 Value::Object(map) => enforce_anthropic_strict_value_map(map),
67 Value::Array(arr) => {
68 for v in arr.iter_mut() {
69 enforce_anthropic_strict_value(v);
70 }
71 }
72 _ => {}
73 }
74}
75
76fn enforce_anthropic_strict_value_map(map: &mut serde_json::Map<String, Value>) {
77 if let Some(Value::Object(props)) = map.get_mut("properties") {
78 for (_, v) in props.iter_mut() {
79 enforce_anthropic_strict_value(v);
80 }
81 }
82 if let Some(items) = map.get_mut("items") {
83 enforce_anthropic_strict_value(items);
84 }
85 for key in ["$defs", "definitions"] {
86 if let Some(Value::Object(defs)) = map.get_mut(key) {
87 for (_, v) in defs.iter_mut() {
88 enforce_anthropic_strict_value(v);
89 }
90 }
91 }
92 for key in ["allOf", "anyOf", "oneOf"] {
93 if let Some(Value::Array(arr)) = map.get_mut(key) {
94 for v in arr.iter_mut() {
95 enforce_anthropic_strict_value(v);
96 }
97 }
98 }
99
100 let is_object_schema = map.get("type").and_then(|v| v.as_str()) == Some("object")
101 || map.contains_key("properties");
102 if !is_object_schema {
103 return;
104 }
105
106 map.insert("additionalProperties".to_string(), Value::Bool(false));
107
108 let prop_keys: Vec<String> = map
109 .get("properties")
110 .and_then(|v| v.as_object())
111 .map(|props| props.keys().cloned().collect())
112 .unwrap_or_default();
113 if prop_keys.is_empty() {
114 return;
115 }
116
117 let required = map
118 .entry("required".to_string())
119 .or_insert_with(|| Value::Array(Vec::new()));
120 if let Value::Array(arr) = required {
121 let existing: HashSet<String> = arr
122 .iter()
123 .filter_map(|v| v.as_str().map(str::to_string))
124 .collect();
125 for key in prop_keys {
126 if !existing.contains(&key) {
127 arr.push(Value::String(key));
128 }
129 }
130 }
131}
132
133#[cfg(test)]
134mod enforce_anthropic_strict_schema_tests {
135 use super::*;
136 use serde_json::json;
137
138 fn run(input: serde_json::Value) -> serde_json::Value {
139 let mut schema: BTreeMap<String, Value> =
140 input.as_object().unwrap().clone().into_iter().collect();
141 enforce_anthropic_strict_schema(&mut schema);
142 Value::Object(schema.into_iter().collect())
143 }
144
145 #[test]
146 fn top_level_object_gets_additional_properties_and_required() {
147 let out = run(json!({
148 "type": "object",
149 "properties": {
150 "name": {"type": "string"},
151 "age": {"type": "integer"}
152 }
153 }));
154 assert_eq!(out["additionalProperties"], json!(false));
155 let required: HashSet<String> = out["required"]
156 .as_array()
157 .unwrap()
158 .iter()
159 .map(|v| v.as_str().unwrap().to_string())
160 .collect();
161 assert_eq!(
162 required,
163 ["name", "age"].iter().map(|s| s.to_string()).collect()
164 );
165 }
166
167 #[test]
168 fn nested_objects_in_properties_and_array_items_are_patched() {
169 let out = run(json!({
170 "type": "object",
171 "properties": {
172 "user": {
173 "type": "object",
174 "properties": {"name": {"type": "string"}}
175 },
176 "tags": {
177 "type": "array",
178 "items": {
179 "type": "object",
180 "properties": {"id": {"type": "string"}}
181 }
182 }
183 }
184 }));
185 assert_eq!(
186 out["properties"]["user"]["additionalProperties"],
187 json!(false)
188 );
189 assert_eq!(out["properties"]["user"]["required"], json!(["name"]));
190 assert_eq!(
191 out["properties"]["tags"]["items"]["additionalProperties"],
192 json!(false)
193 );
194 assert_eq!(
195 out["properties"]["tags"]["items"]["required"],
196 json!(["id"])
197 );
198 }
199
200 #[test]
201 fn defs_and_anyof_branches_are_patched() {
202 let out = run(json!({
203 "type": "object",
204 "properties": {"x": {"$ref": "#/$defs/X"}},
205 "$defs": {
206 "X": {"type": "object", "properties": {"a": {"type": "string"}}}
207 },
208 "anyOf": [
209 {"type": "object", "properties": {"b": {"type": "integer"}}}
210 ]
211 }));
212 assert_eq!(out["$defs"]["X"]["additionalProperties"], json!(false));
213 assert_eq!(out["$defs"]["X"]["required"], json!(["a"]));
214 assert_eq!(out["anyOf"][0]["additionalProperties"], json!(false));
215 assert_eq!(out["anyOf"][0]["required"], json!(["b"]));
216 }
217
218 #[test]
219 fn existing_additional_properties_true_is_overwritten() {
220 let out = run(json!({
221 "type": "object",
222 "additionalProperties": true,
223 "properties": {"k": {"type": "string"}}
224 }));
225 assert_eq!(out["additionalProperties"], json!(false));
226 }
227
228 #[test]
229 fn existing_required_is_extended_not_replaced() {
230 let out = run(json!({
231 "type": "object",
232 "required": ["a"],
233 "properties": {"a": {"type": "string"}, "b": {"type": "string"}}
234 }));
235 let required: HashSet<String> = out["required"]
236 .as_array()
237 .unwrap()
238 .iter()
239 .map(|v| v.as_str().unwrap().to_string())
240 .collect();
241 assert_eq!(required, ["a", "b"].iter().map(|s| s.to_string()).collect());
242 }
243
244 #[test]
245 fn non_object_schemas_are_left_alone() {
246 let out = run(json!({"type": "string", "format": "uuid"}));
247 assert!(out.get("additionalProperties").is_none());
248 assert!(out.get("required").is_none());
249 }
250}