pub mod chunker;
pub mod cleaner;
pub mod string_matcher;
#[cfg(feature = "sbert")]
pub mod semantic_matcher;
#[cfg(feature = "sbert-onnx")]
pub mod onnx_embedder;
#[cfg(feature = "classifier")]
pub mod classifier;
#[cfg(any(feature = "llm", feature = "burn-llm"))]
pub mod llm_evaluator;
#[cfg(feature = "burn-llm")]
pub mod burn_evaluator;
#[cfg(feature = "burn-llm")]
pub(crate) mod burn_model;
#[cfg(feature = "phash")]
pub mod phash_matcher;
#[cfg(feature = "sbert")]
#[derive(Clone, Copy)]
pub(crate) enum EmbeddingApi {
OpenAi,
Ollama,
}
#[cfg(feature = "sbert")]
pub(crate) struct HttpEmbedder {
endpoint: String,
model: String,
api: EmbeddingApi,
client: reqwest::blocking::Client,
cache: std::sync::Mutex<std::collections::HashMap<String, Vec<f32>>>,
}
#[cfg(feature = "sbert")]
impl HttpEmbedder {
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
pub fn openai(endpoint: impl Into<String>, model: impl Into<String>) -> Self {
Self::build(endpoint, model, EmbeddingApi::OpenAi)
}
pub fn ollama(endpoint: impl Into<String>, model: impl Into<String>) -> Self {
Self::build(endpoint, model, EmbeddingApi::Ollama)
}
fn build(
endpoint: impl Into<String>,
model: impl Into<String>,
api: EmbeddingApi,
) -> Self {
let client = reqwest::blocking::Client::builder()
.connect_timeout(Self::CONNECT_TIMEOUT)
.timeout(Self::READ_TIMEOUT)
.build()
.expect("failed to build HTTP client");
Self {
endpoint: endpoint.into(),
model: model.into(),
api,
client,
cache: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
if text.is_empty() {
return Ok(vec![]);
}
{
let cache = self.cache.lock().unwrap();
if let Some(cached) = cache.get(text) {
return Ok(cached.clone());
}
}
let body = serde_json::json!({
"model": self.model,
"input": text
});
let resp = self
.client
.post(&self.endpoint)
.json(&body)
.send()
.map_err(|e| e.to_string())?;
let json: serde_json::Value = resp
.json()
.map_err(|e| e.to_string())?;
let embedding: Vec<f32> = match self.api {
EmbeddingApi::OpenAi => json
.get("data")
.and_then(|v| v.get(0))
.and_then(|v| v.get("embedding"))
.and_then(|v| v.as_array())
.ok_or("unexpected response: missing data[0].embedding")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect(),
EmbeddingApi::Ollama => json
.get("embeddings")
.and_then(|v| v.get(0))
.and_then(|v| v.as_array())
.ok_or("unexpected response: missing embeddings[0]")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect(),
};
{
let mut cache = self.cache.lock().unwrap();
cache.insert(text.to_owned(), embedding.clone());
}
Ok(embedding)
}
#[cfg(test)]
pub fn cache_len(&self) -> usize {
self.cache.lock().unwrap().len()
}
}
#[cfg(all(test, feature = "sbert"))]
mod tests {
use super::HttpEmbedder;
#[test]
fn http_embedder_has_timeouts_configured() {
let embedder = HttpEmbedder::openai(
"http://localhost:1234/v1/embeddings",
"text-embedding-3-small",
);
assert_eq!(
HttpEmbedder::CONNECT_TIMEOUT,
std::time::Duration::from_secs(10)
);
assert_eq!(
HttpEmbedder::READ_TIMEOUT,
std::time::Duration::from_secs(30)
);
assert!(embedder.embed("").unwrap().is_empty());
}
#[test]
fn http_embedder_caches_results() {
let embedder = HttpEmbedder::ollama(
"http://localhost:11434/api/embed",
"all-minilm",
);
assert_eq!(embedder.cache_len(), 0);
let _ = embedder.embed("");
assert_eq!(embedder.cache_len(), 0, "empty text should not be cached");
}
#[test]
fn shared_embedder_used_by_openai_matcher() {
use crate::engine::semantic_matcher::{OpenAiEmbeddingMatcher, SemanticMatcher};
let matcher = OpenAiEmbeddingMatcher::new(
"http://localhost:1234/v1/embeddings",
"text-embedding-3-small",
);
assert!(matcher.embed("").unwrap().is_empty());
}
#[test]
fn shared_embedder_used_by_ollama_matcher() {
use crate::engine::semantic_matcher::{OllamaEmbeddingMatcher, SemanticMatcher};
let matcher = OllamaEmbeddingMatcher::new(
"http://localhost:11434/api/embed",
"all-minilm",
);
assert!(matcher.embed("").unwrap().is_empty());
}
#[cfg(feature = "classifier")]
#[test]
fn shared_embedder_used_by_classifier() {
use crate::engine::classifier::{OpenAiEmbeddingClassifier, TextClassifier};
let classifier = OpenAiEmbeddingClassifier::new(
"http://localhost:1234/v1/embeddings",
"text-embedding-3-small",
);
assert_eq!(classifier.score("", "text").unwrap(), 0.0);
}
}