1use std::fmt;
4use std::str::FromStr;
5
6use serde::{Deserialize, Serialize};
7
8use crate::value::Value;
9
10#[derive(
12 Debug,
13 Clone,
14 PartialEq,
15 Eq,
16 Hash,
17 Serialize,
18 Deserialize,
19 zerompk::ToMessagePack,
20 zerompk::FromMessagePack,
21)]
22#[serde(tag = "type", content = "params")]
23pub enum ColumnType {
24 Int64,
25 Float64,
26 String,
27 Bool,
28 Bytes,
29 Timestamp,
30 Decimal,
31 Geometry,
32 Vector(u32),
34 Uuid,
35 Json,
38 Ulid,
40 Duration,
42 Array,
44 Set,
46 Regex,
48 Range,
50 Record,
52}
53
54impl ColumnType {
55 pub fn fixed_size(&self) -> Option<usize> {
57 match self {
58 Self::Int64 | Self::Float64 | Self::Timestamp | Self::Duration => Some(8),
59 Self::Bool => Some(1),
60 Self::Decimal | Self::Uuid | Self::Ulid => Some(16),
61 Self::Vector(dim) => Some(*dim as usize * 4),
62 Self::String
63 | Self::Bytes
64 | Self::Geometry
65 | Self::Json
66 | Self::Array
67 | Self::Set
68 | Self::Regex
69 | Self::Range
70 | Self::Record => None,
71 }
72 }
73
74 pub fn is_variable_length(&self) -> bool {
76 self.fixed_size().is_none()
77 }
78
79 pub fn accepts(&self, value: &Value) -> bool {
85 matches!(
86 (self, value),
87 (Self::Int64, Value::Integer(_))
88 | (Self::Float64, Value::Float(_) | Value::Integer(_))
89 | (Self::String, Value::String(_))
90 | (Self::Bool, Value::Bool(_))
91 | (Self::Bytes, Value::Bytes(_))
92 | (
93 Self::Timestamp,
94 Value::DateTime(_) | Value::Integer(_) | Value::String(_)
95 )
96 | (
97 Self::Decimal,
98 Value::Decimal(_) | Value::String(_) | Value::Float(_) | Value::Integer(_)
99 )
100 | (Self::Geometry, Value::Geometry(_) | Value::String(_))
101 | (Self::Vector(_), Value::Array(_) | Value::Bytes(_))
102 | (Self::Uuid, Value::Uuid(_) | Value::String(_))
103 | (Self::Ulid, Value::Ulid(_) | Value::String(_))
104 | (
105 Self::Duration,
106 Value::Duration(_) | Value::Integer(_) | Value::String(_)
107 )
108 | (Self::Array, Value::Array(_))
109 | (Self::Set, Value::Set(_) | Value::Array(_))
110 | (Self::Regex, Value::Regex(_) | Value::String(_))
111 | (Self::Range, Value::Range { .. })
112 | (Self::Record, Value::Record { .. } | Value::String(_))
113 | (Self::Json, _)
114 | (_, Value::Null)
115 )
116 }
117}
118
119impl fmt::Display for ColumnType {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 match self {
122 Self::Int64 => f.write_str("BIGINT"),
123 Self::Float64 => f.write_str("FLOAT64"),
124 Self::String => f.write_str("TEXT"),
125 Self::Bool => f.write_str("BOOL"),
126 Self::Bytes => f.write_str("BYTES"),
127 Self::Timestamp => f.write_str("TIMESTAMP"),
128 Self::Decimal => f.write_str("DECIMAL"),
129 Self::Geometry => f.write_str("GEOMETRY"),
130 Self::Vector(dim) => write!(f, "VECTOR({dim})"),
131 Self::Uuid => f.write_str("UUID"),
132 Self::Json => f.write_str("JSON"),
133 Self::Ulid => f.write_str("ULID"),
134 Self::Duration => f.write_str("DURATION"),
135 Self::Array => f.write_str("ARRAY"),
136 Self::Set => f.write_str("SET"),
137 Self::Regex => f.write_str("REGEX"),
138 Self::Range => f.write_str("RANGE"),
139 Self::Record => f.write_str("RECORD"),
140 }
141 }
142}
143
144impl FromStr for ColumnType {
145 type Err = ColumnTypeParseError;
146
147 fn from_str(s: &str) -> Result<Self, Self::Err> {
148 let upper = s.trim().to_uppercase();
149
150 if upper.starts_with("VECTOR") {
152 let inner = upper
153 .trim_start_matches("VECTOR")
154 .trim()
155 .trim_start_matches('(')
156 .trim_end_matches(')')
157 .trim();
158 if inner.is_empty() {
159 return Err(ColumnTypeParseError::InvalidVectorDim("empty".into()));
160 }
161 let dim: u32 = inner
162 .parse()
163 .map_err(|_| ColumnTypeParseError::InvalidVectorDim(inner.into()))?;
164 if dim == 0 {
165 return Err(ColumnTypeParseError::InvalidVectorDim("0".into()));
166 }
167 return Ok(Self::Vector(dim));
168 }
169
170 match upper.as_str() {
171 "BIGINT" | "INT64" | "INTEGER" | "INT" => Ok(Self::Int64),
172 "FLOAT64" | "DOUBLE" | "REAL" | "FLOAT" => Ok(Self::Float64),
173 "TEXT" | "STRING" | "VARCHAR" => Ok(Self::String),
174 "BOOL" | "BOOLEAN" => Ok(Self::Bool),
175 "BYTES" | "BYTEA" | "BLOB" => Ok(Self::Bytes),
176 "TIMESTAMP" | "TIMESTAMPTZ" => Ok(Self::Timestamp),
177 "DECIMAL" | "NUMERIC" => Ok(Self::Decimal),
178 "GEOMETRY" => Ok(Self::Geometry),
179 "UUID" => Ok(Self::Uuid),
180 "JSON" | "JSONB" => Ok(Self::Json),
181 "ULID" => Ok(Self::Ulid),
182 "DURATION" => Ok(Self::Duration),
183 "ARRAY" => Ok(Self::Array),
184 "SET" => Ok(Self::Set),
185 "REGEX" => Ok(Self::Regex),
186 "RANGE" => Ok(Self::Range),
187 "RECORD" => Ok(Self::Record),
188 "DATETIME" => Err(ColumnTypeParseError::UseTimestamp),
189 other => Err(ColumnTypeParseError::Unknown(other.to_string())),
190 }
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
196pub enum ColumnTypeParseError {
197 #[error("unknown column type: '{0}'")]
198 Unknown(String),
199 #[error("'DATETIME' is not a valid type — use 'TIMESTAMP' instead")]
200 UseTimestamp,
201 #[error("invalid VECTOR dimension: '{0}' (must be a positive integer)")]
202 InvalidVectorDim(String),
203}
204
205#[derive(
210 Debug,
211 Clone,
212 PartialEq,
213 Eq,
214 Hash,
215 Serialize,
216 Deserialize,
217 zerompk::ToMessagePack,
218 zerompk::FromMessagePack,
219)]
220#[msgpack(c_enum)]
221#[repr(u8)]
222pub enum ColumnModifier {
223 TimeKey = 0,
226 SpatialIndex = 1,
229}
230
231#[derive(
233 Debug,
234 Clone,
235 PartialEq,
236 Eq,
237 Serialize,
238 Deserialize,
239 zerompk::ToMessagePack,
240 zerompk::FromMessagePack,
241)]
242pub struct ColumnDef {
243 pub name: String,
244 pub column_type: ColumnType,
245 pub nullable: bool,
246 pub default: Option<String>,
247 pub primary_key: bool,
248 #[serde(default, skip_serializing_if = "Vec::is_empty")]
250 pub modifiers: Vec<ColumnModifier>,
251 #[serde(default, skip_serializing_if = "Option::is_none")]
254 pub generated_expr: Option<String>,
255 #[serde(default, skip_serializing_if = "Vec::is_empty")]
257 pub generated_deps: Vec<String>,
258 #[serde(default = "default_added_at_version")]
263 pub added_at_version: u16,
264}
265
266fn default_added_at_version() -> u16 {
267 1
268}
269
270impl ColumnDef {
271 pub fn required(name: impl Into<String>, column_type: ColumnType) -> Self {
272 Self {
273 name: name.into(),
274 column_type,
275 nullable: false,
276 default: None,
277 primary_key: false,
278 modifiers: Vec::new(),
279 generated_expr: None,
280 generated_deps: Vec::new(),
281 added_at_version: 1,
282 }
283 }
284
285 pub fn nullable(name: impl Into<String>, column_type: ColumnType) -> Self {
286 Self {
287 name: name.into(),
288 column_type,
289 nullable: true,
290 default: None,
291 primary_key: false,
292 modifiers: Vec::new(),
293 generated_expr: None,
294 generated_deps: Vec::new(),
295 added_at_version: 1,
296 }
297 }
298
299 pub fn with_primary_key(mut self) -> Self {
300 self.primary_key = true;
301 self.nullable = false;
302 self
303 }
304
305 pub fn is_time_key(&self) -> bool {
307 self.modifiers.contains(&ColumnModifier::TimeKey)
308 }
309
310 pub fn is_spatial_index(&self) -> bool {
312 self.modifiers.contains(&ColumnModifier::SpatialIndex)
313 }
314
315 pub fn with_default(mut self, expr: impl Into<String>) -> Self {
316 self.default = Some(expr.into());
317 self
318 }
319}
320
321impl fmt::Display for ColumnDef {
322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323 write!(f, "{} {}", self.name, self.column_type)?;
324 if !self.nullable {
325 write!(f, " NOT NULL")?;
326 }
327 if self.primary_key {
328 write!(f, " PRIMARY KEY")?;
329 }
330 if let Some(ref d) = self.default {
331 write!(f, " DEFAULT {d}")?;
332 }
333 Ok(())
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn parse_canonical() {
343 assert_eq!("BIGINT".parse::<ColumnType>().unwrap(), ColumnType::Int64);
344 assert_eq!(
345 "FLOAT64".parse::<ColumnType>().unwrap(),
346 ColumnType::Float64
347 );
348 assert_eq!("TEXT".parse::<ColumnType>().unwrap(), ColumnType::String);
349 assert_eq!("BOOL".parse::<ColumnType>().unwrap(), ColumnType::Bool);
350 assert_eq!(
351 "TIMESTAMP".parse::<ColumnType>().unwrap(),
352 ColumnType::Timestamp
353 );
354 assert_eq!(
355 "GEOMETRY".parse::<ColumnType>().unwrap(),
356 ColumnType::Geometry
357 );
358 assert_eq!("UUID".parse::<ColumnType>().unwrap(), ColumnType::Uuid);
359 }
360
361 #[test]
362 fn parse_vector() {
363 assert_eq!(
364 "VECTOR(768)".parse::<ColumnType>().unwrap(),
365 ColumnType::Vector(768)
366 );
367 assert!("VECTOR(0)".parse::<ColumnType>().is_err());
368 }
369
370 #[test]
371 fn display_roundtrip() {
372 for ct in [
373 ColumnType::Int64,
374 ColumnType::Float64,
375 ColumnType::String,
376 ColumnType::Vector(768),
377 ] {
378 let s = ct.to_string();
379 let parsed: ColumnType = s.parse().unwrap();
380 assert_eq!(parsed, ct);
381 }
382 }
383
384 #[test]
385 fn accepts_native_values() {
386 assert!(ColumnType::Int64.accepts(&Value::Integer(42)));
387 assert!(ColumnType::Float64.accepts(&Value::Float(42.0)));
388 assert!(ColumnType::Float64.accepts(&Value::Integer(42))); assert!(ColumnType::String.accepts(&Value::String("x".into())));
390 assert!(ColumnType::Bool.accepts(&Value::Bool(true)));
391 assert!(ColumnType::Bytes.accepts(&Value::Bytes(vec![1])));
392 assert!(
393 ColumnType::Uuid.accepts(&Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()))
394 );
395 assert!(ColumnType::Decimal.accepts(&Value::Decimal(rust_decimal::Decimal::ZERO)));
396
397 assert!(ColumnType::Int64.accepts(&Value::Null));
399
400 assert!(!ColumnType::Int64.accepts(&Value::String("x".into())));
402 assert!(!ColumnType::Bool.accepts(&Value::Integer(1)));
403 }
404
405 #[test]
406 fn accepts_coercion_sources() {
407 assert!(ColumnType::Timestamp.accepts(&Value::String("2024-01-01".into())));
409 assert!(ColumnType::Timestamp.accepts(&Value::Integer(1_700_000_000)));
410 assert!(ColumnType::Uuid.accepts(&Value::String(
411 "550e8400-e29b-41d4-a716-446655440000".into()
412 )));
413 assert!(ColumnType::Decimal.accepts(&Value::String("99.95".into())));
414 assert!(ColumnType::Decimal.accepts(&Value::Float(99.95)));
415 assert!(ColumnType::Geometry.accepts(&Value::String("POINT(0 0)".into())));
416 }
417
418 #[test]
419 fn column_def_display() {
420 let col = ColumnDef::required("id", ColumnType::Int64).with_primary_key();
421 assert_eq!(col.to_string(), "id BIGINT NOT NULL PRIMARY KEY");
422 }
423}