1use std::cmp::Ordering;
2use std::convert::TryFrom;
3
4use crate::ast::ddl::VectorMetric;
5use crate::planner::ResolvedType;
6use serde::{Deserialize, Serialize};
7
8use super::error::{Result, StorageError};
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum SqlValue {
13 Null,
14 Integer(i32),
15 BigInt(i64),
16 Float(f32),
17 Double(f64),
18 Text(String),
19 Blob(Vec<u8>),
20 Boolean(bool),
21 Timestamp(i64), Vector(Vec<f32>),
23}
24
25impl SqlValue {
26 pub fn type_tag(&self) -> u8 {
28 match self {
29 SqlValue::Null => 0x00,
30 SqlValue::Integer(_) => 0x01,
31 SqlValue::BigInt(_) => 0x02,
32 SqlValue::Float(_) => 0x03,
33 SqlValue::Double(_) => 0x04,
34 SqlValue::Text(_) => 0x05,
35 SqlValue::Blob(_) => 0x06,
36 SqlValue::Boolean(_) => 0x07,
37 SqlValue::Timestamp(_) => 0x08,
38 SqlValue::Vector(_) => 0x09,
39 }
40 }
41
42 pub fn is_null(&self) -> bool {
44 matches!(self, SqlValue::Null)
45 }
46
47 pub fn type_name(&self) -> &'static str {
49 match self {
50 SqlValue::Null => "Null",
51 SqlValue::Integer(_) => "Integer",
52 SqlValue::BigInt(_) => "BigInt",
53 SqlValue::Float(_) => "Float",
54 SqlValue::Double(_) => "Double",
55 SqlValue::Text(_) => "Text",
56 SqlValue::Blob(_) => "Blob",
57 SqlValue::Boolean(_) => "Boolean",
58 SqlValue::Timestamp(_) => "Timestamp",
59 SqlValue::Vector(_) => "Vector",
60 }
61 }
62
63 pub fn resolved_type(&self) -> ResolvedType {
65 match self {
66 SqlValue::Null => ResolvedType::Null,
67 SqlValue::Integer(_) => ResolvedType::Integer,
68 SqlValue::BigInt(_) => ResolvedType::BigInt,
69 SqlValue::Float(_) => ResolvedType::Float,
70 SqlValue::Double(_) => ResolvedType::Double,
71 SqlValue::Text(_) => ResolvedType::Text,
72 SqlValue::Blob(_) => ResolvedType::Blob,
73 SqlValue::Boolean(_) => ResolvedType::Boolean,
74 SqlValue::Timestamp(_) => ResolvedType::Timestamp,
75 SqlValue::Vector(v) => ResolvedType::Vector {
76 dimension: v.len() as u32,
77 metric: VectorMetric::Cosine,
78 },
79 }
80 }
81}
82
83impl PartialOrd for SqlValue {
84 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
85 use SqlValue::*;
86 match (self, other) {
87 (Null, _) | (_, Null) => None,
88 (Integer(a), Integer(b)) => Some(a.cmp(b)),
89 (BigInt(a), BigInt(b)) => Some(a.cmp(b)),
90 (Float(a), Float(b)) => a.partial_cmp(b),
91 (Double(a), Double(b)) => a.partial_cmp(b),
92 (Text(a), Text(b)) => Some(a.cmp(b)),
93 (Blob(a), Blob(b)) => Some(a.cmp(b)),
94 (Boolean(a), Boolean(b)) => Some(a.cmp(b)),
95 (Timestamp(a), Timestamp(b)) => Some(a.cmp(b)),
96 (Vector(_), Vector(_)) => None,
98 _ => None,
99 }
100 }
101}
102
103macro_rules! impl_from_sqlvalue_try {
104 ($target:ty, $variant:ident) => {
105 impl TryFrom<SqlValue> for $target {
106 type Error = StorageError;
107 fn try_from(value: SqlValue) -> Result<Self> {
108 if let SqlValue::$variant(v) = value {
109 Ok(v.into())
110 } else {
111 Err(StorageError::TypeMismatch {
112 expected: stringify!($variant).to_string(),
113 actual: value.type_name().to_string(),
114 })
115 }
116 }
117 }
118 };
119}
120
121macro_rules! impl_from_primitive {
122 ($source:ty, $variant:ident) => {
123 impl From<$source> for SqlValue {
124 fn from(value: $source) -> Self {
125 SqlValue::$variant(value.into())
126 }
127 }
128 };
129}
130
131impl_from_primitive!(i32, Integer);
132impl_from_primitive!(i64, BigInt);
133impl_from_primitive!(f32, Float);
134impl_from_primitive!(f64, Double);
135impl_from_primitive!(bool, Boolean);
136impl_from_primitive!(String, Text);
137impl_from_primitive!(&str, Text);
138impl_from_primitive!(Vec<u8>, Blob);
139impl From<&[u8]> for SqlValue {
140 fn from(value: &[u8]) -> Self {
141 SqlValue::Blob(value.to_vec())
142 }
143}
144impl From<Vec<f32>> for SqlValue {
145 fn from(value: Vec<f32>) -> Self {
146 SqlValue::Vector(value)
147 }
148}
149
150impl_from_sqlvalue_try!(i32, Integer);
151impl_from_sqlvalue_try!(f32, Float);
152impl_from_sqlvalue_try!(f64, Double);
153impl_from_sqlvalue_try!(bool, Boolean);
154impl_from_sqlvalue_try!(String, Text);
155impl_from_sqlvalue_try!(Vec<u8>, Blob);
156
157impl TryFrom<SqlValue> for i64 {
158 type Error = StorageError;
159 fn try_from(value: SqlValue) -> Result<Self> {
160 match value {
161 SqlValue::BigInt(v) | SqlValue::Timestamp(v) => Ok(v),
162 other => Err(StorageError::TypeMismatch {
163 expected: "BigInt/Timestamp".to_string(),
164 actual: other.type_name().to_string(),
165 }),
166 }
167 }
168}
169
170impl TryFrom<SqlValue> for Vec<f32> {
171 type Error = StorageError;
172 fn try_from(value: SqlValue) -> Result<Self> {
173 if let SqlValue::Vector(v) = value {
174 Ok(v)
175 } else {
176 Err(StorageError::TypeMismatch {
177 expected: "Vector".to_string(),
178 actual: value.type_name().to_string(),
179 })
180 }
181 }
182}
183
184impl TryFrom<&SqlValue> for ResolvedType {
185 type Error = StorageError;
186 fn try_from(value: &SqlValue) -> Result<Self> {
187 Ok(value.resolved_type())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use proptest::prelude::*;
195
196 #[test]
197 fn null_comparisons_return_none() {
198 assert!(SqlValue::Null.partial_cmp(&SqlValue::Null).is_none());
199 assert!(SqlValue::Null.partial_cmp(&SqlValue::Integer(1)).is_none());
200 }
201
202 #[test]
203 fn heterogeneous_comparisons_return_none() {
204 assert!(
205 SqlValue::Integer(1)
206 .partial_cmp(&SqlValue::Double(1.0))
207 .is_none()
208 );
209 assert!(
210 SqlValue::Text("a".into())
211 .partial_cmp(&SqlValue::Blob(vec![0x61]))
212 .is_none()
213 );
214 }
215
216 #[test]
217 fn same_type_comparisons_work() {
218 assert_eq!(
219 SqlValue::Integer(1).partial_cmp(&SqlValue::Integer(2)),
220 Some(Ordering::Less)
221 );
222 assert_eq!(
223 SqlValue::Text("b".into()).partial_cmp(&SqlValue::Text("a".into())),
224 Some(Ordering::Greater)
225 );
226 assert_eq!(
227 SqlValue::Boolean(false).partial_cmp(&SqlValue::Boolean(true)),
228 Some(Ordering::Less)
229 );
230 }
231
232 proptest! {
233 #[test]
234 fn integer_roundtrip(v in any::<i32>()) {
235 let sql: SqlValue = v.into();
236 let back = i32::try_from(sql.clone()).unwrap();
237 prop_assert_eq!(back, v);
238 prop_assert_eq!(sql.partial_cmp(&SqlValue::Integer(v)), Some(Ordering::Equal));
239 }
240
241 #[test]
242 fn bigint_roundtrip(v in any::<i64>()) {
243 let sql: SqlValue = SqlValue::BigInt(v);
244 let back = i64::try_from(sql.clone()).unwrap();
245 prop_assert_eq!(back, v);
246 }
247
248 #[test]
249 fn float_roundtrip_non_nan(v in any::<f32>().prop_filter("no NaN", |f| f.is_finite())) {
250 let sql: SqlValue = SqlValue::Float(v);
251 let back = f32::try_from(sql.clone()).unwrap();
252 prop_assert_eq!(back, v);
253 prop_assert_eq!(sql.partial_cmp(&SqlValue::Float(v)), Some(Ordering::Equal));
254 }
255
256 #[test]
257 fn text_roundtrip(s in ".*") {
258 let sql: SqlValue = SqlValue::Text(s.clone());
259 let back = String::try_from(sql.clone()).unwrap();
260 prop_assert_eq!(back, s.clone());
261 prop_assert_eq!(sql.partial_cmp(&SqlValue::Text(s)), Some(Ordering::Equal));
262 }
263
264 #[test]
265 fn blob_roundtrip(data in proptest::collection::vec(any::<u8>(), 0..64)) {
266 let sql: SqlValue = SqlValue::Blob(data.clone());
267 let back = Vec::<u8>::try_from(sql.clone()).unwrap();
268 prop_assert_eq!(back, data.clone());
269 prop_assert_eq!(sql.partial_cmp(&SqlValue::Blob(data)), Some(Ordering::Equal));
270 }
271 }
272}