use std::collections::HashSet;
use serde_json::Value;
#[derive(Debug)]
pub struct ValidationError {
pub field: String,
pub msg: String,
pub error_type: String,
}
pub struct InputValidator {
validator: jsonschema::Validator,
properties: HashSet<String>,
required: Vec<String>,
}
impl InputValidator {
pub fn from_openapi_schema(schema: &Value) -> Option<Self> {
Self::from_openapi_schema_key(schema, "Input")
}
pub fn from_openapi_schema_key(schema: &Value, key: &str) -> Option<Self> {
let input_schema = schema.get("components")?.get("schemas")?.get(key)?;
let properties: HashSet<String> = input_schema
.get("properties")
.and_then(|p| p.as_object())
.map(|obj| obj.keys().cloned().collect())
.unwrap_or_default();
let required: Vec<String> = input_schema
.get("required")
.and_then(|r| r.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let mut resolved = input_schema.clone();
let all_schemas = schema.get("components").and_then(|c| c.get("schemas"));
inline_refs(&mut resolved, all_schemas);
let validator = jsonschema::validator_for(&resolved)
.inspect_err(|e| {
tracing::warn!(error = %e, "Failed to compile input schema validator");
})
.ok()?;
Some(Self {
validator,
properties,
required,
})
}
pub fn required_count(&self) -> usize {
self.required.len()
}
pub fn strip_unknown(&self, input: &mut Value) -> Vec<String> {
let Some(obj) = input.as_object_mut() else {
return Vec::new();
};
let unknown_keys: Vec<String> = obj
.keys()
.filter(|k| !self.properties.contains(*k))
.cloned()
.collect();
for key in &unknown_keys {
obj.remove(key);
}
unknown_keys
}
pub fn validate(&self, input: &Value) -> Result<(), Vec<ValidationError>> {
if self.validator.validate(input).is_ok() {
return Ok(());
}
let mut errors = Vec::new();
let mut seen_required = false;
for error in self.validator.iter_errors(input) {
let msg = error.to_string();
if msg.contains("is a required property") && !seen_required {
seen_required = true;
let input_obj = input.as_object();
for field in &self.required {
let present = input_obj
.map(|obj| obj.contains_key(field))
.unwrap_or(false);
if !present {
errors.push(ValidationError {
field: field.clone(),
msg: "Field required".to_string(),
error_type: "value_error.missing".to_string(),
});
}
}
continue;
}
if seen_required && msg.contains("is a required property") {
continue;
}
let path = error.instance_path.to_string();
let field = path.trim_start_matches('/');
let field_name = if field.is_empty() {
"__root__".to_string()
} else {
field.to_string()
};
errors.push(ValidationError {
field: field_name,
msg,
error_type: "value_error".to_string(),
});
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
fn inline_refs(value: &mut Value, all_schemas: Option<&Value>) {
match value {
Value::Object(obj) => {
if let Some(Value::String(ref_str)) = obj.get("$ref")
&& let Some(resolved) = resolve_ref(ref_str, all_schemas)
{
*value = resolved;
inline_refs(value, all_schemas);
return;
}
for v in obj.values_mut() {
inline_refs(v, all_schemas);
}
}
Value::Array(arr) => {
for v in arr.iter_mut() {
inline_refs(v, all_schemas);
}
}
_ => {}
}
}
fn resolve_ref(ref_str: &str, all_schemas: Option<&Value>) -> Option<Value> {
let name = ref_str.strip_prefix("#/components/schemas/")?;
all_schemas?.get(name).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_schema(input_schema: Value) -> Value {
json!({
"components": {
"schemas": {
"Input": input_schema
}
}
})
}
#[test]
fn validates_required_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
assert!(validator.validate(&json!({"s": "hello"})).is_ok());
let errs = validator.validate(&json!({})).unwrap_err();
assert_eq!(errs.len(), 1);
assert_eq!(errs[0].field, "s");
assert_eq!(errs[0].msg, "Field required");
}
#[test]
fn allows_additional_properties_in_validate() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
assert!(
validator
.validate(&json!({"s": "hello", "extra": "bad"}))
.is_ok(),
"unknown inputs should not cause validation errors"
);
}
#[test]
fn strip_unknown_removes_extra_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
let mut input = json!({"s": "hello", "guidance_scale": 7.5, "extra": "bad"});
let stripped = validator.strip_unknown(&mut input);
assert_eq!(stripped.len(), 2);
assert!(stripped.contains(&"guidance_scale".to_string()));
assert!(stripped.contains(&"extra".to_string()));
assert_eq!(input, json!({"s": "hello"}));
}
#[test]
fn strip_unknown_preserves_known_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"},
"n": {"type": "integer"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
let mut input = json!({"s": "hello", "n": 42});
let stripped = validator.strip_unknown(&mut input);
assert!(stripped.is_empty());
assert_eq!(input, json!({"s": "hello", "n": 42}));
}
#[test]
fn strip_unknown_returns_empty_for_no_extra_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
let mut input = json!({"s": "hello"});
let stripped = validator.strip_unknown(&mut input);
assert!(stripped.is_empty());
}
#[test]
fn missing_required_with_extra_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
let mut input = json!({"wrong": "value"});
let stripped = validator.strip_unknown(&mut input);
assert_eq!(stripped, vec!["wrong".to_string()]);
let errs = validator.validate(&input).unwrap_err();
assert_eq!(errs.len(), 1);
assert_eq!(errs[0].field, "s");
assert_eq!(errs[0].msg, "Field required");
}
#[test]
fn validates_types() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"count": {"type": "integer", "title": "Count"}
},
"required": ["count"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
assert!(validator.validate(&json!({"count": 5})).is_ok());
let errs = validator
.validate(&json!({"count": "not_a_number"}))
.unwrap_err();
assert_eq!(errs[0].field, "count");
}
#[test]
fn no_schema_returns_none() {
let schema = json!({"components": {"schemas": {}}});
assert!(InputValidator::from_openapi_schema(&schema).is_none());
}
#[test]
fn resolves_ref_for_choices() {
let schema = json!({
"components": {
"schemas": {
"Input": {
"type": "object",
"properties": {
"color": {
"allOf": [{"$ref": "#/components/schemas/Color"}],
"x-order": 0
}
},
"required": ["color"]
},
"Color": {
"title": "Color",
"description": "An enumeration.",
"enum": ["red", "green", "blue"],
"type": "string"
}
}
}
});
let validator = InputValidator::from_openapi_schema(&schema);
assert!(validator.is_some(), "validator should compile with $ref");
let validator = validator.unwrap();
assert!(validator.validate(&json!({"color": "red"})).is_ok());
assert!(validator.validate(&json!({"color": "purple"})).is_err());
}
#[test]
fn optional_fields_work() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string"},
"n": {"type": "integer"}
},
"required": ["s"]
}));
let validator = InputValidator::from_openapi_schema(&schema).unwrap();
assert!(validator.validate(&json!({"s": "hello"})).is_ok());
assert!(validator.validate(&json!({"s": "hello", "n": 42})).is_ok());
}
}