1use std::collections::HashMap;
6use std::sync::{Mutex, OnceLock};
7
8use regex::Regex;
9
10use crate::error::{DbError, ValidationError};
11use crate::file_format::MAX_REGEX_PATTERN_LEN;
12use crate::record::{RowValue, ScalarValue};
13use crate::schema::{Constraint, FieldDef, Type};
14
15fn err(path: &[String], msg: impl Into<String>) -> DbError {
16 DbError::Validation(ValidationError {
17 path: path.to_vec(),
18 message: msg.into(),
19 })
20}
21
22fn regex_cache() -> &'static Mutex<HashMap<String, Regex>> {
23 static CACHE: OnceLock<Mutex<HashMap<String, Regex>>> = OnceLock::new();
24 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
25}
26
27fn reject_risky_regex_pattern(pattern: &str) -> Result<(), DbError> {
28 if pattern.len() > MAX_REGEX_PATTERN_LEN {
29 return Err(DbError::Validation(ValidationError {
30 path: vec![],
31 message: format!(
32 "regex pattern length {} exceeds maximum {MAX_REGEX_PATTERN_LEN}",
33 pattern.len()
34 ),
35 }));
36 }
37 let mut depth = 0u32;
39 let mut prev_quant = false;
40 for ch in pattern.chars() {
41 match ch {
42 '(' => depth = depth.saturating_add(1),
43 ')' => depth = depth.saturating_sub(1),
44 '+' | '*' | '?' => {
45 if prev_quant {
46 return Err(DbError::Validation(ValidationError {
47 path: vec![],
48 message: "regex pattern contains nested quantifiers".into(),
49 }));
50 }
51 prev_quant = true;
52 }
53 _ => prev_quant = false,
54 }
55 }
56 if depth != 0 {
57 return Err(DbError::Validation(ValidationError {
58 path: vec![],
59 message: "regex pattern has unbalanced parentheses".into(),
60 }));
61 }
62 Ok(())
63}
64
65pub fn validate_constraints_at_registration(constraints: &[Constraint]) -> Result<(), DbError> {
67 for c in constraints {
68 if let Constraint::Regex(pattern) = c {
69 reject_risky_regex_pattern(pattern)?;
70 }
71 }
72 Ok(())
73}
74
75fn compiled_regex(pattern: &str, path: &[String]) -> Result<Regex, DbError> {
76 reject_risky_regex_pattern(pattern)?;
77 if let Ok(g) = regex_cache().lock() {
78 if let Some(re) = g.get(pattern) {
79 return Ok(re.clone());
80 }
81 }
82 let re = Regex::new(pattern).map_err(|e| {
83 DbError::Validation(ValidationError {
84 path: path.to_vec(),
85 message: format!("invalid regex in schema: {e}"),
86 })
87 })?;
88 if let Ok(mut g) = regex_cache().lock() {
89 if g.len() >= 256 {
90 g.clear();
91 }
92 g.entry(pattern.to_string()).or_insert_with(|| re.clone());
93 }
94 Ok(re)
95}
96
97pub fn ensure_pk_type_primitive(ty: &Type) -> Result<(), DbError> {
99 match ty {
100 Type::Bool
101 | Type::Int64
102 | Type::Uint64
103 | Type::Float64
104 | Type::String
105 | Type::Bytes
106 | Type::Uuid
107 | Type::Timestamp => Ok(()),
108 Type::Optional(_) | Type::List(_) | Type::Object(_) | Type::Enum(_) => {
109 Err(DbError::Validation(ValidationError {
110 path: vec![],
111 message:
112 "primary key field must use a primitive type (not optional/list/object/enum)"
113 .into(),
114 }))
115 }
116 }
117}
118
119pub fn ensure_pk_scalar_finite(pk: &ScalarValue) -> Result<(), DbError> {
121 if let ScalarValue::Float64(v) = pk {
122 if !v.is_finite() {
123 return Err(DbError::Validation(ValidationError {
124 path: vec![],
125 message: "primary key float must be finite (not NaN or infinity)".into(),
126 }));
127 }
128 }
129 Ok(())
130}
131
132pub fn allows_absent_root(ty: &Type) -> bool {
134 matches!(ty, Type::Optional(_))
135}
136
137pub fn validate_value(
139 path: &mut Vec<String>,
140 ty: &Type,
141 constraints: &[Constraint],
142 v: &RowValue,
143) -> Result<(), DbError> {
144 match ty {
145 Type::Optional(inner) => {
146 if matches!(v, RowValue::None) {
147 return Ok(());
148 }
149 validate_value(path, inner, &[], v)?;
150 apply_constraints(path, ty, constraints, v)
151 }
152 Type::Bool => {
153 let RowValue::Bool(_) = v else {
154 return Err(err(path, "expected bool"));
155 };
156 apply_constraints(path, ty, constraints, v)
157 }
158 Type::Int64 => {
159 let RowValue::Int64(_) = v else {
160 return Err(err(path, "expected int64"));
161 };
162 apply_constraints(path, ty, constraints, v)
163 }
164 Type::Uint64 => {
165 let RowValue::Uint64(_) = v else {
166 return Err(err(path, "expected uint64"));
167 };
168 apply_constraints(path, ty, constraints, v)
169 }
170 Type::Float64 => {
171 let RowValue::Float64(_) = v else {
172 return Err(err(path, "expected float64"));
173 };
174 apply_constraints(path, ty, constraints, v)
175 }
176 Type::String => {
177 let RowValue::String(_) = v else {
178 return Err(err(path, "expected string"));
179 };
180 apply_constraints(path, ty, constraints, v)
181 }
182 Type::Bytes => {
183 let RowValue::Bytes(_) = v else {
184 return Err(err(path, "expected bytes"));
185 };
186 apply_constraints(path, ty, constraints, v)
187 }
188 Type::Uuid => {
189 let RowValue::Uuid(_) = v else {
190 return Err(err(path, "expected uuid"));
191 };
192 apply_constraints(path, ty, constraints, v)
193 }
194 Type::Timestamp => {
195 let RowValue::Timestamp(_) = v else {
196 return Err(err(path, "expected timestamp"));
197 };
198 apply_constraints(path, ty, constraints, v)
199 }
200 Type::List(inner) => {
201 let RowValue::List(items) = v else {
202 return Err(err(path, "expected list"));
203 };
204 for (i, item) in items.iter().enumerate() {
205 path.push(format!("{i}"));
206 validate_value(path, inner, &[], item)?;
207 path.pop();
208 }
209 apply_constraints(path, ty, constraints, v)
210 }
211 Type::Object(fields) => {
212 let RowValue::Object(m) = v else {
213 return Err(err(path, "expected object"));
214 };
215 for sub in fields {
216 let key = sub.path.0[0].to_string();
217 let absent_ok = allows_absent_root(&sub.ty);
218 let none = RowValue::None;
219 let child: &RowValue = match m.get(&key) {
220 None if absent_ok => &none,
221 None => {
222 path.push(key.clone());
223 return Err(err(path, "missing object field"));
224 }
225 Some(x) => x,
226 };
227 path.push(key);
228 validate_value(path, &sub.ty, &sub.constraints, child)?;
229 path.pop();
230 }
231 for k in m.keys() {
232 if !fields.iter().any(|f| f.path.0[0].as_ref() == k.as_str()) {
233 path.push(k.clone());
234 return Err(err(path, "unknown field in object"));
235 }
236 }
237 apply_constraints(path, ty, constraints, v)
238 }
239 Type::Enum(variants) => {
240 let RowValue::String(s) = v else {
241 return Err(err(path, "expected string (enum)"));
242 };
243 if !variants.iter().any(|x| x == s) {
244 return Err(err(
245 path,
246 format!("enum value must be one of {:?}", variants),
247 ));
248 }
249 apply_constraints(path, ty, constraints, v)
250 }
251 }
252}
253
254fn must_int64(path: &[String], v: &RowValue, requirement: &'static str) -> Result<i64, DbError> {
255 let RowValue::Int64(n) = v else {
256 return Err(err(path, requirement));
257 };
258 Ok(*n)
259}
260
261fn must_uint64(path: &[String], v: &RowValue, requirement: &'static str) -> Result<u64, DbError> {
262 let RowValue::Uint64(n) = v else {
263 return Err(err(path, requirement));
264 };
265 Ok(*n)
266}
267
268fn must_f64(path: &[String], v: &RowValue, requirement: &'static str) -> Result<f64, DbError> {
269 let RowValue::Float64(n) = v else {
270 return Err(err(path, requirement));
271 };
272 Ok(*n)
273}
274
275fn constrain_min_i64(path: &[String], n: i64, min: i64) -> Result<(), DbError> {
276 if n < min {
277 Err(err(path, format!("value {n} is below minimum {min}")))
278 } else {
279 Ok(())
280 }
281}
282
283fn constrain_max_i64(path: &[String], n: i64, max: i64) -> Result<(), DbError> {
284 if n > max {
285 Err(err(path, format!("value {n} is above maximum {max}")))
286 } else {
287 Ok(())
288 }
289}
290
291fn constrain_min_u64(path: &[String], n: u64, min: u64) -> Result<(), DbError> {
292 if n < min {
293 Err(err(path, format!("value {n} is below minimum {min}")))
294 } else {
295 Ok(())
296 }
297}
298
299fn constrain_max_u64(path: &[String], n: u64, max: u64) -> Result<(), DbError> {
300 if n > max {
301 Err(err(path, format!("value {n} is above maximum {max}")))
302 } else {
303 Ok(())
304 }
305}
306
307fn constrain_min_f64(path: &[String], n: f64, min: f64) -> Result<(), DbError> {
308 if n < min {
309 Err(err(path, format!("value {n} is below minimum {min}")))
310 } else {
311 Ok(())
312 }
313}
314
315fn constrain_max_f64(path: &[String], n: f64, max: f64) -> Result<(), DbError> {
316 if n > max {
317 Err(err(path, format!("value {n} is above maximum {max}")))
318 } else {
319 Ok(())
320 }
321}
322
323fn constrain_min_byte_len(
324 path: &[String],
325 len: usize,
326 min: u64,
327 kind: &str,
328) -> Result<(), DbError> {
329 if (len as u64) < min {
330 Err(err(
331 path,
332 format!("{kind} length {len} is below minimum {min}"),
333 ))
334 } else {
335 Ok(())
336 }
337}
338
339fn constrain_max_byte_len(
340 path: &[String],
341 len: usize,
342 max: u64,
343 kind: &str,
344) -> Result<(), DbError> {
345 if (len as u64) > max {
346 Err(err(
347 path,
348 format!("{kind} length {len} is above maximum {max}"),
349 ))
350 } else {
351 Ok(())
352 }
353}
354
355fn apply_constraints(
356 path: &[String],
357 _ty: &Type,
358 constraints: &[Constraint],
359 v: &RowValue,
360) -> Result<(), DbError> {
361 for c in constraints {
362 match c {
363 Constraint::MinI64(min) => {
364 let n = must_int64(path, v, "MinI64 constraint requires int64")?;
365 constrain_min_i64(path, n, *min)?;
366 }
367 Constraint::MaxI64(max) => {
368 let n = must_int64(path, v, "MaxI64 constraint requires int64")?;
369 constrain_max_i64(path, n, *max)?;
370 }
371 Constraint::MinU64(min) => {
372 let n = must_uint64(path, v, "MinU64 constraint requires uint64")?;
373 constrain_min_u64(path, n, *min)?;
374 }
375 Constraint::MaxU64(max) => {
376 let n = must_uint64(path, v, "MaxU64 constraint requires uint64")?;
377 constrain_max_u64(path, n, *max)?;
378 }
379 Constraint::MinF64(min) => {
380 let n = must_f64(path, v, "MinF64 constraint requires float64")?;
381 constrain_min_f64(path, n, *min)?;
382 }
383 Constraint::MaxF64(max) => {
384 let n = must_f64(path, v, "MaxF64 constraint requires float64")?;
385 constrain_max_f64(path, n, *max)?;
386 }
387 Constraint::MinLength(min) => match v {
388 RowValue::String(s) => constrain_min_byte_len(path, s.len(), *min, "string")?,
389 RowValue::Bytes(b) => constrain_min_byte_len(path, b.len(), *min, "bytes")?,
390 RowValue::List(items) => constrain_min_byte_len(path, items.len(), *min, "list")?,
391 _ => return Err(err(path, "MinLength applies to string, bytes, or list")),
392 },
393 Constraint::MaxLength(max) => match v {
394 RowValue::String(s) => constrain_max_byte_len(path, s.len(), *max, "string")?,
395 RowValue::Bytes(b) => constrain_max_byte_len(path, b.len(), *max, "bytes")?,
396 RowValue::List(items) => constrain_max_byte_len(path, items.len(), *max, "list")?,
397 _ => return Err(err(path, "MaxLength applies to string, bytes, or list")),
398 },
399 Constraint::Regex(pattern) => {
400 let RowValue::String(s) = v else {
401 return Err(err(path, "Regex constraint requires string"));
402 };
403 let re = compiled_regex(pattern, path)?;
404 if !re.is_match(s) {
405 return Err(err(path, "string does not match regex"));
406 }
407 }
408 Constraint::Email => {
409 let RowValue::String(s) = v else {
410 return Err(err(path, "Email constraint requires string"));
411 };
412 if !s.contains('@') || !s.contains('.') {
413 return Err(err(path, "string is not a valid email shape"));
414 }
415 }
416 Constraint::Url => {
417 let RowValue::String(s) = v else {
418 return Err(err(path, "Url constraint requires string"));
419 };
420 if !s.starts_with("http://") && !s.starts_with("https://") {
421 return Err(err(path, "string must be an http(s) URL"));
422 }
423 }
424 Constraint::NonEmpty => match v {
425 RowValue::String(s) if s.is_empty() => {
426 return Err(err(path, "string must be non-empty"));
427 }
428 RowValue::Bytes(b) if b.is_empty() => {
429 return Err(err(path, "bytes must be non-empty"));
430 }
431 RowValue::List(items) if items.is_empty() => {
432 return Err(err(path, "list must be non-empty"));
433 }
434 RowValue::String(_) | RowValue::Bytes(_) | RowValue::List(_) => {}
435 _ => return Err(err(path, "NonEmpty applies to string, bytes, or list")),
436 },
437 }
438 }
439
440 Ok(())
441}
442
443pub fn validate_top_level_row(
446 fields: &[FieldDef],
447 pk_name: &str,
448 row: &std::collections::BTreeMap<String, RowValue>,
449) -> Result<(), DbError> {
450 for k in row.keys() {
451 if !fields
452 .iter()
453 .any(|f| f.path.0.len() == 1 && f.path.0[0].as_ref() == k.as_str())
454 {
455 return Err(DbError::Validation(ValidationError {
456 path: vec![k.clone()],
457 message: "unknown field".into(),
458 }));
459 }
460 }
461
462 for def in fields {
463 let name = def.path.0[0].to_string();
464 if name == pk_name {
465 continue;
466 }
467 let absent_ok = allows_absent_root(&def.ty);
468 let none = RowValue::None;
469 let v: &RowValue = match row.get(&name) {
470 None if absent_ok => &none,
471 None => {
472 return Err(DbError::Validation(ValidationError {
473 path: vec![name.clone()],
474 message: "missing field".into(),
475 }));
476 }
477 Some(x) => x,
478 };
479 if matches!(v, RowValue::None) && !absent_ok {
480 return Err(DbError::Validation(ValidationError {
481 path: vec![name.clone()],
482 message: "unexpected null for required field".into(),
483 }));
484 }
485 let mut path = vec![name.clone()];
486 validate_value(&mut path, &def.ty, &def.constraints, v)?;
487 }
488 Ok(())
489}
490
491pub fn validate_multiseg_row(
493 fields: &[FieldDef],
494 pk_name: &str,
495 row: &std::collections::BTreeMap<String, RowValue>,
496) -> Result<(), DbError> {
497 crate::db::validate_unknown_fields_for_multiseg_schema(fields, pk_name, row)?;
498 for def in fields {
499 if def.path.0.len() == 1 && def.path.0[0] == pk_name {
500 continue;
501 }
502 let mut path: Vec<String> = def.path.0.iter().map(|s| s.as_ref().to_string()).collect();
503 let absent_ok = allows_absent_root(&def.ty);
504 let v = match crate::db::row_value_at_path(row, &def.path.0) {
505 Some(x) => x,
506 None if absent_ok => RowValue::None,
507 None => {
508 return Err(DbError::Schema(
509 crate::error::SchemaError::RowMissingField {
510 name: path.join("."),
511 },
512 ));
513 }
514 };
515 if matches!(v, RowValue::None) && !absent_ok {
516 return Err(DbError::Validation(ValidationError {
517 path: path.clone(),
518 message: "unexpected null for required field".into(),
519 }));
520 }
521 validate_value(&mut path, &def.ty, &def.constraints, &v)?;
522 }
523 Ok(())
524}
525
526#[cfg(test)]
527mod constraint_helper_cover_tests {
528 use super::*;
529 use crate::error::DbError;
530
531 #[test]
532 fn constrain_helpers_accept_in_range_values() {
533 let path = vec!["z".into()];
534 constrain_min_i64(&path, 5, 1).unwrap();
535 constrain_max_i64(&path, 1, 10).unwrap();
536 constrain_min_u64(&path, 5, 1).unwrap();
537 constrain_max_u64(&path, 1, 10).unwrap();
538 constrain_min_f64(&path, 5.0, 1.0).unwrap();
539 constrain_max_f64(&path, 1.0, 10.0).unwrap();
540 constrain_min_byte_len(&path, "abcde".len(), 1, "string").unwrap();
541 constrain_max_byte_len(&path, "ab".len(), 10, "string").unwrap();
542 constrain_min_byte_len(&path, vec![1u8, 2, 3].len(), 2, "bytes").unwrap();
543 constrain_max_byte_len(&path, vec![1u8].len(), 4, "bytes").unwrap();
544 }
545
546 #[test]
547 fn constrain_max_numeric_helpers_surface_above_max_messages() {
548 let path = vec!["x".into()];
549
550 let e = constrain_max_i64(&path, 3, 1).unwrap_err();
551 assert!(matches!(
552 &e,
553 DbError::Validation(v) if v.path == path && v.message.contains("above maximum"),
554 ));
555
556 let e = constrain_max_u64(&path, 5, 1).unwrap_err();
557 assert!(matches!(
558 &e,
559 DbError::Validation(v) if v.message.contains("above maximum"),
560 ));
561
562 let e = constrain_max_f64(&path, 3.5, 1.25).unwrap_err();
563 assert!(matches!(
564 &e,
565 DbError::Validation(v) if v.message.contains("above maximum"),
566 ));
567 }
568
569 #[test]
570 fn constrain_max_byte_len_string_and_bytes_surface_above_max() {
571 let path = vec!["f".into()];
572
573 let e = constrain_max_byte_len(&path, "ab".len(), 1, "string").unwrap_err();
574 assert!(matches!(
575 &e,
576 DbError::Validation(v) if v.message.contains("above maximum"),
577 ));
578
579 let e = constrain_max_byte_len(&path, vec![1u8, 2].len(), 1, "bytes").unwrap_err();
580 assert!(matches!(
581 &e,
582 DbError::Validation(v) if v.message.contains("above maximum"),
583 ));
584 }
585}