1use super::{RivetType, TimeUnit};
17use crate::error::Result;
18
19pub fn parse_type_str(s: &str) -> Result<RivetType> {
26 let normalised = s.trim().to_ascii_lowercase();
27 let normalised = normalised.as_str();
28
29 if let Some(inner) = normalised
30 .strip_prefix("decimal(")
31 .and_then(|r| r.strip_suffix(')'))
32 {
33 return parse_decimal_params(s, inner);
34 }
35 if let Some(inner) = normalised
36 .strip_prefix("numeric(")
37 .and_then(|r| r.strip_suffix(')'))
38 {
39 return parse_decimal_params(s, inner);
40 }
41
42 match normalised {
43 "bool" | "boolean" => Ok(RivetType::Bool),
44 "int2" | "smallint" | "int16" => Ok(RivetType::Int16),
45 "int4" | "int" | "integer" | "int32" => Ok(RivetType::Int32),
46 "int8" | "bigint" | "int64" => Ok(RivetType::Int64),
47 "float4" | "real" | "float32" => Ok(RivetType::Float32),
48 "float8" | "double" | "double precision" | "float64" => Ok(RivetType::Float64),
49 "text" | "varchar" | "string" | "char" | "bpchar" | "name" => Ok(RivetType::String),
50 "binary" | "bytea" | "blob" | "varbinary" => Ok(RivetType::Binary),
51 "date" => Ok(RivetType::Date),
52 "json" | "jsonb" => Ok(RivetType::Json),
53 "uuid" => Ok(RivetType::Uuid),
54 "timestamp" | "timestamp without time zone" => Ok(RivetType::Timestamp {
55 unit: TimeUnit::Microsecond,
56 timezone: None,
57 }),
58 "timestamp_ns" => Ok(RivetType::Timestamp {
66 unit: TimeUnit::Nanosecond,
67 timezone: None,
68 }),
69 "timestamp_tz" | "timestamptz" | "timestamp with time zone" | "timestamp_utc" => {
70 Ok(RivetType::Timestamp {
71 unit: TimeUnit::Microsecond,
72 timezone: Some("UTC".into()),
73 })
74 }
75 "timestamp_tz_ns" | "timestamptz_ns" => Ok(RivetType::Timestamp {
76 unit: TimeUnit::Nanosecond,
77 timezone: Some("UTC".into()),
78 }),
79 _ => anyhow::bail!(
80 "column override: unrecognised type '{}'. \
81 Supported: bool, int2/int4/int8, float4/float8, decimal(p,s), \
82 date, timestamp, timestamp_ns, timestamp_tz, timestamp_tz_ns, \
83 text, binary, json, uuid",
84 s
85 ),
86 }
87}
88
89fn parse_decimal_params(original: &str, inner: &str) -> Result<RivetType> {
90 let mut parts = inner.splitn(2, ',');
91 let precision_str = parts.next().ok_or_else(|| {
92 anyhow::anyhow!(
93 "column override: expected decimal(precision,scale) in '{}'",
94 original
95 )
96 })?;
97 let scale_str = parts.next().ok_or_else(|| {
98 anyhow::anyhow!(
99 "column override: missing scale in '{}' — use decimal(precision,scale)",
100 original
101 )
102 })?;
103
104 let precision: u8 = precision_str.trim().parse().map_err(|_| {
105 anyhow::anyhow!(
106 "column override: precision '{}' is not a valid integer (0–76) in '{}'",
107 precision_str.trim(),
108 original
109 )
110 })?;
111 let scale: i8 = scale_str.trim().parse().map_err(|_| {
112 anyhow::anyhow!(
113 "column override: scale '{}' is not a valid integer (-128..127) in '{}'",
114 scale_str.trim(),
115 original
116 )
117 })?;
118
119 if precision == 0 || precision > 76 {
120 anyhow::bail!(
121 "column override: precision {} is out of range (1..=76) in '{}'",
122 precision,
123 original
124 );
125 }
126 if scale > precision as i8 {
127 anyhow::bail!(
128 "column override: scale {} exceeds precision {} in '{}'",
129 scale,
130 precision,
131 original
132 );
133 }
134
135 Ok(RivetType::Decimal { precision, scale })
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn decimal_parses_precision_and_scale() {
144 assert_eq!(
145 parse_type_str("decimal(18,2)").unwrap(),
146 RivetType::Decimal {
147 precision: 18,
148 scale: 2
149 }
150 );
151 assert_eq!(
152 parse_type_str("decimal(38,9)").unwrap(),
153 RivetType::Decimal {
154 precision: 38,
155 scale: 9
156 }
157 );
158 }
159
160 #[test]
161 fn decimal_with_spaces_around_comma() {
162 assert_eq!(
163 parse_type_str("decimal(18, 2)").unwrap(),
164 RivetType::Decimal {
165 precision: 18,
166 scale: 2
167 }
168 );
169 }
170
171 #[test]
172 fn decimal_allows_negative_scale() {
173 assert_eq!(
174 parse_type_str("decimal(5,-2)").unwrap(),
175 RivetType::Decimal {
176 precision: 5,
177 scale: -2
178 }
179 );
180 }
181
182 #[test]
183 fn numeric_is_alias_for_decimal() {
184 assert_eq!(
185 parse_type_str("numeric(18,2)").unwrap(),
186 parse_type_str("decimal(18,2)").unwrap()
187 );
188 }
189
190 #[test]
191 fn scale_exceeding_precision_is_rejected() {
192 assert!(parse_type_str("decimal(2,5)").is_err());
193 }
194
195 #[test]
196 fn precision_out_of_range_is_rejected() {
197 assert!(parse_type_str("decimal(0,0)").is_err());
198 assert!(parse_type_str("decimal(77,0)").is_err());
199 }
200
201 #[test]
202 fn decimal_without_params_is_rejected() {
203 assert!(parse_type_str("decimal").is_err());
205 assert!(parse_type_str("numeric").is_err());
206 }
207
208 #[test]
209 fn timestamp_variants() {
210 assert_eq!(
211 parse_type_str("timestamp").unwrap(),
212 RivetType::Timestamp {
213 unit: TimeUnit::Microsecond,
214 timezone: None
215 }
216 );
217 assert_eq!(
218 parse_type_str("timestamp_tz").unwrap(),
219 RivetType::Timestamp {
220 unit: TimeUnit::Microsecond,
221 timezone: Some("UTC".into())
222 }
223 );
224 assert_eq!(
225 parse_type_str("timestamptz").unwrap(),
226 RivetType::Timestamp {
227 unit: TimeUnit::Microsecond,
228 timezone: Some("UTC".into())
229 }
230 );
231 }
232
233 #[test]
234 fn timestamp_nanosecond_opt_in_variants() {
235 assert_eq!(
238 parse_type_str("timestamp_ns").unwrap(),
239 RivetType::Timestamp {
240 unit: TimeUnit::Nanosecond,
241 timezone: None
242 }
243 );
244 assert_eq!(
245 parse_type_str("timestamp_tz_ns").unwrap(),
246 RivetType::Timestamp {
247 unit: TimeUnit::Nanosecond,
248 timezone: Some("UTC".into())
249 }
250 );
251 assert_eq!(
252 parse_type_str("timestamptz_ns").unwrap(),
253 RivetType::Timestamp {
254 unit: TimeUnit::Nanosecond,
255 timezone: Some("UTC".into())
256 }
257 );
258 }
259
260 #[test]
261 fn primitive_types() {
262 assert_eq!(parse_type_str("bool").unwrap(), RivetType::Bool);
263 assert_eq!(parse_type_str("bigint").unwrap(), RivetType::Int64);
264 assert_eq!(parse_type_str("json").unwrap(), RivetType::Json);
265 assert_eq!(parse_type_str("uuid").unwrap(), RivetType::Uuid);
266 }
267
268 #[test]
269 fn case_insensitive() {
270 assert_eq!(
271 parse_type_str("DECIMAL(18,2)").unwrap(),
272 RivetType::Decimal {
273 precision: 18,
274 scale: 2
275 }
276 );
277 assert_eq!(parse_type_str("BOOL").unwrap(), RivetType::Bool);
278 assert_eq!(
279 parse_type_str("TIMESTAMP_TZ").unwrap(),
280 RivetType::Timestamp {
281 unit: TimeUnit::Microsecond,
282 timezone: Some("UTC".into())
283 }
284 );
285 }
286
287 #[test]
288 fn unrecognised_type_returns_actionable_error() {
289 let err = parse_type_str("geometry").unwrap_err();
290 let msg = err.to_string();
291 assert!(
292 msg.contains("geometry"),
293 "error should name the bad type: {msg}"
294 );
295 assert!(
296 msg.contains("decimal(p,s)"),
297 "error should list alternatives: {msg}"
298 );
299 }
300}