1use std::fmt;
4use std::str::FromStr;
5
6use serde::{Deserialize, Serialize};
7
8use crate::value::Value;
9
10#[derive(
12 Debug,
13 Clone,
14 PartialEq,
15 Eq,
16 Hash,
17 Serialize,
18 Deserialize,
19 zerompk::ToMessagePack,
20 zerompk::FromMessagePack,
21)]
22#[serde(tag = "type", content = "params")]
23pub enum ColumnType {
24 Int64,
25 Float64,
26 String,
27 Bool,
28 Bytes,
29 Timestamp,
30 Decimal,
31 Geometry,
32 Vector(u32),
34 Uuid,
35 Json,
38 Ulid,
40 Duration,
42 Array,
44 Set,
46 Regex,
48 Range,
50 Record,
52}
53
54impl ColumnType {
55 pub fn fixed_size(&self) -> Option<usize> {
57 match self {
58 Self::Int64 | Self::Float64 | Self::Timestamp | Self::Duration => Some(8),
59 Self::Bool => Some(1),
60 Self::Decimal | Self::Uuid | Self::Ulid => Some(16),
61 Self::Vector(dim) => Some(*dim as usize * 4),
62 Self::String
63 | Self::Bytes
64 | Self::Geometry
65 | Self::Json
66 | Self::Array
67 | Self::Set
68 | Self::Regex
69 | Self::Range
70 | Self::Record => None,
71 }
72 }
73
74 pub fn is_variable_length(&self) -> bool {
76 self.fixed_size().is_none()
77 }
78
79 pub fn accepts(&self, value: &Value) -> bool {
85 matches!(
86 (self, value),
87 (Self::Int64, Value::Integer(_))
88 | (Self::Float64, Value::Float(_) | Value::Integer(_))
89 | (Self::String, Value::String(_))
90 | (Self::Bool, Value::Bool(_))
91 | (Self::Bytes, Value::Bytes(_))
92 | (
93 Self::Timestamp,
94 Value::DateTime(_) | Value::Integer(_) | Value::String(_)
95 )
96 | (
97 Self::Decimal,
98 Value::Decimal(_) | Value::String(_) | Value::Float(_) | Value::Integer(_)
99 )
100 | (Self::Geometry, Value::Geometry(_) | Value::String(_))
101 | (Self::Vector(_), Value::Array(_) | Value::Bytes(_))
102 | (Self::Uuid, Value::Uuid(_) | Value::String(_))
103 | (Self::Ulid, Value::Ulid(_) | Value::String(_))
104 | (
105 Self::Duration,
106 Value::Duration(_) | Value::Integer(_) | Value::String(_)
107 )
108 | (Self::Array, Value::Array(_))
109 | (Self::Set, Value::Set(_) | Value::Array(_))
110 | (Self::Regex, Value::Regex(_) | Value::String(_))
111 | (Self::Range, Value::Range { .. })
112 | (Self::Record, Value::Record { .. } | Value::String(_))
113 | (Self::Json, _)
114 | (_, Value::Null)
115 )
116 }
117}
118
119impl fmt::Display for ColumnType {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 match self {
122 Self::Int64 => f.write_str("BIGINT"),
123 Self::Float64 => f.write_str("FLOAT64"),
124 Self::String => f.write_str("TEXT"),
125 Self::Bool => f.write_str("BOOL"),
126 Self::Bytes => f.write_str("BYTES"),
127 Self::Timestamp => f.write_str("TIMESTAMP"),
128 Self::Decimal => f.write_str("DECIMAL"),
129 Self::Geometry => f.write_str("GEOMETRY"),
130 Self::Vector(dim) => write!(f, "VECTOR({dim})"),
131 Self::Uuid => f.write_str("UUID"),
132 Self::Json => f.write_str("JSON"),
133 Self::Ulid => f.write_str("ULID"),
134 Self::Duration => f.write_str("DURATION"),
135 Self::Array => f.write_str("ARRAY"),
136 Self::Set => f.write_str("SET"),
137 Self::Regex => f.write_str("REGEX"),
138 Self::Range => f.write_str("RANGE"),
139 Self::Record => f.write_str("RECORD"),
140 }
141 }
142}
143
144impl FromStr for ColumnType {
145 type Err = ColumnTypeParseError;
146
147 fn from_str(s: &str) -> Result<Self, Self::Err> {
148 let upper = s.trim().to_uppercase();
149
150 if upper.starts_with("VECTOR") {
152 let inner = upper
153 .trim_start_matches("VECTOR")
154 .trim()
155 .trim_start_matches('(')
156 .trim_end_matches(')')
157 .trim();
158 if inner.is_empty() {
159 return Err(ColumnTypeParseError::InvalidVectorDim("empty".into()));
160 }
161 let dim: u32 = inner
162 .parse()
163 .map_err(|_| ColumnTypeParseError::InvalidVectorDim(inner.into()))?;
164 if dim == 0 {
165 return Err(ColumnTypeParseError::InvalidVectorDim("0".into()));
166 }
167 return Ok(Self::Vector(dim));
168 }
169
170 match upper.as_str() {
171 "BIGINT" | "INT64" | "INTEGER" | "INT" => Ok(Self::Int64),
172 "FLOAT64" | "DOUBLE" | "REAL" | "FLOAT" => Ok(Self::Float64),
173 "TEXT" | "STRING" | "VARCHAR" => Ok(Self::String),
174 "BOOL" | "BOOLEAN" => Ok(Self::Bool),
175 "BYTES" | "BYTEA" | "BLOB" => Ok(Self::Bytes),
176 "TIMESTAMP" | "TIMESTAMPTZ" => Ok(Self::Timestamp),
177 "DECIMAL" | "NUMERIC" => Ok(Self::Decimal),
178 "GEOMETRY" => Ok(Self::Geometry),
179 "UUID" => Ok(Self::Uuid),
180 "JSON" | "JSONB" => Ok(Self::Json),
181 "ULID" => Ok(Self::Ulid),
182 "DURATION" => Ok(Self::Duration),
183 "ARRAY" => Ok(Self::Array),
184 "SET" => Ok(Self::Set),
185 "REGEX" => Ok(Self::Regex),
186 "RANGE" => Ok(Self::Range),
187 "RECORD" => Ok(Self::Record),
188 "DATETIME" => Err(ColumnTypeParseError::UseTimestamp),
189 other => Err(ColumnTypeParseError::Unknown(other.to_string())),
190 }
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
196pub enum ColumnTypeParseError {
197 #[error("unknown column type: '{0}'")]
198 Unknown(String),
199 #[error("'DATETIME' is not a valid type — use 'TIMESTAMP' instead")]
200 UseTimestamp,
201 #[error("invalid VECTOR dimension: '{0}' (must be a positive integer)")]
202 InvalidVectorDim(String),
203}
204
205#[derive(
210 Debug,
211 Clone,
212 PartialEq,
213 Eq,
214 Hash,
215 Serialize,
216 Deserialize,
217 zerompk::ToMessagePack,
218 zerompk::FromMessagePack,
219)]
220#[msgpack(c_enum)]
221#[repr(u8)]
222pub enum ColumnModifier {
223 TimeKey = 0,
226 SpatialIndex = 1,
229}
230
231#[derive(
233 Debug,
234 Clone,
235 PartialEq,
236 Eq,
237 Serialize,
238 Deserialize,
239 zerompk::ToMessagePack,
240 zerompk::FromMessagePack,
241)]
242pub struct ColumnDef {
243 pub name: String,
244 pub column_type: ColumnType,
245 pub nullable: bool,
246 pub default: Option<String>,
247 pub primary_key: bool,
248 #[serde(default, skip_serializing_if = "Vec::is_empty")]
250 pub modifiers: Vec<ColumnModifier>,
251 #[serde(default, skip_serializing_if = "Option::is_none")]
254 pub generated_expr: Option<String>,
255 #[serde(default, skip_serializing_if = "Vec::is_empty")]
257 pub generated_deps: Vec<String>,
258}
259
260impl ColumnDef {
261 pub fn required(name: impl Into<String>, column_type: ColumnType) -> Self {
262 Self {
263 name: name.into(),
264 column_type,
265 nullable: false,
266 default: None,
267 primary_key: false,
268 modifiers: Vec::new(),
269 generated_expr: None,
270 generated_deps: Vec::new(),
271 }
272 }
273
274 pub fn nullable(name: impl Into<String>, column_type: ColumnType) -> Self {
275 Self {
276 name: name.into(),
277 column_type,
278 nullable: true,
279 default: None,
280 primary_key: false,
281 modifiers: Vec::new(),
282 generated_expr: None,
283 generated_deps: Vec::new(),
284 }
285 }
286
287 pub fn with_primary_key(mut self) -> Self {
288 self.primary_key = true;
289 self.nullable = false;
290 self
291 }
292
293 pub fn is_time_key(&self) -> bool {
295 self.modifiers.contains(&ColumnModifier::TimeKey)
296 }
297
298 pub fn is_spatial_index(&self) -> bool {
300 self.modifiers.contains(&ColumnModifier::SpatialIndex)
301 }
302
303 pub fn with_default(mut self, expr: impl Into<String>) -> Self {
304 self.default = Some(expr.into());
305 self
306 }
307}
308
309impl fmt::Display for ColumnDef {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 write!(f, "{} {}", self.name, self.column_type)?;
312 if !self.nullable {
313 write!(f, " NOT NULL")?;
314 }
315 if self.primary_key {
316 write!(f, " PRIMARY KEY")?;
317 }
318 if let Some(ref d) = self.default {
319 write!(f, " DEFAULT {d}")?;
320 }
321 Ok(())
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn parse_canonical() {
331 assert_eq!("BIGINT".parse::<ColumnType>().unwrap(), ColumnType::Int64);
332 assert_eq!(
333 "FLOAT64".parse::<ColumnType>().unwrap(),
334 ColumnType::Float64
335 );
336 assert_eq!("TEXT".parse::<ColumnType>().unwrap(), ColumnType::String);
337 assert_eq!("BOOL".parse::<ColumnType>().unwrap(), ColumnType::Bool);
338 assert_eq!(
339 "TIMESTAMP".parse::<ColumnType>().unwrap(),
340 ColumnType::Timestamp
341 );
342 assert_eq!(
343 "GEOMETRY".parse::<ColumnType>().unwrap(),
344 ColumnType::Geometry
345 );
346 assert_eq!("UUID".parse::<ColumnType>().unwrap(), ColumnType::Uuid);
347 }
348
349 #[test]
350 fn parse_vector() {
351 assert_eq!(
352 "VECTOR(768)".parse::<ColumnType>().unwrap(),
353 ColumnType::Vector(768)
354 );
355 assert!("VECTOR(0)".parse::<ColumnType>().is_err());
356 }
357
358 #[test]
359 fn display_roundtrip() {
360 for ct in [
361 ColumnType::Int64,
362 ColumnType::Float64,
363 ColumnType::String,
364 ColumnType::Vector(768),
365 ] {
366 let s = ct.to_string();
367 let parsed: ColumnType = s.parse().unwrap();
368 assert_eq!(parsed, ct);
369 }
370 }
371
372 #[test]
373 fn accepts_native_values() {
374 assert!(ColumnType::Int64.accepts(&Value::Integer(42)));
375 assert!(ColumnType::Float64.accepts(&Value::Float(42.0)));
376 assert!(ColumnType::Float64.accepts(&Value::Integer(42))); assert!(ColumnType::String.accepts(&Value::String("x".into())));
378 assert!(ColumnType::Bool.accepts(&Value::Bool(true)));
379 assert!(ColumnType::Bytes.accepts(&Value::Bytes(vec![1])));
380 assert!(
381 ColumnType::Uuid.accepts(&Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()))
382 );
383 assert!(ColumnType::Decimal.accepts(&Value::Decimal(rust_decimal::Decimal::ZERO)));
384
385 assert!(ColumnType::Int64.accepts(&Value::Null));
387
388 assert!(!ColumnType::Int64.accepts(&Value::String("x".into())));
390 assert!(!ColumnType::Bool.accepts(&Value::Integer(1)));
391 }
392
393 #[test]
394 fn accepts_coercion_sources() {
395 assert!(ColumnType::Timestamp.accepts(&Value::String("2024-01-01".into())));
397 assert!(ColumnType::Timestamp.accepts(&Value::Integer(1_700_000_000)));
398 assert!(ColumnType::Uuid.accepts(&Value::String(
399 "550e8400-e29b-41d4-a716-446655440000".into()
400 )));
401 assert!(ColumnType::Decimal.accepts(&Value::String("99.95".into())));
402 assert!(ColumnType::Decimal.accepts(&Value::Float(99.95)));
403 assert!(ColumnType::Geometry.accepts(&Value::String("POINT(0 0)".into())));
404 }
405
406 #[test]
407 fn column_def_display() {
408 let col = ColumnDef::required("id", ColumnType::Int64).with_primary_key();
409 assert_eq!(col.to_string(), "id BIGINT NOT NULL PRIMARY KEY");
410 }
411}