Skip to main content

modo/embed/
openai.rs

1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::OpenAIConfig;
10use super::convert::to_f32_blob;
11
12const DEFAULT_BASE_URL: &str = "https://api.openai.com";
13
14struct Inner {
15    client: reqwest::Client,
16    api_key: String,
17    model: String,
18    dimensions: usize,
19    base_url: String,
20}
21
22/// OpenAI embedding provider.
23///
24/// Calls `POST {base_url}/v1/embeddings` and returns a little-endian f32 blob.
25///
26/// # Example
27///
28/// ```rust,ignore
29/// let client = reqwest::Client::new();
30/// let provider = OpenAIEmbedding::new(client, &config)?;
31/// let embedder = EmbeddingProvider::new(provider);
32/// ```
33pub struct OpenAIEmbedding(Arc<Inner>);
34
35impl Clone for OpenAIEmbedding {
36    fn clone(&self) -> Self {
37        Self(Arc::clone(&self.0))
38    }
39}
40
41impl OpenAIEmbedding {
42    /// Create from config. Validates config at construction.
43    ///
44    /// # Errors
45    ///
46    /// Returns `Error::bad_request` if config validation fails.
47    pub fn new(client: reqwest::Client, config: &OpenAIConfig) -> Result<Self> {
48        config.validate()?;
49        let base_url = config
50            .base_url
51            .as_deref()
52            .unwrap_or(DEFAULT_BASE_URL)
53            .trim_end_matches('/')
54            .to_owned();
55        Ok(Self(Arc::new(Inner {
56            client,
57            api_key: config.api_key.clone(),
58            model: config.model.clone(),
59            dimensions: config.dimensions,
60            base_url,
61        })))
62    }
63}
64
65impl EmbeddingBackend for OpenAIEmbedding {
66    fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
67        let input = input.to_owned();
68        Box::pin(async move {
69            let url = format!("{}/v1/embeddings", self.0.base_url);
70            let body = Request {
71                input: &input,
72                model: &self.0.model,
73                dimensions: self.0.dimensions,
74            };
75
76            let resp = self
77                .0
78                .client
79                .post(&url)
80                .bearer_auth(&self.0.api_key)
81                .json(&body)
82                .send()
83                .await
84                .map_err(|e| Error::internal("openai embeddings request failed").chain(e))?;
85
86            if !resp.status().is_success() {
87                let status = resp.status();
88                let text = resp.text().await.unwrap_or_default();
89                return Err(Error::internal(format!(
90                    "openai embedding error: {status}: {text}"
91                )));
92            }
93
94            let parsed: Response = resp.json().await.map_err(|e| {
95                Error::internal("failed to parse openai embedding response").chain(e)
96            })?;
97
98            let values = parsed
99                .data
100                .into_iter()
101                .next()
102                .ok_or_else(|| Error::internal("openai returned empty embedding data"))?
103                .embedding;
104
105            Ok(to_f32_blob(&values))
106        })
107    }
108
109    fn dimensions(&self) -> usize {
110        self.0.dimensions
111    }
112
113    fn model_name(&self) -> &str {
114        &self.0.model
115    }
116}
117
118#[derive(Serialize)]
119struct Request<'a> {
120    input: &'a str,
121    model: &'a str,
122    dimensions: usize,
123}
124
125#[derive(Deserialize)]
126struct Response {
127    data: Vec<EmbeddingData>,
128}
129
130#[derive(Deserialize)]
131struct EmbeddingData {
132    embedding: Vec<f32>,
133}