Skip to main content

citadel_vector/vendored/prism/
point.rs

1/// Flat f32 storage for vectors + per-point attribute metadata.
2///
3/// Vectors are stored in a contiguous `Vec<f32>` with stride = `dim`.
4/// Attributes are stored as `k` arrays of `u32`, one per attribute dimension.
5pub struct PointStore {
6    /// Contiguous vector data: point i is at `vectors[i*dim..(i+1)*dim]`
7    pub vectors: Vec<f32>,
8    /// Number of dimensions per vector
9    pub dim: usize,
10    /// Number of points
11    pub len: usize,
12    /// Attribute values: `attrs[j][i]` = value of attribute j for point i
13    pub attrs: Vec<Vec<u32>>,
14}
15
16impl PointStore {
17    pub fn new(dim: usize, k: usize) -> Self {
18        Self {
19            vectors: Vec::new(),
20            dim,
21            len: 0,
22            attrs: vec![Vec::new(); k],
23        }
24    }
25
26    /// Build from pre-allocated vectors and attributes.
27    pub fn from_parts(vectors: Vec<f32>, dim: usize, attrs: Vec<Vec<u32>>) -> Self {
28        let len = vectors.len() / dim;
29        debug_assert_eq!(vectors.len(), len * dim);
30        for a in &attrs {
31            debug_assert_eq!(a.len(), len);
32        }
33        Self {
34            vectors,
35            dim,
36            len,
37            attrs,
38        }
39    }
40
41    /// Number of attribute dimensions.
42    pub fn k(&self) -> usize {
43        self.attrs.len()
44    }
45
46    /// Get the vector slice for point `id`.
47    #[inline]
48    pub fn vector(&self, id: u32) -> &[f32] {
49        let start = id as usize * self.dim;
50        &self.vectors[start..start + self.dim]
51    }
52
53    /// Get attribute value for point `id` on dimension `j`.
54    #[inline]
55    pub fn attr(&self, id: u32, j: usize) -> u32 {
56        self.attrs[j][id as usize]
57    }
58
59    /// Append a single point. Returns its id.
60    pub fn push(&mut self, vector: &[f32], attr_values: &[u32]) -> u32 {
61        debug_assert_eq!(vector.len(), self.dim);
62        debug_assert_eq!(attr_values.len(), self.attrs.len());
63        let id = self.len as u32;
64        self.vectors.extend_from_slice(vector);
65        for (j, &val) in attr_values.iter().enumerate() {
66            self.attrs[j].push(val);
67        }
68        self.len += 1;
69        id
70    }
71
72    /// Number of distinct values for attribute dimension `j`.
73    pub fn cardinality(&self, j: usize) -> usize {
74        let mut seen = std::collections::HashSet::new();
75        for &v in &self.attrs[j] {
76            seen.insert(v);
77        }
78        seen.len()
79    }
80}
81
82impl Drop for PointStore {
83    fn drop(&mut self) {
84        // These vectors may be DECRYPTED plaintext (citadel-mem's sealed ANN cache);
85        // zero them so they never outlive the region key after crypto-erasure.
86        use zeroize::Zeroize;
87        self.vectors.zeroize();
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[test]
96    fn test_point_store_basic() {
97        let mut store = PointStore::new(3, 2);
98        let id0 = store.push(&[1.0, 2.0, 3.0], &[0, 1]);
99        let id1 = store.push(&[4.0, 5.0, 6.0], &[1, 0]);
100        assert_eq!(id0, 0);
101        assert_eq!(id1, 1);
102        assert_eq!(store.len, 2);
103        assert_eq!(store.vector(0), &[1.0, 2.0, 3.0]);
104        assert_eq!(store.vector(1), &[4.0, 5.0, 6.0]);
105        assert_eq!(store.attr(0, 0), 0);
106        assert_eq!(store.attr(0, 1), 1);
107        assert_eq!(store.attr(1, 0), 1);
108        assert_eq!(store.attr(1, 1), 0);
109    }
110
111    #[test]
112    fn test_from_parts() {
113        let vectors = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
114        let attrs = vec![vec![0, 1], vec![1, 0]];
115        let store = PointStore::from_parts(vectors, 3, attrs);
116        assert_eq!(store.len, 2);
117        assert_eq!(store.k(), 2);
118        assert_eq!(store.cardinality(0), 2);
119    }
120}