1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
7pub enum VectorType {
8 #[default]
10 Float32,
11 Float16,
13 BFloat16,
15 Int8,
17 UInt8,
19 Bit1,
21}
22
23impl std::fmt::Display for VectorType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 VectorType::Float32 => write!(f, "FLOAT32"),
27 VectorType::Float16 => write!(f, "FLOAT16"),
28 VectorType::BFloat16 => write!(f, "BFLOAT16"),
29 VectorType::Int8 => write!(f, "INT8"),
30 VectorType::UInt8 => write!(f, "UINT8"),
31 VectorType::Bit1 => write!(f, "1BIT"),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38pub enum DistanceMetric {
39 L2,
41 SquaredL2,
43 L1,
45 #[default]
47 Cosine,
48 DotProduct,
50 Hamming,
52}
53
54impl std::fmt::Display for DistanceMetric {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 DistanceMetric::L2 => write!(f, "l2"),
58 DistanceMetric::SquaredL2 => write!(f, "squared_l2"),
59 DistanceMetric::L1 => write!(f, "l1"),
60 DistanceMetric::Cosine => write!(f, "cosine"),
61 DistanceMetric::DotProduct => write!(f, "dot"),
62 DistanceMetric::Hamming => write!(f, "hamming"),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct VectorConfig {
70 pub vector_type: VectorType,
72 pub dimension: usize,
74 pub distance_metric: DistanceMetric,
76}
77
78impl VectorConfig {
79 pub fn new(vector_type: VectorType, dimension: usize, distance_metric: DistanceMetric) -> Self {
81 Self {
82 vector_type,
83 dimension,
84 distance_metric,
85 }
86 }
87
88 pub(crate) fn to_config_string(&self) -> String {
90 format!("type={},dimension={}", self.vector_type, self.dimension)
91 }
92}
93
94impl Default for VectorConfig {
95 fn default() -> Self {
96 Self {
97 vector_type: VectorType::Float32,
98 dimension: 384, distance_metric: DistanceMetric::Cosine,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Vector {
107 pub data: Vec<f32>,
109 pub dimension: usize,
111}
112
113impl Vector {
114 pub fn new(data: Vec<f32>) -> Self {
116 let dimension = data.len();
117 Self { data, dimension }
118 }
119
120 pub fn to_blob(&self) -> Vec<u8> {
124 self.data.iter().flat_map(|f| f.to_le_bytes()).collect()
125 }
126
127 pub fn from_blob(blob: &[u8]) -> Result<Self, String> {
129 if !blob.len().is_multiple_of(4) {
130 return Err(format!(
131 "Invalid blob length {} for f32 vector (must be multiple of 4)",
132 blob.len()
133 ));
134 }
135
136 let data = blob
137 .chunks_exact(4)
138 .map(|chunk| {
139 let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
140 f32::from_le_bytes(bytes)
141 })
142 .collect::<Vec<_>>();
143
144 Ok(Self::new(data))
145 }
146
147 pub fn to_json(&self) -> String {
149 serde_json::to_string(&self.data).unwrap()
150 }
151
152 pub fn dimension(&self) -> usize {
154 self.dimension
155 }
156
157 pub fn as_slice(&self) -> &[f32] {
159 &self.data
160 }
161}
162
163impl From<Vec<f32>> for Vector {
164 fn from(data: Vec<f32>) -> Self {
165 Self::new(data)
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct SearchResult {
172 pub rowid: i64,
174 pub distance: f32,
179}
180
181impl SearchResult {
182 pub fn new(rowid: i64, distance: f32) -> Self {
184 Self { rowid, distance }
185 }
186
187 pub fn similarity(&self, metric: DistanceMetric) -> f32 {
192 match metric {
193 DistanceMetric::Cosine | DistanceMetric::DotProduct => 1.0 - self.distance,
194 _ => self.distance,
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_vector_blob_roundtrip() {
205 let original = Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
206 let blob = original.to_blob();
207 let decoded = Vector::from_blob(&blob).unwrap();
208
209 assert_eq!(original.dimension, decoded.dimension);
210 for (a, b) in original.data.iter().zip(decoded.data.iter()) {
211 assert!((a - b).abs() < 1e-6);
212 }
213 }
214
215 #[test]
216 fn test_vector_json() {
217 let vector = Vector::new(vec![1.0, 2.0, 3.0]);
218 let json = vector.to_json();
219 assert_eq!(json, "[1.0,2.0,3.0]");
220 }
221
222 #[test]
223 fn test_config_string() {
224 let config = VectorConfig::new(VectorType::Float32, 384, DistanceMetric::Cosine);
225 assert_eq!(config.to_config_string(), "type=FLOAT32,dimension=384");
226 }
227
228 #[test]
229 fn test_invalid_blob() {
230 let invalid_blob = vec![0u8, 1, 2]; let result = Vector::from_blob(&invalid_blob);
232 assert!(result.is_err());
233 }
234}