ai-lib-core 1.0.0

AI-Protocol execution runtime core (protocol, client, pipeline, transport)
//! Embedding client for generating embeddings.

use super::types::{Embedding, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage};
use crate::{Error, ErrorContext, Result};

pub struct EmbeddingClient {
    http_client: reqwest::Client,
    model: String,
    base_url: String,
    api_key: String,
    dimensions: Option<usize>,
    max_batch_size: usize,
}

impl EmbeddingClient {
    pub fn builder() -> EmbeddingClientBuilder {
        EmbeddingClientBuilder::new()
    }

    pub async fn embed(&self, text: &str) -> Result<EmbeddingResponse> {
        let request = EmbeddingRequest::single(&self.model, text);
        self.execute(request).await
    }

    pub async fn embed_batch(&self, texts: &[impl AsRef<str>]) -> Result<EmbeddingResponse> {
        let texts: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
        if texts.len() <= self.max_batch_size {
            return self
                .execute(EmbeddingRequest::batch(&self.model, texts))
                .await;
        }
        let mut all_embeddings: Vec<Embedding> = Vec::new();
        let mut total_usage = EmbeddingUsage::default();
        for (batch_idx, chunk) in texts.chunks(self.max_batch_size).enumerate() {
            let response = self
                .execute(EmbeddingRequest::batch(&self.model, chunk.to_vec()))
                .await?;
            let offset = batch_idx * self.max_batch_size;
            for mut emb in response.embeddings {
                emb.index += offset;
                all_embeddings.push(emb);
            }
            total_usage.add(&response.usage);
        }
        Ok(EmbeddingResponse::new(
            all_embeddings,
            self.model.clone(),
            total_usage,
        ))
    }

    async fn execute(&self, mut request: EmbeddingRequest) -> Result<EmbeddingResponse> {
        if let Some(dims) = self.dimensions {
            request = request.with_dimensions(dims);
        }
        let endpoint = format!("{}/v1/embeddings", self.base_url);
        let response = self
            .http_client
            .post(&endpoint)
            .bearer_auth(&self.api_key)
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await
            .map_err(|e| {
                Error::network_with_context(
                    format!("Embedding request failed: {}", e),
                    ErrorContext::new().with_source("embeddings"),
                )
            })?;
        let status = response.status();
        let body = response.text().await.map_err(|e| {
            Error::network_with_context(
                format!("Failed to read response: {}", e),
                ErrorContext::new(),
            )
        })?;
        if !status.is_success() {
            return Err(Error::api_with_context(
                format!("Embedding API error ({}): {}", status, body),
                ErrorContext::new(),
            ));
        }
        let json: serde_json::Value = serde_json::from_str(&body)?;
        EmbeddingResponse::from_openai_format(&json)
    }

    pub fn model(&self) -> &str {
        &self.model
    }
}

pub struct EmbeddingClientBuilder {
    model: Option<String>,
    api_key: Option<String>,
    base_url: Option<String>,
    dimensions: Option<usize>,
    max_batch_size: usize,
    timeout_secs: u64,
}

impl EmbeddingClientBuilder {
    pub fn new() -> Self {
        Self {
            model: None,
            api_key: None,
            base_url: None,
            dimensions: None,
            max_batch_size: 100,
            timeout_secs: 60,
        }
    }
    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = Some(model.into());
        self
    }
    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
        self.api_key = Some(api_key.into());
        self
    }
    pub fn base_url(mut self, url: impl Into<String>) -> Self {
        self.base_url = Some(url.into());
        self
    }
    pub fn dimensions(mut self, dimensions: usize) -> Self {
        self.dimensions = Some(dimensions);
        self
    }

    pub async fn build(self) -> Result<EmbeddingClient> {
        let model = self
            .model
            .ok_or_else(|| Error::configuration("Model must be specified"))?;
        let api_key = self
            .api_key
            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
            .ok_or_else(|| Error::configuration("API key required"))?;
        let base_url = self
            .base_url
            .unwrap_or_else(|| "https://api.openai.com".to_string());
        let http_client = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(self.timeout_secs))
            .build()
            .map_err(|e| Error::configuration(format!("Failed to create HTTP client: {}", e)))?;
        Ok(EmbeddingClient {
            http_client,
            model,
            base_url,
            api_key,
            dimensions: self.dimensions,
            max_batch_size: self.max_batch_size,
        })
    }
}

impl Default for EmbeddingClientBuilder {
    fn default() -> Self {
        Self::new()
    }
}