omni_search 0.2.0

A unified Rust SDK for multimodal embedding and similarity search.
Documentation
use crate::embedding::Embedding;
use crate::error::Error;

#[derive(Clone, Debug, PartialEq)]
pub struct Scored<T> {
    pub item: T,
    pub score: f32,
}

pub fn cosine_similarity(lhs: &[f32], rhs: &[f32]) -> Result<f32, Error> {
    if lhs.len() != rhs.len() {
        return Err(Error::DimensionMismatch {
            expected: lhs.len(),
            actual: rhs.len(),
        });
    }
    if lhs.is_empty() {
        return Err(Error::invalid_config("cannot score empty vectors"));
    }

    let lhs_norm = lhs.iter().map(|value| value * value).sum::<f32>().sqrt();
    let rhs_norm = rhs.iter().map(|value| value * value).sum::<f32>().sqrt();
    if lhs_norm <= f32::MIN_POSITIVE || rhs_norm <= f32::MIN_POSITIVE {
        return Err(Error::invalid_config(
            "cannot score vectors with zero magnitude",
        ));
    }

    let dot = lhs
        .iter()
        .zip(rhs)
        .map(|(left, right)| left * right)
        .sum::<f32>();
    Ok(dot / (lhs_norm * rhs_norm))
}

pub fn score_embeddings(lhs: &Embedding, rhs: &Embedding) -> Result<f32, Error> {
    if lhs.model_family != rhs.model_family || lhs.model_id != rhs.model_id {
        return Err(Error::EmbeddingModelMismatch {
            left_model_id: lhs.model_id.clone(),
            right_model_id: rhs.model_id.clone(),
        });
    }
    cosine_similarity(lhs.as_slice(), rhs.as_slice())
}

pub fn top_k<T, V>(
    query: &[f32],
    candidates: impl IntoIterator<Item = (T, V)>,
    k: usize,
) -> Result<Vec<Scored<T>>, Error>
where
    V: AsRef<[f32]>,
{
    if k == 0 {
        return Ok(Vec::new());
    }

    let mut scored = candidates
        .into_iter()
        .map(|(item, values)| {
            cosine_similarity(query, values.as_ref()).map(|score| Scored { item, score })
        })
        .collect::<Result<Vec<_>, _>>()?;
    scored.sort_by(|left, right| right.score.total_cmp(&left.score));
    scored.truncate(scored.len().min(k));
    Ok(scored)
}

#[cfg(test)]
mod tests {
    use crate::bundle::ModelInfo;
    use crate::config::{ModelFamily, ModelSourceKind};
    use crate::embedding::Embedding;
    use crate::score::{cosine_similarity, score_embeddings, top_k};

    #[test]
    fn scores_cosine_similarity() {
        let score = cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]).unwrap();
        assert_eq!(score, 1.0);
    }

    #[test]
    fn checks_embedding_model_identity() {
        let info = ModelInfo {
            model_family: ModelFamily::FgClip,
            source_kind: ModelSourceKind::LocalBundleDir,
            model_id: "fgclip".to_owned(),
            model_revision: "1".to_owned(),
            embedding_dim: 2,
            context_length: 64,
            normalize_output: true,
        };
        let left = Embedding::from_vec(&info, vec![1.0, 0.0]).unwrap();
        let mut other_info = info.clone();
        other_info.model_id = "other".to_owned();
        let right = Embedding::from_vec(&other_info, vec![1.0, 0.0]).unwrap();
        assert!(score_embeddings(&left, &right).is_err());
    }

    #[test]
    fn sorts_top_k_descending() {
        let results = top_k(&[1.0, 0.0], [("a", [1.0, 0.0]), ("b", [0.0, 1.0])], 1).unwrap();
        assert_eq!(results.len(), 1);
        assert_eq!(results[0].item, "a");
    }
}