use reqwest::RequestBuilder;
use super::EmbeddingFunction;
#[derive(Debug, thiserror::Error)]
pub enum OllamaEmbeddingError {
#[error("request failed: {0}")]
Reqwest(#[from] reqwest::Error),
}
pub struct OllamaEmbeddingFunction {
client: reqwest::Client,
host: String,
model: String,
}
impl OllamaEmbeddingFunction {
pub async fn new(
host: impl Into<String>,
model: impl Into<String>,
) -> Result<Self, OllamaEmbeddingError> {
let client = reqwest::Client::new();
let host = host.into();
let model = model.into();
let this = Self {
client,
host,
model,
};
this.heartbeat().await?;
Ok(this)
}
pub async fn heartbeat(&self) -> Result<(), OllamaEmbeddingError> {
self.embed_strs(&["heartbeat"]).await?;
Ok(())
}
}
#[async_trait::async_trait]
impl EmbeddingFunction for OllamaEmbeddingFunction {
type Embedding = Vec<f32>;
type Error = OllamaEmbeddingError;
async fn embed_strs(&self, batches: &[&str]) -> Result<Vec<Vec<f32>>, Self::Error> {
let model = &self.model;
let input = batches;
let req = EmbedRequest { model, input };
let resp = req
.make_request(self)
.send()
.await?
.error_for_status()?
.json::<EmbedResponse>()
.await?;
Ok(resp.embeddings)
}
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct EmbedRequest<'a> {
pub model: &'a str,
pub input: &'a [&'a str],
}
impl EmbedRequest<'_> {
pub fn make_request(&self, ef: &OllamaEmbeddingFunction) -> RequestBuilder {
ef.client.post(format!("{}/api/embed", ef.host)).json(self)
}
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct EmbedResponse {
pub model: String,
pub embeddings: Vec<Vec<f32>>,
pub total_duration: Option<f64>,
pub load_duration: Option<f64>,
pub prompt_eval_count: Option<f64>,
}