1use std::collections::HashMap;
6use std::sync::{Mutex, OnceLock};
7
8use regex::Regex;
9
10use crate::error::{DbError, ValidationError};
11use crate::record::RowValue;
12use crate::schema::{Constraint, FieldDef, Type};
13
14fn err(path: &[String], msg: impl Into<String>) -> DbError {
15 DbError::Validation(ValidationError {
16 path: path.to_vec(),
17 message: msg.into(),
18 })
19}
20
21fn regex_cache() -> &'static Mutex<HashMap<String, Regex>> {
22 static CACHE: OnceLock<Mutex<HashMap<String, Regex>>> = OnceLock::new();
23 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
24}
25
26fn compiled_regex(pattern: &str, path: &[String]) -> Result<Regex, DbError> {
27 if let Ok(g) = regex_cache().lock() {
28 if let Some(re) = g.get(pattern) {
29 return Ok(re.clone());
30 }
31 }
32 let re = Regex::new(pattern).map_err(|e| {
33 DbError::Validation(ValidationError {
34 path: path.to_vec(),
35 message: format!("invalid regex in schema: {e}"),
36 })
37 })?;
38 if let Ok(mut g) = regex_cache().lock() {
39 g.entry(pattern.to_string()).or_insert_with(|| re.clone());
40 }
41 Ok(re)
42}
43
44pub fn ensure_pk_type_primitive(ty: &Type) -> Result<(), DbError> {
46 match ty {
47 Type::Bool
48 | Type::Int64
49 | Type::Uint64
50 | Type::Float64
51 | Type::String
52 | Type::Bytes
53 | Type::Uuid
54 | Type::Timestamp => Ok(()),
55 Type::Optional(_) | Type::List(_) | Type::Object(_) | Type::Enum(_) => {
56 Err(DbError::Validation(ValidationError {
57 path: vec![],
58 message:
59 "primary key field must use a primitive type (not optional/list/object/enum)"
60 .into(),
61 }))
62 }
63 }
64}
65
66pub fn allows_absent_root(ty: &Type) -> bool {
68 matches!(ty, Type::Optional(_))
69}
70
71pub fn validate_value(
73 path: &mut Vec<String>,
74 ty: &Type,
75 constraints: &[Constraint],
76 v: &RowValue,
77) -> Result<(), DbError> {
78 match ty {
79 Type::Optional(inner) => {
80 if matches!(v, RowValue::None) {
81 return Ok(());
82 }
83 validate_value(path, inner, &[], v)?;
84 apply_constraints(path, ty, constraints, v)
85 }
86 Type::Bool => {
87 let RowValue::Bool(_) = v else {
88 return Err(err(path, "expected bool"));
89 };
90 apply_constraints(path, ty, constraints, v)
91 }
92 Type::Int64 => {
93 let RowValue::Int64(_) = v else {
94 return Err(err(path, "expected int64"));
95 };
96 apply_constraints(path, ty, constraints, v)
97 }
98 Type::Uint64 => {
99 let RowValue::Uint64(_) = v else {
100 return Err(err(path, "expected uint64"));
101 };
102 apply_constraints(path, ty, constraints, v)
103 }
104 Type::Float64 => {
105 let RowValue::Float64(_) = v else {
106 return Err(err(path, "expected float64"));
107 };
108 apply_constraints(path, ty, constraints, v)
109 }
110 Type::String => {
111 let RowValue::String(_) = v else {
112 return Err(err(path, "expected string"));
113 };
114 apply_constraints(path, ty, constraints, v)
115 }
116 Type::Bytes => {
117 let RowValue::Bytes(_) = v else {
118 return Err(err(path, "expected bytes"));
119 };
120 apply_constraints(path, ty, constraints, v)
121 }
122 Type::Uuid => {
123 let RowValue::Uuid(_) = v else {
124 return Err(err(path, "expected uuid"));
125 };
126 apply_constraints(path, ty, constraints, v)
127 }
128 Type::Timestamp => {
129 let RowValue::Timestamp(_) = v else {
130 return Err(err(path, "expected timestamp"));
131 };
132 apply_constraints(path, ty, constraints, v)
133 }
134 Type::List(inner) => {
135 let RowValue::List(items) = v else {
136 return Err(err(path, "expected list"));
137 };
138 for (i, item) in items.iter().enumerate() {
139 path.push(format!("{i}"));
140 validate_value(path, inner, &[], item)?;
141 path.pop();
142 }
143 apply_constraints(path, ty, constraints, v)
144 }
145 Type::Object(fields) => {
146 let RowValue::Object(m) = v else {
147 return Err(err(path, "expected object"));
148 };
149 for sub in fields {
150 let key = sub.path.0[0].to_string();
151 let absent_ok = allows_absent_root(&sub.ty);
152 let none = RowValue::None;
153 let child: &RowValue = match m.get(&key) {
154 None if absent_ok => &none,
155 None => {
156 path.push(key.clone());
157 return Err(err(path, "missing object field"));
158 }
159 Some(x) => x,
160 };
161 path.push(key);
162 validate_value(path, &sub.ty, &sub.constraints, child)?;
163 path.pop();
164 }
165 for k in m.keys() {
166 if !fields.iter().any(|f| f.path.0[0].as_ref() == k.as_str()) {
167 path.push(k.clone());
168 return Err(err(path, "unknown field in object"));
169 }
170 }
171 apply_constraints(path, ty, constraints, v)
172 }
173 Type::Enum(variants) => {
174 let RowValue::String(s) = v else {
175 return Err(err(path, "expected string (enum)"));
176 };
177 if !variants.iter().any(|x| x == s) {
178 return Err(err(
179 path,
180 format!("enum value must be one of {:?}", variants),
181 ));
182 }
183 apply_constraints(path, ty, constraints, v)
184 }
185 }
186}
187
188fn must_int64(path: &[String], v: &RowValue, requirement: &'static str) -> Result<i64, DbError> {
189 let RowValue::Int64(n) = v else {
190 return Err(err(path, requirement));
191 };
192 Ok(*n)
193}
194
195fn must_uint64(path: &[String], v: &RowValue, requirement: &'static str) -> Result<u64, DbError> {
196 let RowValue::Uint64(n) = v else {
197 return Err(err(path, requirement));
198 };
199 Ok(*n)
200}
201
202fn must_f64(path: &[String], v: &RowValue, requirement: &'static str) -> Result<f64, DbError> {
203 let RowValue::Float64(n) = v else {
204 return Err(err(path, requirement));
205 };
206 Ok(*n)
207}
208
209fn constrain_min_i64(path: &[String], n: i64, min: i64) -> Result<(), DbError> {
210 if n < min {
211 Err(err(path, format!("value {n} is below minimum {min}")))
212 } else {
213 Ok(())
214 }
215}
216
217fn constrain_max_i64(path: &[String], n: i64, max: i64) -> Result<(), DbError> {
218 if n > max {
219 Err(err(path, format!("value {n} is above maximum {max}")))
220 } else {
221 Ok(())
222 }
223}
224
225fn constrain_min_u64(path: &[String], n: u64, min: u64) -> Result<(), DbError> {
226 if n < min {
227 Err(err(path, format!("value {n} is below minimum {min}")))
228 } else {
229 Ok(())
230 }
231}
232
233fn constrain_max_u64(path: &[String], n: u64, max: u64) -> Result<(), DbError> {
234 if n > max {
235 Err(err(path, format!("value {n} is above maximum {max}")))
236 } else {
237 Ok(())
238 }
239}
240
241fn constrain_min_f64(path: &[String], n: f64, min: f64) -> Result<(), DbError> {
242 if n < min {
243 Err(err(path, format!("value {n} is below minimum {min}")))
244 } else {
245 Ok(())
246 }
247}
248
249fn constrain_max_f64(path: &[String], n: f64, max: f64) -> Result<(), DbError> {
250 if n > max {
251 Err(err(path, format!("value {n} is above maximum {max}")))
252 } else {
253 Ok(())
254 }
255}
256
257fn constrain_min_byte_len(
258 path: &[String],
259 len: usize,
260 min: u64,
261 kind: &str,
262) -> Result<(), DbError> {
263 if (len as u64) < min {
264 Err(err(
265 path,
266 format!("{kind} length {len} is below minimum {min}"),
267 ))
268 } else {
269 Ok(())
270 }
271}
272
273fn constrain_max_byte_len(
274 path: &[String],
275 len: usize,
276 max: u64,
277 kind: &str,
278) -> Result<(), DbError> {
279 if (len as u64) > max {
280 Err(err(
281 path,
282 format!("{kind} length {len} is above maximum {max}"),
283 ))
284 } else {
285 Ok(())
286 }
287}
288
289fn apply_constraints(
290 path: &[String],
291 _ty: &Type,
292 constraints: &[Constraint],
293 v: &RowValue,
294) -> Result<(), DbError> {
295 for c in constraints {
296 match c {
297 Constraint::MinI64(min) => {
298 let n = must_int64(path, v, "MinI64 constraint requires int64")?;
299 constrain_min_i64(path, n, *min)?;
300 }
301 Constraint::MaxI64(max) => {
302 let n = must_int64(path, v, "MaxI64 constraint requires int64")?;
303 constrain_max_i64(path, n, *max)?;
304 }
305 Constraint::MinU64(min) => {
306 let n = must_uint64(path, v, "MinU64 constraint requires uint64")?;
307 constrain_min_u64(path, n, *min)?;
308 }
309 Constraint::MaxU64(max) => {
310 let n = must_uint64(path, v, "MaxU64 constraint requires uint64")?;
311 constrain_max_u64(path, n, *max)?;
312 }
313 Constraint::MinF64(min) => {
314 let n = must_f64(path, v, "MinF64 constraint requires float64")?;
315 constrain_min_f64(path, n, *min)?;
316 }
317 Constraint::MaxF64(max) => {
318 let n = must_f64(path, v, "MaxF64 constraint requires float64")?;
319 constrain_max_f64(path, n, *max)?;
320 }
321 Constraint::MinLength(min) => match v {
322 RowValue::String(s) => constrain_min_byte_len(path, s.len(), *min, "string")?,
323 RowValue::Bytes(b) => constrain_min_byte_len(path, b.len(), *min, "bytes")?,
324 RowValue::List(items) => constrain_min_byte_len(path, items.len(), *min, "list")?,
325 _ => return Err(err(path, "MinLength applies to string, bytes, or list")),
326 },
327 Constraint::MaxLength(max) => match v {
328 RowValue::String(s) => constrain_max_byte_len(path, s.len(), *max, "string")?,
329 RowValue::Bytes(b) => constrain_max_byte_len(path, b.len(), *max, "bytes")?,
330 RowValue::List(items) => constrain_max_byte_len(path, items.len(), *max, "list")?,
331 _ => return Err(err(path, "MaxLength applies to string, bytes, or list")),
332 },
333 Constraint::Regex(pattern) => {
334 let RowValue::String(s) = v else {
335 return Err(err(path, "Regex constraint requires string"));
336 };
337 let re = compiled_regex(pattern, path)?;
338 if !re.is_match(s) {
339 return Err(err(path, "string does not match regex"));
340 }
341 }
342 Constraint::Email => {
343 let RowValue::String(s) = v else {
344 return Err(err(path, "Email constraint requires string"));
345 };
346 if !s.contains('@') || !s.contains('.') {
347 return Err(err(path, "string is not a valid email shape"));
348 }
349 }
350 Constraint::Url => {
351 let RowValue::String(s) = v else {
352 return Err(err(path, "Url constraint requires string"));
353 };
354 if !s.starts_with("http://") && !s.starts_with("https://") {
355 return Err(err(path, "string must be an http(s) URL"));
356 }
357 }
358 Constraint::NonEmpty => match v {
359 RowValue::String(s) if s.is_empty() => {
360 return Err(err(path, "string must be non-empty"));
361 }
362 RowValue::Bytes(b) if b.is_empty() => {
363 return Err(err(path, "bytes must be non-empty"));
364 }
365 RowValue::List(items) if items.is_empty() => {
366 return Err(err(path, "list must be non-empty"));
367 }
368 RowValue::String(_) | RowValue::Bytes(_) | RowValue::List(_) => {}
369 _ => return Err(err(path, "NonEmpty applies to string, bytes, or list")),
370 },
371 }
372 }
373
374 Ok(())
375}
376
377pub fn validate_top_level_row(
380 fields: &[FieldDef],
381 pk_name: &str,
382 row: &std::collections::BTreeMap<String, RowValue>,
383) -> Result<(), DbError> {
384 for k in row.keys() {
385 if !fields
386 .iter()
387 .any(|f| f.path.0.len() == 1 && f.path.0[0].as_ref() == k.as_str())
388 {
389 return Err(DbError::Validation(ValidationError {
390 path: vec![k.clone()],
391 message: "unknown field".into(),
392 }));
393 }
394 }
395
396 for def in fields {
397 let name = def.path.0[0].to_string();
398 if name == pk_name {
399 continue;
400 }
401 let absent_ok = allows_absent_root(&def.ty);
402 let none = RowValue::None;
403 let v: &RowValue = match row.get(&name) {
404 None if absent_ok => &none,
405 None => {
406 return Err(DbError::Validation(ValidationError {
407 path: vec![name.clone()],
408 message: "missing field".into(),
409 }));
410 }
411 Some(x) => x,
412 };
413 if matches!(v, RowValue::None) && !absent_ok {
414 return Err(DbError::Validation(ValidationError {
415 path: vec![name.clone()],
416 message: "unexpected null for required field".into(),
417 }));
418 }
419 let mut path = vec![name.clone()];
420 validate_value(&mut path, &def.ty, &def.constraints, v)?;
421 }
422 Ok(())
423}
424
425pub fn validate_multiseg_row(
427 fields: &[FieldDef],
428 pk_name: &str,
429 row: &std::collections::BTreeMap<String, RowValue>,
430) -> Result<(), DbError> {
431 crate::db::validate_unknown_fields_for_multiseg_schema(fields, pk_name, row)?;
432 for def in fields {
433 if def.path.0.len() == 1 && def.path.0[0] == pk_name {
434 continue;
435 }
436 let mut path: Vec<String> = def.path.0.iter().map(|s| s.as_ref().to_string()).collect();
437 let absent_ok = allows_absent_root(&def.ty);
438 let v = match crate::db::row_value_at_path(row, &def.path.0) {
439 Some(x) => x,
440 None if absent_ok => RowValue::None,
441 None => {
442 return Err(DbError::Schema(
443 crate::error::SchemaError::RowMissingField {
444 name: path.join("."),
445 },
446 ));
447 }
448 };
449 if matches!(v, RowValue::None) && !absent_ok {
450 return Err(DbError::Validation(ValidationError {
451 path: path.clone(),
452 message: "unexpected null for required field".into(),
453 }));
454 }
455 validate_value(&mut path, &def.ty, &def.constraints, &v)?;
456 }
457 Ok(())
458}
459
460#[cfg(test)]
461mod constraint_helper_cover_tests {
462 use super::*;
463 use crate::error::DbError;
464
465 #[test]
466 fn constrain_helpers_accept_in_range_values() {
467 let path = vec!["z".into()];
468 constrain_min_i64(&path, 5, 1).unwrap();
469 constrain_max_i64(&path, 1, 10).unwrap();
470 constrain_min_u64(&path, 5, 1).unwrap();
471 constrain_max_u64(&path, 1, 10).unwrap();
472 constrain_min_f64(&path, 5.0, 1.0).unwrap();
473 constrain_max_f64(&path, 1.0, 10.0).unwrap();
474 constrain_min_byte_len(&path, "abcde".len(), 1, "string").unwrap();
475 constrain_max_byte_len(&path, "ab".len(), 10, "string").unwrap();
476 constrain_min_byte_len(&path, vec![1u8, 2, 3].len(), 2, "bytes").unwrap();
477 constrain_max_byte_len(&path, vec![1u8].len(), 4, "bytes").unwrap();
478 }
479
480 #[test]
481 fn constrain_max_numeric_helpers_surface_above_max_messages() {
482 let path = vec!["x".into()];
483
484 let e = constrain_max_i64(&path, 3, 1).unwrap_err();
485 assert!(matches!(
486 &e,
487 DbError::Validation(v) if v.path == path && v.message.contains("above maximum"),
488 ));
489
490 let e = constrain_max_u64(&path, 5, 1).unwrap_err();
491 assert!(matches!(
492 &e,
493 DbError::Validation(v) if v.message.contains("above maximum"),
494 ));
495
496 let e = constrain_max_f64(&path, 3.5, 1.25).unwrap_err();
497 assert!(matches!(
498 &e,
499 DbError::Validation(v) if v.message.contains("above maximum"),
500 ));
501 }
502
503 #[test]
504 fn constrain_max_byte_len_string_and_bytes_surface_above_max() {
505 let path = vec!["f".into()];
506
507 let e = constrain_max_byte_len(&path, "ab".len(), 1, "string").unwrap_err();
508 assert!(matches!(
509 &e,
510 DbError::Validation(v) if v.message.contains("above maximum"),
511 ));
512
513 let e = constrain_max_byte_len(&path, vec![1u8, 2].len(), 1, "bytes").unwrap_err();
514 assert!(matches!(
515 &e,
516 DbError::Validation(v) if v.message.contains("above maximum"),
517 ));
518 }
519}