Skip to main content

synaptic_openai/
embeddings.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::{Embeddings, SynapticError};
6use synaptic_models::{ProviderBackend, ProviderRequest};
7
8pub struct OpenAiEmbeddingsConfig {
9    pub api_key: String,
10    pub model: String,
11    pub base_url: String,
12}
13
14impl OpenAiEmbeddingsConfig {
15    pub fn new(api_key: impl Into<String>) -> Self {
16        Self {
17            api_key: api_key.into(),
18            model: "text-embedding-3-small".to_string(),
19            base_url: "https://api.openai.com/v1".to_string(),
20        }
21    }
22
23    pub fn with_model(mut self, model: impl Into<String>) -> Self {
24        self.model = model.into();
25        self
26    }
27
28    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
29        self.base_url = base_url.into();
30        self
31    }
32}
33
34pub struct OpenAiEmbeddings {
35    config: OpenAiEmbeddingsConfig,
36    backend: Arc<dyn ProviderBackend>,
37}
38
39impl OpenAiEmbeddings {
40    pub fn new(config: OpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
41        Self { config, backend }
42    }
43
44    fn build_request(&self, input: Vec<String>) -> ProviderRequest {
45        ProviderRequest {
46            url: format!("{}/embeddings", self.config.base_url),
47            headers: vec![
48                (
49                    "Authorization".to_string(),
50                    format!("Bearer {}", self.config.api_key),
51                ),
52                ("Content-Type".to_string(), "application/json".to_string()),
53            ],
54            body: json!({
55                "model": self.config.model,
56                "input": input,
57            }),
58        }
59    }
60
61    fn parse_response(&self, body: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
62        let data = body.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
63            SynapticError::Embedding("missing 'data' field in response".to_string())
64        })?;
65
66        let mut embeddings = Vec::with_capacity(data.len());
67        for item in data {
68            let embedding = item
69                .get("embedding")
70                .and_then(|e| e.as_array())
71                .ok_or_else(|| SynapticError::Embedding("missing 'embedding' field".to_string()))?
72                .iter()
73                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
74                .collect();
75            embeddings.push(embedding);
76        }
77
78        Ok(embeddings)
79    }
80}
81
82#[async_trait]
83impl Embeddings for OpenAiEmbeddings {
84    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
85        let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
86        let request = self.build_request(input);
87        let response = self.backend.send(request).await?;
88
89        if response.status != 200 {
90            return Err(SynapticError::Embedding(format!(
91                "OpenAI API error ({}): {}",
92                response.status, response.body
93            )));
94        }
95
96        self.parse_response(&response.body)
97    }
98
99    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
100        let mut results = self.embed_documents(&[text]).await?;
101        results
102            .pop()
103            .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
104    }
105}