neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Jina Embeddings - FREE and reliable

use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error};

use super::base::EmbeddingBase;
use crate::config::JinaEmbedderConfig;
use crate::error::{NeomemxError, Result};

/// Jina embedding client with cached configuration
pub struct JinaEmbedding {
    api_key: String,
    url: String,
    model: String,
    embedding_dims: usize,
    client: Client,
}

#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
    model: &'a str,
    input: Vec<&'a str>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
    data: Vec<EmbeddingData>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingData {
    embedding: Vec<f32>,
}

#[derive(Debug, Deserialize)]
struct ErrorResponse {
    detail: Option<String>,
    error: Option<String>,
}

impl JinaEmbedding {
    /// Creates a new Jina embedding client
    pub fn new(config: JinaEmbedderConfig) -> Result<Self> {
        let api_key = config.get_api_key().ok_or_else(|| {
            NeomemxError::EmbeddingError(
                "Jina API key not found. Get your FREE key at: https://jina.ai/embeddings/ \
                 Then set JINA_API_KEY environment variable."
                    .to_string(),
            )
        })?;

        // Jina expects the embeddings endpoint under /v1/embeddings; build it
        // defensively so users can supply either the root or full endpoint.
        let url = {
            let base = config.base_url.trim_end_matches('/');
            if base.ends_with("embeddings") {
                base.to_string()
            } else {
                format!("{}/embeddings", base)
            }
        };

        let client = Client::builder()
            .pool_max_idle_per_host(16)
            .pool_idle_timeout(std::time::Duration::from_secs(90))
            .tcp_keepalive(std::time::Duration::from_secs(60))
            .timeout(std::time::Duration::from_secs(30))
            .no_proxy()
            .build()
            .map_err(|e| {
                NeomemxError::EmbeddingError(format!("Failed to create HTTP client: {}", e))
            })?;

        Ok(Self {
            api_key,
            url,
            model: config.model,
            embedding_dims: config.embedding_dims,
            client,
        })
    }

    async fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }

        let request = EmbeddingRequest {
            model: &self.model,
            input: texts.iter().map(|t| t.trim()).collect(),
        };

        debug!("Jina batch embedding: {} texts", texts.len());

        let response = self
            .client
            .post(&self.url)
            .header("Content-Type", "application/json")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?;

        let status = response.status();
        let body = response.text().await?;

        if !status.is_success() {
            let error_msg = serde_json::from_str::<ErrorResponse>(&body)
                .ok()
                .and_then(|e| e.detail.or(e.error))
                .unwrap_or_else(|| body.clone());
            error!("Jina API error: {}", error_msg);
            return Err(NeomemxError::EmbeddingError(format!(
                "Jina API error: {}",
                error_msg
            )));
        }

        let parsed: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
            NeomemxError::EmbeddingError(format!("Failed to parse Jina response: {}", e))
        })?;

        if parsed.data.len() != texts.len() {
            return Err(NeomemxError::EmbeddingError(format!(
                "Jina returned {} embeddings but expected {}",
                parsed.data.len(),
                texts.len()
            )));
        }

        Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
    }
}

#[async_trait]
impl EmbeddingBase for JinaEmbedding {
    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
        let mut results = self.embed_texts(&[text]).await?;
        results.pop().ok_or_else(|| {
            NeomemxError::EmbeddingError("No embeddings in Jina response".to_string())
        })
    }

    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        self.embed_texts(texts).await
    }

    fn embedding_dims(&self) -> usize {
        self.embedding_dims
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = JinaEmbedderConfig::default();
        assert_eq!(config.model, "jina-embeddings-v2-base-en");
        assert_eq!(config.embedding_dims, 768);
    }

    #[test]
    fn test_custom_model() {
        let config = JinaEmbedderConfig::with_model("jina-embeddings-v3");
        assert_eq!(config.model, "jina-embeddings-v3");
        assert_eq!(config.embedding_dims, 768);
    }
}