1use std::fmt;
18
19use serde::{Deserialize, Serialize};
20
21use crate::types::Embedding;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
25#[non_exhaustive]
26pub enum DistanceMetric {
27 L2,
33
34 InnerProduct,
43
44 Cosine,
50
51 L1,
57}
58
59impl DistanceMetric {
60 pub fn operator(&self) -> &'static str {
62 match self {
63 Self::L2 => "<->",
64 Self::InnerProduct => "<#>",
65 Self::Cosine => "<=>",
66 Self::L1 => "<+>",
67 }
68 }
69
70 pub fn ops_class(&self) -> &'static str {
72 match self {
73 Self::L2 => "vector_l2_ops",
74 Self::InnerProduct => "vector_ip_ops",
75 Self::Cosine => "vector_cosine_ops",
76 Self::L1 => "vector_l1_ops",
77 }
78 }
79
80 pub fn name(&self) -> &'static str {
82 match self {
83 Self::L2 => "euclidean",
84 Self::InnerProduct => "inner_product",
85 Self::Cosine => "cosine",
86 Self::L1 => "manhattan",
87 }
88 }
89
90 pub fn prefers_normalized(&self) -> bool {
92 matches!(self, Self::InnerProduct | Self::Cosine)
93 }
94}
95
96impl fmt::Display for DistanceMetric {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 write!(f, "{}", self.name())
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
104#[non_exhaustive]
105pub enum BinaryDistanceMetric {
106 Hamming,
110
111 Jaccard,
115}
116
117impl BinaryDistanceMetric {
118 pub fn operator(&self) -> &'static str {
120 match self {
121 Self::Hamming => "<~>",
122 Self::Jaccard => "<%>",
123 }
124 }
125
126 pub fn ops_class(&self) -> &'static str {
128 match self {
129 Self::Hamming => "bit_hamming_ops",
130 Self::Jaccard => "bit_jaccard_ops",
131 }
132 }
133
134 pub fn name(&self) -> &'static str {
136 match self {
137 Self::Hamming => "hamming",
138 Self::Jaccard => "jaccard",
139 }
140 }
141}
142
143impl fmt::Display for BinaryDistanceMetric {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 write!(f, "{}", self.name())
146 }
147}
148
149pub fn distance_sql(column: &str, query_vector: &Embedding, metric: DistanceMetric) -> String {
161 format!(
162 "{} {} {}",
163 column,
164 metric.operator(),
165 query_vector.to_sql_literal()
166 )
167}
168
169pub fn distance_param_sql(column: &str, param: &str, metric: DistanceMetric) -> String {
183 format!("{} {} {}", column, metric.operator(), param)
184}
185
186pub fn order_by_distance(column: &str, query_vector: &Embedding, metric: DistanceMetric) -> String {
191 distance_sql(column, query_vector, metric)
192}
193
194pub fn nearest_neighbor_sql(
204 table: &str,
205 column: &str,
206 metric: DistanceMetric,
207 param_index: usize,
208 limit: usize,
209 extra_columns: &[&str],
210) -> String {
211 let distance_expr = distance_param_sql(column, &format!("${param_index}"), metric);
212
213 let select_cols = if extra_columns.is_empty() {
214 "*".to_string()
215 } else {
216 let mut cols = vec!["*".to_string()];
217 cols.extend(extra_columns.iter().map(|c| (*c).to_string()));
218 cols.join(", ")
219 };
220
221 format!(
222 "SELECT {}, {} AS distance FROM {} ORDER BY distance LIMIT {}",
223 select_cols, distance_expr, table, limit
224 )
225}
226
227pub fn radius_search_sql(
238 table: &str,
239 column: &str,
240 metric: DistanceMetric,
241 param_index: usize,
242 max_distance: f64,
243 limit: Option<usize>,
244) -> String {
245 let param = format!("${param_index}");
246 let distance_expr = distance_param_sql(column, ¶m, metric);
247
248 let limit_clause = limit.map(|l| format!(" LIMIT {l}")).unwrap_or_default();
249
250 format!(
251 "SELECT *, {} AS distance FROM {} WHERE {} < {} ORDER BY distance{}",
252 distance_expr, table, distance_expr, max_distance, limit_clause
253 )
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct SearchParams {
261 pub ivfflat_probes: Option<usize>,
265
266 pub hnsw_ef_search: Option<usize>,
270}
271
272impl SearchParams {
273 pub fn new() -> Self {
275 Self {
276 ivfflat_probes: None,
277 hnsw_ef_search: None,
278 }
279 }
280
281 pub fn probes(mut self, probes: usize) -> Self {
283 self.ivfflat_probes = Some(probes);
284 self
285 }
286
287 pub fn ef_search(mut self, ef: usize) -> Self {
289 self.hnsw_ef_search = Some(ef);
290 self
291 }
292
293 pub fn to_set_sql(&self) -> Vec<String> {
295 let mut statements = Vec::new();
296
297 if let Some(probes) = self.ivfflat_probes {
298 statements.push(format!("SET ivfflat.probes = {probes}"));
299 }
300 if let Some(ef) = self.hnsw_ef_search {
301 statements.push(format!("SET hnsw.ef_search = {ef}"));
302 }
303
304 statements
305 }
306
307 pub fn has_params(&self) -> bool {
309 self.ivfflat_probes.is_some() || self.hnsw_ef_search.is_some()
310 }
311}
312
313impl Default for SearchParams {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_distance_metric_operator() {
325 assert_eq!(DistanceMetric::L2.operator(), "<->");
326 assert_eq!(DistanceMetric::InnerProduct.operator(), "<#>");
327 assert_eq!(DistanceMetric::Cosine.operator(), "<=>");
328 assert_eq!(DistanceMetric::L1.operator(), "<+>");
329 }
330
331 #[test]
332 fn test_distance_metric_ops_class() {
333 assert_eq!(DistanceMetric::L2.ops_class(), "vector_l2_ops");
334 assert_eq!(DistanceMetric::InnerProduct.ops_class(), "vector_ip_ops");
335 assert_eq!(DistanceMetric::Cosine.ops_class(), "vector_cosine_ops");
336 assert_eq!(DistanceMetric::L1.ops_class(), "vector_l1_ops");
337 }
338
339 #[test]
340 fn test_distance_metric_prefers_normalized() {
341 assert!(!DistanceMetric::L2.prefers_normalized());
342 assert!(DistanceMetric::InnerProduct.prefers_normalized());
343 assert!(DistanceMetric::Cosine.prefers_normalized());
344 assert!(!DistanceMetric::L1.prefers_normalized());
345 }
346
347 #[test]
348 fn test_binary_distance_metric_operator() {
349 assert_eq!(BinaryDistanceMetric::Hamming.operator(), "<~>");
350 assert_eq!(BinaryDistanceMetric::Jaccard.operator(), "<%>");
351 }
352
353 #[test]
354 fn test_binary_distance_metric_ops_class() {
355 assert_eq!(BinaryDistanceMetric::Hamming.ops_class(), "bit_hamming_ops");
356 assert_eq!(BinaryDistanceMetric::Jaccard.ops_class(), "bit_jaccard_ops");
357 }
358
359 #[test]
360 fn test_distance_sql() {
361 let query = Embedding::new(vec![0.1, 0.2, 0.3]);
362 let sql = distance_sql("embedding", &query, DistanceMetric::Cosine);
363 assert!(sql.contains("<=>"));
364 assert!(sql.contains("::vector"));
365 }
366
367 #[test]
368 fn test_distance_param_sql() {
369 let sql = distance_param_sql("embedding", "$1", DistanceMetric::L2);
370 assert_eq!(sql, "embedding <-> $1");
371 }
372
373 #[test]
374 fn test_nearest_neighbor_sql() {
375 let sql =
376 nearest_neighbor_sql("documents", "embedding", DistanceMetric::Cosine, 1, 10, &[]);
377 assert!(sql.contains("SELECT *"));
378 assert!(sql.contains("<=>"));
379 assert!(sql.contains("$1"));
380 assert!(sql.contains("LIMIT 10"));
381 assert!(sql.contains("AS distance"));
382 assert!(sql.contains("ORDER BY distance"));
383 }
384
385 #[test]
386 fn test_radius_search_sql() {
387 let sql = radius_search_sql(
388 "documents",
389 "embedding",
390 DistanceMetric::L2,
391 1,
392 0.5,
393 Some(100),
394 );
395 assert!(sql.contains("<->"));
396 assert!(sql.contains("< 0.5"));
397 assert!(sql.contains("LIMIT 100"));
398 }
399
400 #[test]
401 fn test_radius_search_sql_no_limit() {
402 let sql = radius_search_sql("documents", "embedding", DistanceMetric::L2, 1, 1.0, None);
403 assert!(!sql.contains("LIMIT"));
404 }
405
406 #[test]
407 fn test_search_params_probes() {
408 let params = SearchParams::new().probes(10);
409 let sql = params.to_set_sql();
410 assert_eq!(sql.len(), 1);
411 assert_eq!(sql[0], "SET ivfflat.probes = 10");
412 }
413
414 #[test]
415 fn test_search_params_ef_search() {
416 let params = SearchParams::new().ef_search(200);
417 let sql = params.to_set_sql();
418 assert_eq!(sql.len(), 1);
419 assert_eq!(sql[0], "SET hnsw.ef_search = 200");
420 }
421
422 #[test]
423 fn test_search_params_both() {
424 let params = SearchParams::new().probes(10).ef_search(200);
425 let sql = params.to_set_sql();
426 assert_eq!(sql.len(), 2);
427 assert!(params.has_params());
428 }
429
430 #[test]
431 fn test_search_params_empty() {
432 let params = SearchParams::new();
433 assert!(!params.has_params());
434 assert!(params.to_set_sql().is_empty());
435 }
436
437 #[test]
438 fn test_distance_metric_display() {
439 assert_eq!(format!("{}", DistanceMetric::L2), "euclidean");
440 assert_eq!(format!("{}", DistanceMetric::Cosine), "cosine");
441 }
442}