use crate::{Error, Result};
pub trait Embedder: Send + Sync {
fn embed(&self, text: &str) -> Result<Option<Vec<f32>>>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoopEmbedder;
impl Embedder for NoopEmbedder {
fn embed(&self, _text: &str) -> Result<Option<Vec<f32>>> {
Ok(None)
}
}
#[derive(Debug, Clone)]
pub struct HttpEmbedder {
endpoint: String,
bearer_token: Option<String>,
timeout_ms: u64,
}
impl HttpEmbedder {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
bearer_token: None,
timeout_ms: 30_000,
}
}
pub fn with_bearer(mut self, token: impl Into<String>) -> Self {
self.bearer_token = Some(token.into());
self
}
pub fn with_timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
}
impl Embedder for HttpEmbedder {
fn embed(&self, text: &str) -> Result<Option<Vec<f32>>> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_millis(self.timeout_ms))
.build()
.map_err(|e| Error::query_execution(format!("embedder client: {e}")))?;
let body = serde_json::json!({ "input": text });
let mut req = client.post(&self.endpoint).json(&body);
if let Some(tok) = &self.bearer_token {
req = req.bearer_auth(tok);
}
let resp = req
.send()
.map_err(|e| Error::query_execution(format!("embedder request: {e}")))?;
if !resp.status().is_success() {
return Err(Error::query_execution(format!(
"embedder returned HTTP {}",
resp.status()
)));
}
let parsed: EmbeddingResponse = resp
.json()
.map_err(|e| Error::query_execution(format!("embedder response: {e}")))?;
Ok(Some(parsed.embedding))
}
}
#[derive(serde::Deserialize)]
struct EmbeddingResponse {
embedding: Vec<f32>,
}
#[cfg(feature = "code-embed")]
pub use fastembed_impl::FastEmbedder;
#[cfg(feature = "code-embed")]
mod fastembed_impl {
use super::{Embedder, Error, Result};
use std::sync::Mutex;
pub struct FastEmbedder {
inner: Mutex<fastembed::TextEmbedding>,
}
impl FastEmbedder {
pub fn try_default() -> Result<Self> {
Self::with_model(fastembed::EmbeddingModel::BGESmallENV15)
}
pub fn with_model(model: fastembed::EmbeddingModel) -> Result<Self> {
let opts = fastembed::InitOptions::new(model);
let inner = fastembed::TextEmbedding::try_new(opts)
.map_err(|e| Error::query_execution(format!("fastembed init: {e}")))?;
Ok(Self { inner: Mutex::new(inner) })
}
}
impl std::fmt::Debug for FastEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FastEmbedder").finish()
}
}
impl Embedder for FastEmbedder {
fn embed(&self, text: &str) -> Result<Option<Vec<f32>>> {
let guard = self
.inner
.lock()
.map_err(|e| Error::query_execution(format!("fastembed lock: {e}")))?;
let mut out = guard
.embed(vec![text.to_string()], None)
.map_err(|e| Error::query_execution(format!("fastembed embed: {e}")))?;
Ok(out.pop())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn noop_returns_none() {
let e = NoopEmbedder;
assert!(e.embed("anything").unwrap().is_none());
}
}