lnmp_embedding/
delta.rs

1use crate::vector::{EmbeddingType, Vector};
2use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
3use serde::{Deserialize, Serialize};
4use std::io::Cursor;
5
6/// Represents a single change in a vector
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct DeltaChange {
9    pub index: u16,
10    pub delta: f32,
11}
12
13/// Represents a delta update for a vector
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct VectorDelta {
16    pub base_id: u16,
17    pub changes: Vec<DeltaChange>,
18}
19
20impl VectorDelta {
21    /// Create a new VectorDelta
22    pub fn new(base_id: u16, changes: Vec<DeltaChange>) -> Self {
23        Self { base_id, changes }
24    }
25
26    /// Compute delta between two vectors
27    /// Only supports F32 embeddings currently
28    pub fn from_vectors(old: &Vector, new: &Vector, base_id: u16) -> Result<Self, String> {
29        if old.dtype != new.dtype {
30            return Err("Type mismatch between vectors".to_string());
31        }
32        if old.dim != new.dim {
33            return Err("Dimension mismatch between vectors".to_string());
34        }
35        if old.dtype != EmbeddingType::F32 {
36            return Err("Delta only supported for F32 embeddings".to_string());
37        }
38
39        let old_values = old.as_f32()?;
40        let new_values = new.as_f32()?;
41
42        let mut changes = Vec::new();
43        for (i, (old_val, new_val)) in old_values.iter().zip(new_values.iter()).enumerate() {
44            if (old_val - new_val).abs() > f32::EPSILON {
45                changes.push(DeltaChange {
46                    index: i as u16,
47                    delta: new_val - old_val,
48                });
49            }
50        }
51
52        Ok(Self { base_id, changes })
53    }
54
55    /// Apply delta to a base vector to produce updated vector
56    pub fn apply(&self, base: &Vector) -> Result<Vector, String> {
57        if base.dtype != EmbeddingType::F32 {
58            return Err("Delta application only supported for F32 embeddings".to_string());
59        }
60
61        let mut values = base.as_f32()?;
62
63        // Apply each change
64        for change in &self.changes {
65            let idx = change.index as usize;
66            if idx >= values.len() {
67                return Err(format!("Invalid index {} in delta", idx));
68            }
69            values[idx] += change.delta;
70        }
71
72        Ok(Vector::from_f32(values))
73    }
74
75    /// Encode delta to binary format
76    /// Format: base_id (u16) | change_count (u16) | [(index: u16, delta: f32), ...]
77    pub fn encode(&self) -> Result<Vec<u8>, std::io::Error> {
78        let mut buf = Vec::new();
79
80        buf.write_u16::<LittleEndian>(self.base_id)?;
81        buf.write_u16::<LittleEndian>(self.changes.len() as u16)?;
82
83        for change in &self.changes {
84            buf.write_u16::<LittleEndian>(change.index)?;
85            buf.write_f32::<LittleEndian>(change.delta)?;
86        }
87
88        Ok(buf)
89    }
90
91    /// Decode delta from binary format
92    pub fn decode(data: &[u8]) -> Result<Self, std::io::Error> {
93        let mut rdr = Cursor::new(data);
94
95        let base_id = rdr.read_u16::<LittleEndian>()?;
96        let change_count = rdr.read_u16::<LittleEndian>()?;
97
98        let mut changes = Vec::with_capacity(change_count as usize);
99        for _ in 0..change_count {
100            let index = rdr.read_u16::<LittleEndian>()?;
101            let delta = rdr.read_f32::<LittleEndian>()?;
102            changes.push(DeltaChange { index, delta });
103        }
104
105        Ok(Self { base_id, changes })
106    }
107
108    /// Get the size of the encoded delta in bytes
109    pub fn encoded_size(&self) -> usize {
110        4 + (self.changes.len() * 6) // header (4) + changes (6 bytes each)
111    }
112
113    /// Calculate change ratio (percentage of values changed)
114    pub fn change_ratio(&self, total_dim: u16) -> f32 {
115        self.changes.len() as f32 / total_dim as f32
116    }
117
118    /// Check if delta is worth using vs full vector
119    /// Returns true if delta is smaller than full vector encoding
120    pub fn is_beneficial(&self, full_vector_size: usize) -> bool {
121        self.encoded_size() < full_vector_size
122    }
123}
124
125/// Strategy for deciding between full and delta encoding
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum UpdateStrategy {
128    /// Always send full vector
129    AlwaysFull,
130    /// Always send delta (when available)
131    AlwaysDelta,
132    /// Automatically decide based on change ratio
133    /// Uses delta if change_ratio < threshold (default 0.3)
134    Adaptive { threshold: u8 }, // threshold as percentage (0-100)
135}
136
137impl Default for UpdateStrategy {
138    fn default() -> Self {
139        UpdateStrategy::Adaptive { threshold: 30 }
140    }
141}
142
143impl UpdateStrategy {
144    /// Decide whether to use delta based on the strategy
145    pub fn should_use_delta(&self, delta: &VectorDelta, vector_dim: u16) -> bool {
146        match self {
147            UpdateStrategy::AlwaysFull => false,
148            UpdateStrategy::AlwaysDelta => true,
149            UpdateStrategy::Adaptive { threshold } => {
150                let change_ratio = delta.change_ratio(vector_dim);
151                change_ratio < (*threshold as f32 / 100.0)
152            }
153        }
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_delta_creation() {
163        let changes = vec![
164            DeltaChange {
165                index: 0,
166                delta: 0.1,
167            },
168            DeltaChange {
169                index: 5,
170                delta: -0.2,
171            },
172        ];
173        let delta = VectorDelta::new(1001, changes.clone());
174        assert_eq!(delta.base_id, 1001);
175        assert_eq!(delta.changes.len(), 2);
176    }
177
178    #[test]
179    fn test_delta_from_vectors() {
180        let old = Vector::from_f32(vec![1.0, 2.0, 3.0, 4.0]);
181        let new = Vector::from_f32(vec![1.0, 2.5, 3.0, 3.5]);
182
183        let delta = VectorDelta::from_vectors(&old, &new, 100).unwrap();
184
185        assert_eq!(delta.base_id, 100);
186        assert_eq!(delta.changes.len(), 2);
187        assert_eq!(delta.changes[0].index, 1);
188        assert!((delta.changes[0].delta - 0.5).abs() < f32::EPSILON);
189        assert_eq!(delta.changes[1].index, 3);
190        assert!((delta.changes[1].delta - (-0.5)).abs() < f32::EPSILON);
191    }
192
193    #[test]
194    fn test_delta_apply() {
195        let base = Vector::from_f32(vec![1.0, 2.0, 3.0, 4.0]);
196        let changes = vec![
197            DeltaChange {
198                index: 1,
199                delta: 0.5,
200            },
201            DeltaChange {
202                index: 3,
203                delta: -0.5,
204            },
205        ];
206        let delta = VectorDelta::new(100, changes);
207
208        let result = delta.apply(&base).unwrap();
209        let values = result.as_f32().unwrap();
210
211        assert_eq!(values.len(), 4);
212        assert!((values[0] - 1.0).abs() < f32::EPSILON);
213        assert!((values[1] - 2.5).abs() < f32::EPSILON);
214        assert!((values[2] - 3.0).abs() < f32::EPSILON);
215        assert!((values[3] - 3.5).abs() < f32::EPSILON);
216    }
217
218    #[test]
219    fn test_delta_encode_decode() {
220        let changes = vec![
221            DeltaChange {
222                index: 10,
223                delta: 0.123,
224            },
225            DeltaChange {
226                index: 20,
227                delta: -0.456,
228            },
229        ];
230        let delta = VectorDelta::new(999, changes);
231
232        let encoded = delta.encode().unwrap();
233        assert_eq!(encoded.len(), 4 + 2 * 6); // header + 2 changes
234
235        let decoded = VectorDelta::decode(&encoded).unwrap();
236        assert_eq!(decoded.base_id, delta.base_id);
237        assert_eq!(decoded.changes.len(), delta.changes.len());
238        assert_eq!(decoded.changes[0].index, delta.changes[0].index);
239        assert!((decoded.changes[0].delta - delta.changes[0].delta).abs() < 0.0001);
240    }
241
242    #[test]
243    fn test_delta_roundtrip() {
244        let old = Vector::from_f32(vec![0.1; 1536]);
245        let mut new_data = vec![0.1; 1536];
246        // Change 1% of values
247        for i in 0..15 {
248            new_data[i * 100] += 0.01;
249        }
250        let new = Vector::from_f32(new_data);
251
252        let delta = VectorDelta::from_vectors(&old, &new, 1).unwrap();
253        let encoded = delta.encode().unwrap();
254        let decoded = VectorDelta::decode(&encoded).unwrap();
255        let reconstructed = decoded.apply(&old).unwrap();
256
257        assert_eq!(new, reconstructed);
258    }
259
260    #[test]
261    fn test_update_strategy() {
262        let small_delta = VectorDelta::new(
263            1,
264            vec![DeltaChange {
265                index: 0,
266                delta: 0.1,
267            }],
268        );
269        let large_delta = VectorDelta::new(
270            1,
271            (0..500)
272                .map(|i| DeltaChange {
273                    index: i,
274                    delta: 0.1,
275                })
276                .collect(),
277        );
278
279        let strategy = UpdateStrategy::default();
280
281        assert!(strategy.should_use_delta(&small_delta, 1536));
282        assert!(!strategy.should_use_delta(&large_delta, 1536));
283    }
284}