Skip to main content

lash_sansio/
schema_contract.rs

1//! JSON Schema contracts and provider dialect projection.
2//!
3//! A [`SchemaContract`] keeps the canonical schema used for runtime
4//! validation separate from the provider wire schema. Providers declare the
5//! dialects they accept for each purpose and resolve contracts lazily at the
6//! request boundary.
7
8use serde_json::{Map, Value, json};
9
10#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
11pub struct SchemaContract {
12    pub canonical: Value,
13    #[serde(default, skip_serializing_if = "SchemaProjectionPolicy::is_default")]
14    pub projection: SchemaProjectionPolicy,
15}
16
17impl SchemaContract {
18    pub fn new(canonical: Value) -> Self {
19        Self {
20            canonical,
21            projection: SchemaProjectionPolicy::default(),
22        }
23    }
24
25    pub fn with_projection(mut self, projection: SchemaProjectionPolicy) -> Self {
26        self.projection = projection;
27        self
28    }
29
30    pub fn with_override(mut self, dialect: impl Into<String>, schema: Value) -> Self {
31        self.projection
32            .set_override(SchemaProjectionOverride::new(dialect, schema));
33        self
34    }
35
36    pub fn canonical(&self) -> &Value {
37        &self.canonical
38    }
39}
40
41impl Default for SchemaContract {
42    fn default() -> Self {
43        Self::new(Value::Null)
44    }
45}
46
47impl From<Value> for SchemaContract {
48    fn from(value: Value) -> Self {
49        Self::new(value)
50    }
51}
52
53#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
54pub struct SchemaProjectionPolicy {
55    #[serde(default, skip_serializing_if = "ProjectionMode::is_auto")]
56    pub mode: ProjectionMode,
57    #[serde(default, skip_serializing_if = "Vec::is_empty")]
58    pub overrides: Vec<SchemaProjectionOverride>,
59}
60
61impl SchemaProjectionPolicy {
62    pub fn is_default(&self) -> bool {
63        self.mode == ProjectionMode::Auto && self.overrides.is_empty()
64    }
65
66    pub fn set_override(&mut self, override_schema: SchemaProjectionOverride) {
67        self.overrides
68            .retain(|projection| projection.dialect != override_schema.dialect);
69        self.overrides.push(override_schema);
70    }
71}
72
73impl Default for SchemaProjectionPolicy {
74    fn default() -> Self {
75        Self {
76            mode: ProjectionMode::Auto,
77            overrides: Vec::new(),
78        }
79    }
80}
81
82#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
83#[serde(rename_all = "snake_case")]
84pub enum ProjectionMode {
85    #[default]
86    Auto,
87    ExplicitOnly,
88    Exact,
89}
90
91impl ProjectionMode {
92    fn is_auto(&self) -> bool {
93        *self == Self::Auto
94    }
95}
96
97#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
98pub struct SchemaProjectionOverride {
99    pub dialect: String,
100    pub schema: Value,
101}
102
103impl SchemaProjectionOverride {
104    pub fn new(dialect: impl Into<String>, schema: Value) -> Self {
105        Self {
106            dialect: dialect.into(),
107            schema,
108        }
109    }
110}
111
112#[derive(
113    Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
114)]
115#[serde(transparent)]
116pub struct SchemaDialect(String);
117
118impl SchemaDialect {
119    pub const OPENAI_TOOL_PARAMETERS: &'static str = "openai_tool_parameters";
120    pub const OPENAI_STRICT_TOOL_PARAMETERS: &'static str = "openai_strict_tool_parameters";
121    pub const OPENAI_STRUCTURED_OUTPUT: &'static str = "openai_structured_output";
122    pub const ANTHROPIC_TOOL_INPUT: &'static str = "anthropic_tool_input";
123    pub const ANTHROPIC_OUTPUT_CONFIG_JSON_SCHEMA: &'static str =
124        "anthropic_output_config_json_schema";
125    pub const BEDROCK_CLAUDE_OUTPUT_CONFIG_JSON_SCHEMA: &'static str =
126        "bedrock_claude_output_config_json_schema";
127    pub const GOOGLE_SCHEMA: &'static str = "google_schema";
128    pub const JSON_PROMPT_SCHEMA: &'static str = "json_prompt_schema";
129
130    pub fn new(value: impl Into<String>) -> Self {
131        Self(value.into())
132    }
133
134    pub fn as_str(&self) -> &str {
135        &self.0
136    }
137
138    pub fn openai_tool_parameters() -> Self {
139        Self::new(Self::OPENAI_TOOL_PARAMETERS)
140    }
141
142    pub fn openai_strict_tool_parameters() -> Self {
143        Self::new(Self::OPENAI_STRICT_TOOL_PARAMETERS)
144    }
145
146    pub fn openai_structured_output() -> Self {
147        Self::new(Self::OPENAI_STRUCTURED_OUTPUT)
148    }
149
150    pub fn anthropic_tool_input() -> Self {
151        Self::new(Self::ANTHROPIC_TOOL_INPUT)
152    }
153
154    pub fn anthropic_output_config_json_schema() -> Self {
155        Self::new(Self::ANTHROPIC_OUTPUT_CONFIG_JSON_SCHEMA)
156    }
157
158    pub fn bedrock_claude_output_config_json_schema() -> Self {
159        Self::new(Self::BEDROCK_CLAUDE_OUTPUT_CONFIG_JSON_SCHEMA)
160    }
161
162    pub fn google_schema() -> Self {
163        Self::new(Self::GOOGLE_SCHEMA)
164    }
165}
166
167impl From<&str> for SchemaDialect {
168    fn from(value: &str) -> Self {
169        Self::new(value)
170    }
171}
172
173impl From<String> for SchemaDialect {
174    fn from(value: String) -> Self {
175        Self::new(value)
176    }
177}
178
179#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
180#[serde(rename_all = "snake_case")]
181pub enum SchemaPurpose {
182    ToolInput,
183    ToolOutput,
184    StructuredOutput,
185    PromptSchema,
186}
187
188#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
189pub struct ProviderSchemaCapabilities {
190    #[serde(default, skip_serializing_if = "Vec::is_empty")]
191    pub tool_input: Vec<SchemaDialect>,
192    #[serde(default, skip_serializing_if = "Vec::is_empty")]
193    pub tool_output: Vec<SchemaDialect>,
194    #[serde(default, skip_serializing_if = "Vec::is_empty")]
195    pub structured_output: Vec<SchemaDialect>,
196    #[serde(default, skip_serializing_if = "Vec::is_empty")]
197    pub prompt_schema: Vec<SchemaDialect>,
198}
199
200impl ProviderSchemaCapabilities {
201    pub fn openai(strict_tools: bool) -> Self {
202        Self {
203            tool_input: vec![if strict_tools {
204                SchemaDialect::openai_strict_tool_parameters()
205            } else {
206                SchemaDialect::openai_tool_parameters()
207            }],
208            structured_output: vec![SchemaDialect::openai_structured_output()],
209            ..Default::default()
210        }
211    }
212
213    pub fn anthropic() -> Self {
214        Self {
215            tool_input: vec![SchemaDialect::anthropic_tool_input()],
216            structured_output: vec![SchemaDialect::anthropic_output_config_json_schema()],
217            ..Default::default()
218        }
219    }
220
221    pub fn bedrock_claude() -> Self {
222        Self {
223            tool_input: vec![SchemaDialect::anthropic_tool_input()],
224            structured_output: vec![SchemaDialect::bedrock_claude_output_config_json_schema()],
225            ..Default::default()
226        }
227    }
228
229    pub fn google() -> Self {
230        Self {
231            tool_input: vec![SchemaDialect::google_schema()],
232            structured_output: vec![SchemaDialect::google_schema()],
233            ..Default::default()
234        }
235    }
236
237    pub fn dialects_for(&self, purpose: SchemaPurpose) -> &[SchemaDialect] {
238        match purpose {
239            SchemaPurpose::ToolInput => &self.tool_input,
240            SchemaPurpose::ToolOutput => &self.tool_output,
241            SchemaPurpose::StructuredOutput => &self.structured_output,
242            SchemaPurpose::PromptSchema => &self.prompt_schema,
243        }
244    }
245}
246
247#[derive(Clone, Debug)]
248pub struct SchemaResolutionRequest<'a> {
249    pub provider: &'a str,
250    pub purpose: SchemaPurpose,
251    pub dialects: &'a [SchemaDialect],
252}
253
254#[derive(Clone, Debug, PartialEq, Eq)]
255pub struct ResolvedSchema {
256    pub schema: Value,
257    pub dialect: SchemaDialect,
258    pub diagnostics: Vec<String>,
259}
260
261#[derive(Clone, Debug, PartialEq, Eq)]
262pub struct SchemaResolutionError {
263    pub provider: String,
264    pub purpose: SchemaPurpose,
265    pub dialect: Option<SchemaDialect>,
266    pub diagnostics: Vec<String>,
267    first_diagnostic: String,
268}
269
270impl std::fmt::Display for SchemaResolutionError {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        write!(
273            f,
274            "{} schema resolution for {:?} failed: {}",
275            self.provider, self.purpose, self.first_diagnostic
276        )
277    }
278}
279
280impl std::error::Error for SchemaResolutionError {}
281
282impl SchemaResolutionError {
283    fn new(
284        provider: impl Into<String>,
285        purpose: SchemaPurpose,
286        dialect: Option<SchemaDialect>,
287        diagnostics: Vec<String>,
288    ) -> Self {
289        let first_diagnostic = diagnostics
290            .first()
291            .cloned()
292            .unwrap_or_else(|| "schema resolution failed".to_string());
293        Self {
294            provider: provider.into(),
295            purpose,
296            dialect,
297            diagnostics,
298            first_diagnostic,
299        }
300    }
301
302    pub fn first_diagnostic(&self) -> &str {
303        &self.first_diagnostic
304    }
305}
306
307pub fn resolve_schema(
308    contract: &SchemaContract,
309    request: SchemaResolutionRequest<'_>,
310) -> Result<ResolvedSchema, SchemaResolutionError> {
311    if request.dialects.is_empty() {
312        return Err(SchemaResolutionError::new(
313            request.provider,
314            request.purpose,
315            None,
316            vec!["provider declared no schema dialects for this purpose".to_string()],
317        ));
318    }
319
320    let mut diagnostics = Vec::new();
321    for dialect in request.dialects {
322        if let Some(override_schema) = contract
323            .projection
324            .overrides
325            .iter()
326            .find(|projection| projection.dialect == dialect.as_str())
327        {
328            let diagnostics = diagnostics;
329            return Ok(ResolvedSchema {
330                schema: override_schema.schema.clone(),
331                dialect: dialect.clone(),
332                diagnostics,
333            });
334        }
335
336        match contract.projection.mode {
337            ProjectionMode::ExplicitOnly => diagnostics.push(format!(
338                "{}: no explicit projection override for {}",
339                dialect.as_str(),
340                format_purpose(request.purpose)
341            )),
342            ProjectionMode::Exact => match project_for_dialect(&contract.canonical, dialect) {
343                Ok(projection)
344                    if projection.schema == contract.canonical
345                        && projection.diagnostics.is_empty() =>
346                {
347                    let diagnostics = diagnostics;
348                    return Ok(ResolvedSchema {
349                        schema: contract.canonical.clone(),
350                        dialect: dialect.clone(),
351                        diagnostics,
352                    });
353                }
354                Ok(projection) => diagnostics.push(format!(
355                    "{}: canonical schema is not exact for provider dialect ({})",
356                    dialect.as_str(),
357                    projection.diagnostics.join("; ")
358                )),
359                Err(err) => diagnostics.extend(err.diagnostics),
360            },
361            ProjectionMode::Auto => match project_for_dialect(&contract.canonical, dialect) {
362                Ok(projection) => {
363                    diagnostics.extend(projection.diagnostics);
364                    return Ok(ResolvedSchema {
365                        schema: projection.schema,
366                        dialect: dialect.clone(),
367                        diagnostics,
368                    });
369                }
370                Err(err) => diagnostics.extend(err.diagnostics),
371            },
372        }
373    }
374
375    Err(SchemaResolutionError::new(
376        request.provider,
377        request.purpose,
378        request.dialects.last().cloned(),
379        diagnostics,
380    ))
381}
382
383fn format_purpose(purpose: SchemaPurpose) -> &'static str {
384    match purpose {
385        SchemaPurpose::ToolInput => "tool input",
386        SchemaPurpose::ToolOutput => "tool output",
387        SchemaPurpose::StructuredOutput => "structured output",
388        SchemaPurpose::PromptSchema => "prompt schema",
389    }
390}
391
392#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
393#[serde(rename_all = "snake_case")]
394enum OpenAiSchemaProfile {
395    ToolParameters,
396    StrictToolParameters,
397    StructuredOutput,
398}
399
400#[derive(Clone, Debug, PartialEq, Eq)]
401pub struct SchemaProjection {
402    pub schema: Value,
403    pub diagnostics: Vec<String>,
404}
405
406#[derive(Clone, Debug, PartialEq, Eq)]
407struct SchemaProjectionError {
408    profile: OpenAiSchemaProfile,
409    diagnostics: Vec<String>,
410    first_diagnostic: String,
411}
412
413impl std::fmt::Display for SchemaProjectionError {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        write!(
416            f,
417            "OpenAI schema projection for {:?} failed: {}",
418            self.profile, self.first_diagnostic
419        )
420    }
421}
422
423impl std::error::Error for SchemaProjectionError {}
424
425impl SchemaProjectionError {
426    fn new(profile: OpenAiSchemaProfile, diagnostics: Vec<String>) -> Self {
427        let first_diagnostic = diagnostics
428            .first()
429            .cloned()
430            .unwrap_or_else(|| "schema projection failed".to_string());
431        Self {
432            profile,
433            diagnostics,
434            first_diagnostic,
435        }
436    }
437}
438
439fn project_schema(
440    schema: &Value,
441    profile: OpenAiSchemaProfile,
442) -> Result<SchemaProjection, SchemaProjectionError> {
443    Projector::new(profile).project(schema)
444}
445
446fn project_tool_parameters(schema: &Value) -> Result<SchemaProjection, SchemaProjectionError> {
447    project_schema(schema, OpenAiSchemaProfile::ToolParameters)
448}
449
450fn project_strict_tool_parameters(
451    schema: &Value,
452) -> Result<SchemaProjection, SchemaProjectionError> {
453    project_schema(schema, OpenAiSchemaProfile::StrictToolParameters)
454}
455
456fn project_structured_output(schema: &Value) -> Result<SchemaProjection, SchemaProjectionError> {
457    project_schema(schema, OpenAiSchemaProfile::StructuredOutput)
458}
459
460pub fn project_for_dialect(
461    schema: &Value,
462    dialect: &SchemaDialect,
463) -> Result<SchemaProjection, SchemaResolutionError> {
464    match dialect.as_str() {
465        SchemaDialect::OPENAI_TOOL_PARAMETERS => project_tool_parameters(schema).map_err(|err| {
466            SchemaResolutionError::new(
467                "schema",
468                SchemaPurpose::ToolInput,
469                Some(dialect.clone()),
470                err.diagnostics,
471            )
472        }),
473        SchemaDialect::OPENAI_STRICT_TOOL_PARAMETERS => project_strict_tool_parameters(schema)
474            .map_err(|err| {
475                SchemaResolutionError::new(
476                    "schema",
477                    SchemaPurpose::ToolInput,
478                    Some(dialect.clone()),
479                    err.diagnostics,
480                )
481            }),
482        SchemaDialect::OPENAI_STRUCTURED_OUTPUT => {
483            project_structured_output(schema).map_err(|err| {
484                SchemaResolutionError::new(
485                    "schema",
486                    SchemaPurpose::StructuredOutput,
487                    Some(dialect.clone()),
488                    err.diagnostics,
489                )
490            })
491        }
492        SchemaDialect::ANTHROPIC_TOOL_INPUT
493        | SchemaDialect::ANTHROPIC_OUTPUT_CONFIG_JSON_SCHEMA
494        | SchemaDialect::BEDROCK_CLAUDE_OUTPUT_CONFIG_JSON_SCHEMA => {
495            project_anthropic_bedrock_schema(schema, dialect)
496        }
497        SchemaDialect::GOOGLE_SCHEMA | SchemaDialect::JSON_PROMPT_SCHEMA => Ok(SchemaProjection {
498            schema: schema.clone(),
499            diagnostics: Vec::new(),
500        }),
501        other => Err(SchemaResolutionError::new(
502            "schema",
503            SchemaPurpose::StructuredOutput,
504            Some(dialect.clone()),
505            vec![format!("unsupported schema dialect {other}")],
506        )),
507    }
508}
509
510pub fn project_anthropic_bedrock_schema(
511    schema: &Value,
512    dialect: &SchemaDialect,
513) -> Result<SchemaProjection, SchemaResolutionError> {
514    let mut projected = schema.clone();
515    let mut sanitizer = AnthropicBedrockSanitizer {
516        diagnostics: Vec::new(),
517        errors: Vec::new(),
518    };
519    sanitizer.sanitize_value(&mut projected, Path::root());
520    sanitizer.ensure_object_root(&mut projected);
521    if sanitizer.errors.is_empty() {
522        Ok(SchemaProjection {
523            schema: projected,
524            diagnostics: sanitizer.diagnostics,
525        })
526    } else {
527        Err(SchemaResolutionError::new(
528            "schema",
529            SchemaPurpose::StructuredOutput,
530            Some(dialect.clone()),
531            sanitizer.errors,
532        ))
533    }
534}
535
536struct Projector {
537    profile: OpenAiSchemaProfile,
538    diagnostics: Vec<String>,
539    errors: Vec<String>,
540}
541
542impl Projector {
543    fn new(profile: OpenAiSchemaProfile) -> Self {
544        Self {
545            profile,
546            diagnostics: Vec::new(),
547            errors: Vec::new(),
548        }
549    }
550
551    fn project(mut self, schema: &Value) -> Result<SchemaProjection, SchemaProjectionError> {
552        let mut projected = schema.clone();
553        self.project_value(&mut projected, Path::root(), true);
554        self.ensure_object_root(&mut projected);
555
556        if self.errors.is_empty() {
557            Ok(SchemaProjection {
558                schema: projected,
559                diagnostics: self.diagnostics,
560            })
561        } else {
562            Err(SchemaProjectionError::new(self.profile, self.errors))
563        }
564    }
565
566    fn project_value(&mut self, value: &mut Value, path: Path, is_root: bool) {
567        let Some(obj) = value.as_object_mut() else {
568            self.errors
569                .push(format!("{path}: schema must be a JSON object"));
570            return;
571        };
572
573        self.flatten_single_all_of(obj, &path);
574        self.convert_const(obj, &path);
575        self.infer_type(obj, &path, is_root);
576
577        if is_root
578            && obj.contains_key("anyOf")
579            && self.profile == OpenAiSchemaProfile::StructuredOutput
580        {
581            self.errors.push(format!(
582                "{path}: OpenAI structured outputs do not allow root anyOf"
583            ));
584        }
585
586        self.reject_unsupported_keywords(obj, &path);
587
588        if schema_type_contains(obj, "object") {
589            self.project_object(obj, &path);
590        }
591
592        if let Some(items) = obj.get_mut("items") {
593            self.project_value(items, path.child("items"), false);
594        }
595
596        for key in ["$defs", "definitions"] {
597            if let Some(defs) = obj.get_mut(key).and_then(Value::as_object_mut) {
598                for (name, schema) in defs {
599                    self.project_value(schema, path.child(key).child(name), false);
600                }
601            }
602        }
603
604        if let Some(any_of) = obj.get_mut("anyOf").and_then(Value::as_array_mut) {
605            for (idx, schema) in any_of.iter_mut().enumerate() {
606                self.project_value(schema, path.child("anyOf").index(idx), false);
607            }
608        }
609    }
610
611    fn ensure_object_root(&mut self, value: &mut Value) {
612        let Some(obj) = value.as_object_mut() else {
613            return;
614        };
615        if !schema_type_contains(obj, "object") {
616            self.errors
617                .push("$: root schema must be an object schema".to_string());
618            return;
619        }
620        obj.entry("properties")
621            .or_insert_with(|| Value::Object(Map::new()));
622        if !obj.get("properties").is_some_and(Value::is_object) {
623            self.errors
624                .push("$: object schema properties must be an object".to_string());
625        }
626    }
627
628    fn project_object(&mut self, obj: &mut Map<String, Value>, path: &Path) {
629        let originally_required = required_set(obj);
630        match obj.get_mut("properties") {
631            Some(Value::Object(properties)) => {
632                let property_names = properties.keys().cloned().collect::<Vec<_>>();
633                for (name, schema) in properties.iter_mut() {
634                    let optional = !originally_required.iter().any(|required| required == name);
635                    self.project_value(schema, path.child("properties").child(name), false);
636                    if optional && self.requires_strict_objects() {
637                        make_nullable(schema);
638                    }
639                }
640                if self.requires_strict_objects() {
641                    obj.insert(
642                        "required".to_string(),
643                        Value::Array(property_names.into_iter().map(Value::String).collect()),
644                    );
645                }
646            }
647            Some(_) => self.errors.push(format!(
648                "{path}: object schema properties must be an object"
649            )),
650            None => {
651                obj.insert("properties".to_string(), Value::Object(Map::new()));
652                self.diagnostics
653                    .push(format!("{path}: inserted missing object properties"));
654                if self.requires_strict_objects() {
655                    obj.insert("required".to_string(), Value::Array(Vec::new()));
656                }
657            }
658        }
659
660        if self.requires_strict_objects() {
661            obj.insert("additionalProperties".to_string(), Value::Bool(false));
662        }
663    }
664
665    fn requires_strict_objects(&self) -> bool {
666        matches!(
667            self.profile,
668            OpenAiSchemaProfile::StrictToolParameters | OpenAiSchemaProfile::StructuredOutput
669        )
670    }
671
672    fn flatten_single_all_of(&mut self, obj: &mut Map<String, Value>, path: &Path) {
673        let Some(all_of) = obj.remove("allOf") else {
674            return;
675        };
676        let Value::Array(mut branches) = all_of else {
677            obj.insert("allOf".to_string(), all_of);
678            return;
679        };
680        if branches.len() != 1 {
681            obj.insert("allOf".to_string(), Value::Array(branches));
682            return;
683        }
684
685        let branch = branches.pop().expect("single allOf branch");
686        let Value::Object(branch_obj) = branch else {
687            obj.insert("allOf".to_string(), Value::Array(vec![branch]));
688            self.errors.push(format!(
689                "{path}: single-branch allOf must contain an object schema"
690            ));
691            return;
692        };
693
694        let conflicts = branch_obj
695            .iter()
696            .filter_map(|(key, value)| {
697                obj.get(key)
698                    .filter(|existing| *existing != value)
699                    .map(|_| key.clone())
700            })
701            .collect::<Vec<_>>();
702        if !conflicts.is_empty() {
703            obj.insert(
704                "allOf".to_string(),
705                Value::Array(vec![Value::Object(branch_obj)]),
706            );
707            self.errors.push(format!(
708                "{path}: single-branch allOf conflicts with sibling schema keys: {}",
709                conflicts.join(", ")
710            ));
711            return;
712        }
713
714        for (key, value) in branch_obj {
715            obj.entry(key).or_insert(value);
716        }
717        self.diagnostics
718            .push(format!("{path}: flattened single-branch allOf"));
719    }
720
721    fn convert_const(&mut self, obj: &mut Map<String, Value>, path: &Path) {
722        if let Some(value) = obj.remove("const") {
723            obj.entry("enum".to_string())
724                .or_insert_with(|| Value::Array(vec![value]));
725            self.diagnostics
726                .push(format!("{path}: converted const to single-value enum"));
727        }
728    }
729
730    fn infer_type(&mut self, obj: &mut Map<String, Value>, path: &Path, is_root: bool) {
731        if obj.contains_key("type") {
732            return;
733        }
734        let inferred = if is_root
735            || obj.contains_key("properties")
736            || obj.contains_key("required")
737            || obj.contains_key("additionalProperties")
738        {
739            Some("object")
740        } else if obj.contains_key("items") {
741            Some("array")
742        } else if let Some(enum_values) = obj.get("enum").and_then(Value::as_array) {
743            infer_enum_type(enum_values)
744        } else {
745            None
746        };
747        if let Some(inferred) = inferred {
748            obj.insert("type".to_string(), Value::String(inferred.to_string()));
749            self.diagnostics
750                .push(format!("{path}: inferred missing type as {inferred}"));
751        }
752    }
753
754    fn reject_unsupported_keywords(&mut self, obj: &Map<String, Value>, path: &Path) {
755        let unsupported = [
756            "allOf",
757            "oneOf",
758            "not",
759            "dependentRequired",
760            "dependentSchemas",
761            "if",
762            "then",
763            "else",
764            "patternProperties",
765        ];
766        for key in unsupported {
767            if obj.contains_key(key) {
768                self.errors
769                    .push(format!("{path}: unsupported JSON Schema keyword `{key}`"));
770            }
771        }
772    }
773}
774
775struct AnthropicBedrockSanitizer {
776    diagnostics: Vec<String>,
777    errors: Vec<String>,
778}
779
780impl AnthropicBedrockSanitizer {
781    fn sanitize_value(&mut self, value: &mut Value, path: Path) {
782        let Some(obj) = value.as_object_mut() else {
783            self.errors
784                .push(format!("{path}: schema must be a JSON object"));
785            return;
786        };
787
788        self.convert_const(obj, &path);
789        self.infer_rootless_type(obj, &path);
790        self.strip_validation_keywords(obj, &path);
791
792        if let Some(properties) = obj.get_mut("properties") {
793            let Some(properties) = properties.as_object_mut() else {
794                self.errors
795                    .push(format!("{path}.properties: properties must be an object"));
796                return;
797            };
798            for (name, property) in properties {
799                self.sanitize_value(property, path.child("properties").child(name));
800            }
801        }
802
803        if let Some(items) = obj.get_mut("items") {
804            match items {
805                Value::Array(values) => {
806                    for (idx, item) in values.iter_mut().enumerate() {
807                        self.sanitize_value(item, path.child("items").index(idx));
808                    }
809                }
810                _ => self.sanitize_value(items, path.child("items")),
811            }
812        }
813
814        for key in ["$defs", "definitions"] {
815            if let Some(defs) = obj.get_mut(key).and_then(Value::as_object_mut) {
816                for (name, schema) in defs {
817                    self.sanitize_value(schema, path.child(key).child(name));
818                }
819            }
820        }
821
822        for key in ["anyOf", "oneOf", "allOf"] {
823            if let Some(values) = obj.get_mut(key).and_then(Value::as_array_mut) {
824                for (idx, schema) in values.iter_mut().enumerate() {
825                    self.sanitize_value(schema, path.child(key).index(idx));
826                }
827            }
828        }
829    }
830
831    fn ensure_object_root(&mut self, value: &mut Value) {
832        let Some(obj) = value.as_object_mut() else {
833            return;
834        };
835        obj.entry("type".to_string())
836            .or_insert_with(|| Value::String("object".to_string()));
837        if !schema_type_contains(obj, "object") {
838            self.errors
839                .push("$: root schema must be an object schema".to_string());
840            return;
841        }
842        obj.entry("properties".to_string())
843            .or_insert_with(|| Value::Object(Map::new()));
844    }
845
846    fn convert_const(&mut self, obj: &mut Map<String, Value>, path: &Path) {
847        if let Some(value) = obj.remove("const") {
848            obj.entry("enum".to_string())
849                .or_insert_with(|| Value::Array(vec![value]));
850            self.diagnostics
851                .push(format!("{path}: converted const to single-value enum"));
852        }
853    }
854
855    fn infer_rootless_type(&mut self, obj: &mut Map<String, Value>, path: &Path) {
856        if obj.contains_key("type") {
857            return;
858        }
859        let inferred = if obj.contains_key("properties")
860            || obj.contains_key("required")
861            || obj.contains_key("additionalProperties")
862        {
863            Some("object")
864        } else if obj.contains_key("items") {
865            Some("array")
866        } else if let Some(enum_values) = obj.get("enum").and_then(Value::as_array) {
867            infer_enum_type(enum_values)
868        } else {
869            None
870        };
871        if let Some(inferred) = inferred {
872            obj.insert("type".to_string(), Value::String(inferred.to_string()));
873            self.diagnostics
874                .push(format!("{path}: inferred missing type as {inferred}"));
875        }
876    }
877
878    fn strip_validation_keywords(&mut self, obj: &mut Map<String, Value>, path: &Path) {
879        let strip_keys = [
880            "minItems",
881            "maxItems",
882            "uniqueItems",
883            "minLength",
884            "maxLength",
885            "pattern",
886            "format",
887            "minimum",
888            "maximum",
889            "exclusiveMinimum",
890            "exclusiveMaximum",
891            "multipleOf",
892        ];
893        let mut stripped = Vec::new();
894        for key in strip_keys {
895            if let Some(value) = obj.remove(key) {
896                stripped.push(format!("{key}={}", compact_value(&value)));
897                self.diagnostics.push(format!(
898                    "{path}: removed unsupported provider schema keyword `{key}`"
899                ));
900            }
901        }
902        if !stripped.is_empty() {
903            append_description_constraint(
904                obj,
905                &format!("Provider compatibility note: {}", stripped.join(", ")),
906            );
907        }
908    }
909}
910
911fn append_description_constraint(obj: &mut Map<String, Value>, note: &str) {
912    match obj.get_mut("description") {
913        Some(Value::String(description)) if !description.trim().is_empty() => {
914            if !description.ends_with('.') {
915                description.push('.');
916            }
917            description.push(' ');
918            description.push_str(note);
919        }
920        _ => {
921            obj.insert("description".to_string(), Value::String(note.to_string()));
922        }
923    }
924}
925
926fn compact_value(value: &Value) -> String {
927    match value {
928        Value::String(value) => format!("{value:?}"),
929        _ => value.to_string(),
930    }
931}
932
933#[derive(Clone, Debug)]
934struct Path(String);
935
936impl Path {
937    fn root() -> Self {
938        Self("$".to_string())
939    }
940
941    fn child(&self, segment: &str) -> Self {
942        Self(format!("{}.{}", self.0, segment))
943    }
944
945    fn index(&self, index: usize) -> Self {
946        Self(format!("{}[{index}]", self.0))
947    }
948}
949
950impl std::fmt::Display for Path {
951    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
952        f.write_str(&self.0)
953    }
954}
955
956fn schema_type_contains(obj: &Map<String, Value>, expected: &str) -> bool {
957    match obj.get("type") {
958        Some(Value::String(value)) => value == expected,
959        Some(Value::Array(values)) => values.iter().any(|value| value.as_str() == Some(expected)),
960        _ => false,
961    }
962}
963
964fn required_set(obj: &Map<String, Value>) -> Vec<String> {
965    obj.get("required")
966        .and_then(Value::as_array)
967        .map(|values| {
968            values
969                .iter()
970                .filter_map(Value::as_str)
971                .map(str::to_string)
972                .collect()
973        })
974        .unwrap_or_default()
975}
976
977fn infer_enum_type(values: &[Value]) -> Option<&'static str> {
978    let mut inferred = None;
979    for value in values {
980        let value_type = match value {
981            Value::String(_) => "string",
982            Value::Bool(_) => "boolean",
983            Value::Number(number) if number.is_i64() || number.is_u64() => "integer",
984            Value::Number(_) => "number",
985            Value::Null => continue,
986            _ => return None,
987        };
988        match inferred {
989            Some(existing) if existing != value_type => return None,
990            Some(_) => {}
991            None => inferred = Some(value_type),
992        }
993    }
994    inferred
995}
996
997fn make_nullable(schema: &mut Value) {
998    let Some(obj) = schema.as_object_mut() else {
999        return;
1000    };
1001    if let Some(any_of) = obj.get_mut("anyOf").and_then(Value::as_array_mut) {
1002        if !any_of.iter().any(is_null_schema) {
1003            any_of.push(json!({ "type": "null" }));
1004        }
1005        return;
1006    }
1007    match obj.get_mut("type") {
1008        Some(Value::String(value)) if value != "null" => {
1009            let original = value.clone();
1010            obj.insert(
1011                "type".to_string(),
1012                Value::Array(vec![
1013                    Value::String(original),
1014                    Value::String("null".to_string()),
1015                ]),
1016            );
1017        }
1018        Some(Value::Array(values))
1019            if !values.iter().any(|value| value.as_str() == Some("null")) =>
1020        {
1021            values.push(Value::String("null".to_string()));
1022        }
1023        None => {
1024            let original = Value::Object(obj.clone());
1025            obj.clear();
1026            obj.insert(
1027                "anyOf".to_string(),
1028                Value::Array(vec![original, json!({ "type": "null" })]),
1029            );
1030        }
1031        _ => {}
1032    }
1033}
1034
1035fn is_null_schema(value: &Value) -> bool {
1036    value
1037        .as_object()
1038        .is_some_and(|obj| schema_type_contains(obj, "null"))
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043    use super::*;
1044    use serde_json::json;
1045
1046    fn required_names(schema: &Value) -> Vec<String> {
1047        let mut names = schema["required"]
1048            .as_array()
1049            .unwrap()
1050            .iter()
1051            .map(|value| value.as_str().unwrap().to_string())
1052            .collect::<Vec<_>>();
1053        names.sort();
1054        names
1055    }
1056
1057    #[test]
1058    fn projection_does_not_mutate_canonical_schema() {
1059        let schema = json!({"type": "object"});
1060        let projected = project_tool_parameters(&schema).unwrap();
1061        assert_eq!(schema, json!({"type": "object"}));
1062        assert_eq!(projected.schema["properties"], json!({}));
1063    }
1064
1065    #[test]
1066    fn tool_parameters_repairs_empty_root_object() {
1067        let projected = project_tool_parameters(&json!({})).unwrap();
1068        assert_eq!(projected.schema["type"], "object");
1069        assert_eq!(projected.schema["properties"], json!({}));
1070        assert!(projected.diagnostics.iter().any(|d| d.contains("inferred")));
1071    }
1072
1073    #[test]
1074    fn tool_parameters_repairs_missing_properties_missing_type_and_const() {
1075        let schema = json!({
1076            "properties": {
1077                "mode": { "const": "fast" }
1078            }
1079        });
1080        let projected = project_tool_parameters(&schema).unwrap();
1081        assert_eq!(projected.schema["type"], "object");
1082        assert_eq!(
1083            projected.schema["properties"]["mode"]["enum"],
1084            json!(["fast"])
1085        );
1086        assert!(
1087            projected.schema["properties"]["mode"]
1088                .get("const")
1089                .is_none()
1090        );
1091    }
1092
1093    #[test]
1094    fn tool_parameters_infers_array_and_enum_types() {
1095        let schema = json!({
1096            "type": "object",
1097            "properties": {
1098                "tags": { "items": { "type": "string" } },
1099                "level": { "enum": [1, 2, 3] }
1100            }
1101        });
1102        let projected = project_tool_parameters(&schema).unwrap();
1103        assert_eq!(projected.schema["properties"]["tags"]["type"], "array");
1104        assert_eq!(projected.schema["properties"]["level"]["type"], "integer");
1105    }
1106
1107    #[test]
1108    fn strict_projection_requires_optional_nullable_properties() {
1109        let schema = json!({
1110            "type": "object",
1111            "properties": {
1112                "required_name": { "type": "string" },
1113                "optional_count": { "type": "integer" }
1114            },
1115            "required": ["required_name"]
1116        });
1117        let projected = project_strict_tool_parameters(&schema).unwrap();
1118        assert_eq!(
1119            required_names(&projected.schema),
1120            vec!["optional_count", "required_name"]
1121        );
1122        assert_eq!(
1123            projected.schema["properties"]["optional_count"]["type"],
1124            json!(["integer", "null"])
1125        );
1126        assert_eq!(projected.schema["additionalProperties"], false);
1127    }
1128
1129    #[test]
1130    fn strict_projection_preserves_required_nonnullable_fields() {
1131        let schema = json!({
1132            "type": "object",
1133            "properties": {
1134                "name": { "type": "string" },
1135                "age": { "type": "integer" }
1136            },
1137            "required": ["name", "age"]
1138        });
1139        let projected = project_strict_tool_parameters(&schema).unwrap();
1140        assert_eq!(projected.schema["properties"]["name"]["type"], "string");
1141        assert_eq!(projected.schema["properties"]["age"]["type"], "integer");
1142    }
1143
1144    #[test]
1145    fn strict_projection_does_not_duplicate_existing_nullable_type() {
1146        let schema = json!({
1147            "type": "object",
1148            "properties": {
1149                "name": { "type": ["string", "null"] }
1150            }
1151        });
1152        let projected = project_strict_tool_parameters(&schema).unwrap();
1153        assert_eq!(
1154            projected.schema["properties"]["name"]["type"],
1155            json!(["string", "null"])
1156        );
1157    }
1158
1159    #[test]
1160    fn strict_projection_adds_null_branch_to_optional_any_of() {
1161        let schema = json!({
1162            "type": "object",
1163            "properties": {
1164                "value": {
1165                    "anyOf": [
1166                        { "type": "string" },
1167                        { "type": "integer" }
1168                    ]
1169                }
1170            }
1171        });
1172        let projected = project_strict_tool_parameters(&schema).unwrap();
1173        assert_eq!(
1174            projected.schema["properties"]["value"]["anyOf"][2],
1175            json!({ "type": "null" })
1176        );
1177    }
1178
1179    #[test]
1180    fn strict_projection_flattens_single_branch_all_of_before_nullable() {
1181        let schema = json!({
1182            "type": "object",
1183            "properties": {
1184                "limit": {
1185                    "description": "Maximum number of results.",
1186                    "allOf": [
1187                        {
1188                            "type": "integer",
1189                            "minimum": 1,
1190                            "maximum": 100
1191                        }
1192                    ]
1193                }
1194            }
1195        });
1196        let projected = project_strict_tool_parameters(&schema).unwrap();
1197        let limit = &projected.schema["properties"]["limit"];
1198        assert!(limit.get("allOf").is_none());
1199        assert_eq!(limit["description"], "Maximum number of results.");
1200        assert_eq!(limit["type"], json!(["integer", "null"]));
1201        assert_eq!(limit["minimum"], 1);
1202        assert_eq!(limit["maximum"], 100);
1203        assert!(
1204            projected
1205                .diagnostics
1206                .iter()
1207                .any(|diagnostic| diagnostic.contains("flattened single-branch allOf"))
1208        );
1209    }
1210
1211    #[test]
1212    fn structured_output_enforces_strict_objects() {
1213        let schema = json!({
1214            "type": "object",
1215            "properties": {
1216                "value": { "type": "string" }
1217            }
1218        });
1219        let projected = project_structured_output(&schema).unwrap();
1220        assert_eq!(projected.schema["required"], json!(["value"]));
1221        assert_eq!(projected.schema["additionalProperties"], false);
1222    }
1223
1224    #[test]
1225    fn structured_output_recurses_into_nested_objects_and_arrays() {
1226        let schema = json!({
1227            "type": "object",
1228            "properties": {
1229                "items": {
1230                    "type": "array",
1231                    "items": {
1232                        "type": "object",
1233                        "properties": {
1234                            "title": { "type": "string" },
1235                            "score": { "type": "number" }
1236                        },
1237                        "required": ["title"]
1238                    }
1239                }
1240            }
1241        });
1242        let projected = project_structured_output(&schema).unwrap();
1243        let nested = &projected.schema["properties"]["items"]["items"];
1244        assert_eq!(nested["additionalProperties"], false);
1245        assert_eq!(required_names(nested), vec!["score", "title"]);
1246        assert_eq!(
1247            nested["properties"]["score"]["type"],
1248            json!(["number", "null"])
1249        );
1250    }
1251
1252    #[test]
1253    fn structured_output_recurses_into_defs_and_definitions() {
1254        let schema = json!({
1255            "type": "object",
1256            "properties": {
1257                "item": { "$ref": "#/$defs/Item" },
1258                "legacy": { "$ref": "#/definitions/Legacy" }
1259            },
1260            "$defs": {
1261                "Item": {
1262                    "type": "object",
1263                    "properties": { "id": { "type": "string" } }
1264                }
1265            },
1266            "definitions": {
1267                "Legacy": {
1268                    "type": "object",
1269                    "properties": { "flag": { "type": "boolean" } }
1270                }
1271            }
1272        });
1273        let projected = project_structured_output(&schema).unwrap();
1274        assert_eq!(
1275            projected.schema["$defs"]["Item"]["additionalProperties"],
1276            false
1277        );
1278        assert_eq!(projected.schema["$defs"]["Item"]["required"], json!(["id"]));
1279        assert_eq!(
1280            projected.schema["definitions"]["Legacy"]["additionalProperties"],
1281            false
1282        );
1283        assert_eq!(
1284            projected.schema["definitions"]["Legacy"]["required"],
1285            json!(["flag"])
1286        );
1287    }
1288
1289    #[test]
1290    fn structured_output_allows_nested_any_of_and_projects_branches() {
1291        let schema = json!({
1292            "type": "object",
1293            "properties": {
1294                "value": {
1295                    "anyOf": [
1296                        { "type": "string" },
1297                        {
1298                            "type": "object",
1299                            "properties": { "count": { "type": "integer" } }
1300                        }
1301                    ]
1302                }
1303            }
1304        });
1305        let projected = project_structured_output(&schema).unwrap();
1306        let object_branch = &projected.schema["properties"]["value"]["anyOf"][1];
1307        assert_eq!(object_branch["required"], json!(["count"]));
1308        assert_eq!(object_branch["additionalProperties"], false);
1309        assert_eq!(
1310            projected.schema["properties"]["value"]["anyOf"][2],
1311            json!({ "type": "null" })
1312        );
1313    }
1314
1315    #[test]
1316    fn structured_output_rejects_root_scalar() {
1317        let err = project_structured_output(&json!({ "type": "string" })).unwrap_err();
1318        assert!(
1319            err.diagnostics
1320                .iter()
1321                .any(|diagnostic| diagnostic.contains("root schema must be an object"))
1322        );
1323    }
1324
1325    #[test]
1326    fn structured_output_rejects_root_any_of() {
1327        let err = project_structured_output(&json!({
1328            "anyOf": [
1329                { "type": "object", "properties": {} },
1330                { "type": "object", "properties": {} }
1331            ]
1332        }))
1333        .unwrap_err();
1334        assert!(
1335            err.diagnostics
1336                .iter()
1337                .any(|diagnostic| diagnostic.contains("root anyOf"))
1338        );
1339    }
1340
1341    #[test]
1342    fn projection_rejects_unsupported_lossy_keywords() {
1343        let err = project_structured_output(&json!({
1344            "type": "object",
1345            "properties": {},
1346            "allOf": [
1347                { "type": "object", "properties": {} },
1348                { "type": "object", "properties": {} }
1349            ],
1350            "patternProperties": {}
1351        }))
1352        .unwrap_err();
1353        assert!(
1354            err.diagnostics
1355                .iter()
1356                .any(|diagnostic| diagnostic.contains("allOf"))
1357        );
1358        assert!(
1359            err.diagnostics
1360                .iter()
1361                .any(|diagnostic| diagnostic.contains("patternProperties"))
1362        );
1363    }
1364
1365    #[test]
1366    fn projection_rejects_non_object_properties() {
1367        let err = project_tool_parameters(&json!({
1368            "type": "object",
1369            "properties": []
1370        }))
1371        .unwrap_err();
1372        assert!(
1373            err.diagnostics
1374                .iter()
1375                .any(|diagnostic| diagnostic.contains("properties must be an object"))
1376        );
1377    }
1378
1379    #[test]
1380    fn resolver_auto_prefers_explicit_override_for_matching_dialect() {
1381        let contract = SchemaContract::new(json!({
1382            "type": "object",
1383            "properties": { "raw": { "const": "x" } }
1384        }))
1385        .with_override(
1386            SchemaDialect::OPENAI_TOOL_PARAMETERS,
1387            json!({
1388                "type": "object",
1389                "properties": { "raw": { "type": "string", "enum": ["x"] } }
1390            }),
1391        );
1392
1393        let resolved = resolve_schema(
1394            &contract,
1395            SchemaResolutionRequest {
1396                provider: "test",
1397                purpose: SchemaPurpose::ToolInput,
1398                dialects: &[SchemaDialect::openai_tool_parameters()],
1399            },
1400        )
1401        .unwrap();
1402
1403        assert_eq!(
1404            resolved.schema["properties"]["raw"],
1405            json!({ "type": "string", "enum": ["x"] })
1406        );
1407    }
1408
1409    #[test]
1410    fn resolver_explicit_only_fails_without_matching_override() {
1411        let mut contract = SchemaContract::new(json!({
1412            "type": "object",
1413            "properties": {}
1414        }));
1415        contract.projection.mode = ProjectionMode::ExplicitOnly;
1416
1417        let err = resolve_schema(
1418            &contract,
1419            SchemaResolutionRequest {
1420                provider: "test",
1421                purpose: SchemaPurpose::StructuredOutput,
1422                dialects: &[SchemaDialect::openai_structured_output()],
1423            },
1424        )
1425        .unwrap_err();
1426
1427        assert!(
1428            err.diagnostics
1429                .iter()
1430                .any(|diagnostic| diagnostic.contains("no explicit projection override"))
1431        );
1432    }
1433
1434    #[test]
1435    fn bedrock_projection_strips_array_constraints_from_wire_schema_only() {
1436        let contract = SchemaContract::new(json!({
1437            "type": "object",
1438            "required": ["ranked"],
1439            "properties": {
1440                "ranked": {
1441                    "type": "array",
1442                    "minItems": 3,
1443                    "maxItems": 3,
1444                    "items": { "type": "string" }
1445                }
1446            }
1447        }));
1448
1449        let resolved = resolve_schema(
1450            &contract,
1451            SchemaResolutionRequest {
1452                provider: "test",
1453                purpose: SchemaPurpose::StructuredOutput,
1454                dialects: &[SchemaDialect::bedrock_claude_output_config_json_schema()],
1455            },
1456        )
1457        .unwrap();
1458
1459        let ranked = &resolved.schema["properties"]["ranked"];
1460        assert!(ranked.get("minItems").is_none());
1461        assert!(ranked.get("maxItems").is_none());
1462        assert_eq!(contract.canonical["properties"]["ranked"]["minItems"], 3);
1463        assert!(
1464            ranked["description"]
1465                .as_str()
1466                .is_some_and(|description| description.contains("minItems=3"))
1467        );
1468    }
1469}