Skip to main content

rivet/types/
override_type.rs

1//! Parser for column type override strings (roadmap §8).
2//!
3//! Users write short type strings in `rivet.yaml`:
4//! ```yaml
5//! exports:
6//!   - name: payments
7//!     columns:
8//!       amount: decimal(18,2)
9//!       created_at: timestamp_tz
10//!       payload: json
11//! ```
12//! [`parse_type_str`] converts each string to the canonical [`RivetType`],
13//! failing at config-load time (before any export runs) if the string is
14//! invalid.
15
16use super::{RivetType, TimeUnit};
17use crate::error::Result;
18
19/// Parse a user-supplied column type string into a [`RivetType`].
20///
21/// Case-insensitive. Whitespace around delimiters is trimmed.
22/// Returns an error with an actionable message if the string is not
23/// a recognised type — the error is surfaced before the export starts
24/// so the user can fix their `rivet.yaml`.
25pub fn parse_type_str(s: &str) -> Result<RivetType> {
26    let normalised = s.trim().to_ascii_lowercase();
27    let normalised = normalised.as_str();
28
29    if let Some(inner) = normalised
30        .strip_prefix("decimal(")
31        .and_then(|r| r.strip_suffix(')'))
32    {
33        return parse_decimal_params(s, inner);
34    }
35    if let Some(inner) = normalised
36        .strip_prefix("numeric(")
37        .and_then(|r| r.strip_suffix(')'))
38    {
39        return parse_decimal_params(s, inner);
40    }
41
42    match normalised {
43        "bool" | "boolean" => Ok(RivetType::Bool),
44        "int2" | "smallint" | "int16" => Ok(RivetType::Int16),
45        "int4" | "int" | "integer" | "int32" => Ok(RivetType::Int32),
46        "int8" | "bigint" | "int64" => Ok(RivetType::Int64),
47        "float4" | "real" | "float32" => Ok(RivetType::Float32),
48        "float8" | "double" | "double precision" | "float64" => Ok(RivetType::Float64),
49        "text" | "varchar" | "string" | "char" | "bpchar" | "name" => Ok(RivetType::String),
50        "binary" | "bytea" | "blob" | "varbinary" => Ok(RivetType::Binary),
51        "date" => Ok(RivetType::Date),
52        "json" | "jsonb" => Ok(RivetType::Json),
53        "uuid" => Ok(RivetType::Uuid),
54        "timestamp" | "timestamp without time zone" => Ok(RivetType::Timestamp {
55            unit: TimeUnit::Microsecond,
56            timezone: None,
57        }),
58        "timestamp_tz" | "timestamptz" | "timestamp with time zone" | "timestamp_utc" => {
59            Ok(RivetType::Timestamp {
60                unit: TimeUnit::Microsecond,
61                timezone: Some("UTC".into()),
62            })
63        }
64        _ => anyhow::bail!(
65            "column override: unrecognised type '{}'. \
66             Supported: bool, int2/int4/int8, float4/float8, decimal(p,s), \
67             date, timestamp, timestamp_tz, text, binary, json, uuid",
68            s
69        ),
70    }
71}
72
73fn parse_decimal_params(original: &str, inner: &str) -> Result<RivetType> {
74    let mut parts = inner.splitn(2, ',');
75    let precision_str = parts.next().ok_or_else(|| {
76        anyhow::anyhow!(
77            "column override: expected decimal(precision,scale) in '{}'",
78            original
79        )
80    })?;
81    let scale_str = parts.next().ok_or_else(|| {
82        anyhow::anyhow!(
83            "column override: missing scale in '{}' — use decimal(precision,scale)",
84            original
85        )
86    })?;
87
88    let precision: u8 = precision_str.trim().parse().map_err(|_| {
89        anyhow::anyhow!(
90            "column override: precision '{}' is not a valid integer (0–76) in '{}'",
91            precision_str.trim(),
92            original
93        )
94    })?;
95    let scale: i8 = scale_str.trim().parse().map_err(|_| {
96        anyhow::anyhow!(
97            "column override: scale '{}' is not a valid integer (-128..127) in '{}'",
98            scale_str.trim(),
99            original
100        )
101    })?;
102
103    if precision == 0 || precision > 76 {
104        anyhow::bail!(
105            "column override: precision {} is out of range (1..=76) in '{}'",
106            precision,
107            original
108        );
109    }
110    if scale > precision as i8 {
111        anyhow::bail!(
112            "column override: scale {} exceeds precision {} in '{}'",
113            scale,
114            precision,
115            original
116        );
117    }
118
119    Ok(RivetType::Decimal { precision, scale })
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn decimal_parses_precision_and_scale() {
128        assert_eq!(
129            parse_type_str("decimal(18,2)").unwrap(),
130            RivetType::Decimal {
131                precision: 18,
132                scale: 2
133            }
134        );
135        assert_eq!(
136            parse_type_str("decimal(38,9)").unwrap(),
137            RivetType::Decimal {
138                precision: 38,
139                scale: 9
140            }
141        );
142    }
143
144    #[test]
145    fn decimal_with_spaces_around_comma() {
146        assert_eq!(
147            parse_type_str("decimal(18, 2)").unwrap(),
148            RivetType::Decimal {
149                precision: 18,
150                scale: 2
151            }
152        );
153    }
154
155    #[test]
156    fn decimal_allows_negative_scale() {
157        assert_eq!(
158            parse_type_str("decimal(5,-2)").unwrap(),
159            RivetType::Decimal {
160                precision: 5,
161                scale: -2
162            }
163        );
164    }
165
166    #[test]
167    fn numeric_is_alias_for_decimal() {
168        assert_eq!(
169            parse_type_str("numeric(18,2)").unwrap(),
170            parse_type_str("decimal(18,2)").unwrap()
171        );
172    }
173
174    #[test]
175    fn scale_exceeding_precision_is_rejected() {
176        assert!(parse_type_str("decimal(2,5)").is_err());
177    }
178
179    #[test]
180    fn precision_out_of_range_is_rejected() {
181        assert!(parse_type_str("decimal(0,0)").is_err());
182        assert!(parse_type_str("decimal(77,0)").is_err());
183    }
184
185    #[test]
186    fn decimal_without_params_is_rejected() {
187        // bare "decimal" without (p,s) must fail — unbounded numeric is not safe
188        assert!(parse_type_str("decimal").is_err());
189        assert!(parse_type_str("numeric").is_err());
190    }
191
192    #[test]
193    fn timestamp_variants() {
194        assert_eq!(
195            parse_type_str("timestamp").unwrap(),
196            RivetType::Timestamp {
197                unit: TimeUnit::Microsecond,
198                timezone: None
199            }
200        );
201        assert_eq!(
202            parse_type_str("timestamp_tz").unwrap(),
203            RivetType::Timestamp {
204                unit: TimeUnit::Microsecond,
205                timezone: Some("UTC".into())
206            }
207        );
208        assert_eq!(
209            parse_type_str("timestamptz").unwrap(),
210            RivetType::Timestamp {
211                unit: TimeUnit::Microsecond,
212                timezone: Some("UTC".into())
213            }
214        );
215    }
216
217    #[test]
218    fn primitive_types() {
219        assert_eq!(parse_type_str("bool").unwrap(), RivetType::Bool);
220        assert_eq!(parse_type_str("bigint").unwrap(), RivetType::Int64);
221        assert_eq!(parse_type_str("json").unwrap(), RivetType::Json);
222        assert_eq!(parse_type_str("uuid").unwrap(), RivetType::Uuid);
223    }
224
225    #[test]
226    fn case_insensitive() {
227        assert_eq!(
228            parse_type_str("DECIMAL(18,2)").unwrap(),
229            RivetType::Decimal {
230                precision: 18,
231                scale: 2
232            }
233        );
234        assert_eq!(parse_type_str("BOOL").unwrap(), RivetType::Bool);
235        assert_eq!(
236            parse_type_str("TIMESTAMP_TZ").unwrap(),
237            RivetType::Timestamp {
238                unit: TimeUnit::Microsecond,
239                timezone: Some("UTC".into())
240            }
241        );
242    }
243
244    #[test]
245    fn unrecognised_type_returns_actionable_error() {
246        let err = parse_type_str("geometry").unwrap_err();
247        let msg = err.to_string();
248        assert!(
249            msg.contains("geometry"),
250            "error should name the bad type: {msg}"
251        );
252        assert!(
253            msg.contains("decimal(p,s)"),
254            "error should list alternatives: {msg}"
255        );
256    }
257}