better_fetch/
schema_validate.rs1use std::collections::HashMap;
4
5use http::Method;
6use indexmap::IndexMap;
7use jsonschema::{Draft, Validator};
8use schemars::schema::RootSchema;
9use serde_json::Value;
10
11use crate::error::Error;
12use crate::response::Response;
13use crate::schema::SchemaRegistry;
14use crate::url_build::QueryValue;
15use crate::Result;
16
17#[cfg(feature = "schema-validate")]
19#[derive(Clone)]
20pub(crate) struct StreamResponseSchemaCtx {
21 pub registry: std::sync::Arc<SchemaRegistry>,
22 pub route_path: String,
23 pub method: Method,
24}
25
26#[cfg(feature = "schema-validate")]
28pub(crate) fn validate_response_if_registered(
29 registry: &SchemaRegistry,
30 path: &str,
31 method: &Method,
32 response: &Response,
33) -> Result<()> {
34 if !registry.is_strict() || !response.is_success() {
35 return Ok(());
36 }
37 if registry.response_schema(path, method).is_none() {
38 return Ok(());
39 }
40 let bytes = response.bytes();
41 if bytes.is_empty() {
42 return Ok(());
43 }
44 let value: Value = serde_json::from_slice(bytes).map_err(|e| Error::SchemaValidation {
45 phase: "response",
46 message: format!("response body is not JSON: {e}"),
47 })?;
48 validate_response(registry, path, method, &value)
49}
50
51pub fn validate_request(
55 registry: &SchemaRegistry,
56 path: &str,
57 method: &Method,
58 body: &Value,
59) -> Result<()> {
60 if !registry.is_strict() {
61 return Ok(());
62 }
63 let Some(schema) = registry.request_schema(path, method) else {
64 return Ok(());
65 };
66 validate_value(schema, body, "request")
67}
68
69pub fn validate_response(
73 registry: &SchemaRegistry,
74 path: &str,
75 method: &Method,
76 body: &Value,
77) -> Result<()> {
78 if !registry.is_strict() {
79 return Ok(());
80 }
81 let Some(schema) = registry.response_schema(path, method) else {
82 return Ok(());
83 };
84 validate_value(schema, body, "response")
85}
86
87pub fn validate_params(
91 registry: &SchemaRegistry,
92 path: &str,
93 method: &Method,
94 params: &HashMap<String, String>,
95) -> Result<()> {
96 if !registry.is_strict() {
97 return Ok(());
98 }
99 let Some(schema) = registry.params_schema(path, method) else {
100 return Ok(());
101 };
102 validate_value(schema, ¶ms_to_json(params), "params")
103}
104
105pub fn validate_query(
109 registry: &SchemaRegistry,
110 path: &str,
111 method: &Method,
112 query: &IndexMap<String, QueryValue>,
113) -> Result<()> {
114 if !registry.is_strict() {
115 return Ok(());
116 }
117 let Some(schema) = registry.query_schema(path, method) else {
118 return Ok(());
119 };
120 validate_value(schema, &query_to_json(query), "query")
121}
122
123pub(crate) fn wire_scalar_to_json(s: &str) -> Value {
125 match s {
126 "true" => Value::Bool(true),
127 "false" => Value::Bool(false),
128 _ => {
129 if let Ok(n) = s.parse::<i64>() {
130 return Value::Number(n.into());
131 }
132 if let Ok(n) = s.parse::<u64>() {
133 return Value::Number(n.into());
134 }
135 if let Ok(n) = s.parse::<f64>() {
136 if let Some(num) = serde_json::Number::from_f64(n) {
137 return Value::Number(num);
138 }
139 }
140 Value::String(s.to_owned())
141 }
142 }
143}
144
145fn params_to_json(params: &HashMap<String, String>) -> Value {
146 let mut map = serde_json::Map::new();
147 for (key, value) in params {
148 map.insert(key.clone(), wire_scalar_to_json(value));
149 }
150 Value::Object(map)
151}
152
153fn query_to_json(query: &IndexMap<String, QueryValue>) -> Value {
154 let mut map = serde_json::Map::new();
155 for (key, value) in query {
156 let json_value = match value {
157 QueryValue::Scalar(s) => wire_scalar_to_json(s),
158 QueryValue::Array(values) => {
159 Value::Array(values.iter().map(|v| wire_scalar_to_json(v)).collect())
160 }
161 };
162 map.insert(key.clone(), json_value);
163 }
164 Value::Object(map)
165}
166
167fn validate_value(schema: &RootSchema, value: &Value, phase: &'static str) -> Result<()> {
168 let validator = Validator::options()
169 .with_draft(Draft::Draft7)
170 .build(&serde_json::to_value(schema).map_err(|e| Error::Config(e.to_string()))?)
171 .map_err(|e| Error::Config(e.to_string()))?;
172 validator
173 .validate(value)
174 .map_err(|error| Error::SchemaValidation {
175 phase,
176 message: error.to_string(),
177 })
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use serde_json::json;
184
185 #[test]
186 fn wire_scalar_coerces_numbers_and_bools() {
187 assert_eq!(wire_scalar_to_json("42"), json!(42));
188 assert_eq!(wire_scalar_to_json("true"), json!(true));
189 assert_eq!(wire_scalar_to_json("hello"), json!("hello"));
190 }
191
192 #[test]
193 fn query_json_scalar_and_array() {
194 let mut q = IndexMap::new();
195 q.insert("tag".into(), QueryValue::Scalar("a".into()));
196 q.insert(
197 "ids".into(),
198 QueryValue::Array(vec!["1".into(), "2".into()]),
199 );
200 let v = query_to_json(&q);
201 assert_eq!(v["tag"], json!("a"));
202 assert_eq!(v["ids"], json!([1, 2]));
203 }
204
205 #[test]
206 fn params_json_coerces_id() {
207 let mut p = HashMap::new();
208 p.insert("id".into(), "7".into());
209 assert_eq!(params_to_json(&p)["id"], json!(7));
210 }
211}