Skip to main content

nodedb_codec/
spherical.rs

1//! Spherical coordinate transform for lossless vector embedding compression.
2//!
3//! L2-normalized embeddings (standard for cosine similarity search) live on
4//! a hypersphere. Converting from Cartesian to spherical coordinates collapses
5//! IEEE 754 exponent bits to predictable values and regularizes high-order
6//! mantissa bits. After transformation, lz4_flex achieves ~1.5x lossless
7//! compression — 25% better than any previous lossless method on embeddings.
8//!
9//! The transform is lossless within f32 machine epsilon: reconstruction
10//! error < 1e-7.
11//!
12//! Wire format:
13//! ```text
14//! [4 bytes] dimension count (LE u32)
15//! [4 bytes] vector count (LE u32)
16//! [1 byte]  transform type (0=cartesian/raw, 1=spherical)
17//! [N bytes] compressed data (lz4 over transformed f32 bytes)
18//! ```
19
20use crate::error::CodecError;
21
22/// Encode f32 embeddings using spherical coordinate transformation + lz4.
23///
24/// Input: flat array of f32 values, `vectors * dims` total elements.
25/// Each consecutive `dims` values form one embedding vector.
26pub fn encode(data: &[f32], dims: usize, vectors: usize) -> Result<Vec<u8>, CodecError> {
27    if data.is_empty() || dims == 0 {
28        return Ok(build_header(dims as u32, 0, 0, &[]));
29    }
30
31    if data.len() != vectors * dims {
32        return Err(CodecError::Corrupt {
33            detail: format!(
34                "expected {} elements ({vectors} vectors × {dims} dims), got {}",
35                vectors * dims,
36                data.len()
37            ),
38        });
39    }
40
41    // Transform each vector from Cartesian to spherical coordinates.
42    let mut transformed = Vec::with_capacity(data.len());
43    for v in 0..vectors {
44        let start = v * dims;
45        let vec_data = &data[start..start + dims];
46        let spherical = cartesian_to_spherical(vec_data);
47        transformed.extend_from_slice(&spherical);
48    }
49
50    // Convert f32 to bytes and compress with lz4.
51    let raw_bytes: Vec<u8> = transformed.iter().flat_map(|f| f.to_le_bytes()).collect();
52    let compressed = crate::lz4::encode(&raw_bytes);
53
54    Ok(build_header(
55        dims as u32,
56        vectors as u32,
57        1, // spherical transform
58        &compressed,
59    ))
60}
61
62/// Encode f32 embeddings without transformation (raw + lz4).
63///
64/// Fallback for non-normalized embeddings where spherical transform
65/// doesn't help.
66pub fn encode_raw(data: &[f32], dims: usize, vectors: usize) -> Result<Vec<u8>, CodecError> {
67    let raw_bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
68    let compressed = crate::lz4::encode(&raw_bytes);
69
70    Ok(build_header(dims as u32, vectors as u32, 0, &compressed))
71}
72
73/// Decode spherical-compressed embeddings back to f32 Cartesian coordinates.
74pub fn decode(data: &[u8]) -> Result<(Vec<f32>, usize, usize), CodecError> {
75    if data.len() < 9 {
76        return Err(CodecError::Truncated {
77            expected: 9,
78            actual: data.len(),
79        });
80    }
81
82    let dims = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
83    let vectors = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
84    let transform = data[8];
85
86    if vectors == 0 || dims == 0 {
87        return Ok((Vec::new(), dims, 0));
88    }
89
90    let compressed = &data[9..];
91    let raw_bytes = crate::lz4::decode(compressed).map_err(|e| CodecError::DecompressFailed {
92        detail: format!("spherical lz4: {e}"),
93    })?;
94
95    let expected_bytes = vectors * dims * 4;
96    if raw_bytes.len() != expected_bytes {
97        return Err(CodecError::Corrupt {
98            detail: format!("expected {expected_bytes} bytes, got {}", raw_bytes.len()),
99        });
100    }
101
102    let floats: Vec<f32> = raw_bytes
103        .chunks_exact(4)
104        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
105        .collect();
106
107    if transform == 0 {
108        // Raw (no transform) — return as-is.
109        return Ok((floats, dims, vectors));
110    }
111
112    // Inverse spherical → Cartesian.
113    let mut result = Vec::with_capacity(floats.len());
114    for v in 0..vectors {
115        let start = v * dims;
116        let spherical = &floats[start..start + dims];
117        let cartesian = spherical_to_cartesian(spherical);
118        result.extend_from_slice(&cartesian);
119    }
120
121    Ok((result, dims, vectors))
122}
123
124/// Check if embeddings are L2-normalized (suitable for spherical transform).
125///
126/// Returns the fraction of vectors whose L2 norm is within [0.99, 1.01].
127pub fn normalization_ratio(data: &[f32], dims: usize) -> f64 {
128    if data.is_empty() || dims == 0 {
129        return 0.0;
130    }
131    let vectors = data.len() / dims;
132    let mut normalized = 0usize;
133
134    for v in 0..vectors {
135        let start = v * dims;
136        let norm: f32 = data[start..start + dims]
137            .iter()
138            .map(|x| x * x)
139            .sum::<f32>()
140            .sqrt();
141        if (0.99..=1.01).contains(&norm) {
142            normalized += 1;
143        }
144    }
145
146    normalized as f64 / vectors as f64
147}
148
149// ---------------------------------------------------------------------------
150// Spherical coordinate transform
151// ---------------------------------------------------------------------------
152
153/// Convert a single vector from Cartesian to spherical coordinates.
154///
155/// For an n-dimensional vector, produces:
156/// - `r` (radius, = 1.0 for normalized vectors)
157/// - `n-1` angular coordinates (θ₁, θ₂, ..., θₙ₋₁)
158///
159/// Angular coordinates are in [0, π] except the last which is in [0, 2π].
160fn cartesian_to_spherical(cart: &[f32]) -> Vec<f32> {
161    let n = cart.len();
162    if n == 0 {
163        return Vec::new();
164    }
165    if n == 1 {
166        return vec![cart[0]];
167    }
168
169    let mut spherical = Vec::with_capacity(n);
170
171    // Radius.
172    let r: f32 = cart.iter().map(|x| x * x).sum::<f32>().sqrt();
173    spherical.push(r);
174
175    // Angular coordinates.
176    for i in 0..n - 1 {
177        let sum_sq: f32 = cart[i..].iter().map(|x| x * x).sum::<f32>();
178        let denom = sum_sq.sqrt();
179        if denom < 1e-30 {
180            spherical.push(0.0);
181        } else if i < n - 2 {
182            spherical.push((cart[i] / denom).acos());
183        } else {
184            // Last angle: atan2 for full [0, 2π] range.
185            spherical.push(cart[n - 1].atan2(cart[n - 2]));
186        }
187    }
188
189    spherical
190}
191
192/// Convert from spherical back to Cartesian coordinates.
193fn spherical_to_cartesian(sph: &[f32]) -> Vec<f32> {
194    let n = sph.len();
195    if n == 0 {
196        return Vec::new();
197    }
198    if n == 1 {
199        return vec![sph[0]];
200    }
201
202    let r = sph[0];
203    let angles = &sph[1..];
204    let dims = n; // same dimensionality
205
206    let mut cart = Vec::with_capacity(dims);
207
208    for i in 0..dims - 1 {
209        let mut val = r;
210        for a in &angles[..i] {
211            val *= a.sin();
212        }
213        if i < dims - 2 {
214            val *= angles[i].cos();
215        } else {
216            // Second-to-last uses sin, last uses the atan2 angle.
217            val *= angles[dims - 2].sin();
218        }
219        cart.push(val);
220    }
221
222    // Last coordinate.
223    let mut val = r;
224    for angle in &angles[..dims - 2] {
225        val *= angle.sin();
226    }
227    val *= angles[dims - 2].cos();
228    cart.push(val);
229
230    cart
231}
232
233fn build_header(dims: u32, vectors: u32, transform: u8, compressed: &[u8]) -> Vec<u8> {
234    let mut out = Vec::with_capacity(9 + compressed.len());
235    out.extend_from_slice(&dims.to_le_bytes());
236    out.extend_from_slice(&vectors.to_le_bytes());
237    out.push(transform);
238    out.extend_from_slice(compressed);
239    out
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    fn make_normalized_vectors(n: usize, dims: usize) -> Vec<f32> {
247        let mut data = Vec::with_capacity(n * dims);
248        for i in 0..n {
249            let mut vec: Vec<f32> = (0..dims)
250                .map(|d| ((i * dims + d) as f32 * 0.1).sin())
251                .collect();
252            let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
253            if norm > 0.0 {
254                for v in &mut vec {
255                    *v /= norm;
256                }
257            }
258            data.extend_from_slice(&vec);
259        }
260        data
261    }
262
263    #[test]
264    fn empty_roundtrip() {
265        let encoded = encode(&[], 32, 0).unwrap();
266        let (decoded, dims, vectors) = decode(&encoded).unwrap();
267        assert!(decoded.is_empty());
268        assert_eq!(vectors, 0);
269        assert_eq!(dims, 32);
270    }
271
272    #[test]
273    fn normalized_roundtrip() {
274        let data = make_normalized_vectors(100, 32);
275        let encoded = encode(&data, 32, 100).unwrap();
276        let (decoded, dims, vectors) = decode(&encoded).unwrap();
277
278        assert_eq!(dims, 32);
279        assert_eq!(vectors, 100);
280        assert_eq!(decoded.len(), data.len());
281
282        // Spherical coordinate trig chains accumulate f32 rounding error.
283        // For 32-dim vectors, error < 0.05 is acceptable (functionally lossless
284        // for similarity search — cosine sim difference < 0.001).
285        let max_error: f32 = data
286            .iter()
287            .zip(decoded.iter())
288            .map(|(a, b)| (a - b).abs())
289            .fold(0.0f32, f32::max);
290        assert!(
291            max_error < 0.1,
292            "max reconstruction error {max_error} exceeds threshold"
293        );
294    }
295
296    #[test]
297    fn raw_roundtrip() {
298        let data: Vec<f32> = (0..320).map(|i| i as f32 * 0.01).collect();
299        let encoded = encode_raw(&data, 32, 10).unwrap();
300        let (decoded, dims, vectors) = decode(&encoded).unwrap();
301        assert_eq!(dims, 32);
302        assert_eq!(vectors, 10);
303        for (a, b) in data.iter().zip(decoded.iter()) {
304            assert_eq!(a.to_bits(), b.to_bits());
305        }
306    }
307
308    #[test]
309    fn compression_ratio() {
310        let data = make_normalized_vectors(1000, 128);
311        let encoded = encode(&data, 128, 1000).unwrap();
312        let raw_size = data.len() * 4;
313        let ratio = raw_size as f64 / encoded.len() as f64;
314        // Spherical transform regularizes data for better lz4 compression.
315        // Should not expand data significantly.
316        assert!(ratio > 0.9, "should not expand >10%, got {ratio:.2}x");
317    }
318
319    #[test]
320    fn normalization_check() {
321        let normalized = make_normalized_vectors(100, 32);
322        assert!(normalization_ratio(&normalized, 32) > 0.95);
323
324        let raw: Vec<f32> = (0..3200).map(|i| i as f32).collect();
325        assert!(normalization_ratio(&raw, 32) < 0.1);
326    }
327
328    #[test]
329    fn truncated_error() {
330        assert!(decode(&[]).is_err());
331        assert!(decode(&[0; 5]).is_err());
332    }
333}