Skip to main content

azoth_vector/
types.rs

1//! Vector types and configuration
2
3use serde::{Deserialize, Serialize};
4
5/// Vector data types supported by sqlite-vector
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
7pub enum VectorType {
8    /// 32-bit floating point (default)
9    #[default]
10    Float32,
11    /// 16-bit floating point
12    Float16,
13    /// Brain floating point 16-bit
14    BFloat16,
15    /// 8-bit signed integer
16    Int8,
17    /// 8-bit unsigned integer
18    UInt8,
19    /// 1-bit binary vector
20    Bit1,
21}
22
23impl std::fmt::Display for VectorType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            VectorType::Float32 => write!(f, "FLOAT32"),
27            VectorType::Float16 => write!(f, "FLOAT16"),
28            VectorType::BFloat16 => write!(f, "BFLOAT16"),
29            VectorType::Int8 => write!(f, "INT8"),
30            VectorType::UInt8 => write!(f, "UINT8"),
31            VectorType::Bit1 => write!(f, "1BIT"),
32        }
33    }
34}
35
36/// Distance metrics for similarity search
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38pub enum DistanceMetric {
39    /// Euclidean distance (L2 norm)
40    L2,
41    /// Squared Euclidean distance (faster, same ranking as L2)
42    SquaredL2,
43    /// Manhattan distance (L1 norm)
44    L1,
45    /// Cosine similarity (angle between vectors)
46    #[default]
47    Cosine,
48    /// Dot product (for normalized vectors)
49    DotProduct,
50    /// Hamming distance (for 1-bit vectors)
51    Hamming,
52}
53
54impl std::fmt::Display for DistanceMetric {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            DistanceMetric::L2 => write!(f, "l2"),
58            DistanceMetric::SquaredL2 => write!(f, "squared_l2"),
59            DistanceMetric::L1 => write!(f, "l1"),
60            DistanceMetric::Cosine => write!(f, "cosine"),
61            DistanceMetric::DotProduct => write!(f, "dot"),
62            DistanceMetric::Hamming => write!(f, "hamming"),
63        }
64    }
65}
66
67/// Configuration for a vector column
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct VectorConfig {
70    /// Vector data type
71    pub vector_type: VectorType,
72    /// Vector dimension (number of elements)
73    pub dimension: usize,
74    /// Distance metric for similarity search
75    pub distance_metric: DistanceMetric,
76}
77
78impl VectorConfig {
79    /// Create a new vector configuration
80    pub fn new(vector_type: VectorType, dimension: usize, distance_metric: DistanceMetric) -> Self {
81        Self {
82            vector_type,
83            dimension,
84            distance_metric,
85        }
86    }
87
88    /// Convert to sqlite-vector config string
89    pub(crate) fn to_config_string(&self) -> String {
90        format!("type={},dimension={}", self.vector_type, self.dimension)
91    }
92}
93
94impl Default for VectorConfig {
95    fn default() -> Self {
96        Self {
97            vector_type: VectorType::Float32,
98            dimension: 384, // Common dimension for sentence transformers
99            distance_metric: DistanceMetric::Cosine,
100        }
101    }
102}
103
104/// A vector embedding
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Vector {
107    /// Vector data as f32 values
108    pub data: Vec<f32>,
109    /// Vector dimension
110    pub dimension: usize,
111}
112
113impl Vector {
114    /// Create a new vector from f32 data
115    pub fn new(data: Vec<f32>) -> Self {
116        let dimension = data.len();
117        Self { data, dimension }
118    }
119
120    /// Serialize to BLOB for SQLite storage
121    ///
122    /// Stores as little-endian f32 values
123    pub fn to_blob(&self) -> Vec<u8> {
124        self.data.iter().flat_map(|f| f.to_le_bytes()).collect()
125    }
126
127    /// Deserialize from BLOB
128    pub fn from_blob(blob: &[u8]) -> Result<Self, String> {
129        if !blob.len().is_multiple_of(4) {
130            return Err(format!(
131                "Invalid blob length {} for f32 vector (must be multiple of 4)",
132                blob.len()
133            ));
134        }
135
136        let data = blob
137            .chunks_exact(4)
138            .map(|chunk| {
139                let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
140                f32::from_le_bytes(bytes)
141            })
142            .collect::<Vec<_>>();
143
144        Ok(Self::new(data))
145    }
146
147    /// Convert to JSON array (for queries)
148    pub fn to_json(&self) -> String {
149        serde_json::to_string(&self.data).unwrap()
150    }
151
152    /// Get vector dimension
153    pub fn dimension(&self) -> usize {
154        self.dimension
155    }
156
157    /// Get vector data as slice
158    pub fn as_slice(&self) -> &[f32] {
159        &self.data
160    }
161}
162
163impl From<Vec<f32>> for Vector {
164    fn from(data: Vec<f32>) -> Self {
165        Self::new(data)
166    }
167}
168
169/// Search result with distance/similarity score
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct SearchResult {
172    /// Row ID of the result
173    pub rowid: i64,
174    /// Distance/similarity score
175    ///
176    /// For L2/L1/Hamming: lower is more similar
177    /// For Cosine/Dot: higher is more similar (but sqlite-vector returns distance)
178    pub distance: f32,
179}
180
181impl SearchResult {
182    /// Create a new search result
183    pub fn new(rowid: i64, distance: f32) -> Self {
184        Self { rowid, distance }
185    }
186
187    /// Convert distance to similarity score for cosine/dot metrics
188    ///
189    /// sqlite-vector returns distance, but cosine similarity is more intuitive
190    /// as a score where 1.0 = identical, 0.0 = orthogonal
191    pub fn similarity(&self, metric: DistanceMetric) -> f32 {
192        match metric {
193            DistanceMetric::Cosine | DistanceMetric::DotProduct => 1.0 - self.distance,
194            _ => self.distance,
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_vector_blob_roundtrip() {
205        let original = Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
206        let blob = original.to_blob();
207        let decoded = Vector::from_blob(&blob).unwrap();
208
209        assert_eq!(original.dimension, decoded.dimension);
210        for (a, b) in original.data.iter().zip(decoded.data.iter()) {
211            assert!((a - b).abs() < 1e-6);
212        }
213    }
214
215    #[test]
216    fn test_vector_json() {
217        let vector = Vector::new(vec![1.0, 2.0, 3.0]);
218        let json = vector.to_json();
219        assert_eq!(json, "[1.0,2.0,3.0]");
220    }
221
222    #[test]
223    fn test_config_string() {
224        let config = VectorConfig::new(VectorType::Float32, 384, DistanceMetric::Cosine);
225        assert_eq!(config.to_config_string(), "type=FLOAT32,dimension=384");
226    }
227
228    #[test]
229    fn test_invalid_blob() {
230        let invalid_blob = vec![0u8, 1, 2]; // Not multiple of 4
231        let result = Vector::from_blob(&invalid_blob);
232        assert!(result.is_err());
233    }
234}