1use crate::{Result, Slot};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13 pub api_key: String,
15
16 pub model: String,
18
19 pub base_url: Option<String>,
21
22 pub max_tokens: Option<u32>,
24
25 pub temperature: Option<f32>,
27
28 pub timeout_seconds: Option<u64>,
30
31 pub api_key_url: Option<String>,
33}
34
35impl ProviderConfig {
36 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
38 Self {
39 api_key: api_key.into(),
40 model: model.into(),
41 base_url: None,
42 max_tokens: None,
43 temperature: None,
44 timeout_seconds: None,
45 api_key_url: None,
46 }
47 }
48
49 pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
51 self.api_key_url = Some(url.into());
52 self
53 }
54
55 pub async fn resolve_api_key(&self) -> Result<String> {
57 if let Some(ref url) = self.api_key_url {
58 let resp = reqwest::get(url)
59 .await
60 .map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
61
62 let key = resp
63 .text()
64 .await
65 .map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
66
67 Ok(key.trim().to_string())
68 } else {
69 Ok(self.api_key.clone())
70 }
71 }
72
73 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
75 self.base_url = Some(url.into());
76 self
77 }
78
79 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
81 self.max_tokens = Some(tokens);
82 self
83 }
84
85 pub fn with_temperature(mut self, temp: f32) -> Self {
87 self.temperature = Some(temp.clamp(0.0, 2.0));
88 self
89 }
90
91 pub fn with_timeout(mut self, seconds: u64) -> Self {
93 self.timeout_seconds = Some(seconds);
94 self
95 }
96
97 pub fn from_env() -> Result<Self> {
104 let api_key = std::env::var("AETHER_API_KEY")
105 .or_else(|_| std::env::var("OPENAI_API_KEY"))
106 .map_err(|_| {
107 crate::AetherError::ConfigError(
108 "AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
109 )
110 })?;
111
112 let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
113
114 let mut config = Self::new(api_key, model);
115
116 if let Ok(url) = std::env::var("AETHER_BASE_URL") {
117 config = config.with_base_url(url);
118 }
119
120 Ok(config)
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct GenerationRequest {
127 pub slot: Slot,
129
130 pub context: Option<String>,
132
133 pub system_prompt: Option<String>,
135
136 pub model: Option<String>,
138
139 pub max_tokens: Option<u32>,
141}
142
143use futures::stream::BoxStream;
144
145#[derive(Debug, Clone)]
147pub struct GenerationResponse {
148 pub code: String,
150
151 pub tokens_used: Option<u32>,
153
154 pub metadata: Option<serde_json::Value>,
156}
157
158#[derive(Debug, Clone)]
160pub struct StreamResponse {
161 pub delta: String,
163
164 pub metadata: Option<serde_json::Value>,
166}
167
168#[async_trait]
172pub trait AiProvider: Send + Sync {
173 fn name(&self) -> &str;
175
176 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
186
187 fn generate_stream(
197 &self,
198 _request: GenerationRequest,
199 ) -> BoxStream<'static, Result<StreamResponse>> {
200 let name = self.name().to_string();
201 Box::pin(async_stream::stream! {
202 yield Err(crate::AetherError::ProviderError(format!(
203 "Streaming not implemented for provider: {}",
204 name
205 )));
206 })
207 }
208
209 async fn generate_batch(
213 &self,
214 requests: Vec<GenerationRequest>,
215 ) -> Result<Vec<GenerationResponse>> {
216 let mut responses = Vec::with_capacity(requests.len());
217 for request in requests {
218 responses.push(self.generate(request).await?);
219 }
220 Ok(responses)
221 }
222
223 async fn health_check(&self) -> Result<bool> {
225 Ok(true)
226 }
227}
228
229#[async_trait]
230impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Arc<T> {
231 fn name(&self) -> &str {
232 (**self).name()
233 }
234
235 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
236 (**self).generate(request).await
237 }
238
239 fn generate_stream(
240 &self,
241 request: GenerationRequest,
242 ) -> BoxStream<'static, Result<StreamResponse>> {
243 (**self).generate_stream(request)
244 }
245}
246
247#[async_trait]
248impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Box<T> {
249 fn name(&self) -> &str {
250 (**self).name()
251 }
252
253 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
254 (**self).generate(request).await
255 }
256
257 fn generate_stream(
258 &self,
259 request: GenerationRequest,
260 ) -> BoxStream<'static, Result<StreamResponse>> {
261 (**self).generate_stream(request)
262 }
263}
264
265#[derive(Debug, Default)]
267pub struct MockProvider {
268 pub responses: std::collections::HashMap<String, String>,
270}
271
272impl MockProvider {
273 pub fn new() -> Self {
275 Self::default()
276 }
277
278 pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
280 self.responses.insert(slot.into(), code.into());
281 self
282 }
283}
284
285#[async_trait]
286impl AiProvider for MockProvider {
287 fn name(&self) -> &str {
288 "mock"
289 }
290
291 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
292 let code = self
293 .responses
294 .get(&request.slot.name)
295 .cloned()
296 .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
297
298 Ok(GenerationResponse {
299 code,
300 tokens_used: Some(10),
301 metadata: None,
302 })
303 }
304
305 fn generate_stream(
306 &self,
307 request: GenerationRequest,
308 ) -> BoxStream<'static, Result<StreamResponse>> {
309 let code = self
310 .responses
311 .get(&request.slot.name)
312 .cloned()
313 .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
314
315 let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
316
317 let stream = async_stream::stream! {
318 for word in words {
319 yield Ok(StreamResponse {
320 delta: word,
321 metadata: None,
322 });
323 }
324 };
325
326 Box::pin(stream)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[tokio::test]
335 async fn test_mock_provider() {
336 let provider = MockProvider::new()
337 .with_response("button", "<button>Click me</button>");
338
339 let request = GenerationRequest {
340 slot: Slot::new("button", "Create a button"),
341 context: None,
342 system_prompt: None,
343 model: None,
344 max_tokens: None,
345 };
346
347 let response = provider.generate(request).await.unwrap();
348 assert_eq!(response.code, "<button>Click me</button>");
349 }
350}