1use std::fmt;
4use std::str::FromStr;
5
6use serde::{Deserialize, Serialize};
7
8use crate::value::Value;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(tag = "type", content = "params")]
13pub enum ColumnType {
14 Int64,
15 Float64,
16 String,
17 Bool,
18 Bytes,
19 Timestamp,
20 Decimal,
21 Geometry,
22 Vector(u32),
24 Uuid,
25}
26
27impl ColumnType {
28 pub fn fixed_size(&self) -> Option<usize> {
30 match self {
31 Self::Int64 | Self::Float64 | Self::Timestamp => Some(8),
32 Self::Bool => Some(1),
33 Self::Decimal => Some(16),
34 Self::Uuid => Some(16),
35 Self::Vector(dim) => Some(*dim as usize * 4),
36 Self::String | Self::Bytes | Self::Geometry => None,
37 }
38 }
39
40 pub fn is_variable_length(&self) -> bool {
42 self.fixed_size().is_none()
43 }
44
45 pub fn accepts(&self, value: &Value) -> bool {
51 matches!(
52 (self, value),
53 (Self::Int64, Value::Integer(_))
54 | (Self::Float64, Value::Float(_) | Value::Integer(_))
55 | (Self::String, Value::String(_))
56 | (Self::Bool, Value::Bool(_))
57 | (Self::Bytes, Value::Bytes(_))
58 | (
59 Self::Timestamp,
60 Value::DateTime(_) | Value::Integer(_) | Value::String(_)
61 )
62 | (
63 Self::Decimal,
64 Value::Decimal(_) | Value::String(_) | Value::Float(_) | Value::Integer(_)
65 )
66 | (Self::Geometry, Value::Geometry(_) | Value::String(_))
67 | (Self::Vector(_), Value::Array(_) | Value::Bytes(_))
68 | (Self::Uuid, Value::Uuid(_) | Value::String(_))
69 | (_, Value::Null)
70 )
71 }
72}
73
74impl fmt::Display for ColumnType {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::Int64 => f.write_str("BIGINT"),
78 Self::Float64 => f.write_str("FLOAT64"),
79 Self::String => f.write_str("TEXT"),
80 Self::Bool => f.write_str("BOOL"),
81 Self::Bytes => f.write_str("BYTES"),
82 Self::Timestamp => f.write_str("TIMESTAMP"),
83 Self::Decimal => f.write_str("DECIMAL"),
84 Self::Geometry => f.write_str("GEOMETRY"),
85 Self::Vector(dim) => write!(f, "VECTOR({dim})"),
86 Self::Uuid => f.write_str("UUID"),
87 }
88 }
89}
90
91impl FromStr for ColumnType {
92 type Err = ColumnTypeParseError;
93
94 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 let upper = s.trim().to_uppercase();
96
97 if upper.starts_with("VECTOR") {
99 let inner = upper
100 .trim_start_matches("VECTOR")
101 .trim()
102 .trim_start_matches('(')
103 .trim_end_matches(')')
104 .trim();
105 if inner.is_empty() {
106 return Err(ColumnTypeParseError::InvalidVectorDim("empty".into()));
107 }
108 let dim: u32 = inner
109 .parse()
110 .map_err(|_| ColumnTypeParseError::InvalidVectorDim(inner.into()))?;
111 if dim == 0 {
112 return Err(ColumnTypeParseError::InvalidVectorDim("0".into()));
113 }
114 return Ok(Self::Vector(dim));
115 }
116
117 match upper.as_str() {
118 "BIGINT" | "INT64" | "INTEGER" | "INT" => Ok(Self::Int64),
119 "FLOAT64" | "DOUBLE" | "REAL" | "FLOAT" => Ok(Self::Float64),
120 "TEXT" | "STRING" | "VARCHAR" => Ok(Self::String),
121 "BOOL" | "BOOLEAN" => Ok(Self::Bool),
122 "BYTES" | "BYTEA" | "BLOB" => Ok(Self::Bytes),
123 "TIMESTAMP" | "TIMESTAMPTZ" => Ok(Self::Timestamp),
124 "DECIMAL" | "NUMERIC" => Ok(Self::Decimal),
125 "GEOMETRY" => Ok(Self::Geometry),
126 "UUID" => Ok(Self::Uuid),
127 "DATETIME" => Err(ColumnTypeParseError::UseTimestamp),
128 other => Err(ColumnTypeParseError::Unknown(other.to_string())),
129 }
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
135pub enum ColumnTypeParseError {
136 #[error("unknown column type: '{0}'")]
137 Unknown(String),
138 #[error("'DATETIME' is not a valid type — use 'TIMESTAMP' instead")]
139 UseTimestamp,
140 #[error("invalid VECTOR dimension: '{0}' (must be a positive integer)")]
141 InvalidVectorDim(String),
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
146pub struct ColumnDef {
147 pub name: String,
148 pub column_type: ColumnType,
149 pub nullable: bool,
150 pub default: Option<String>,
151 pub primary_key: bool,
152}
153
154impl ColumnDef {
155 pub fn required(name: impl Into<String>, column_type: ColumnType) -> Self {
156 Self {
157 name: name.into(),
158 column_type,
159 nullable: false,
160 default: None,
161 primary_key: false,
162 }
163 }
164
165 pub fn nullable(name: impl Into<String>, column_type: ColumnType) -> Self {
166 Self {
167 name: name.into(),
168 column_type,
169 nullable: true,
170 default: None,
171 primary_key: false,
172 }
173 }
174
175 pub fn with_primary_key(mut self) -> Self {
176 self.primary_key = true;
177 self.nullable = false;
178 self
179 }
180
181 pub fn with_default(mut self, expr: impl Into<String>) -> Self {
182 self.default = Some(expr.into());
183 self
184 }
185}
186
187impl fmt::Display for ColumnDef {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 write!(f, "{} {}", self.name, self.column_type)?;
190 if !self.nullable {
191 write!(f, " NOT NULL")?;
192 }
193 if self.primary_key {
194 write!(f, " PRIMARY KEY")?;
195 }
196 if let Some(ref d) = self.default {
197 write!(f, " DEFAULT {d}")?;
198 }
199 Ok(())
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn parse_canonical() {
209 assert_eq!("BIGINT".parse::<ColumnType>().unwrap(), ColumnType::Int64);
210 assert_eq!(
211 "FLOAT64".parse::<ColumnType>().unwrap(),
212 ColumnType::Float64
213 );
214 assert_eq!("TEXT".parse::<ColumnType>().unwrap(), ColumnType::String);
215 assert_eq!("BOOL".parse::<ColumnType>().unwrap(), ColumnType::Bool);
216 assert_eq!(
217 "TIMESTAMP".parse::<ColumnType>().unwrap(),
218 ColumnType::Timestamp
219 );
220 assert_eq!(
221 "GEOMETRY".parse::<ColumnType>().unwrap(),
222 ColumnType::Geometry
223 );
224 assert_eq!("UUID".parse::<ColumnType>().unwrap(), ColumnType::Uuid);
225 }
226
227 #[test]
228 fn parse_vector() {
229 assert_eq!(
230 "VECTOR(768)".parse::<ColumnType>().unwrap(),
231 ColumnType::Vector(768)
232 );
233 assert!("VECTOR(0)".parse::<ColumnType>().is_err());
234 }
235
236 #[test]
237 fn display_roundtrip() {
238 for ct in [
239 ColumnType::Int64,
240 ColumnType::Float64,
241 ColumnType::String,
242 ColumnType::Vector(768),
243 ] {
244 let s = ct.to_string();
245 let parsed: ColumnType = s.parse().unwrap();
246 assert_eq!(parsed, ct);
247 }
248 }
249
250 #[test]
251 fn accepts_native_values() {
252 assert!(ColumnType::Int64.accepts(&Value::Integer(42)));
253 assert!(ColumnType::Float64.accepts(&Value::Float(42.0)));
254 assert!(ColumnType::Float64.accepts(&Value::Integer(42))); assert!(ColumnType::String.accepts(&Value::String("x".into())));
256 assert!(ColumnType::Bool.accepts(&Value::Bool(true)));
257 assert!(ColumnType::Bytes.accepts(&Value::Bytes(vec![1])));
258 assert!(
259 ColumnType::Uuid.accepts(&Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()))
260 );
261 assert!(ColumnType::Decimal.accepts(&Value::Decimal(rust_decimal::Decimal::ZERO)));
262
263 assert!(ColumnType::Int64.accepts(&Value::Null));
265
266 assert!(!ColumnType::Int64.accepts(&Value::String("x".into())));
268 assert!(!ColumnType::Bool.accepts(&Value::Integer(1)));
269 }
270
271 #[test]
272 fn accepts_coercion_sources() {
273 assert!(ColumnType::Timestamp.accepts(&Value::String("2024-01-01".into())));
275 assert!(ColumnType::Timestamp.accepts(&Value::Integer(1_700_000_000)));
276 assert!(ColumnType::Uuid.accepts(&Value::String(
277 "550e8400-e29b-41d4-a716-446655440000".into()
278 )));
279 assert!(ColumnType::Decimal.accepts(&Value::String("99.95".into())));
280 assert!(ColumnType::Decimal.accepts(&Value::Float(99.95)));
281 assert!(ColumnType::Geometry.accepts(&Value::String("POINT(0 0)".into())));
282 }
283
284 #[test]
285 fn column_def_display() {
286 let col = ColumnDef::required("id", ColumnType::Int64).with_primary_key();
287 assert_eq!(col.to_string(), "id BIGINT NOT NULL PRIMARY KEY");
288 }
289}