Skip to main content

agentic_tools_core/
schema.rs

1//! Schema engine for runtime transforms.
2
3use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6
7/// Field-level constraint to apply to a schema.
8#[derive(Clone, Debug)]
9pub enum FieldConstraint {
10    /// Restrict field to specific enum values.
11    Enum(Vec<Json>),
12
13    /// Apply numeric range constraints.
14    Range {
15        minimum: Option<Json>,
16        maximum: Option<Json>,
17    },
18
19    /// Apply string pattern constraint.
20    Pattern(String),
21
22    /// Apply a JSON merge-patch to the field schema.
23    MergePatch(Json),
24}
25
26/// Trait for custom schema transforms.
27pub trait SchemaTransform: Send + Sync {
28    /// Apply the transform to a tool's schema.
29    fn apply(&self, tool: &str, schema: &mut Json);
30}
31
32/// Engine for applying runtime transforms to tool schemas.
33///
34/// Schemars derive generates base schemas at compile time.
35/// SchemaEngine applies transforms at runtime for provider flexibility.
36///
37/// # Clone behavior
38/// When cloned, `custom_transforms` are **not** carried over (they are not `Clone`).
39/// Only `per_tool` constraints and `global_strict` settings are cloned.
40#[derive(Default)]
41pub struct SchemaEngine {
42    per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
43    global_strict: bool,
44    custom_transforms: Vec<Box<dyn SchemaTransform>>,
45}
46
47impl Clone for SchemaEngine {
48    fn clone(&self) -> Self {
49        // Custom transforms cannot be cloned, so we only clone the config
50        Self {
51            per_tool: self.per_tool.clone(),
52            global_strict: self.global_strict,
53            custom_transforms: Vec::new(), // Transforms are not cloned
54        }
55    }
56}
57
58impl std::fmt::Debug for SchemaEngine {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("SchemaEngine")
61            .field("per_tool", &self.per_tool)
62            .field("global_strict", &self.global_strict)
63            .field(
64                "custom_transforms",
65                &format!("[{} transforms]", self.custom_transforms.len()),
66            )
67            .finish()
68    }
69}
70
71impl SchemaEngine {
72    /// Create a new schema engine.
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    /// Enable strict mode (additionalProperties=false) globally.
78    pub fn with_strict(mut self, strict: bool) -> Self {
79        self.global_strict = strict;
80        self
81    }
82
83    /// Get global strict mode setting.
84    pub fn is_strict(&self) -> bool {
85        self.global_strict
86    }
87
88    /// Add a field constraint for a specific tool.
89    ///
90    /// The `json_path` is a list of property names to traverse to reach the field.
91    /// For example, `["properties", "count"]` would target the "count" property.
92    pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
93        self.per_tool
94            .entry(tool.to_string())
95            .or_default()
96            .push((json_path, c));
97    }
98
99    /// Add a custom transform.
100    pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
101        self.custom_transforms.push(Box::new(transform));
102    }
103
104    /// Transform a tool's schema applying all constraints and transforms.
105    pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
106        let mut v = serde_json::to_value(&schema).expect("serialize schema");
107
108        // Apply global strict mode
109        if self.global_strict
110            && let Some(obj) = v.as_object_mut()
111        {
112            obj.insert("additionalProperties".to_string(), Json::Bool(false));
113        }
114
115        // Apply per-tool constraints
116        if let Some(entries) = self.per_tool.get(tool) {
117            for (path, constraint) in entries {
118                Self::apply_constraint(&mut v, path, constraint);
119            }
120        }
121
122        // Apply custom transforms
123        for transform in &self.custom_transforms {
124            transform.apply(tool, &mut v);
125        }
126
127        // try_from only rejects non-object/non-bool JSON values.  Since we start
128        // from a valid Schema (always an object) and built-in transforms only mutate
129        // sub-nodes, failure here means a custom SchemaTransform replaced the root
130        // type — a programming error that must surface immediately.
131        Schema::try_from(v).expect("schema transform must produce a valid schema")
132    }
133
134    fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
135        let Some(node) = Self::find_node_mut(root, path) else {
136            return;
137        };
138        let Some(obj) = node.as_object_mut() else {
139            return;
140        };
141        match constraint {
142            FieldConstraint::Enum(vals) => {
143                obj.insert("enum".into(), Json::Array(vals.clone()));
144            }
145            FieldConstraint::Range { minimum, maximum } => {
146                if let Some(m) = minimum {
147                    obj.insert("minimum".into(), m.clone());
148                }
149                if let Some(m) = maximum {
150                    obj.insert("maximum".into(), m.clone());
151                }
152            }
153            FieldConstraint::Pattern(p) => {
154                obj.insert("pattern".into(), Json::String(p.clone()));
155            }
156            FieldConstraint::MergePatch(patch) => {
157                json_patch::merge(node, patch);
158            }
159        }
160    }
161
162    fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
163        let mut cur = root;
164        for seg in path {
165            cur = cur.as_object_mut()?.get_mut(seg)?;
166        }
167        Some(cur)
168    }
169}
170
171// ============================================================================
172// Centralized Draft 2020-12 Generator for MCP + Registry
173// ============================================================================
174
175/// Centralized schema generation using Draft 2020-12 with AddNullable transform.
176///
177/// This module provides cached schema generation matching the MCP Rust SDK pattern:
178/// - JSON Schema Draft 2020-12 (MCP protocol requirement)
179/// - AddNullable transform for `Option<T>` fields
180/// - Thread-local caching keyed by TypeId for performance
181pub mod mcp_schema {
182    use schemars::JsonSchema;
183    use schemars::Schema;
184    use schemars::generate::SchemaSettings;
185    use schemars::transform::AddNullable;
186    use schemars::transform::RestrictFormats;
187    use schemars::transform::Transform;
188    use std::any::TypeId;
189    use std::cell::RefCell;
190    use std::collections::HashMap;
191    use std::sync::Arc;
192
193    thread_local! {
194        static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
195        static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
196    }
197
198    /// Sanitizes null-only schema branches that AddNullable produces.
199    /// Converts `{"const": null, "nullable": true}` (no type) → `{"type": "null"}`
200    #[derive(Clone, Copy, Default)]
201    struct SanitizeNullBranches;
202
203    impl Transform for SanitizeNullBranches {
204        fn transform(&mut self, schema: &mut Schema) {
205            // Serialize to JSON, sanitize recursively, deserialize back
206            let mut v = serde_json::to_value(&*schema).expect("serialize schema for sanitize");
207            sanitize_null_branches_recursive(&mut v);
208            *schema = Schema::try_from(v).expect("rebuild sanitized schema");
209        }
210    }
211
212    fn sanitize_null_branches_recursive(node: &mut serde_json::Value) {
213        use serde_json::Value as Json;
214        match node {
215            Json::Object(map) => {
216                // Fix pattern: {"const": null, "nullable": true} without "type"
217                let has_nullable_true = map
218                    .get("nullable")
219                    .and_then(|v| v.as_bool())
220                    .unwrap_or(false);
221                let const_is_null = map.get("const").map(|v| v.is_null()).unwrap_or(false);
222                let has_type = map.contains_key("type");
223
224                if has_nullable_true && const_is_null && !has_type {
225                    map.remove("const");
226                    map.remove("nullable");
227                    map.insert("type".to_string(), Json::String("null".to_string()));
228                }
229
230                // Recurse into all values (covers subschemas at arbitrary keys)
231                for value in map.values_mut() {
232                    sanitize_null_branches_recursive(value);
233                }
234            }
235            Json::Array(arr) => {
236                for elem in arr {
237                    sanitize_null_branches_recursive(elem);
238                }
239            }
240            _ => {}
241        }
242    }
243
244    fn settings() -> SchemaSettings {
245        SchemaSettings::draft2020_12()
246            .with_transform(AddNullable::default())
247            .with_transform(RestrictFormats::default())
248            .with_transform(SanitizeNullBranches)
249    }
250
251    /// Generate a cached schema for type T using Draft 2020-12 + AddNullable.
252    pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
253        CACHE_FOR_TYPE.with(|cache| {
254            let mut cache = cache.borrow_mut();
255            if let Some(x) = cache.get(&TypeId::of::<T>()) {
256                return x.clone();
257            }
258            let generator = settings().into_generator();
259            let root = generator.into_root_schema_for::<T>();
260            let arc = Arc::new(root);
261            cache.insert(TypeId::of::<T>(), arc.clone());
262            arc
263        })
264    }
265
266    /// Generate a cached output schema for type T, validating root type is "object".
267    /// Returns Err if the root type is not "object" (per MCP spec requirement).
268    pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
269        CACHE_FOR_OUTPUT.with(|cache| {
270            let mut cache = cache.borrow_mut();
271            if let Some(r) = cache.get(&TypeId::of::<T>()) {
272                return r.clone();
273            }
274            let root = cached_schema_for::<T>();
275            let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
276            let result = match json.get("type") {
277                Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
278                Some(serde_json::Value::String(t)) => Err(format!(
279                    "MCP requires output_schema root type 'object', found '{}'",
280                    t
281                )),
282                None => {
283                    // Schema might use $ref or other patterns without explicit type
284                    // Accept if it has properties (likely an object schema)
285                    if json.get("properties").is_some() {
286                        Ok(root.clone())
287                    } else {
288                        Err(
289                            "Schema missing 'type' — output_schema must have root type 'object'"
290                                .to_string(),
291                        )
292                    }
293                }
294                Some(other) => Err(format!(
295                    "Unexpected 'type' format: {:?} — expected string 'object'",
296                    other
297                )),
298            };
299            cache.insert(TypeId::of::<T>(), result.clone());
300            result
301        })
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use serde::Serialize;
309
310    #[derive(schemars::JsonSchema, Serialize)]
311    struct TestInput {
312        count: i32,
313        name: String,
314    }
315
316    #[test]
317    fn test_strict_mode() {
318        let engine = SchemaEngine::new().with_strict(true);
319        let schema = schemars::schema_for!(TestInput);
320        let transformed = engine.transform("test", schema);
321
322        let json = serde_json::to_value(&transformed).unwrap();
323        assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
324    }
325
326    #[test]
327    fn test_is_strict_getter() {
328        let e = SchemaEngine::new();
329        assert!(!e.is_strict());
330        let e2 = SchemaEngine::new().with_strict(true);
331        assert!(e2.is_strict());
332    }
333
334    #[test]
335    fn test_enum_constraint() {
336        let mut engine = SchemaEngine::new();
337
338        // Use a simple schema object for testing
339        let test_schema: Json = serde_json::json!({
340            "type": "object",
341            "properties": {
342                "name": {
343                    "type": "string"
344                }
345            }
346        });
347
348        engine.constrain_field(
349            "test",
350            vec!["properties".into(), "name".into()],
351            FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
352        );
353
354        let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
355        let transformed = engine.transform("test", schema);
356
357        let json = serde_json::to_value(&transformed).unwrap();
358        let name_schema = &json["properties"]["name"];
359        assert!(name_schema.get("enum").is_some());
360    }
361
362    #[test]
363    fn test_range_constraint() {
364        // Test that range constraints are applied to the correct schema path
365        let mut engine = SchemaEngine::new();
366        engine.constrain_field(
367            "test",
368            vec!["properties".into(), "count".into()],
369            FieldConstraint::Range {
370                minimum: Some(Json::Number(0.into())),
371                maximum: Some(Json::Number(100.into())),
372            },
373        );
374
375        // Use schemars to generate a real schema
376        let schema = schemars::schema_for!(TestInput);
377
378        // The transform function modifies the schema
379        let transformed = engine.transform("test", schema);
380
381        // Verify the range constraints were applied
382        let json = serde_json::to_value(&transformed).unwrap();
383        let count_schema = &json["properties"]["count"];
384
385        // Verify range was applied (compare as f64 since schemars may use floats)
386        let min = count_schema.get("minimum").and_then(|v| v.as_f64());
387        let max = count_schema.get("maximum").and_then(|v| v.as_f64());
388
389        assert_eq!(min, Some(0.0), "minimum constraint should be applied");
390        assert_eq!(max, Some(100.0), "maximum constraint should be applied");
391    }
392
393    // ========================================================================
394    // mcp_schema module tests
395    // ========================================================================
396
397    mod mcp_schema_tests {
398        use super::mcp_schema;
399        use serde::Serialize;
400
401        #[derive(schemars::JsonSchema, Serialize)]
402        struct WithOption {
403            a: Option<String>,
404        }
405
406        #[test]
407        fn test_central_generator_addnullable() {
408            let root = mcp_schema::cached_schema_for::<WithOption>();
409            let v = serde_json::to_value(root.as_ref()).unwrap();
410            let a = &v["properties"]["a"];
411            // AddNullable should add "nullable": true
412            assert_eq!(
413                a.get("nullable"),
414                Some(&serde_json::Value::Bool(true)),
415                "Option<T> fields should have nullable: true"
416            );
417        }
418
419        #[derive(schemars::JsonSchema, Serialize)]
420        struct OutputObj {
421            x: i32,
422        }
423
424        #[test]
425        fn test_output_schema_validation_object() {
426            let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
427            assert!(
428                ok.is_ok(),
429                "Object types should pass output schema validation"
430            );
431        }
432
433        #[test]
434        fn test_output_schema_validation_non_object() {
435            // String is not an object type
436            let bad = mcp_schema::cached_output_schema_for::<String>();
437            assert!(
438                bad.is_err(),
439                "Non-object types should fail output schema validation"
440            );
441        }
442
443        #[test]
444        fn test_draft_2020_12_uses_defs() {
445            let root = mcp_schema::cached_schema_for::<WithOption>();
446            let v = serde_json::to_value(root.as_ref()).unwrap();
447            // Draft 2020-12 should use $defs, not definitions
448            // Note: simple types may not have $defs, so we just verify
449            // the schema is valid and contains expected structure
450            assert!(v.is_object(), "Schema should be an object");
451            assert!(
452                v.get("$schema")
453                    .and_then(|s| s.as_str())
454                    .is_some_and(|s| s.contains("2020-12")),
455                "Schema should reference Draft 2020-12"
456            );
457        }
458
459        #[test]
460        fn test_caching_returns_same_arc() {
461            let first = mcp_schema::cached_schema_for::<OutputObj>();
462            let second = mcp_schema::cached_schema_for::<OutputObj>();
463            assert!(
464                std::sync::Arc::ptr_eq(&first, &second),
465                "Cached schemas should return the same Arc"
466            );
467        }
468
469        // ====================================================================
470        // SanitizeNullBranches and RestrictFormats transform tests
471        // ====================================================================
472
473        #[allow(dead_code)]
474        #[derive(schemars::JsonSchema, Serialize)]
475        enum TestEnum {
476            A,
477            B,
478        }
479
480        #[derive(schemars::JsonSchema, Serialize)]
481        struct HasOptEnum {
482            e: Option<TestEnum>,
483        }
484
485        #[test]
486        fn test_option_enum_anyof_null_branch_has_type() {
487            let root = mcp_schema::cached_schema_for::<HasOptEnum>();
488            let v = serde_json::to_value(root.as_ref()).unwrap();
489            let any_of = v["properties"]["e"]["anyOf"]
490                .as_array()
491                .expect("Option<Enum> should generate anyOf");
492
493            // There must be a branch with explicit type "null"
494            assert!(
495                any_of
496                    .iter()
497                    .any(|b| b.get("type") == Some(&serde_json::json!("null"))),
498                "anyOf for Option<Enum> must include a branch with type:\"null\""
499            );
500
501            // No branch should have nullable:true without a type
502            for branch in any_of {
503                let has_nullable = branch.get("nullable") == Some(&serde_json::json!(true));
504                let has_type = branch.get("type").is_some() || branch.get("$ref").is_some();
505                assert!(
506                    !has_nullable || has_type,
507                    "No branch may contain nullable:true without a type or $ref"
508                );
509            }
510        }
511
512        #[derive(schemars::JsonSchema, Serialize)]
513        struct Unsigneds {
514            a: u32,
515            b: u64,
516        }
517
518        #[test]
519        fn test_strip_uint_formats() {
520            let root = mcp_schema::cached_schema_for::<Unsigneds>();
521            let v = serde_json::to_value(root.as_ref()).unwrap();
522            let pa = &v["properties"]["a"];
523            let pb = &v["properties"]["b"];
524
525            assert!(
526                pa.get("format").is_none(),
527                "u32 should not include non-standard 'format'"
528            );
529            assert!(
530                pb.get("format").is_none(),
531                "u64 should not include non-standard 'format'"
532            );
533            assert_eq!(
534                pa.get("minimum").and_then(|x| x.as_u64()),
535                Some(0),
536                "u32 minimum must be preserved"
537            );
538            assert_eq!(
539                pb.get("minimum").and_then(|x| x.as_u64()),
540                Some(0),
541                "u64 minimum must be preserved"
542            );
543        }
544
545        #[derive(schemars::JsonSchema, Serialize)]
546        struct HasOptString {
547            s: Option<String>,
548        }
549
550        #[test]
551        fn test_option_string_preserves_nullable() {
552            let root = mcp_schema::cached_schema_for::<HasOptString>();
553            let v = serde_json::to_value(root.as_ref()).unwrap();
554            let s = &v["properties"]["s"];
555
556            assert_eq!(
557                s.get("type"),
558                Some(&serde_json::json!("string")),
559                "Option<String> should have type: string"
560            );
561            assert_eq!(
562                s.get("nullable"),
563                Some(&serde_json::json!(true)),
564                "Option<String> should retain nullable: true"
565            );
566        }
567    }
568}