Skip to main content

noether_engine/llm/
openai.rs

1//! OpenAI API provider (also works with any OpenAI-compatible API).
2//!
3//! Auth: `OPENAI_API_KEY` environment variable.
4//!
5//! Compatible with OpenAI, Ollama, Together AI, and any other service
6//! that implements the OpenAI chat/completions and embeddings endpoints.
7//!
8//! Override the base URL with `OPENAI_API_BASE` for self-hosted or
9//! third-party OpenAI-compatible services.
10
11use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
12use crate::llm::{LlmConfig, LlmError, LlmProvider, Message, Role};
13use serde_json::{json, Value};
14
15const DEFAULT_API_BASE: &str = "https://api.openai.com/v1";
16const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-3-small";
17const DEFAULT_EMBEDDING_DIMENSIONS: usize = 1536;
18
19// ── LLM provider ────────────────────────────────────────────────────────────
20
21/// Calls `{base}/chat/completions` with an OpenAI-compatible API.
22///
23/// Supports all OpenAI chat models and any compatible endpoint:
24/// - `gpt-4o-mini` — fast and cheap (default)
25/// - `gpt-4o` — most capable
26/// - Any model exposed by an OpenAI-compatible API
27///
28/// Set `OPENAI_API_KEY` to your API key.
29/// Override model with `OPENAI_MODEL`.
30/// Override base URL with `OPENAI_API_BASE` (e.g. `http://localhost:11434/v1` for Ollama).
31pub struct OpenAiProvider {
32    api_key: String,
33    api_base: String,
34    client: reqwest::blocking::Client,
35}
36
37impl OpenAiProvider {
38    pub fn new(api_key: impl Into<String>, api_base: impl Into<String>) -> Self {
39        let client = reqwest::blocking::Client::builder()
40            .timeout(std::time::Duration::from_secs(120))
41            .connect_timeout(std::time::Duration::from_secs(15))
42            .build()
43            .expect("failed to build reqwest client");
44        Self {
45            api_key: api_key.into(),
46            api_base: api_base.into(),
47            client,
48        }
49    }
50
51    /// Construct from environment. Returns `Err` if `OPENAI_API_KEY` is not set.
52    pub fn from_env() -> Result<Self, String> {
53        let key =
54            std::env::var("OPENAI_API_KEY").map_err(|_| "OPENAI_API_KEY is not set".to_string())?;
55        let base =
56            std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| DEFAULT_API_BASE.to_string());
57        Ok(Self::new(key, base))
58    }
59}
60
61impl LlmProvider for OpenAiProvider {
62    fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
63        let url = format!("{}/chat/completions", self.api_base);
64
65        let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| config.model.clone());
66
67        let msgs: Vec<Value> = messages
68            .iter()
69            .map(|m| {
70                let role = match m.role {
71                    Role::System => "system",
72                    Role::User => "user",
73                    Role::Assistant => "assistant",
74                };
75                json!({"role": role, "content": m.content})
76            })
77            .collect();
78
79        let body = json!({
80            "model": model,
81            "messages": msgs,
82            "max_tokens": config.max_tokens,
83            "temperature": config.temperature,
84            "stream": false,
85        });
86
87        let resp = self
88            .client
89            .post(&url)
90            .bearer_auth(&self.api_key)
91            .json(&body)
92            .send()
93            .map_err(|e| LlmError::Http(e.to_string()))?;
94
95        let status = resp.status();
96        let text = resp.text().map_err(|e| LlmError::Http(e.to_string()))?;
97
98        if !status.is_success() {
99            return Err(LlmError::Provider(format!(
100                "OpenAI API HTTP {status}: {text}"
101            )));
102        }
103
104        let json: Value =
105            serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
106
107        json["choices"][0]["message"]["content"]
108            .as_str()
109            .map(|s| s.to_string())
110            .ok_or_else(|| LlmError::Parse(format!("unexpected OpenAI response shape: {json}")))
111    }
112}
113
114// ── Embedding provider ───────────────────────────────────────────────────────
115
116/// Calls `{base}/embeddings` using the OpenAI embeddings API.
117///
118/// - Default model: `text-embedding-3-small` (1536 dimensions)
119/// - Compatible with any OpenAI-compatible embeddings endpoint
120pub struct OpenAiEmbeddingProvider {
121    api_key: String,
122    api_base: String,
123    model: String,
124    client: reqwest::blocking::Client,
125}
126
127impl OpenAiEmbeddingProvider {
128    pub fn new(api_key: impl Into<String>, api_base: impl Into<String>) -> Self {
129        let client = reqwest::blocking::Client::builder()
130            .timeout(std::time::Duration::from_secs(30))
131            .connect_timeout(std::time::Duration::from_secs(15))
132            .build()
133            .expect("failed to build reqwest client");
134        Self {
135            api_key: api_key.into(),
136            api_base: api_base.into(),
137            model: std::env::var("OPENAI_EMBEDDING_MODEL")
138                .unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.into()),
139            client,
140        }
141    }
142
143    /// Construct from environment. Returns `Err` if `OPENAI_API_KEY` is not set.
144    pub fn from_env() -> Result<Self, String> {
145        let key =
146            std::env::var("OPENAI_API_KEY").map_err(|_| "OPENAI_API_KEY is not set".to_string())?;
147        let base =
148            std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| DEFAULT_API_BASE.to_string());
149        Ok(Self::new(key, base))
150    }
151}
152
153impl EmbeddingProvider for OpenAiEmbeddingProvider {
154    fn dimensions(&self) -> usize {
155        DEFAULT_EMBEDDING_DIMENSIONS
156    }
157
158    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
159        let mut batch = self.embed_batch(&[text])?;
160        batch
161            .pop()
162            .ok_or_else(|| EmbeddingError::Provider("empty response".into()))
163    }
164
165    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
166        if texts.is_empty() {
167            return Ok(Vec::new());
168        }
169
170        let url = format!("{}/embeddings", self.api_base);
171        let body = json!({
172            "model": self.model,
173            "input": texts,
174            "encoding_format": "float",
175        });
176
177        let resp = self
178            .client
179            .post(&url)
180            .bearer_auth(&self.api_key)
181            .json(&body)
182            .send()
183            .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
184
185        let status = resp.status();
186        let text = resp
187            .text()
188            .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
189
190        if !status.is_success() {
191            return Err(EmbeddingError::Provider(format!(
192                "OpenAI embeddings HTTP {status}: {text}"
193            )));
194        }
195
196        let json: Value =
197            serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
198
199        let mut items: Vec<(usize, Embedding)> = json["data"]
200            .as_array()
201            .ok_or_else(|| EmbeddingError::Provider("missing 'data' field".into()))?
202            .iter()
203            .map(|item| {
204                let index = item["index"].as_u64().unwrap_or(0) as usize;
205                let vec: Embedding = item["embedding"]
206                    .as_array()
207                    .unwrap_or(&vec![])
208                    .iter()
209                    .filter_map(|v| v.as_f64().map(|f| f as f32))
210                    .collect();
211                (index, vec)
212            })
213            .collect();
214
215        items.sort_by_key(|(idx, _)| *idx);
216        Ok(items.into_iter().map(|(_, v)| v).collect())
217    }
218}
219
220// ── Tests ────────────────────────────────────────────────────────────────────
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn from_env_errors_without_key() {
228        let saved = std::env::var("OPENAI_API_KEY").ok();
229        std::env::remove_var("OPENAI_API_KEY");
230        assert!(OpenAiProvider::from_env().is_err());
231        assert!(OpenAiEmbeddingProvider::from_env().is_err());
232        if let Some(k) = saved {
233            std::env::set_var("OPENAI_API_KEY", k);
234        }
235    }
236
237    #[test]
238    fn default_base_url() {
239        let provider = OpenAiProvider::new("test-key", DEFAULT_API_BASE);
240        assert_eq!(provider.api_base, "https://api.openai.com/v1");
241    }
242
243    #[test]
244    fn custom_base_url() {
245        let provider = OpenAiProvider::new("test-key", "http://localhost:11434/v1");
246        assert_eq!(provider.api_base, "http://localhost:11434/v1");
247    }
248}