Skip to main content

synaptic_openai/
embeddings.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::{Embeddings, SynapticError};
6use synaptic_models::{ProviderBackend, ProviderRequest};
7
8pub struct OpenAiEmbeddingsConfig {
9    pub api_key: String,
10    pub model: String,
11    pub base_url: String,
12}
13
14impl OpenAiEmbeddingsConfig {
15    pub fn new(api_key: impl Into<String>) -> Self {
16        Self {
17            api_key: api_key.into(),
18            model: "text-embedding-3-small".to_string(),
19            base_url: "https://api.openai.com/v1".to_string(),
20        }
21    }
22
23    pub fn with_model(mut self, model: impl Into<String>) -> Self {
24        self.model = model.into();
25        self
26    }
27
28    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
29        self.base_url = base_url.into();
30        self
31    }
32}
33
34pub struct OpenAiEmbeddings {
35    config: OpenAiEmbeddingsConfig,
36    backend: Arc<dyn ProviderBackend>,
37}
38
39impl OpenAiEmbeddings {
40    pub fn new(config: OpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
41        Self { config, backend }
42    }
43
44    fn build_request(&self, input: Vec<String>) -> ProviderRequest {
45        ProviderRequest {
46            url: format!("{}/embeddings", self.config.base_url),
47            headers: vec![
48                (
49                    "Authorization".to_string(),
50                    format!("Bearer {}", self.config.api_key),
51                ),
52                ("Content-Type".to_string(), "application/json".to_string()),
53            ],
54            body: json!({
55                "model": self.config.model,
56                "input": input,
57            }),
58        }
59    }
60
61    fn parse_response(&self, body: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
62        parse_embeddings_response(body)
63    }
64}
65
66pub(crate) fn parse_embeddings_response(
67    body: &serde_json::Value,
68) -> Result<Vec<Vec<f32>>, SynapticError> {
69    let data = body
70        .get("data")
71        .and_then(|d| d.as_array())
72        .ok_or_else(|| SynapticError::Embedding("missing 'data' field in response".to_string()))?;
73
74    let mut embeddings = Vec::with_capacity(data.len());
75    for item in data {
76        let embedding = item
77            .get("embedding")
78            .and_then(|e| e.as_array())
79            .ok_or_else(|| SynapticError::Embedding("missing 'embedding' field".to_string()))?
80            .iter()
81            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
82            .collect();
83        embeddings.push(embedding);
84    }
85
86    Ok(embeddings)
87}
88
89#[async_trait]
90impl Embeddings for OpenAiEmbeddings {
91    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
92        let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
93        let request = self.build_request(input);
94        let response = self.backend.send(request).await?;
95
96        if response.status != 200 {
97            return Err(SynapticError::Embedding(format!(
98                "OpenAI API error ({}): {}",
99                response.status, response.body
100            )));
101        }
102
103        self.parse_response(&response.body)
104    }
105
106    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
107        let mut results = self.embed_documents(&[text]).await?;
108        results
109            .pop()
110            .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
111    }
112}