Skip to main content

ai_lib_rust/embeddings/
client.rs

1//! Embedding client for generating embeddings.
2
3use crate::{Error, ErrorContext, Result};
4use super::types::{Embedding, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage};
5
6pub struct EmbeddingClient {
7    http_client: reqwest::Client,
8    model: String,
9    base_url: String,
10    api_key: String,
11    dimensions: Option<usize>,
12    max_batch_size: usize,
13}
14
15impl EmbeddingClient {
16    pub fn builder() -> EmbeddingClientBuilder { EmbeddingClientBuilder::new() }
17
18    pub async fn embed(&self, text: &str) -> Result<EmbeddingResponse> {
19        let request = EmbeddingRequest::single(&self.model, text);
20        self.execute(request).await
21    }
22
23    pub async fn embed_batch(&self, texts: &[impl AsRef<str>]) -> Result<EmbeddingResponse> {
24        let texts: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
25        if texts.len() <= self.max_batch_size {
26            return self.execute(EmbeddingRequest::batch(&self.model, texts)).await;
27        }
28        let mut all_embeddings: Vec<Embedding> = Vec::new();
29        let mut total_usage = EmbeddingUsage::default();
30        for (batch_idx, chunk) in texts.chunks(self.max_batch_size).enumerate() {
31            let response = self.execute(EmbeddingRequest::batch(&self.model, chunk.to_vec())).await?;
32            let offset = batch_idx * self.max_batch_size;
33            for mut emb in response.embeddings { emb.index += offset; all_embeddings.push(emb); }
34            total_usage.add(&response.usage);
35        }
36        Ok(EmbeddingResponse::new(all_embeddings, self.model.clone(), total_usage))
37    }
38
39    async fn execute(&self, mut request: EmbeddingRequest) -> Result<EmbeddingResponse> {
40        if let Some(dims) = self.dimensions { request = request.with_dimensions(dims); }
41        let endpoint = format!("{}/v1/embeddings", self.base_url);
42        let response = self.http_client.post(&endpoint).bearer_auth(&self.api_key).header("Content-Type", "application/json").json(&request).send().await
43            .map_err(|e| Error::network_with_context(format!("Embedding request failed: {}", e), ErrorContext::new().with_source("embeddings")))?;
44        let status = response.status();
45        let body = response.text().await.map_err(|e| Error::network_with_context(format!("Failed to read response: {}", e), ErrorContext::new()))?;
46        if !status.is_success() { return Err(Error::api_with_context(format!("Embedding API error ({}): {}", status, body), ErrorContext::new())); }
47        let json: serde_json::Value = serde_json::from_str(&body)?;
48        EmbeddingResponse::from_openai_format(&json)
49    }
50
51    pub fn model(&self) -> &str { &self.model }
52}
53
54pub struct EmbeddingClientBuilder {
55    model: Option<String>,
56    api_key: Option<String>,
57    base_url: Option<String>,
58    dimensions: Option<usize>,
59    max_batch_size: usize,
60    timeout_secs: u64,
61}
62
63impl EmbeddingClientBuilder {
64    pub fn new() -> Self { Self { model: None, api_key: None, base_url: None, dimensions: None, max_batch_size: 100, timeout_secs: 60 } }
65    pub fn model(mut self, model: impl Into<String>) -> Self { self.model = Some(model.into()); self }
66    pub fn api_key(mut self, api_key: impl Into<String>) -> Self { self.api_key = Some(api_key.into()); self }
67    pub fn base_url(mut self, url: impl Into<String>) -> Self { self.base_url = Some(url.into()); self }
68    pub fn dimensions(mut self, dimensions: usize) -> Self { self.dimensions = Some(dimensions); self }
69
70    pub async fn build(self) -> Result<EmbeddingClient> {
71        let model = self.model.ok_or_else(|| Error::configuration("Model must be specified"))?;
72        let api_key = self.api_key.or_else(|| std::env::var("OPENAI_API_KEY").ok()).ok_or_else(|| Error::configuration("API key required"))?;
73        let base_url = self.base_url.unwrap_or_else(|| "https://api.openai.com".to_string());
74        let http_client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(self.timeout_secs)).build()
75            .map_err(|e| Error::configuration(format!("Failed to create HTTP client: {}", e)))?;
76        Ok(EmbeddingClient { http_client, model, base_url, api_key, dimensions: self.dimensions, max_batch_size: self.max_batch_size })
77    }
78}
79
80impl Default for EmbeddingClientBuilder { fn default() -> Self { Self::new() } }