1use super::{RivetType, TimeUnit};
17use crate::error::Result;
18
19pub fn parse_type_str(s: &str) -> Result<RivetType> {
26 let normalised = s.trim().to_ascii_lowercase();
27 let normalised = normalised.as_str();
28
29 if let Some(inner) = normalised
30 .strip_prefix("decimal(")
31 .and_then(|r| r.strip_suffix(')'))
32 {
33 return parse_decimal_params(s, inner);
34 }
35 if let Some(inner) = normalised
36 .strip_prefix("numeric(")
37 .and_then(|r| r.strip_suffix(')'))
38 {
39 return parse_decimal_params(s, inner);
40 }
41
42 match normalised {
43 "bool" | "boolean" => Ok(RivetType::Bool),
44 "int2" | "smallint" | "int16" => Ok(RivetType::Int16),
45 "int4" | "int" | "integer" | "int32" => Ok(RivetType::Int32),
46 "int8" | "bigint" | "int64" => Ok(RivetType::Int64),
47 "float4" | "real" | "float32" => Ok(RivetType::Float32),
48 "float8" | "double" | "double precision" | "float64" => Ok(RivetType::Float64),
49 "text" | "varchar" | "string" | "char" | "bpchar" | "name" => Ok(RivetType::String),
50 "binary" | "bytea" | "blob" | "varbinary" => Ok(RivetType::Binary),
51 "date" => Ok(RivetType::Date),
52 "json" | "jsonb" => Ok(RivetType::Json),
53 "uuid" => Ok(RivetType::Uuid),
54 "timestamp" | "timestamp without time zone" => Ok(RivetType::Timestamp {
55 unit: TimeUnit::Microsecond,
56 timezone: None,
57 }),
58 "timestamp_tz" | "timestamptz" | "timestamp with time zone" | "timestamp_utc" => {
59 Ok(RivetType::Timestamp {
60 unit: TimeUnit::Microsecond,
61 timezone: Some("UTC".into()),
62 })
63 }
64 _ => anyhow::bail!(
65 "column override: unrecognised type '{}'. \
66 Supported: bool, int2/int4/int8, float4/float8, decimal(p,s), \
67 date, timestamp, timestamp_tz, text, binary, json, uuid",
68 s
69 ),
70 }
71}
72
73fn parse_decimal_params(original: &str, inner: &str) -> Result<RivetType> {
74 let mut parts = inner.splitn(2, ',');
75 let precision_str = parts.next().ok_or_else(|| {
76 anyhow::anyhow!(
77 "column override: expected decimal(precision,scale) in '{}'",
78 original
79 )
80 })?;
81 let scale_str = parts.next().ok_or_else(|| {
82 anyhow::anyhow!(
83 "column override: missing scale in '{}' — use decimal(precision,scale)",
84 original
85 )
86 })?;
87
88 let precision: u8 = precision_str.trim().parse().map_err(|_| {
89 anyhow::anyhow!(
90 "column override: precision '{}' is not a valid integer (0–76) in '{}'",
91 precision_str.trim(),
92 original
93 )
94 })?;
95 let scale: i8 = scale_str.trim().parse().map_err(|_| {
96 anyhow::anyhow!(
97 "column override: scale '{}' is not a valid integer (-128..127) in '{}'",
98 scale_str.trim(),
99 original
100 )
101 })?;
102
103 if precision == 0 || precision > 76 {
104 anyhow::bail!(
105 "column override: precision {} is out of range (1..=76) in '{}'",
106 precision,
107 original
108 );
109 }
110 if scale > precision as i8 {
111 anyhow::bail!(
112 "column override: scale {} exceeds precision {} in '{}'",
113 scale,
114 precision,
115 original
116 );
117 }
118
119 Ok(RivetType::Decimal { precision, scale })
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn decimal_parses_precision_and_scale() {
128 assert_eq!(
129 parse_type_str("decimal(18,2)").unwrap(),
130 RivetType::Decimal {
131 precision: 18,
132 scale: 2
133 }
134 );
135 assert_eq!(
136 parse_type_str("decimal(38,9)").unwrap(),
137 RivetType::Decimal {
138 precision: 38,
139 scale: 9
140 }
141 );
142 }
143
144 #[test]
145 fn decimal_with_spaces_around_comma() {
146 assert_eq!(
147 parse_type_str("decimal(18, 2)").unwrap(),
148 RivetType::Decimal {
149 precision: 18,
150 scale: 2
151 }
152 );
153 }
154
155 #[test]
156 fn decimal_allows_negative_scale() {
157 assert_eq!(
158 parse_type_str("decimal(5,-2)").unwrap(),
159 RivetType::Decimal {
160 precision: 5,
161 scale: -2
162 }
163 );
164 }
165
166 #[test]
167 fn numeric_is_alias_for_decimal() {
168 assert_eq!(
169 parse_type_str("numeric(18,2)").unwrap(),
170 parse_type_str("decimal(18,2)").unwrap()
171 );
172 }
173
174 #[test]
175 fn scale_exceeding_precision_is_rejected() {
176 assert!(parse_type_str("decimal(2,5)").is_err());
177 }
178
179 #[test]
180 fn precision_out_of_range_is_rejected() {
181 assert!(parse_type_str("decimal(0,0)").is_err());
182 assert!(parse_type_str("decimal(77,0)").is_err());
183 }
184
185 #[test]
186 fn decimal_without_params_is_rejected() {
187 assert!(parse_type_str("decimal").is_err());
189 assert!(parse_type_str("numeric").is_err());
190 }
191
192 #[test]
193 fn timestamp_variants() {
194 assert_eq!(
195 parse_type_str("timestamp").unwrap(),
196 RivetType::Timestamp {
197 unit: TimeUnit::Microsecond,
198 timezone: None
199 }
200 );
201 assert_eq!(
202 parse_type_str("timestamp_tz").unwrap(),
203 RivetType::Timestamp {
204 unit: TimeUnit::Microsecond,
205 timezone: Some("UTC".into())
206 }
207 );
208 assert_eq!(
209 parse_type_str("timestamptz").unwrap(),
210 RivetType::Timestamp {
211 unit: TimeUnit::Microsecond,
212 timezone: Some("UTC".into())
213 }
214 );
215 }
216
217 #[test]
218 fn primitive_types() {
219 assert_eq!(parse_type_str("bool").unwrap(), RivetType::Bool);
220 assert_eq!(parse_type_str("bigint").unwrap(), RivetType::Int64);
221 assert_eq!(parse_type_str("json").unwrap(), RivetType::Json);
222 assert_eq!(parse_type_str("uuid").unwrap(), RivetType::Uuid);
223 }
224
225 #[test]
226 fn case_insensitive() {
227 assert_eq!(
228 parse_type_str("DECIMAL(18,2)").unwrap(),
229 RivetType::Decimal {
230 precision: 18,
231 scale: 2
232 }
233 );
234 assert_eq!(parse_type_str("BOOL").unwrap(), RivetType::Bool);
235 assert_eq!(
236 parse_type_str("TIMESTAMP_TZ").unwrap(),
237 RivetType::Timestamp {
238 unit: TimeUnit::Microsecond,
239 timezone: Some("UTC".into())
240 }
241 );
242 }
243
244 #[test]
245 fn unrecognised_type_returns_actionable_error() {
246 let err = parse_type_str("geometry").unwrap_err();
247 let msg = err.to_string();
248 assert!(
249 msg.contains("geometry"),
250 "error should name the bad type: {msg}"
251 );
252 assert!(
253 msg.contains("decimal(p,s)"),
254 "error should list alternatives: {msg}"
255 );
256 }
257}