1use crate::vector::{EmbeddingType, Vector};
2use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
3use serde::{Deserialize, Serialize};
4use std::io::Cursor;
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct DeltaChange {
9 pub index: u16,
10 pub delta: f32,
11}
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct VectorDelta {
16 pub base_id: u16,
17 pub changes: Vec<DeltaChange>,
18}
19
20impl VectorDelta {
21 pub fn new(base_id: u16, changes: Vec<DeltaChange>) -> Self {
23 Self { base_id, changes }
24 }
25
26 pub fn from_vectors(old: &Vector, new: &Vector, base_id: u16) -> Result<Self, String> {
29 if old.dtype != new.dtype {
30 return Err("Type mismatch between vectors".to_string());
31 }
32 if old.dim != new.dim {
33 return Err("Dimension mismatch between vectors".to_string());
34 }
35 if old.dtype != EmbeddingType::F32 {
36 return Err("Delta only supported for F32 embeddings".to_string());
37 }
38
39 let old_values = old.as_f32()?;
40 let new_values = new.as_f32()?;
41
42 let mut changes = Vec::new();
43 for (i, (old_val, new_val)) in old_values.iter().zip(new_values.iter()).enumerate() {
44 if (old_val - new_val).abs() > f32::EPSILON {
45 changes.push(DeltaChange {
46 index: i as u16,
47 delta: new_val - old_val,
48 });
49 }
50 }
51
52 Ok(Self { base_id, changes })
53 }
54
55 pub fn apply(&self, base: &Vector) -> Result<Vector, String> {
57 if base.dtype != EmbeddingType::F32 {
58 return Err("Delta application only supported for F32 embeddings".to_string());
59 }
60
61 let mut values = base.as_f32()?;
62
63 for change in &self.changes {
65 let idx = change.index as usize;
66 if idx >= values.len() {
67 return Err(format!("Invalid index {} in delta", idx));
68 }
69 values[idx] += change.delta;
70 }
71
72 Ok(Vector::from_f32(values))
73 }
74
75 pub fn encode(&self) -> Result<Vec<u8>, std::io::Error> {
78 let mut buf = Vec::new();
79
80 buf.write_u16::<LittleEndian>(self.base_id)?;
81 buf.write_u16::<LittleEndian>(self.changes.len() as u16)?;
82
83 for change in &self.changes {
84 buf.write_u16::<LittleEndian>(change.index)?;
85 buf.write_f32::<LittleEndian>(change.delta)?;
86 }
87
88 Ok(buf)
89 }
90
91 pub fn decode(data: &[u8]) -> Result<Self, std::io::Error> {
93 let mut rdr = Cursor::new(data);
94
95 let base_id = rdr.read_u16::<LittleEndian>()?;
96 let change_count = rdr.read_u16::<LittleEndian>()?;
97
98 let mut changes = Vec::with_capacity(change_count as usize);
99 for _ in 0..change_count {
100 let index = rdr.read_u16::<LittleEndian>()?;
101 let delta = rdr.read_f32::<LittleEndian>()?;
102 changes.push(DeltaChange { index, delta });
103 }
104
105 Ok(Self { base_id, changes })
106 }
107
108 pub fn encoded_size(&self) -> usize {
110 4 + (self.changes.len() * 6) }
112
113 pub fn change_ratio(&self, total_dim: u16) -> f32 {
115 self.changes.len() as f32 / total_dim as f32
116 }
117
118 pub fn is_beneficial(&self, full_vector_size: usize) -> bool {
121 self.encoded_size() < full_vector_size
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum UpdateStrategy {
128 AlwaysFull,
130 AlwaysDelta,
132 Adaptive { threshold: u8 }, }
136
137impl Default for UpdateStrategy {
138 fn default() -> Self {
139 UpdateStrategy::Adaptive { threshold: 30 }
140 }
141}
142
143impl UpdateStrategy {
144 pub fn should_use_delta(&self, delta: &VectorDelta, vector_dim: u16) -> bool {
146 match self {
147 UpdateStrategy::AlwaysFull => false,
148 UpdateStrategy::AlwaysDelta => true,
149 UpdateStrategy::Adaptive { threshold } => {
150 let change_ratio = delta.change_ratio(vector_dim);
151 change_ratio < (*threshold as f32 / 100.0)
152 }
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_delta_creation() {
163 let changes = vec![
164 DeltaChange {
165 index: 0,
166 delta: 0.1,
167 },
168 DeltaChange {
169 index: 5,
170 delta: -0.2,
171 },
172 ];
173 let delta = VectorDelta::new(1001, changes.clone());
174 assert_eq!(delta.base_id, 1001);
175 assert_eq!(delta.changes.len(), 2);
176 }
177
178 #[test]
179 fn test_delta_from_vectors() {
180 let old = Vector::from_f32(vec![1.0, 2.0, 3.0, 4.0]);
181 let new = Vector::from_f32(vec![1.0, 2.5, 3.0, 3.5]);
182
183 let delta = VectorDelta::from_vectors(&old, &new, 100).unwrap();
184
185 assert_eq!(delta.base_id, 100);
186 assert_eq!(delta.changes.len(), 2);
187 assert_eq!(delta.changes[0].index, 1);
188 assert!((delta.changes[0].delta - 0.5).abs() < f32::EPSILON);
189 assert_eq!(delta.changes[1].index, 3);
190 assert!((delta.changes[1].delta - (-0.5)).abs() < f32::EPSILON);
191 }
192
193 #[test]
194 fn test_delta_apply() {
195 let base = Vector::from_f32(vec![1.0, 2.0, 3.0, 4.0]);
196 let changes = vec![
197 DeltaChange {
198 index: 1,
199 delta: 0.5,
200 },
201 DeltaChange {
202 index: 3,
203 delta: -0.5,
204 },
205 ];
206 let delta = VectorDelta::new(100, changes);
207
208 let result = delta.apply(&base).unwrap();
209 let values = result.as_f32().unwrap();
210
211 assert_eq!(values.len(), 4);
212 assert!((values[0] - 1.0).abs() < f32::EPSILON);
213 assert!((values[1] - 2.5).abs() < f32::EPSILON);
214 assert!((values[2] - 3.0).abs() < f32::EPSILON);
215 assert!((values[3] - 3.5).abs() < f32::EPSILON);
216 }
217
218 #[test]
219 fn test_delta_encode_decode() {
220 let changes = vec![
221 DeltaChange {
222 index: 10,
223 delta: 0.123,
224 },
225 DeltaChange {
226 index: 20,
227 delta: -0.456,
228 },
229 ];
230 let delta = VectorDelta::new(999, changes);
231
232 let encoded = delta.encode().unwrap();
233 assert_eq!(encoded.len(), 4 + 2 * 6); let decoded = VectorDelta::decode(&encoded).unwrap();
236 assert_eq!(decoded.base_id, delta.base_id);
237 assert_eq!(decoded.changes.len(), delta.changes.len());
238 assert_eq!(decoded.changes[0].index, delta.changes[0].index);
239 assert!((decoded.changes[0].delta - delta.changes[0].delta).abs() < 0.0001);
240 }
241
242 #[test]
243 fn test_delta_roundtrip() {
244 let old = Vector::from_f32(vec![0.1; 1536]);
245 let mut new_data = vec![0.1; 1536];
246 for i in 0..15 {
248 new_data[i * 100] += 0.01;
249 }
250 let new = Vector::from_f32(new_data);
251
252 let delta = VectorDelta::from_vectors(&old, &new, 1).unwrap();
253 let encoded = delta.encode().unwrap();
254 let decoded = VectorDelta::decode(&encoded).unwrap();
255 let reconstructed = decoded.apply(&old).unwrap();
256
257 assert_eq!(new, reconstructed);
258 }
259
260 #[test]
261 fn test_update_strategy() {
262 let small_delta = VectorDelta::new(
263 1,
264 vec![DeltaChange {
265 index: 0,
266 delta: 0.1,
267 }],
268 );
269 let large_delta = VectorDelta::new(
270 1,
271 (0..500)
272 .map(|i| DeltaChange {
273 index: i,
274 delta: 0.1,
275 })
276 .collect(),
277 );
278
279 let strategy = UpdateStrategy::default();
280
281 assert!(strategy.should_use_delta(&small_delta, 1536));
282 assert!(!strategy.should_use_delta(&large_delta, 1536));
283 }
284}