1use crate::error::{Result, SQuaJLError};
2
3#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum SimilarityMetric {
7 #[default]
12 Cosine,
13 Dot,
18}
19
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq)]
26pub struct SQuaJLConfig {
27 pub input_dim: usize,
29 pub sketch_dim: usize,
33 pub bits: u8,
37 pub hashes_per_input: u8,
41 pub clip: f32,
45 pub seed: u64,
47 pub metric: SimilarityMetric,
49 pub norm_log2_min: f32,
51 pub norm_log2_max: f32,
53}
54
55impl Default for SQuaJLConfig {
56 fn default() -> Self {
57 Self {
58 input_dim: 384,
59 sketch_dim: 96,
60 bits: 4,
61 hashes_per_input: 4,
62 clip: 3.0,
63 seed: 0x5EED_CAFE_1234_5678,
64 metric: SimilarityMetric::Cosine,
65 norm_log2_min: -16.0,
66 norm_log2_max: 16.0,
67 }
68 }
69}
70
71impl SQuaJLConfig {
72 pub fn new(input_dim: usize) -> Self {
75 Self {
76 input_dim,
77 ..Self::default()
78 }
79 }
80
81 pub fn with_sketch_dim(mut self, sketch_dim: usize) -> Self {
83 self.sketch_dim = sketch_dim;
84 self
85 }
86
87 pub fn with_bits(mut self, bits: u8) -> Self {
89 self.bits = bits;
90 self
91 }
92
93 pub fn with_hashes_per_input(mut self, hashes_per_input: u8) -> Self {
95 self.hashes_per_input = hashes_per_input;
96 self
97 }
98
99 pub fn with_clip(mut self, clip: f32) -> Self {
101 self.clip = clip;
102 self
103 }
104
105 pub fn with_seed(mut self, seed: u64) -> Self {
107 self.seed = seed;
108 self
109 }
110
111 pub fn with_metric(mut self, metric: SimilarityMetric) -> Self {
113 self.metric = metric;
114 self
115 }
116
117 pub fn with_norm_log2_range(mut self, min: f32, max: f32) -> Self {
119 self.norm_log2_min = min;
120 self.norm_log2_max = max;
121 self
122 }
123
124 pub fn validate(&self) -> Result<()> {
126 if self.input_dim == 0 {
127 return Err(SQuaJLError::InvalidConfig(
128 "input_dim must be greater than zero".to_owned(),
129 ));
130 }
131 if self.sketch_dim == 0 {
132 return Err(SQuaJLError::InvalidConfig(
133 "sketch_dim must be greater than zero".to_owned(),
134 ));
135 }
136 if !(1..=8).contains(&self.bits) {
137 return Err(SQuaJLError::InvalidConfig(
138 "bits must be between 1 and 8".to_owned(),
139 ));
140 }
141 if self.hashes_per_input == 0 {
142 return Err(SQuaJLError::InvalidConfig(
143 "hashes_per_input must be greater than zero".to_owned(),
144 ));
145 }
146 if !self.clip.is_finite() || self.clip <= 0.0 {
147 return Err(SQuaJLError::InvalidConfig(
148 "clip must be finite and greater than zero".to_owned(),
149 ));
150 }
151 if !self.norm_log2_min.is_finite()
152 || !self.norm_log2_max.is_finite()
153 || self.norm_log2_min >= self.norm_log2_max
154 {
155 return Err(SQuaJLError::InvalidConfig(
156 "norm_log2_min must be smaller than norm_log2_max".to_owned(),
157 ));
158 }
159 Ok(())
160 }
161}