mempal 0.3.1

Project memory for coding agents. Single binary, hybrid search, knowledge graph.
Documentation
use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::{EmbedError, Embedder, Result};

#[derive(Debug, Clone)]
pub struct ApiEmbedder {
    endpoint: String,
    model: Option<String>,
    dimensions: usize,
}

impl ApiEmbedder {
    pub fn new(endpoint: String, model: Option<String>, dimensions: usize) -> Self {
        Self {
            endpoint,
            model,
            dimensions,
        }
    }

    pub fn endpoint(&self) -> &str {
        &self.endpoint
    }

    pub fn model(&self) -> Option<&str> {
        self.model.as_deref()
    }
}

#[async_trait::async_trait]
impl Embedder for ApiEmbedder {
    async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }

        if is_ollama_embeddings_endpoint(self.endpoint()) {
            return self.embed_ollama_compatible(texts).await;
        }

        self.embed_openai_compatible(texts).await
    }

    fn dimensions(&self) -> usize {
        self.dimensions
    }

    fn name(&self) -> &str {
        "api"
    }
}

impl ApiEmbedder {
    async fn embed_openai_compatible(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        let endpoint = self.endpoint().to_string();
        let response = reqwest::Client::new()
            .post(self.endpoint())
            .json(&OpenAiEmbeddingsRequest {
                model: self.model().unwrap_or("text-embedding-3-small"),
                input: texts,
            })
            .send()
            .await
            .map_err(|source| EmbedError::HttpRequest {
                endpoint: endpoint.clone(),
                source,
            })?
            .error_for_status()
            .map_err(|source| EmbedError::HttpStatus {
                endpoint: endpoint.clone(),
                source,
            })?
            .json::<OpenAiEmbeddingsResponse>()
            .await
            .map_err(|source| EmbedError::DecodeResponse { endpoint, source })?;

        let vectors = response
            .data
            .into_iter()
            .map(|item| item.embedding)
            .collect::<Vec<_>>();
        validate_vectors(&vectors, self.dimensions())?;
        Ok(vectors)
    }

    async fn embed_ollama_compatible(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        let client = reqwest::Client::new();
        let mut vectors = Vec::with_capacity(texts.len());
        let endpoint = self.endpoint().to_string();

        for text in texts {
            let response = client
                .post(self.endpoint())
                .json(&OllamaEmbeddingsRequest {
                    model: self.model().unwrap_or("nomic-embed-text"),
                    prompt: text,
                })
                .send()
                .await
                .map_err(|source| EmbedError::HttpRequest {
                    endpoint: endpoint.clone(),
                    source,
                })?
                .error_for_status()
                .map_err(|source| EmbedError::HttpStatus {
                    endpoint: endpoint.clone(),
                    source,
                })?
                .json::<Value>()
                .await
                .map_err(|source| EmbedError::DecodeResponse {
                    endpoint: endpoint.clone(),
                    source,
                })?;

            let embedding = response
                .get("embedding")
                .and_then(Value::as_array)
                .ok_or_else(|| {
                    EmbedError::InvalidResponse(
                        "embedding response missing embedding array".to_string(),
                    )
                })?
                .iter()
                .map(json_number_to_f32)
                .collect::<Result<Vec<_>>>()?;
            vectors.push(embedding);
        }

        validate_vectors(&vectors, self.dimensions())?;
        Ok(vectors)
    }
}

#[derive(Debug, Serialize)]
struct OpenAiEmbeddingsRequest<'a> {
    model: &'a str,
    input: &'a [&'a str],
}

#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingsResponse {
    data: Vec<OpenAiEmbeddingItem>,
}

#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingItem {
    embedding: Vec<f32>,
}

#[derive(Debug, Serialize)]
struct OllamaEmbeddingsRequest<'a> {
    model: &'a str,
    prompt: &'a str,
}

fn is_ollama_embeddings_endpoint(endpoint: &str) -> bool {
    endpoint.trim_end_matches('/').ends_with("/api/embeddings")
}

fn validate_vectors(vectors: &[Vec<f32>], expected_dimensions: usize) -> Result<()> {
    if vectors.is_empty() {
        return Err(EmbedError::EmptyVectors);
    }

    if let Some(actual) = vectors
        .iter()
        .map(Vec::len)
        .find(|length| *length != expected_dimensions)
    {
        return Err(EmbedError::InvalidDimensions {
            expected: expected_dimensions,
            actual,
        });
    }

    Ok(())
}

fn json_number_to_f32(value: &Value) -> Result<f32> {
    value
        .as_f64()
        .map(|number| number as f32)
        .ok_or_else(|| EmbedError::InvalidResponse("embedding element was not numeric".to_string()))
}