Skip to main content

clark_hash/
config.rs

1use crate::error::{Result, SQuaJLError};
2
3/// Similarity objective used by the codec.
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum SimilarityMetric {
7    /// Approximate cosine similarity.
8    ///
9    /// The codec stores only the quantized sketch of the normalized direction.
10    /// This is usually the right choice for modern semantic embeddings.
11    #[default]
12    Cosine,
13    /// Approximate raw inner product.
14    ///
15    /// The codec stores the normalized direction sketch plus a tiny quantized norm
16    /// channel so the final score can recover scale information.
17    Dot,
18}
19
20/// Configuration for the stateless sparse-JL quantizer.
21///
22/// The defaults are chosen to be a practical starting point for 384-dimensional
23/// sentence embeddings such as `all-MiniLM-L6-v2`.
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq)]
26pub struct SQuaJLConfig {
27    /// Input embedding dimension.
28    pub input_dim: usize,
29    /// Output sketch dimension.
30    ///
31    /// Larger values usually improve recall at the cost of more memory.
32    pub sketch_dim: usize,
33    /// Number of bits per quantized coordinate.
34    ///
35    /// Supported range: `1..=8`.
36    pub bits: u8,
37    /// Number of non-zero sketch updates per input coordinate.
38    ///
39    /// Larger values reduce projection noise but cost more CPU at encode time.
40    pub hashes_per_input: u8,
41    /// Symmetric clip range for the scaled sketch.
42    ///
43    /// Coordinates are clipped to `[-clip, clip]` before scalar quantization.
44    pub clip: f32,
45    /// Global seed used to derive sparse bucket locations and signs.
46    pub seed: u64,
47    /// Similarity objective for scoring.
48    pub metric: SimilarityMetric,
49    /// Lower bound for `log2(norm)` when the norm channel is enabled.
50    pub norm_log2_min: f32,
51    /// Upper bound for `log2(norm)` when the norm channel is enabled.
52    pub norm_log2_max: f32,
53}
54
55impl Default for SQuaJLConfig {
56    fn default() -> Self {
57        Self {
58            input_dim: 384,
59            sketch_dim: 96,
60            bits: 4,
61            hashes_per_input: 4,
62            clip: 3.0,
63            seed: 0x5EED_CAFE_1234_5678,
64            metric: SimilarityMetric::Cosine,
65            norm_log2_min: -16.0,
66            norm_log2_max: 16.0,
67        }
68    }
69}
70
71impl SQuaJLConfig {
72    /// Creates a configuration with the provided input dimension and sensible defaults
73    /// for the remaining fields.
74    pub fn new(input_dim: usize) -> Self {
75        Self {
76            input_dim,
77            ..Self::default()
78        }
79    }
80
81    /// Sets the output sketch dimension.
82    pub fn with_sketch_dim(mut self, sketch_dim: usize) -> Self {
83        self.sketch_dim = sketch_dim;
84        self
85    }
86
87    /// Sets the number of bits per quantized coordinate.
88    pub fn with_bits(mut self, bits: u8) -> Self {
89        self.bits = bits;
90        self
91    }
92
93    /// Sets the number of sparse hash updates performed for each input dimension.
94    pub fn with_hashes_per_input(mut self, hashes_per_input: u8) -> Self {
95        self.hashes_per_input = hashes_per_input;
96        self
97    }
98
99    /// Sets the symmetric clip range used by the scalar quantizer.
100    pub fn with_clip(mut self, clip: f32) -> Self {
101        self.clip = clip;
102        self
103    }
104
105    /// Sets the deterministic seed used by the sparse signed projection.
106    pub fn with_seed(mut self, seed: u64) -> Self {
107        self.seed = seed;
108        self
109    }
110
111    /// Sets the similarity objective.
112    pub fn with_metric(mut self, metric: SimilarityMetric) -> Self {
113        self.metric = metric;
114        self
115    }
116
117    /// Sets the `log2(norm)` range for the optional norm channel.
118    pub fn with_norm_log2_range(mut self, min: f32, max: f32) -> Self {
119        self.norm_log2_min = min;
120        self.norm_log2_max = max;
121        self
122    }
123
124    /// Validates the configuration.
125    pub fn validate(&self) -> Result<()> {
126        if self.input_dim == 0 {
127            return Err(SQuaJLError::InvalidConfig(
128                "input_dim must be greater than zero".to_owned(),
129            ));
130        }
131        if self.sketch_dim == 0 {
132            return Err(SQuaJLError::InvalidConfig(
133                "sketch_dim must be greater than zero".to_owned(),
134            ));
135        }
136        if !(1..=8).contains(&self.bits) {
137            return Err(SQuaJLError::InvalidConfig(
138                "bits must be between 1 and 8".to_owned(),
139            ));
140        }
141        if self.hashes_per_input == 0 {
142            return Err(SQuaJLError::InvalidConfig(
143                "hashes_per_input must be greater than zero".to_owned(),
144            ));
145        }
146        if !self.clip.is_finite() || self.clip <= 0.0 {
147            return Err(SQuaJLError::InvalidConfig(
148                "clip must be finite and greater than zero".to_owned(),
149            ));
150        }
151        if !self.norm_log2_min.is_finite()
152            || !self.norm_log2_max.is_finite()
153            || self.norm_log2_min >= self.norm_log2_max
154        {
155            return Err(SQuaJLError::InvalidConfig(
156                "norm_log2_min must be smaller than norm_log2_max".to_owned(),
157            ));
158        }
159        Ok(())
160    }
161}