ai_lib_rust/embeddings/
client.rs1use crate::{Error, ErrorContext, Result};
4use super::types::{Embedding, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage};
5
6pub struct EmbeddingClient {
7 http_client: reqwest::Client,
8 model: String,
9 base_url: String,
10 api_key: String,
11 dimensions: Option<usize>,
12 max_batch_size: usize,
13}
14
15impl EmbeddingClient {
16 pub fn builder() -> EmbeddingClientBuilder { EmbeddingClientBuilder::new() }
17
18 pub async fn embed(&self, text: &str) -> Result<EmbeddingResponse> {
19 let request = EmbeddingRequest::single(&self.model, text);
20 self.execute(request).await
21 }
22
23 pub async fn embed_batch(&self, texts: &[impl AsRef<str>]) -> Result<EmbeddingResponse> {
24 let texts: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
25 if texts.len() <= self.max_batch_size {
26 return self.execute(EmbeddingRequest::batch(&self.model, texts)).await;
27 }
28 let mut all_embeddings: Vec<Embedding> = Vec::new();
29 let mut total_usage = EmbeddingUsage::default();
30 for (batch_idx, chunk) in texts.chunks(self.max_batch_size).enumerate() {
31 let response = self.execute(EmbeddingRequest::batch(&self.model, chunk.to_vec())).await?;
32 let offset = batch_idx * self.max_batch_size;
33 for mut emb in response.embeddings { emb.index += offset; all_embeddings.push(emb); }
34 total_usage.add(&response.usage);
35 }
36 Ok(EmbeddingResponse::new(all_embeddings, self.model.clone(), total_usage))
37 }
38
39 async fn execute(&self, mut request: EmbeddingRequest) -> Result<EmbeddingResponse> {
40 if let Some(dims) = self.dimensions { request = request.with_dimensions(dims); }
41 let endpoint = format!("{}/v1/embeddings", self.base_url);
42 let response = self.http_client.post(&endpoint).bearer_auth(&self.api_key).header("Content-Type", "application/json").json(&request).send().await
43 .map_err(|e| Error::network_with_context(format!("Embedding request failed: {}", e), ErrorContext::new().with_source("embeddings")))?;
44 let status = response.status();
45 let body = response.text().await.map_err(|e| Error::network_with_context(format!("Failed to read response: {}", e), ErrorContext::new()))?;
46 if !status.is_success() { return Err(Error::api_with_context(format!("Embedding API error ({}): {}", status, body), ErrorContext::new())); }
47 let json: serde_json::Value = serde_json::from_str(&body)?;
48 EmbeddingResponse::from_openai_format(&json)
49 }
50
51 pub fn model(&self) -> &str { &self.model }
52}
53
54pub struct EmbeddingClientBuilder {
55 model: Option<String>,
56 api_key: Option<String>,
57 base_url: Option<String>,
58 dimensions: Option<usize>,
59 max_batch_size: usize,
60 timeout_secs: u64,
61}
62
63impl EmbeddingClientBuilder {
64 pub fn new() -> Self { Self { model: None, api_key: None, base_url: None, dimensions: None, max_batch_size: 100, timeout_secs: 60 } }
65 pub fn model(mut self, model: impl Into<String>) -> Self { self.model = Some(model.into()); self }
66 pub fn api_key(mut self, api_key: impl Into<String>) -> Self { self.api_key = Some(api_key.into()); self }
67 pub fn base_url(mut self, url: impl Into<String>) -> Self { self.base_url = Some(url.into()); self }
68 pub fn dimensions(mut self, dimensions: usize) -> Self { self.dimensions = Some(dimensions); self }
69
70 pub async fn build(self) -> Result<EmbeddingClient> {
71 let model = self.model.ok_or_else(|| Error::configuration("Model must be specified"))?;
72 let api_key = self.api_key.or_else(|| std::env::var("OPENAI_API_KEY").ok()).ok_or_else(|| Error::configuration("API key required"))?;
73 let base_url = self.base_url.unwrap_or_else(|| "https://api.openai.com".to_string());
74 let http_client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(self.timeout_secs)).build()
75 .map_err(|e| Error::configuration(format!("Failed to create HTTP client: {}", e)))?;
76 Ok(EmbeddingClient { http_client, model, base_url, api_key, dimensions: self.dimensions, max_batch_size: self.max_batch_size })
77 }
78}
79
80impl Default for EmbeddingClientBuilder { fn default() -> Self { Self::new() } }