Skip to main content

rivet/types/
override_type.rs

1//! Parser for column type override strings (roadmap §8).
2//!
3//! Users write short type strings in `rivet.yaml`:
4//! ```yaml
5//! exports:
6//!   - name: payments
7//!     columns:
8//!       amount: decimal(18,2)
9//!       created_at: timestamp_tz
10//!       payload: json
11//! ```
12//! [`parse_type_str`] converts each string to the canonical [`RivetType`],
13//! failing at config-load time (before any export runs) if the string is
14//! invalid.
15
16use super::{RivetType, TimeUnit};
17use crate::error::Result;
18
19/// Parse a user-supplied column type string into a [`RivetType`].
20///
21/// Case-insensitive. Whitespace around delimiters is trimmed.
22/// Returns an error with an actionable message if the string is not
23/// a recognised type — the error is surfaced before the export starts
24/// so the user can fix their `rivet.yaml`.
25pub fn parse_type_str(s: &str) -> Result<RivetType> {
26    let normalised = s.trim().to_ascii_lowercase();
27    let normalised = normalised.as_str();
28
29    if let Some(inner) = normalised
30        .strip_prefix("decimal(")
31        .and_then(|r| r.strip_suffix(')'))
32    {
33        return parse_decimal_params(s, inner);
34    }
35    if let Some(inner) = normalised
36        .strip_prefix("numeric(")
37        .and_then(|r| r.strip_suffix(')'))
38    {
39        return parse_decimal_params(s, inner);
40    }
41
42    match normalised {
43        "bool" | "boolean" => Ok(RivetType::Bool),
44        "int2" | "smallint" | "int16" => Ok(RivetType::Int16),
45        "int4" | "int" | "integer" | "int32" => Ok(RivetType::Int32),
46        "int8" | "bigint" | "int64" => Ok(RivetType::Int64),
47        "float4" | "real" | "float32" => Ok(RivetType::Float32),
48        "float8" | "double" | "double precision" | "float64" => Ok(RivetType::Float64),
49        "text" | "varchar" | "string" | "char" | "bpchar" | "name" => Ok(RivetType::String),
50        "binary" | "bytea" | "blob" | "varbinary" => Ok(RivetType::Binary),
51        "date" => Ok(RivetType::Date),
52        "json" | "jsonb" => Ok(RivetType::Json),
53        "uuid" => Ok(RivetType::Uuid),
54        "timestamp" | "timestamp without time zone" => Ok(RivetType::Timestamp {
55            unit: TimeUnit::Microsecond,
56            timezone: None,
57        }),
58        // Nanosecond opt-in: preserves a source's sub-microsecond fractional
59        // seconds (e.g. SQL Server `datetime2(7)`'s 100 ns tick) that the default
60        // microsecond mapping truncates. Arrow nanosecond timestamps are i64 ns,
61        // so the representable range is 1677-09-21 .. 2262-04-11 — values outside
62        // it export as NULL. Default to `timestamp` (microsecond, full range)
63        // unless you specifically need the extra precision and your data is in
64        // range. See docs/type-mapping.md.
65        "timestamp_ns" => Ok(RivetType::Timestamp {
66            unit: TimeUnit::Nanosecond,
67            timezone: None,
68        }),
69        "timestamp_tz" | "timestamptz" | "timestamp with time zone" | "timestamp_utc" => {
70            Ok(RivetType::Timestamp {
71                unit: TimeUnit::Microsecond,
72                timezone: Some("UTC".into()),
73            })
74        }
75        "timestamp_tz_ns" | "timestamptz_ns" => Ok(RivetType::Timestamp {
76            unit: TimeUnit::Nanosecond,
77            timezone: Some("UTC".into()),
78        }),
79        _ => anyhow::bail!(
80            "column override: unrecognised type '{}'. \
81             Supported: bool, int2/int4/int8, float4/float8, decimal(p,s), \
82             date, timestamp, timestamp_ns, timestamp_tz, timestamp_tz_ns, \
83             text, binary, json, uuid",
84            s
85        ),
86    }
87}
88
89fn parse_decimal_params(original: &str, inner: &str) -> Result<RivetType> {
90    let mut parts = inner.splitn(2, ',');
91    let precision_str = parts.next().ok_or_else(|| {
92        anyhow::anyhow!(
93            "column override: expected decimal(precision,scale) in '{}'",
94            original
95        )
96    })?;
97    let scale_str = parts.next().ok_or_else(|| {
98        anyhow::anyhow!(
99            "column override: missing scale in '{}' — use decimal(precision,scale)",
100            original
101        )
102    })?;
103
104    let precision: u8 = precision_str.trim().parse().map_err(|_| {
105        anyhow::anyhow!(
106            "column override: precision '{}' is not a valid integer (0–76) in '{}'",
107            precision_str.trim(),
108            original
109        )
110    })?;
111    let scale: i8 = scale_str.trim().parse().map_err(|_| {
112        anyhow::anyhow!(
113            "column override: scale '{}' is not a valid integer (-128..127) in '{}'",
114            scale_str.trim(),
115            original
116        )
117    })?;
118
119    if precision == 0 || precision > 76 {
120        anyhow::bail!(
121            "column override: precision {} is out of range (1..=76) in '{}'",
122            precision,
123            original
124        );
125    }
126    if scale > precision as i8 {
127        anyhow::bail!(
128            "column override: scale {} exceeds precision {} in '{}'",
129            scale,
130            precision,
131            original
132        );
133    }
134
135    Ok(RivetType::Decimal { precision, scale })
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn decimal_parses_precision_and_scale() {
144        assert_eq!(
145            parse_type_str("decimal(18,2)").unwrap(),
146            RivetType::Decimal {
147                precision: 18,
148                scale: 2
149            }
150        );
151        assert_eq!(
152            parse_type_str("decimal(38,9)").unwrap(),
153            RivetType::Decimal {
154                precision: 38,
155                scale: 9
156            }
157        );
158    }
159
160    #[test]
161    fn decimal_with_spaces_around_comma() {
162        assert_eq!(
163            parse_type_str("decimal(18, 2)").unwrap(),
164            RivetType::Decimal {
165                precision: 18,
166                scale: 2
167            }
168        );
169    }
170
171    #[test]
172    fn decimal_allows_negative_scale() {
173        assert_eq!(
174            parse_type_str("decimal(5,-2)").unwrap(),
175            RivetType::Decimal {
176                precision: 5,
177                scale: -2
178            }
179        );
180    }
181
182    #[test]
183    fn numeric_is_alias_for_decimal() {
184        assert_eq!(
185            parse_type_str("numeric(18,2)").unwrap(),
186            parse_type_str("decimal(18,2)").unwrap()
187        );
188    }
189
190    #[test]
191    fn scale_exceeding_precision_is_rejected() {
192        assert!(parse_type_str("decimal(2,5)").is_err());
193    }
194
195    #[test]
196    fn precision_out_of_range_is_rejected() {
197        assert!(parse_type_str("decimal(0,0)").is_err());
198        assert!(parse_type_str("decimal(77,0)").is_err());
199    }
200
201    #[test]
202    fn decimal_without_params_is_rejected() {
203        // bare "decimal" without (p,s) must fail — unbounded numeric is not safe
204        assert!(parse_type_str("decimal").is_err());
205        assert!(parse_type_str("numeric").is_err());
206    }
207
208    #[test]
209    fn timestamp_variants() {
210        assert_eq!(
211            parse_type_str("timestamp").unwrap(),
212            RivetType::Timestamp {
213                unit: TimeUnit::Microsecond,
214                timezone: None
215            }
216        );
217        assert_eq!(
218            parse_type_str("timestamp_tz").unwrap(),
219            RivetType::Timestamp {
220                unit: TimeUnit::Microsecond,
221                timezone: Some("UTC".into())
222            }
223        );
224        assert_eq!(
225            parse_type_str("timestamptz").unwrap(),
226            RivetType::Timestamp {
227                unit: TimeUnit::Microsecond,
228                timezone: Some("UTC".into())
229            }
230        );
231    }
232
233    #[test]
234    fn timestamp_nanosecond_opt_in_variants() {
235        // The ns opt-in for sub-microsecond sources (e.g. SQL Server
236        // datetime2(7)); default `timestamp` stays microsecond.
237        assert_eq!(
238            parse_type_str("timestamp_ns").unwrap(),
239            RivetType::Timestamp {
240                unit: TimeUnit::Nanosecond,
241                timezone: None
242            }
243        );
244        assert_eq!(
245            parse_type_str("timestamp_tz_ns").unwrap(),
246            RivetType::Timestamp {
247                unit: TimeUnit::Nanosecond,
248                timezone: Some("UTC".into())
249            }
250        );
251        assert_eq!(
252            parse_type_str("timestamptz_ns").unwrap(),
253            RivetType::Timestamp {
254                unit: TimeUnit::Nanosecond,
255                timezone: Some("UTC".into())
256            }
257        );
258    }
259
260    #[test]
261    fn primitive_types() {
262        assert_eq!(parse_type_str("bool").unwrap(), RivetType::Bool);
263        assert_eq!(parse_type_str("bigint").unwrap(), RivetType::Int64);
264        assert_eq!(parse_type_str("json").unwrap(), RivetType::Json);
265        assert_eq!(parse_type_str("uuid").unwrap(), RivetType::Uuid);
266    }
267
268    #[test]
269    fn case_insensitive() {
270        assert_eq!(
271            parse_type_str("DECIMAL(18,2)").unwrap(),
272            RivetType::Decimal {
273                precision: 18,
274                scale: 2
275            }
276        );
277        assert_eq!(parse_type_str("BOOL").unwrap(), RivetType::Bool);
278        assert_eq!(
279            parse_type_str("TIMESTAMP_TZ").unwrap(),
280            RivetType::Timestamp {
281                unit: TimeUnit::Microsecond,
282                timezone: Some("UTC".into())
283            }
284        );
285    }
286
287    #[test]
288    fn unrecognised_type_returns_actionable_error() {
289        let err = parse_type_str("geometry").unwrap_err();
290        let msg = err.to_string();
291        assert!(
292            msg.contains("geometry"),
293            "error should name the bad type: {msg}"
294        );
295        assert!(
296            msg.contains("decimal(p,s)"),
297            "error should list alternatives: {msg}"
298        );
299    }
300}