openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! OpenAI embedding provider
//!
//! Uses the OpenAI Embeddings API to generate vector embeddings.

use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::{resize_vector, EmbeddingProvider};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

/// OpenAI embedding provider
pub struct OpenAIProvider {
    api_key: String,
    base_url: String,
    model: String,
    dim: usize,
    client: reqwest::Client,
}

impl OpenAIProvider {
    /// Create a new OpenAI provider
    pub fn new(config: &Config) -> Self {
        let api_key = config
            .openai_key
            .clone()
            .unwrap_or_default();

        let model = config
            .openai_model
            .clone()
            .unwrap_or_else(|| "text-embedding-3-small".to_string());

        Self {
            api_key,
            base_url: config.openai_base_url.clone(),
            model,
            dim: config.vec_dim,
            client: reqwest::Client::new(),
        }
    }

    /// Get the model name for a sector
    fn model_for_sector(&self, _sector: &Sector) -> &str {
        // Use the same model for all sectors
        &self.model
    }
}

#[derive(Serialize)]
struct EmbeddingRequest {
    input: Vec<String>,
    model: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    dimensions: Option<usize>,
}

#[derive(Deserialize)]
struct EmbeddingResponse {
    data: Vec<EmbeddingData>,
}

#[derive(Deserialize)]
struct EmbeddingData {
    embedding: Vec<f32>,
}

#[async_trait]
impl EmbeddingProvider for OpenAIProvider {
    async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
        if self.api_key.is_empty() {
            return Err(Error::config("OpenAI API key is not configured"));
        }

        let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
        let model = self.model_for_sector(sector);

        let request = EmbeddingRequest {
            input: vec![text.to_string()],
            model: model.to_string(),
            dimensions: Some(self.dim),
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();

            if status.as_u16() == 429 {
                return Err(Error::RateLimit {
                    retry_after_secs: 5,
                });
            }

            return Err(Error::embedding(format!(
                "OpenAI API error {}: {}",
                status, body
            )));
        }

        let data: EmbeddingResponse = response.json().await?;

        if data.data.is_empty() {
            return Err(Error::embedding("No embedding returned from OpenAI"));
        }

        let vector = resize_vector(&data.data[0].embedding, self.dim);

        Ok(EmbeddingResult {
            sector: *sector,
            vector: vector.clone(),
            dim: vector.len(),
        })
    }

    async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
        if self.api_key.is_empty() {
            return Err(Error::config("OpenAI API key is not configured"));
        }

        if texts.is_empty() {
            return Ok(Vec::new());
        }

        let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));

        let input: Vec<String> = texts.iter().map(|(t, _)| t.to_string()).collect();
        let sectors: Vec<Sector> = texts.iter().map(|(_, s)| **s).collect();

        let request = EmbeddingRequest {
            input,
            model: self.model.clone(),
            dimensions: Some(self.dim),
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();

            if status.as_u16() == 429 {
                return Err(Error::RateLimit {
                    retry_after_secs: 5,
                });
            }

            return Err(Error::embedding(format!(
                "OpenAI API error {}: {}",
                status, body
            )));
        }

        let data: EmbeddingResponse = response.json().await?;

        let results: Vec<EmbeddingResult> = data
            .data
            .into_iter()
            .zip(sectors.into_iter())
            .map(|(emb, sector)| {
                let vector = resize_vector(&emb.embedding, self.dim);
                EmbeddingResult {
                    sector,
                    vector: vector.clone(),
                    dim: vector.len(),
                }
            })
            .collect();

        Ok(results)
    }

    fn dimensions(&self) -> usize {
        self.dim
    }

    fn name(&self) -> &'static str {
        "openai"
    }

    fn supports_batch(&self) -> bool {
        true
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_provider_creation() {
        let config = Config::default();
        let provider = OpenAIProvider::new(&config);

        assert_eq!(provider.name(), "openai");
        assert!(provider.supports_batch());
    }

    #[test]
    fn test_model_name() {
        let mut config = Config::default();
        config.openai_model = Some("text-embedding-ada-002".to_string());

        let provider = OpenAIProvider::new(&config);
        assert_eq!(provider.model, "text-embedding-ada-002");
    }
}