use crate::model::{Match, Query};
pub const DEFAULT_LIMIT: usize = 50;
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 {
0.0
} else {
dot / (na * nb)
}
}
fn score_sort_limit(
query_emb: &[f32],
mut matches: Vec<Match>,
match_embs: &[Vec<f32>],
limit: usize,
) -> Vec<Match> {
for (m, emb) in matches.iter_mut().zip(match_embs) {
m.similarity = cosine(query_emb, emb);
}
matches.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
matches.truncate(limit);
matches.retain(|m| m.similarity >= 0.0);
matches
}
pub struct Ranker {
model: fastembed::TextEmbedding,
}
fn model_cache_dir() -> Option<std::path::PathBuf> {
dirs::cache_dir().map(|d| d.join("patent").join("fastembed"))
}
pub fn model_is_cached() -> bool {
match model_cache_dir() {
Some(dir) => dir
.read_dir()
.map(|mut entries| entries.next().is_some())
.unwrap_or(false),
None => true,
}
}
impl Ranker {
pub fn new() -> crate::Result<Self> {
let mut opts = fastembed::InitOptions::new(fastembed::EmbeddingModel::AllMiniLML6V2)
.with_show_download_progress(true);
if let Some(dir) = model_cache_dir() {
let _ = std::fs::create_dir_all(&dir);
opts = opts.with_cache_dir(dir);
}
let model = fastembed::TextEmbedding::try_new(opts)
.map_err(|e| crate::Error::Embedding(e.to_string()))?;
Ok(Self { model })
}
pub fn embed_query(&mut self, idea: &str) -> crate::Result<Vec<f32>> {
let embs = self
.model
.embed(vec![idea], None)
.map_err(|e| crate::Error::Embedding(e.to_string()))?;
Ok(embs.into_iter().next().unwrap_or_default())
}
pub fn rank_with(
&mut self,
query_emb: &[f32],
matches: Vec<Match>,
limit: usize,
) -> crate::Result<Vec<Match>> {
if matches.is_empty() {
return Ok(vec![]);
}
let texts: Vec<String> = matches
.iter()
.map(|m| {
if m.description.is_empty() {
m.name.clone()
} else {
format!("{}: {}", m.name, m.description)
}
})
.collect();
let descriptions: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let match_embs = self
.model
.embed(descriptions, None)
.map_err(|e| crate::Error::Embedding(e.to_string()))?;
Ok(score_sort_limit(query_emb, matches, &match_embs, limit))
}
}
pub fn rank(query: &Query, matches: Vec<Match>, limit: usize) -> crate::Result<Vec<Match>> {
if matches.is_empty() {
return Ok(vec![]);
}
let mut ranker = Ranker::new()?;
let query_emb = ranker.embed_query(&query.idea)?;
ranker.rank_with(&query_emb, matches, limit)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::Source;
#[test]
fn cosine_identical_is_one() {
let v = [1.0, 2.0, 3.0];
assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_zero() {
assert!((cosine(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
}
#[test]
fn cosine_zero_vector_is_zero() {
assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
}
fn test_match(name: &str, desc: &str) -> Match {
Match {
name: name.to_string(),
source: Source::CratesIo,
url: format!("https://example.com/{name}"),
description: desc.to_string(),
popularity: None,
similarity: 0.0,
}
}
#[test]
fn ssl_empty_input() {
let result = score_sort_limit(&[1.0, 0.0], vec![], &[], 10);
assert!(result.is_empty());
}
#[test]
fn ssl_fills_similarity() {
let q = vec![1.0, 0.0, 0.0];
let matches = vec![test_match("a", "something")];
let embs = vec![vec![0.8, 0.1, 0.0]];
let result = score_sort_limit(&q, matches, &embs, 10);
assert!(result[0].similarity > 0.0);
}
#[test]
fn ssl_sorts_descending() {
let q = vec![1.0, 0.0];
let matches = vec![test_match("low", ""), test_match("high", "")];
let embs = vec![
vec![0.1, 0.9], vec![0.9, 0.1], ];
let result = score_sort_limit(&q, matches, &embs, 10);
assert_eq!(result[0].name, "high");
assert_eq!(result[1].name, "low");
assert!(result[0].similarity > result[1].similarity);
}
#[test]
fn ssl_truncates_to_limit() {
let q = vec![1.0, 0.0];
let matches = vec![
test_match("a", ""),
test_match("b", ""),
test_match("c", ""),
];
let embs = vec![vec![1.0, 0.0], vec![0.5, 0.5], vec![0.0, 1.0]];
let result = score_sort_limit(&q, matches, &embs, 2);
assert_eq!(result.len(), 2);
}
#[test]
fn ssl_fewer_than_limit_returns_all() {
let q = vec![1.0, 0.0];
let matches = vec![test_match("only", "")];
let embs = vec![vec![1.0, 0.0]];
let result = score_sort_limit(&q, matches, &embs, 10);
assert_eq!(result.len(), 1);
}
#[test]
fn ssl_preserves_match_fields() {
let q = vec![1.0, 0.0];
let mut m = test_match("foo", "bar");
m.popularity = Some(42);
let embs = vec![vec![0.9, 0.1]];
let result = score_sort_limit(&q, vec![m], &embs, 10);
assert_eq!(result[0].name, "foo");
assert_eq!(result[0].description, "bar");
assert_eq!(result[0].popularity, Some(42));
}
fn test_query() -> Query {
Query {
idea: "a fast async runtime for Rust".to_string(),
keywords: vec!["async".to_string(), "runtime".to_string()],
}
}
#[test]
fn rank_empty_matches_returns_empty() {
let result = rank(&test_query(), vec![], 10).unwrap();
assert!(result.is_empty());
}
#[test]
fn rank_fills_positive_similarity_for_related_content() {
let matches = vec![test_match(
"tokio",
"An event-driven async runtime for Rust",
)];
let result = rank(&test_query(), matches, 10).unwrap();
assert!(
result[0].similarity > 0.0,
"related content must have positive similarity"
);
}
#[test]
fn rank_orders_relevant_above_irrelevant() {
let matches = vec![
test_match("recipes", "A collection of baking recipes and kitchen tips"),
test_match(
"tokio",
"An event-driven non-blocking I/O platform for async Rust",
),
];
let result = rank(&test_query(), matches, 10).unwrap();
assert_eq!(result[0].name, "tokio");
}
#[test]
fn rank_respects_limit() {
let matches = vec![
test_match("a", "async runtime alpha"),
test_match("b", "async runtime beta"),
test_match("c", "async runtime gamma"),
test_match("d", "async runtime delta"),
];
let result = rank(&test_query(), matches, 2).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn rank_returns_sorted_descending() {
let matches = vec![
test_match("recipes", "Baking sourdough bread at home"),
test_match("smol", "A small async runtime"),
test_match("tokio", "An async runtime for Rust applications"),
];
let result = rank(&test_query(), matches, 10).unwrap();
for pair in result.windows(2) {
assert!(
pair[0].similarity >= pair[1].similarity,
"{} ({}) should be >= {} ({})",
pair[0].name,
pair[0].similarity,
pair[1].name,
pair[1].similarity,
);
}
}
}