Skip to main content

agentic_tools_core/
schema.rs

1//! Schema engine for runtime transforms.
2
3use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6
7/// Field-level constraint to apply to a schema.
8#[derive(Clone, Debug)]
9pub enum FieldConstraint {
10    /// Restrict field to specific enum values.
11    Enum(Vec<Json>),
12
13    /// Apply numeric range constraints.
14    Range {
15        minimum: Option<Json>,
16        maximum: Option<Json>,
17    },
18
19    /// Apply string pattern constraint.
20    Pattern(String),
21
22    /// Apply a JSON merge-patch to the field schema.
23    MergePatch(Json),
24}
25
26/// Trait for custom schema transforms.
27pub trait SchemaTransform: Send + Sync {
28    /// Apply the transform to a tool's schema.
29    fn apply(&self, tool: &str, schema: &mut Json);
30}
31
32/// Engine for applying runtime transforms to tool schemas.
33///
34/// Schemars derive generates base schemas at compile time.
35/// SchemaEngine applies transforms at runtime for provider flexibility.
36///
37/// # Clone behavior
38/// When cloned, `custom_transforms` are **not** carried over (they are not `Clone`).
39/// Only `per_tool` constraints and `global_strict` settings are cloned.
40#[derive(Default)]
41pub struct SchemaEngine {
42    per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
43    global_strict: bool,
44    custom_transforms: Vec<Box<dyn SchemaTransform>>,
45}
46
47impl Clone for SchemaEngine {
48    fn clone(&self) -> Self {
49        // Custom transforms cannot be cloned, so we only clone the config
50        Self {
51            per_tool: self.per_tool.clone(),
52            global_strict: self.global_strict,
53            custom_transforms: Vec::new(), // Transforms are not cloned
54        }
55    }
56}
57
58impl std::fmt::Debug for SchemaEngine {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("SchemaEngine")
61            .field("per_tool", &self.per_tool)
62            .field("global_strict", &self.global_strict)
63            .field(
64                "custom_transforms",
65                &format!("[{} transforms]", self.custom_transforms.len()),
66            )
67            .finish()
68    }
69}
70
71impl SchemaEngine {
72    /// Create a new schema engine.
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    /// Enable strict mode (additionalProperties=false) globally.
78    pub fn with_strict(mut self, strict: bool) -> Self {
79        self.global_strict = strict;
80        self
81    }
82
83    /// Get global strict mode setting.
84    pub fn is_strict(&self) -> bool {
85        self.global_strict
86    }
87
88    /// Add a field constraint for a specific tool.
89    ///
90    /// The `json_path` is a list of property names to traverse to reach the field.
91    /// For example, `["properties", "count"]` would target the "count" property.
92    pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
93        self.per_tool
94            .entry(tool.to_string())
95            .or_default()
96            .push((json_path, c));
97    }
98
99    /// Add a custom transform.
100    pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
101        self.custom_transforms.push(Box::new(transform));
102    }
103
104    /// Transform a tool's schema applying all constraints and transforms.
105    pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
106        let mut v = serde_json::to_value(&schema).expect("serialize schema");
107
108        // Apply global strict mode
109        if self.global_strict
110            && let Some(obj) = v.as_object_mut()
111        {
112            obj.insert("additionalProperties".to_string(), Json::Bool(false));
113        }
114
115        // Apply per-tool constraints
116        if let Some(entries) = self.per_tool.get(tool) {
117            for (path, constraint) in entries {
118                Self::apply_constraint(&mut v, path, constraint);
119            }
120        }
121
122        // Apply custom transforms
123        for transform in &self.custom_transforms {
124            transform.apply(tool, &mut v);
125        }
126
127        // try_from only rejects non-object/non-bool JSON values.  Since we start
128        // from a valid Schema (always an object) and built-in transforms only mutate
129        // sub-nodes, failure here means a custom SchemaTransform replaced the root
130        // type — a programming error that must surface immediately.
131        Schema::try_from(v).expect("schema transform must produce a valid schema")
132    }
133
134    fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
135        let Some(node) = Self::find_node_mut(root, path) else {
136            return;
137        };
138        let Some(obj) = node.as_object_mut() else {
139            return;
140        };
141        match constraint {
142            FieldConstraint::Enum(vals) => {
143                obj.insert("enum".into(), Json::Array(vals.clone()));
144            }
145            FieldConstraint::Range { minimum, maximum } => {
146                if let Some(m) = minimum {
147                    obj.insert("minimum".into(), m.clone());
148                }
149                if let Some(m) = maximum {
150                    obj.insert("maximum".into(), m.clone());
151                }
152            }
153            FieldConstraint::Pattern(p) => {
154                obj.insert("pattern".into(), Json::String(p.clone()));
155            }
156            FieldConstraint::MergePatch(patch) => {
157                json_patch::merge(node, patch);
158            }
159        }
160    }
161
162    fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
163        let mut cur = root;
164        for seg in path {
165            cur = cur.as_object_mut()?.get_mut(seg)?;
166        }
167        Some(cur)
168    }
169}
170
171// ============================================================================
172// Centralized Draft 2020-12 Generator for MCP + Registry
173// ============================================================================
174
175/// Centralized schema generation using Draft 2020-12.
176///
177/// This module provides cached schema generation for MCP:
178/// - JSON Schema Draft 2020-12 (MCP protocol requirement)
179/// - `Option<T>` fields include `null` in schema shape (for example, `type`
180///   arrays for simple scalar cases and `anyOf`/`$ref` forms for complex types)
181/// - Thread-local caching keyed by TypeId for performance
182pub mod mcp_schema {
183    use schemars::JsonSchema;
184    use schemars::Schema;
185    use schemars::generate::SchemaSettings;
186    use schemars::transform::RestrictFormats;
187    use std::any::TypeId;
188    use std::cell::RefCell;
189    use std::collections::HashMap;
190    use std::sync::Arc;
191
192    thread_local! {
193        static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
194        static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
195    }
196
197    fn settings() -> SchemaSettings {
198        SchemaSettings::draft2020_12().with_transform(RestrictFormats::default())
199    }
200
201    /// Generate a cached schema for type T using Draft 2020-12.
202    pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
203        CACHE_FOR_TYPE.with(|cache| {
204            let mut cache = cache.borrow_mut();
205            if let Some(x) = cache.get(&TypeId::of::<T>()) {
206                return x.clone();
207            }
208            let generator = settings().into_generator();
209            let root = generator.into_root_schema_for::<T>();
210            let arc = Arc::new(root);
211            cache.insert(TypeId::of::<T>(), arc.clone());
212            arc
213        })
214    }
215
216    /// Generate a cached output schema for type T, validating root type is "object".
217    /// Returns Err if the root type is not "object" (per MCP spec requirement).
218    pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
219        CACHE_FOR_OUTPUT.with(|cache| {
220            let mut cache = cache.borrow_mut();
221            if let Some(r) = cache.get(&TypeId::of::<T>()) {
222                return r.clone();
223            }
224            let root = cached_schema_for::<T>();
225            let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
226            let result = match json.get("type") {
227                Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
228                Some(serde_json::Value::String(t)) => Err(format!(
229                    "MCP requires output_schema root type 'object', found '{}'",
230                    t
231                )),
232                None => {
233                    // Schema might use $ref or other patterns without explicit type
234                    // Accept if it has properties (likely an object schema)
235                    if json.get("properties").is_some() {
236                        Ok(root.clone())
237                    } else {
238                        Err(
239                            "Schema missing 'type' — output_schema must have root type 'object'"
240                                .to_string(),
241                        )
242                    }
243                }
244                Some(other) => Err(format!(
245                    "Unexpected 'type' format: {:?} — expected string 'object'",
246                    other
247                )),
248            };
249            cache.insert(TypeId::of::<T>(), result.clone());
250            result
251        })
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use serde::Serialize;
259
260    #[derive(schemars::JsonSchema, Serialize)]
261    struct TestInput {
262        count: i32,
263        name: String,
264    }
265
266    #[test]
267    fn test_strict_mode() {
268        let engine = SchemaEngine::new().with_strict(true);
269        let schema = schemars::schema_for!(TestInput);
270        let transformed = engine.transform("test", schema);
271
272        let json = serde_json::to_value(&transformed).unwrap();
273        assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
274    }
275
276    #[test]
277    fn test_is_strict_getter() {
278        let e = SchemaEngine::new();
279        assert!(!e.is_strict());
280        let e2 = SchemaEngine::new().with_strict(true);
281        assert!(e2.is_strict());
282    }
283
284    #[test]
285    fn test_enum_constraint() {
286        let mut engine = SchemaEngine::new();
287
288        // Use a simple schema object for testing
289        let test_schema: Json = serde_json::json!({
290            "type": "object",
291            "properties": {
292                "name": {
293                    "type": "string"
294                }
295            }
296        });
297
298        engine.constrain_field(
299            "test",
300            vec!["properties".into(), "name".into()],
301            FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
302        );
303
304        let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
305        let transformed = engine.transform("test", schema);
306
307        let json = serde_json::to_value(&transformed).unwrap();
308        let name_schema = &json["properties"]["name"];
309        assert!(name_schema.get("enum").is_some());
310    }
311
312    #[test]
313    fn test_range_constraint() {
314        // Test that range constraints are applied to the correct schema path
315        let mut engine = SchemaEngine::new();
316        engine.constrain_field(
317            "test",
318            vec!["properties".into(), "count".into()],
319            FieldConstraint::Range {
320                minimum: Some(Json::Number(0.into())),
321                maximum: Some(Json::Number(100.into())),
322            },
323        );
324
325        // Use schemars to generate a real schema
326        let schema = schemars::schema_for!(TestInput);
327
328        // The transform function modifies the schema
329        let transformed = engine.transform("test", schema);
330
331        // Verify the range constraints were applied
332        let json = serde_json::to_value(&transformed).unwrap();
333        let count_schema = &json["properties"]["count"];
334
335        // Verify range was applied (compare as f64 since schemars may use floats)
336        let min = count_schema.get("minimum").and_then(|v| v.as_f64());
337        let max = count_schema.get("maximum").and_then(|v| v.as_f64());
338
339        assert_eq!(min, Some(0.0), "minimum constraint should be applied");
340        assert_eq!(max, Some(100.0), "maximum constraint should be applied");
341    }
342
343    // ========================================================================
344    // mcp_schema module tests
345    // ========================================================================
346
347    mod mcp_schema_tests {
348        use super::mcp_schema;
349        use serde::Serialize;
350
351        #[derive(schemars::JsonSchema, Serialize)]
352        struct WithOption {
353            a: Option<String>,
354        }
355
356        #[test]
357        fn test_option_generates_type_array() {
358            let root = mcp_schema::cached_schema_for::<WithOption>();
359            let v = serde_json::to_value(root.as_ref()).unwrap();
360            let a = &v["properties"]["a"];
361            // Option<String> should produce {"type": ["string", "null"]}
362            let ty = a
363                .get("type")
364                .and_then(|v| v.as_array())
365                .expect("Option<T> should emit a type array");
366            assert!(ty.contains(&serde_json::json!("string")));
367            assert!(ty.contains(&serde_json::json!("null")));
368            assert_eq!(ty.len(), 2, "Option<T> should contain only string|null");
369        }
370
371        #[derive(schemars::JsonSchema, Serialize)]
372        struct OutputObj {
373            x: i32,
374        }
375
376        #[test]
377        fn test_output_schema_validation_object() {
378            let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
379            assert!(
380                ok.is_ok(),
381                "Object types should pass output schema validation"
382            );
383        }
384
385        #[test]
386        fn test_output_schema_validation_non_object() {
387            // String is not an object type
388            let bad = mcp_schema::cached_output_schema_for::<String>();
389            assert!(
390                bad.is_err(),
391                "Non-object types should fail output schema validation"
392            );
393        }
394
395        #[test]
396        fn test_draft_2020_12_uses_defs() {
397            let root = mcp_schema::cached_schema_for::<WithOption>();
398            let v = serde_json::to_value(root.as_ref()).unwrap();
399            // Draft 2020-12 should use $defs, not definitions
400            // Note: simple types may not have $defs, so we just verify
401            // the schema is valid and contains expected structure
402            assert!(v.is_object(), "Schema should be an object");
403            assert!(
404                v.get("$schema")
405                    .and_then(|s| s.as_str())
406                    .is_some_and(|s| s.contains("2020-12")),
407                "Schema should reference Draft 2020-12"
408            );
409        }
410
411        #[test]
412        fn test_caching_returns_same_arc() {
413            let first = mcp_schema::cached_schema_for::<OutputObj>();
414            let second = mcp_schema::cached_schema_for::<OutputObj>();
415            assert!(
416                std::sync::Arc::ptr_eq(&first, &second),
417                "Cached schemas should return the same Arc"
418            );
419        }
420
421        // ====================================================================
422        // RestrictFormats transform and Option<Enum> tests
423        // ====================================================================
424
425        #[allow(dead_code)]
426        #[derive(schemars::JsonSchema, Serialize)]
427        enum TestEnum {
428            A,
429            B,
430        }
431
432        #[derive(schemars::JsonSchema, Serialize)]
433        struct HasOptEnum {
434            e: Option<TestEnum>,
435        }
436
437        #[test]
438        fn test_option_enum_anyof_null_branch_has_type() {
439            let root = mcp_schema::cached_schema_for::<HasOptEnum>();
440            let v = serde_json::to_value(root.as_ref()).unwrap();
441            let any_of = v["properties"]["e"]["anyOf"]
442                .as_array()
443                .expect("Option<Enum> should generate anyOf");
444
445            // There must be a branch with explicit type "null"
446            assert!(
447                any_of
448                    .iter()
449                    .any(|b| b.get("type") == Some(&serde_json::json!("null"))),
450                "anyOf for Option<Enum> must include a branch with type:\"null\""
451            );
452
453            // No branch should have nullable:true without a type
454            for branch in any_of {
455                let has_nullable = branch.get("nullable") == Some(&serde_json::json!(true));
456                let has_type = branch.get("type").is_some() || branch.get("$ref").is_some();
457                assert!(
458                    !has_nullable || has_type,
459                    "No branch may contain nullable:true without a type or $ref"
460                );
461            }
462        }
463
464        #[derive(schemars::JsonSchema, Serialize)]
465        struct Unsigneds {
466            a: u32,
467            b: u64,
468        }
469
470        #[test]
471        fn test_strip_uint_formats() {
472            let root = mcp_schema::cached_schema_for::<Unsigneds>();
473            let v = serde_json::to_value(root.as_ref()).unwrap();
474            let pa = &v["properties"]["a"];
475            let pb = &v["properties"]["b"];
476
477            assert!(
478                pa.get("format").is_none(),
479                "u32 should not include non-standard 'format'"
480            );
481            assert!(
482                pb.get("format").is_none(),
483                "u64 should not include non-standard 'format'"
484            );
485            assert_eq!(
486                pa.get("minimum").and_then(|x| x.as_u64()),
487                Some(0),
488                "u32 minimum must be preserved"
489            );
490            assert_eq!(
491                pb.get("minimum").and_then(|x| x.as_u64()),
492                Some(0),
493                "u64 minimum must be preserved"
494            );
495        }
496
497        #[derive(schemars::JsonSchema, Serialize)]
498        struct HasOptString {
499            s: Option<String>,
500        }
501
502        #[test]
503        fn test_option_string_uses_type_array() {
504            let root = mcp_schema::cached_schema_for::<HasOptString>();
505            let v = serde_json::to_value(root.as_ref()).unwrap();
506            let s = &v["properties"]["s"];
507
508            // Option<String> should produce {"type": ["string", "null"]}
509            let ty = s
510                .get("type")
511                .and_then(|v| v.as_array())
512                .expect("Option<String> should emit a type array");
513            assert!(ty.contains(&serde_json::json!("string")));
514            assert!(ty.contains(&serde_json::json!("null")));
515            assert_eq!(
516                ty.len(),
517                2,
518                "Option<String> should contain only string|null"
519            );
520            // Should NOT have nullable keyword (that was from AddNullable)
521            assert!(
522                s.get("nullable").is_none(),
523                "Option<String> should not have nullable keyword"
524            );
525        }
526    }
527}