Skip to main content

better_fetch/
schema_validate.rs

1//! Runtime JSON Schema validation against a [`SchemaRegistry`](crate::schema::SchemaRegistry).
2
3use std::collections::HashMap;
4
5use http::Method;
6use indexmap::IndexMap;
7use jsonschema::{Draft, Validator};
8use schemars::schema::RootSchema;
9use serde_json::Value;
10
11use crate::error::Error;
12use crate::response::Response;
13use crate::schema::SchemaRegistry;
14use crate::url_build::QueryValue;
15use crate::Result;
16
17/// Context for validating a streamed response after [`StreamingResponse::collect`](crate::StreamingResponse::collect).
18#[cfg(feature = "schema-validate")]
19#[derive(Clone)]
20pub(crate) struct StreamResponseSchemaCtx {
21    pub registry: std::sync::Arc<SchemaRegistry>,
22    pub route_path: String,
23    pub method: Method,
24}
25
26/// Validates a buffered [`Response`] when strict mode and a response schema are registered.
27#[cfg(feature = "schema-validate")]
28pub(crate) fn validate_response_if_registered(
29    registry: &SchemaRegistry,
30    path: &str,
31    method: &Method,
32    response: &Response,
33) -> Result<()> {
34    if !registry.is_strict() || !response.is_success() {
35        return Ok(());
36    }
37    if registry.response_schema(path, method).is_none() {
38        return Ok(());
39    }
40    let bytes = response.bytes();
41    if bytes.is_empty() {
42        return Ok(());
43    }
44    let value: Value = serde_json::from_slice(bytes).map_err(|e| Error::SchemaValidation {
45        phase: "response",
46        message: format!("response body is not JSON: {e}"),
47    })?;
48    validate_response(registry, path, method, &value)
49}
50
51/// Validates a JSON request body against the registered request schema for `path` + `method`.
52///
53/// No-op when the registry is not [strict](SchemaRegistry::is_strict).
54pub fn validate_request(
55    registry: &SchemaRegistry,
56    path: &str,
57    method: &Method,
58    body: &Value,
59) -> Result<()> {
60    if !registry.is_strict() {
61        return Ok(());
62    }
63    let Some(schema) = registry.request_schema(path, method) else {
64        return Ok(());
65    };
66    validate_value(schema, body, "request")
67}
68
69/// Validates a JSON response body against the registered response schema for `path` + `method`.
70///
71/// No-op when the registry is not strict.
72pub fn validate_response(
73    registry: &SchemaRegistry,
74    path: &str,
75    method: &Method,
76    body: &Value,
77) -> Result<()> {
78    if !registry.is_strict() {
79        return Ok(());
80    }
81    let Some(schema) = registry.response_schema(path, method) else {
82        return Ok(());
83    };
84    validate_value(schema, body, "response")
85}
86
87/// Validates path parameters (as a JSON object) when a params schema is registered.
88///
89/// Wire values are coerced from strings (numbers, booleans) before validation. No-op when not strict.
90pub fn validate_params(
91    registry: &SchemaRegistry,
92    path: &str,
93    method: &Method,
94    params: &HashMap<String, String>,
95) -> Result<()> {
96    if !registry.is_strict() {
97        return Ok(());
98    }
99    let Some(schema) = registry.params_schema(path, method) else {
100        return Ok(());
101    };
102    validate_value(schema, &params_to_json(params), "params")
103}
104
105/// Validates query parameters (as a JSON object) when a query schema is registered.
106///
107/// Wire values are coerced from strings (numbers, booleans) before validation. No-op when not strict.
108pub fn validate_query(
109    registry: &SchemaRegistry,
110    path: &str,
111    method: &Method,
112    query: &IndexMap<String, QueryValue>,
113) -> Result<()> {
114    if !registry.is_strict() {
115        return Ok(());
116    }
117    let Some(schema) = registry.query_schema(path, method) else {
118        return Ok(());
119    };
120    validate_value(schema, &query_to_json(query), "query")
121}
122
123/// Coerces a single query/path wire string into a JSON value for schema validation.
124pub(crate) fn wire_scalar_to_json(s: &str) -> Value {
125    match s {
126        "true" => Value::Bool(true),
127        "false" => Value::Bool(false),
128        _ => {
129            if let Ok(n) = s.parse::<i64>() {
130                return Value::Number(n.into());
131            }
132            if let Ok(n) = s.parse::<u64>() {
133                return Value::Number(n.into());
134            }
135            if let Ok(n) = s.parse::<f64>() {
136                if let Some(num) = serde_json::Number::from_f64(n) {
137                    return Value::Number(num);
138                }
139            }
140            Value::String(s.to_owned())
141        }
142    }
143}
144
145fn params_to_json(params: &HashMap<String, String>) -> Value {
146    let mut map = serde_json::Map::new();
147    for (key, value) in params {
148        map.insert(key.clone(), wire_scalar_to_json(value));
149    }
150    Value::Object(map)
151}
152
153fn query_to_json(query: &IndexMap<String, QueryValue>) -> Value {
154    let mut map = serde_json::Map::new();
155    for (key, value) in query {
156        let json_value = match value {
157            QueryValue::Scalar(s) => wire_scalar_to_json(s),
158            QueryValue::Array(values) => {
159                Value::Array(values.iter().map(|v| wire_scalar_to_json(v)).collect())
160            }
161        };
162        map.insert(key.clone(), json_value);
163    }
164    Value::Object(map)
165}
166
167fn validate_value(schema: &RootSchema, value: &Value, phase: &'static str) -> Result<()> {
168    let validator = Validator::options()
169        .with_draft(Draft::Draft7)
170        .build(&serde_json::to_value(schema).map_err(|e| Error::Config(e.to_string()))?)
171        .map_err(|e| Error::Config(e.to_string()))?;
172    validator
173        .validate(value)
174        .map_err(|error| Error::SchemaValidation {
175            phase,
176            message: error.to_string(),
177        })
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use serde_json::json;
184
185    #[test]
186    fn wire_scalar_coerces_numbers_and_bools() {
187        assert_eq!(wire_scalar_to_json("42"), json!(42));
188        assert_eq!(wire_scalar_to_json("true"), json!(true));
189        assert_eq!(wire_scalar_to_json("hello"), json!("hello"));
190    }
191
192    #[test]
193    fn query_json_scalar_and_array() {
194        let mut q = IndexMap::new();
195        q.insert("tag".into(), QueryValue::Scalar("a".into()));
196        q.insert(
197            "ids".into(),
198            QueryValue::Array(vec!["1".into(), "2".into()]),
199        );
200        let v = query_to_json(&q);
201        assert_eq!(v["tag"], json!("a"));
202        assert_eq!(v["ids"], json!([1, 2]));
203    }
204
205    #[test]
206    fn params_json_coerces_id() {
207        let mut p = HashMap::new();
208        p.insert("id".into(), "7".into());
209        assert_eq!(params_to_json(&p)["id"], json!(7));
210    }
211}