Skip to main content

reddb_server/storage/schema/
parametric.rs

1//! Parametric type validators — Phase 4 partial drop.
2//!
3//! The full Fase 3 plan calls for `DataType::Varchar { max_len }`
4//! and `DataType::Decimal { precision, scale }` baked into the
5//! `DataType` enum. That migration cascades into hundreds of
6//! pattern-match sites and requires cargo verification to land
7//! safely. Until that session runs, this module ships the
8//! **validators** as standalone functions so the rest of the
9//! codebase can use them today via `coerce::coerce_via_catalog`
10//! without touching `DataType`.
11//!
12//! Once the enum migration lands, these functions become the
13//! body of the cast-catalog entries for the parametric variants.
14//!
15//! ## Coverage
16//!
17//! - `validate_varchar(s, max_len)` — Postgres-strict (rejects
18//!   strings longer than `max_len`). Configurable via
19//!   `VarcharMode::Truncate` for SQL Server-style behavior.
20//! - `validate_decimal(value, precision, scale)` — verifies the
21//!   value fits in `precision` total digits with `scale` digits
22//!   after the decimal point.
23//! - `parse_varchar_modifier`, `parse_decimal_modifier` — pull
24//!   `(n)` / `(p, s)` out of the legacy `SqlTypeName` modifiers
25//!   so callers don't reinvent the parsing.
26
27use super::types::{SqlTypeName, TypeModifier, Value};
28
29/// VARCHAR length-check policy. Postgres rejects; SQL Server
30/// silently truncates. reddb defaults to Postgres-strict.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum VarcharMode {
33    Reject,
34    Truncate,
35}
36
37/// Errors raised by the parametric validators.
38#[derive(Debug, Clone)]
39pub enum ParametricError {
40    /// Input string exceeds VARCHAR's declared `max_len`.
41    VarcharOverflow { actual: usize, max: u32 },
42    /// Decimal exceeds declared total precision.
43    DecimalPrecisionOverflow { precision: u8, actual_digits: usize },
44    /// Decimal scale doesn't match the declared scale (rounded
45    /// would lose information).
46    DecimalScaleOverflow {
47        scale: u8,
48        actual_fraction_digits: usize,
49    },
50    /// Input is not a parsable decimal at all.
51    NotADecimal(String),
52    /// SqlTypeName modifier list doesn't match the expected
53    /// shape for VARCHAR(n) / DECIMAL(p,s).
54    BadModifier(String),
55}
56
57impl std::fmt::Display for ParametricError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::VarcharOverflow { actual, max } => {
61                write!(f, "string of length {actual} exceeds VARCHAR({max})")
62            }
63            Self::DecimalPrecisionOverflow {
64                precision,
65                actual_digits,
66            } => {
67                write!(
68                    f,
69                    "decimal with {actual_digits} digits exceeds DECIMAL precision {precision}"
70                )
71            }
72            Self::DecimalScaleOverflow {
73                scale,
74                actual_fraction_digits,
75            } => {
76                write!(
77                    f,
78                    "decimal with {actual_fraction_digits} fractional digits exceeds DECIMAL scale {scale}"
79                )
80            }
81            Self::NotADecimal(input) => write!(f, "`{input}` is not a valid decimal literal"),
82            Self::BadModifier(reason) => write!(f, "bad parametric modifier: {reason}"),
83        }
84    }
85}
86
87impl std::error::Error for ParametricError {}
88
89/// Validate a `Value::Text` against a declared VARCHAR length
90/// limit. Returns the value unchanged on success, or a coerced
91/// truncated copy when `mode == Truncate`.
92pub fn validate_varchar(
93    value: &Value,
94    max_len: u32,
95    mode: VarcharMode,
96) -> Result<Value, ParametricError> {
97    let s = match value {
98        Value::Text(s) => s,
99        // Non-text values are coerced to text via display first
100        // then re-checked. Caller can avoid the round-trip by
101        // pre-coercing.
102        other => {
103            return validate_varchar(&Value::text(other.display_string()), max_len, mode);
104        }
105    };
106    let len = s.chars().count();
107    if (len as u32) <= max_len {
108        return Ok(value.clone());
109    }
110    match mode {
111        VarcharMode::Reject => Err(ParametricError::VarcharOverflow {
112            actual: len,
113            max: max_len,
114        }),
115        VarcharMode::Truncate => {
116            let truncated: String = s.chars().take(max_len as usize).collect();
117            Ok(Value::text(truncated))
118        }
119    }
120}
121
122/// Validate a `Value::Decimal` against declared (precision, scale).
123/// `precision` is the maximum total digit count (both sides of the
124/// decimal point); `scale` is the maximum fractional digit count.
125///
126/// reddb's `Value::Decimal` stores a fixed-point i64 with implicit
127/// scale = 4 (4 digits of fraction). We round-trip via
128/// `display_string()` to count the actual digits, which is the
129/// most correct path until `DataType::Decimal { p, s }` lands and
130/// we can carry the scale on the value itself.
131pub fn validate_decimal(value: &Value, precision: u8, scale: u8) -> Result<Value, ParametricError> {
132    let s = value.display_string();
133    let trimmed = s.trim();
134    let body = trimmed.strip_prefix('-').unwrap_or(trimmed);
135    let (whole, frac) = match body.split_once('.') {
136        Some((w, f)) => (w, f),
137        None => (body, ""),
138    };
139    if whole.is_empty() && frac.is_empty() {
140        return Err(ParametricError::NotADecimal(s));
141    }
142    if !whole.bytes().all(|b| b.is_ascii_digit()) || !frac.bytes().all(|b| b.is_ascii_digit()) {
143        return Err(ParametricError::NotADecimal(s));
144    }
145    let total_digits = whole.len() + frac.len();
146    let frac_digits = frac.len();
147    if total_digits > precision as usize {
148        return Err(ParametricError::DecimalPrecisionOverflow {
149            precision,
150            actual_digits: total_digits,
151        });
152    }
153    if frac_digits > scale as usize {
154        return Err(ParametricError::DecimalScaleOverflow {
155            scale,
156            actual_fraction_digits: frac_digits,
157        });
158    }
159    Ok(value.clone())
160}
161
162/// Pull `(n)` out of `VARCHAR(n)`'s SqlTypeName modifiers. Returns
163/// `Err` if the modifier list is empty, has more than one element,
164/// or the single element isn't a `Number`.
165pub fn parse_varchar_modifier(sql_type: &SqlTypeName) -> Result<u32, ParametricError> {
166    if sql_type.modifiers.is_empty() {
167        // VARCHAR with no length is valid in Postgres (treated as
168        // unbounded text). Return a sentinel max so callers know
169        // not to enforce a length.
170        return Ok(u32::MAX);
171    }
172    if sql_type.modifiers.len() > 1 {
173        return Err(ParametricError::BadModifier(format!(
174            "VARCHAR expects 1 modifier, got {}",
175            sql_type.modifiers.len()
176        )));
177    }
178    match &sql_type.modifiers[0] {
179        TypeModifier::Number(n) => Ok(*n),
180        other => Err(ParametricError::BadModifier(format!(
181            "VARCHAR length must be a number, got {other:?}"
182        ))),
183    }
184}
185
186/// Pull `(p, s)` out of `DECIMAL(p, s)`'s SqlTypeName modifiers.
187/// Returns `(precision, scale)` on success.
188pub fn parse_decimal_modifier(sql_type: &SqlTypeName) -> Result<(u8, u8), ParametricError> {
189    let mods = &sql_type.modifiers;
190    if mods.is_empty() {
191        // DECIMAL with no params defaults to (38, 0) per SQL standard.
192        return Ok((38, 0));
193    }
194    if mods.len() > 2 {
195        return Err(ParametricError::BadModifier(format!(
196            "DECIMAL expects (p) or (p,s), got {} modifiers",
197            mods.len()
198        )));
199    }
200    let precision = match &mods[0] {
201        TypeModifier::Number(n) => u8::try_from(*n).map_err(|_| {
202            ParametricError::BadModifier(format!("DECIMAL precision {n} out of u8 range"))
203        })?,
204        other => {
205            return Err(ParametricError::BadModifier(format!(
206                "DECIMAL precision must be a number, got {other:?}"
207            )))
208        }
209    };
210    let scale = if let Some(s_mod) = mods.get(1) {
211        match s_mod {
212            TypeModifier::Number(n) => u8::try_from(*n).map_err(|_| {
213                ParametricError::BadModifier(format!("DECIMAL scale {n} out of u8 range"))
214            })?,
215            other => {
216                return Err(ParametricError::BadModifier(format!(
217                    "DECIMAL scale must be a number, got {other:?}"
218                )))
219            }
220        }
221    } else {
222        0
223    };
224    if scale > precision {
225        return Err(ParametricError::BadModifier(format!(
226            "DECIMAL scale {scale} cannot exceed precision {precision}"
227        )));
228    }
229    Ok((precision, scale))
230}