Skip to main content

agentic_tools_core/
schema.rs

1//! Schema engine for runtime transforms.
2
3use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6
7/// Field-level constraint to apply to a schema.
8#[derive(Clone, Debug)]
9pub enum FieldConstraint {
10    /// Restrict field to specific enum values.
11    Enum(Vec<Json>),
12
13    /// Apply numeric range constraints.
14    Range {
15        minimum: Option<Json>,
16        maximum: Option<Json>,
17    },
18
19    /// Apply string pattern constraint.
20    Pattern(String),
21
22    /// Apply a JSON merge-patch to the field schema.
23    MergePatch(Json),
24}
25
26/// Trait for custom schema transforms.
27pub trait SchemaTransform: Send + Sync {
28    /// Apply the transform to a tool's schema.
29    fn apply(&self, tool: &str, schema: &mut Json);
30}
31
32/// Engine for applying runtime transforms to tool schemas.
33///
34/// Schemars derive generates base schemas at compile time.
35/// SchemaEngine applies transforms at runtime for provider flexibility.
36///
37/// # Clone behavior
38/// When cloned, `custom_transforms` are **not** carried over (they are not `Clone`).
39/// Only `per_tool` constraints and `global_strict` settings are cloned.
40#[derive(Default)]
41pub struct SchemaEngine {
42    per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
43    global_strict: bool,
44    custom_transforms: Vec<Box<dyn SchemaTransform>>,
45}
46
47impl Clone for SchemaEngine {
48    fn clone(&self) -> Self {
49        // Custom transforms cannot be cloned, so we only clone the config
50        Self {
51            per_tool: self.per_tool.clone(),
52            global_strict: self.global_strict,
53            custom_transforms: Vec::new(), // Transforms are not cloned
54        }
55    }
56}
57
58impl std::fmt::Debug for SchemaEngine {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("SchemaEngine")
61            .field("per_tool", &self.per_tool)
62            .field("global_strict", &self.global_strict)
63            .field(
64                "custom_transforms",
65                &format!("[{} transforms]", self.custom_transforms.len()),
66            )
67            .finish()
68    }
69}
70
71impl SchemaEngine {
72    /// Create a new schema engine.
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    /// Enable strict mode (additionalProperties=false) globally.
78    pub fn with_strict(mut self, strict: bool) -> Self {
79        self.global_strict = strict;
80        self
81    }
82
83    /// Get global strict mode setting.
84    pub fn is_strict(&self) -> bool {
85        self.global_strict
86    }
87
88    /// Add a field constraint for a specific tool.
89    ///
90    /// The `json_path` is a list of property names to traverse to reach the field.
91    /// For example, `["properties", "count"]` would target the "count" property.
92    pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
93        self.per_tool
94            .entry(tool.to_string())
95            .or_default()
96            .push((json_path, c));
97    }
98
99    /// Add a custom transform.
100    pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
101        self.custom_transforms.push(Box::new(transform));
102    }
103
104    /// Transform a tool's schema applying all constraints and transforms.
105    pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
106        let mut v = serde_json::to_value(&schema).expect("serialize schema");
107
108        // Apply global strict mode
109        if self.global_strict
110            && let Some(obj) = v.as_object_mut()
111        {
112            obj.insert("additionalProperties".to_string(), Json::Bool(false));
113        }
114
115        // Apply per-tool constraints
116        if let Some(entries) = self.per_tool.get(tool) {
117            for (path, constraint) in entries {
118                Self::apply_constraint(&mut v, path, constraint);
119            }
120        }
121
122        // Apply custom transforms
123        for transform in &self.custom_transforms {
124            transform.apply(tool, &mut v);
125        }
126
127        // try_from only rejects non-object/non-bool JSON values.  Since we start
128        // from a valid Schema (always an object) and built-in transforms only mutate
129        // sub-nodes, failure here means a custom SchemaTransform replaced the root
130        // type — a programming error that must surface immediately.
131        Schema::try_from(v).expect("schema transform must produce a valid schema")
132    }
133
134    fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
135        let Some(node) = Self::find_node_mut(root, path) else {
136            return;
137        };
138        let Some(obj) = node.as_object_mut() else {
139            return;
140        };
141        match constraint {
142            FieldConstraint::Enum(vals) => {
143                obj.insert("enum".into(), Json::Array(vals.clone()));
144            }
145            FieldConstraint::Range { minimum, maximum } => {
146                if let Some(m) = minimum {
147                    obj.insert("minimum".into(), m.clone());
148                }
149                if let Some(m) = maximum {
150                    obj.insert("maximum".into(), m.clone());
151                }
152            }
153            FieldConstraint::Pattern(p) => {
154                obj.insert("pattern".into(), Json::String(p.clone()));
155            }
156            FieldConstraint::MergePatch(patch) => {
157                json_patch::merge(node, patch);
158            }
159        }
160    }
161
162    fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
163        let mut cur = root;
164        for seg in path {
165            cur = cur.as_object_mut()?.get_mut(seg)?;
166        }
167        Some(cur)
168    }
169}
170
171// ============================================================================
172// Centralized Draft 2020-12 Generator for MCP + Registry
173// ============================================================================
174
175/// Centralized schema generation using Draft 2020-12 with AddNullable transform.
176///
177/// This module provides cached schema generation matching the MCP Rust SDK pattern:
178/// - JSON Schema Draft 2020-12 (MCP protocol requirement)
179/// - AddNullable transform for `Option<T>` fields
180/// - Thread-local caching keyed by TypeId for performance
181pub mod mcp_schema {
182    use schemars::generate::SchemaSettings;
183    use schemars::transform::{AddNullable, RestrictFormats, Transform};
184    use schemars::{JsonSchema, Schema};
185    use std::any::TypeId;
186    use std::cell::RefCell;
187    use std::collections::HashMap;
188    use std::sync::Arc;
189
190    thread_local! {
191        static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
192        static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
193    }
194
195    /// Sanitizes null-only schema branches that AddNullable produces.
196    /// Converts `{"const": null, "nullable": true}` (no type) → `{"type": "null"}`
197    #[derive(Clone, Copy, Default)]
198    struct SanitizeNullBranches;
199
200    impl Transform for SanitizeNullBranches {
201        fn transform(&mut self, schema: &mut Schema) {
202            // Serialize to JSON, sanitize recursively, deserialize back
203            let mut v = serde_json::to_value(&*schema).expect("serialize schema for sanitize");
204            sanitize_null_branches_recursive(&mut v);
205            *schema = Schema::try_from(v).expect("rebuild sanitized schema");
206        }
207    }
208
209    fn sanitize_null_branches_recursive(node: &mut serde_json::Value) {
210        use serde_json::Value as Json;
211        match node {
212            Json::Object(map) => {
213                // Fix pattern: {"const": null, "nullable": true} without "type"
214                let has_nullable_true = map
215                    .get("nullable")
216                    .and_then(|v| v.as_bool())
217                    .unwrap_or(false);
218                let const_is_null = map.get("const").map(|v| v.is_null()).unwrap_or(false);
219                let has_type = map.contains_key("type");
220
221                if has_nullable_true && const_is_null && !has_type {
222                    map.remove("const");
223                    map.remove("nullable");
224                    map.insert("type".to_string(), Json::String("null".to_string()));
225                }
226
227                // Recurse into all values (covers subschemas at arbitrary keys)
228                for value in map.values_mut() {
229                    sanitize_null_branches_recursive(value);
230                }
231            }
232            Json::Array(arr) => {
233                for elem in arr {
234                    sanitize_null_branches_recursive(elem);
235                }
236            }
237            _ => {}
238        }
239    }
240
241    fn settings() -> SchemaSettings {
242        SchemaSettings::draft2020_12()
243            .with_transform(AddNullable::default())
244            .with_transform(RestrictFormats::default())
245            .with_transform(SanitizeNullBranches)
246    }
247
248    /// Generate a cached schema for type T using Draft 2020-12 + AddNullable.
249    pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
250        CACHE_FOR_TYPE.with(|cache| {
251            let mut cache = cache.borrow_mut();
252            if let Some(x) = cache.get(&TypeId::of::<T>()) {
253                return x.clone();
254            }
255            let generator = settings().into_generator();
256            let root = generator.into_root_schema_for::<T>();
257            let arc = Arc::new(root);
258            cache.insert(TypeId::of::<T>(), arc.clone());
259            arc
260        })
261    }
262
263    /// Generate a cached output schema for type T, validating root type is "object".
264    /// Returns Err if the root type is not "object" (per MCP spec requirement).
265    pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
266        CACHE_FOR_OUTPUT.with(|cache| {
267            let mut cache = cache.borrow_mut();
268            if let Some(r) = cache.get(&TypeId::of::<T>()) {
269                return r.clone();
270            }
271            let root = cached_schema_for::<T>();
272            let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
273            let result = match json.get("type") {
274                Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
275                Some(serde_json::Value::String(t)) => Err(format!(
276                    "MCP requires output_schema root type 'object', found '{}'",
277                    t
278                )),
279                None => {
280                    // Schema might use $ref or other patterns without explicit type
281                    // Accept if it has properties (likely an object schema)
282                    if json.get("properties").is_some() {
283                        Ok(root.clone())
284                    } else {
285                        Err(
286                            "Schema missing 'type' — output_schema must have root type 'object'"
287                                .to_string(),
288                        )
289                    }
290                }
291                Some(other) => Err(format!(
292                    "Unexpected 'type' format: {:?} — expected string 'object'",
293                    other
294                )),
295            };
296            cache.insert(TypeId::of::<T>(), result.clone());
297            result
298        })
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use serde::Serialize;
306
307    #[derive(schemars::JsonSchema, Serialize)]
308    struct TestInput {
309        count: i32,
310        name: String,
311    }
312
313    #[test]
314    fn test_strict_mode() {
315        let engine = SchemaEngine::new().with_strict(true);
316        let schema = schemars::schema_for!(TestInput);
317        let transformed = engine.transform("test", schema);
318
319        let json = serde_json::to_value(&transformed).unwrap();
320        assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
321    }
322
323    #[test]
324    fn test_is_strict_getter() {
325        let e = SchemaEngine::new();
326        assert!(!e.is_strict());
327        let e2 = SchemaEngine::new().with_strict(true);
328        assert!(e2.is_strict());
329    }
330
331    #[test]
332    fn test_enum_constraint() {
333        let mut engine = SchemaEngine::new();
334
335        // Use a simple schema object for testing
336        let test_schema: Json = serde_json::json!({
337            "type": "object",
338            "properties": {
339                "name": {
340                    "type": "string"
341                }
342            }
343        });
344
345        engine.constrain_field(
346            "test",
347            vec!["properties".into(), "name".into()],
348            FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
349        );
350
351        let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
352        let transformed = engine.transform("test", schema);
353
354        let json = serde_json::to_value(&transformed).unwrap();
355        let name_schema = &json["properties"]["name"];
356        assert!(name_schema.get("enum").is_some());
357    }
358
359    #[test]
360    fn test_range_constraint() {
361        // Test that range constraints are applied to the correct schema path
362        let mut engine = SchemaEngine::new();
363        engine.constrain_field(
364            "test",
365            vec!["properties".into(), "count".into()],
366            FieldConstraint::Range {
367                minimum: Some(Json::Number(0.into())),
368                maximum: Some(Json::Number(100.into())),
369            },
370        );
371
372        // Use schemars to generate a real schema
373        let schema = schemars::schema_for!(TestInput);
374
375        // The transform function modifies the schema
376        let transformed = engine.transform("test", schema);
377
378        // Verify the range constraints were applied
379        let json = serde_json::to_value(&transformed).unwrap();
380        let count_schema = &json["properties"]["count"];
381
382        // Verify range was applied (compare as f64 since schemars may use floats)
383        let min = count_schema.get("minimum").and_then(|v| v.as_f64());
384        let max = count_schema.get("maximum").and_then(|v| v.as_f64());
385
386        assert_eq!(min, Some(0.0), "minimum constraint should be applied");
387        assert_eq!(max, Some(100.0), "maximum constraint should be applied");
388    }
389
390    // ========================================================================
391    // mcp_schema module tests
392    // ========================================================================
393
394    mod mcp_schema_tests {
395        use super::mcp_schema;
396        use serde::Serialize;
397
398        #[derive(schemars::JsonSchema, Serialize)]
399        struct WithOption {
400            a: Option<String>,
401        }
402
403        #[test]
404        fn test_central_generator_addnullable() {
405            let root = mcp_schema::cached_schema_for::<WithOption>();
406            let v = serde_json::to_value(root.as_ref()).unwrap();
407            let a = &v["properties"]["a"];
408            // AddNullable should add "nullable": true
409            assert_eq!(
410                a.get("nullable"),
411                Some(&serde_json::Value::Bool(true)),
412                "Option<T> fields should have nullable: true"
413            );
414        }
415
416        #[derive(schemars::JsonSchema, Serialize)]
417        struct OutputObj {
418            x: i32,
419        }
420
421        #[test]
422        fn test_output_schema_validation_object() {
423            let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
424            assert!(
425                ok.is_ok(),
426                "Object types should pass output schema validation"
427            );
428        }
429
430        #[test]
431        fn test_output_schema_validation_non_object() {
432            // String is not an object type
433            let bad = mcp_schema::cached_output_schema_for::<String>();
434            assert!(
435                bad.is_err(),
436                "Non-object types should fail output schema validation"
437            );
438        }
439
440        #[test]
441        fn test_draft_2020_12_uses_defs() {
442            let root = mcp_schema::cached_schema_for::<WithOption>();
443            let v = serde_json::to_value(root.as_ref()).unwrap();
444            // Draft 2020-12 should use $defs, not definitions
445            // Note: simple types may not have $defs, so we just verify
446            // the schema is valid and contains expected structure
447            assert!(v.is_object(), "Schema should be an object");
448            assert!(
449                v.get("$schema")
450                    .and_then(|s| s.as_str())
451                    .is_some_and(|s| s.contains("2020-12")),
452                "Schema should reference Draft 2020-12"
453            );
454        }
455
456        #[test]
457        fn test_caching_returns_same_arc() {
458            let first = mcp_schema::cached_schema_for::<OutputObj>();
459            let second = mcp_schema::cached_schema_for::<OutputObj>();
460            assert!(
461                std::sync::Arc::ptr_eq(&first, &second),
462                "Cached schemas should return the same Arc"
463            );
464        }
465
466        // ====================================================================
467        // SanitizeNullBranches and RestrictFormats transform tests
468        // ====================================================================
469
470        #[allow(dead_code)]
471        #[derive(schemars::JsonSchema, Serialize)]
472        enum TestEnum {
473            A,
474            B,
475        }
476
477        #[derive(schemars::JsonSchema, Serialize)]
478        struct HasOptEnum {
479            e: Option<TestEnum>,
480        }
481
482        #[test]
483        fn test_option_enum_anyof_null_branch_has_type() {
484            let root = mcp_schema::cached_schema_for::<HasOptEnum>();
485            let v = serde_json::to_value(root.as_ref()).unwrap();
486            let any_of = v["properties"]["e"]["anyOf"]
487                .as_array()
488                .expect("Option<Enum> should generate anyOf");
489
490            // There must be a branch with explicit type "null"
491            assert!(
492                any_of
493                    .iter()
494                    .any(|b| b.get("type") == Some(&serde_json::json!("null"))),
495                "anyOf for Option<Enum> must include a branch with type:\"null\""
496            );
497
498            // No branch should have nullable:true without a type
499            for branch in any_of {
500                let has_nullable = branch.get("nullable") == Some(&serde_json::json!(true));
501                let has_type = branch.get("type").is_some() || branch.get("$ref").is_some();
502                assert!(
503                    !has_nullable || has_type,
504                    "No branch may contain nullable:true without a type or $ref"
505                );
506            }
507        }
508
509        #[derive(schemars::JsonSchema, Serialize)]
510        struct Unsigneds {
511            a: u32,
512            b: u64,
513        }
514
515        #[test]
516        fn test_strip_uint_formats() {
517            let root = mcp_schema::cached_schema_for::<Unsigneds>();
518            let v = serde_json::to_value(root.as_ref()).unwrap();
519            let pa = &v["properties"]["a"];
520            let pb = &v["properties"]["b"];
521
522            assert!(
523                pa.get("format").is_none(),
524                "u32 should not include non-standard 'format'"
525            );
526            assert!(
527                pb.get("format").is_none(),
528                "u64 should not include non-standard 'format'"
529            );
530            assert_eq!(
531                pa.get("minimum").and_then(|x| x.as_u64()),
532                Some(0),
533                "u32 minimum must be preserved"
534            );
535            assert_eq!(
536                pb.get("minimum").and_then(|x| x.as_u64()),
537                Some(0),
538                "u64 minimum must be preserved"
539            );
540        }
541
542        #[derive(schemars::JsonSchema, Serialize)]
543        struct HasOptString {
544            s: Option<String>,
545        }
546
547        #[test]
548        fn test_option_string_preserves_nullable() {
549            let root = mcp_schema::cached_schema_for::<HasOptString>();
550            let v = serde_json::to_value(root.as_ref()).unwrap();
551            let s = &v["properties"]["s"];
552
553            assert_eq!(
554                s.get("type"),
555                Some(&serde_json::json!("string")),
556                "Option<String> should have type: string"
557            );
558            assert_eq!(
559                s.get("nullable"),
560                Some(&serde_json::json!(true)),
561                "Option<String> should retain nullable: true"
562            );
563        }
564    }
565}