noether_engine/llm/
openai.rs1use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
12use crate::llm::{LlmConfig, LlmError, LlmProvider, Message, Role};
13use serde_json::{json, Value};
14
15const DEFAULT_API_BASE: &str = "https://api.openai.com/v1";
16const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-3-small";
17const DEFAULT_EMBEDDING_DIMENSIONS: usize = 1536;
18
19pub struct OpenAiProvider {
32 api_key: String,
33 api_base: String,
34 client: reqwest::blocking::Client,
35}
36
37impl OpenAiProvider {
38 pub fn new(api_key: impl Into<String>, api_base: impl Into<String>) -> Self {
39 let client = reqwest::blocking::Client::builder()
40 .timeout(std::time::Duration::from_secs(120))
41 .connect_timeout(std::time::Duration::from_secs(15))
42 .build()
43 .expect("failed to build reqwest client");
44 Self {
45 api_key: api_key.into(),
46 api_base: api_base.into(),
47 client,
48 }
49 }
50
51 pub fn from_env() -> Result<Self, String> {
53 let key =
54 std::env::var("OPENAI_API_KEY").map_err(|_| "OPENAI_API_KEY is not set".to_string())?;
55 let base =
56 std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| DEFAULT_API_BASE.to_string());
57 Ok(Self::new(key, base))
58 }
59}
60
61impl LlmProvider for OpenAiProvider {
62 fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
63 let url = format!("{}/chat/completions", self.api_base);
64
65 let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| config.model.clone());
66
67 let msgs: Vec<Value> = messages
68 .iter()
69 .map(|m| {
70 let role = match m.role {
71 Role::System => "system",
72 Role::User => "user",
73 Role::Assistant => "assistant",
74 };
75 json!({"role": role, "content": m.content})
76 })
77 .collect();
78
79 let body = json!({
80 "model": model,
81 "messages": msgs,
82 "max_tokens": config.max_tokens,
83 "temperature": config.temperature,
84 "stream": false,
85 });
86
87 let resp = self
88 .client
89 .post(&url)
90 .bearer_auth(&self.api_key)
91 .json(&body)
92 .send()
93 .map_err(|e| LlmError::Http(e.to_string()))?;
94
95 let status = resp.status();
96 let text = resp.text().map_err(|e| LlmError::Http(e.to_string()))?;
97
98 if !status.is_success() {
99 return Err(LlmError::Provider(format!(
100 "OpenAI API HTTP {status}: {text}"
101 )));
102 }
103
104 let json: Value =
105 serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
106
107 json["choices"][0]["message"]["content"]
108 .as_str()
109 .map(|s| s.to_string())
110 .ok_or_else(|| LlmError::Parse(format!("unexpected OpenAI response shape: {json}")))
111 }
112}
113
114pub struct OpenAiEmbeddingProvider {
121 api_key: String,
122 api_base: String,
123 model: String,
124 client: reqwest::blocking::Client,
125}
126
127impl OpenAiEmbeddingProvider {
128 pub fn new(api_key: impl Into<String>, api_base: impl Into<String>) -> Self {
129 let client = reqwest::blocking::Client::builder()
130 .timeout(std::time::Duration::from_secs(30))
131 .connect_timeout(std::time::Duration::from_secs(15))
132 .build()
133 .expect("failed to build reqwest client");
134 Self {
135 api_key: api_key.into(),
136 api_base: api_base.into(),
137 model: std::env::var("OPENAI_EMBEDDING_MODEL")
138 .unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.into()),
139 client,
140 }
141 }
142
143 pub fn from_env() -> Result<Self, String> {
145 let key =
146 std::env::var("OPENAI_API_KEY").map_err(|_| "OPENAI_API_KEY is not set".to_string())?;
147 let base =
148 std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| DEFAULT_API_BASE.to_string());
149 Ok(Self::new(key, base))
150 }
151}
152
153impl EmbeddingProvider for OpenAiEmbeddingProvider {
154 fn dimensions(&self) -> usize {
155 DEFAULT_EMBEDDING_DIMENSIONS
156 }
157
158 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
159 let mut batch = self.embed_batch(&[text])?;
160 batch
161 .pop()
162 .ok_or_else(|| EmbeddingError::Provider("empty response".into()))
163 }
164
165 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
166 if texts.is_empty() {
167 return Ok(Vec::new());
168 }
169
170 let url = format!("{}/embeddings", self.api_base);
171 let body = json!({
172 "model": self.model,
173 "input": texts,
174 "encoding_format": "float",
175 });
176
177 let resp = self
178 .client
179 .post(&url)
180 .bearer_auth(&self.api_key)
181 .json(&body)
182 .send()
183 .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
184
185 let status = resp.status();
186 let text = resp
187 .text()
188 .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
189
190 if !status.is_success() {
191 return Err(EmbeddingError::Provider(format!(
192 "OpenAI embeddings HTTP {status}: {text}"
193 )));
194 }
195
196 let json: Value =
197 serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
198
199 let mut items: Vec<(usize, Embedding)> = json["data"]
200 .as_array()
201 .ok_or_else(|| EmbeddingError::Provider("missing 'data' field".into()))?
202 .iter()
203 .map(|item| {
204 let index = item["index"].as_u64().unwrap_or(0) as usize;
205 let vec: Embedding = item["embedding"]
206 .as_array()
207 .unwrap_or(&vec![])
208 .iter()
209 .filter_map(|v| v.as_f64().map(|f| f as f32))
210 .collect();
211 (index, vec)
212 })
213 .collect();
214
215 items.sort_by_key(|(idx, _)| *idx);
216 Ok(items.into_iter().map(|(_, v)| v).collect())
217 }
218}
219
220#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn from_env_errors_without_key() {
228 let saved = std::env::var("OPENAI_API_KEY").ok();
229 std::env::remove_var("OPENAI_API_KEY");
230 assert!(OpenAiProvider::from_env().is_err());
231 assert!(OpenAiEmbeddingProvider::from_env().is_err());
232 if let Some(k) = saved {
233 std::env::set_var("OPENAI_API_KEY", k);
234 }
235 }
236
237 #[test]
238 fn default_base_url() {
239 let provider = OpenAiProvider::new("test-key", DEFAULT_API_BASE);
240 assert_eq!(provider.api_base, "https://api.openai.com/v1");
241 }
242
243 #[test]
244 fn custom_base_url() {
245 let provider = OpenAiProvider::new("test-key", "http://localhost:11434/v1");
246 assert_eq!(provider.api_base, "http://localhost:11434/v1");
247 }
248}