use crate::error::{Result, SQuaJLError};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SimilarityMetric {
#[default]
Cosine,
Dot,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct SQuaJLConfig {
pub input_dim: usize,
pub sketch_dim: usize,
pub bits: u8,
pub hashes_per_input: u8,
pub clip: f32,
pub seed: u64,
pub metric: SimilarityMetric,
pub norm_log2_min: f32,
pub norm_log2_max: f32,
}
impl Default for SQuaJLConfig {
fn default() -> Self {
Self {
input_dim: 384,
sketch_dim: 96,
bits: 4,
hashes_per_input: 4,
clip: 3.0,
seed: 0x5EED_CAFE_1234_5678,
metric: SimilarityMetric::Cosine,
norm_log2_min: -16.0,
norm_log2_max: 16.0,
}
}
}
impl SQuaJLConfig {
pub fn new(input_dim: usize) -> Self {
Self {
input_dim,
..Self::default()
}
}
pub fn with_sketch_dim(mut self, sketch_dim: usize) -> Self {
self.sketch_dim = sketch_dim;
self
}
pub fn with_bits(mut self, bits: u8) -> Self {
self.bits = bits;
self
}
pub fn with_hashes_per_input(mut self, hashes_per_input: u8) -> Self {
self.hashes_per_input = hashes_per_input;
self
}
pub fn with_clip(mut self, clip: f32) -> Self {
self.clip = clip;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_metric(mut self, metric: SimilarityMetric) -> Self {
self.metric = metric;
self
}
pub fn with_norm_log2_range(mut self, min: f32, max: f32) -> Self {
self.norm_log2_min = min;
self.norm_log2_max = max;
self
}
pub fn validate(&self) -> Result<()> {
if self.input_dim == 0 {
return Err(SQuaJLError::InvalidConfig(
"input_dim must be greater than zero".to_owned(),
));
}
if self.sketch_dim == 0 {
return Err(SQuaJLError::InvalidConfig(
"sketch_dim must be greater than zero".to_owned(),
));
}
if !(1..=8).contains(&self.bits) {
return Err(SQuaJLError::InvalidConfig(
"bits must be between 1 and 8".to_owned(),
));
}
if self.hashes_per_input == 0 {
return Err(SQuaJLError::InvalidConfig(
"hashes_per_input must be greater than zero".to_owned(),
));
}
if !self.clip.is_finite() || self.clip <= 0.0 {
return Err(SQuaJLError::InvalidConfig(
"clip must be finite and greater than zero".to_owned(),
));
}
if !self.norm_log2_min.is_finite()
|| !self.norm_log2_max.is_finite()
|| self.norm_log2_min >= self.norm_log2_max
{
return Err(SQuaJLError::InvalidConfig(
"norm_log2_min must be smaller than norm_log2_max".to_owned(),
));
}
Ok(())
}
}