1use crate::{Result, Slot};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13 pub api_key: String,
15
16 pub model: String,
18
19 pub base_url: Option<String>,
21
22 pub max_tokens: Option<u32>,
24
25 pub temperature: Option<f32>,
27
28 pub timeout_seconds: Option<u64>,
30
31 pub api_key_url: Option<String>,
33}
34
35impl ProviderConfig {
36 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
38 Self {
39 api_key: api_key.into(),
40 model: model.into(),
41 base_url: None,
42 max_tokens: None,
43 temperature: None,
44 timeout_seconds: None,
45 api_key_url: None,
46 }
47 }
48
49 pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
51 self.api_key_url = Some(url.into());
52 self
53 }
54
55 pub async fn resolve_api_key(&self) -> Result<String> {
57 if let Some(ref url) = self.api_key_url {
58 let resp = reqwest::get(url)
59 .await
60 .map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
61
62 let key = resp
63 .text()
64 .await
65 .map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
66
67 Ok(key.trim().to_string())
68 } else {
69 Ok(self.api_key.clone())
70 }
71 }
72
73 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
75 self.base_url = Some(url.into());
76 self
77 }
78
79 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
81 self.max_tokens = Some(tokens);
82 self
83 }
84
85 pub fn with_temperature(mut self, temp: f32) -> Self {
87 self.temperature = Some(temp.clamp(0.0, 2.0));
88 self
89 }
90
91 pub fn with_timeout(mut self, seconds: u64) -> Self {
93 self.timeout_seconds = Some(seconds);
94 self
95 }
96
97 pub fn from_env() -> Result<Self> {
104 let api_key = std::env::var("AETHER_API_KEY")
105 .or_else(|_| std::env::var("OPENAI_API_KEY"))
106 .map_err(|_| {
107 crate::AetherError::ConfigError(
108 "AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
109 )
110 })?;
111
112 let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
113
114 let mut config = Self::new(api_key, model);
115
116 if let Ok(url) = std::env::var("AETHER_BASE_URL") {
117 config = config.with_base_url(url);
118 }
119
120 Ok(config)
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct GenerationRequest {
127 pub slot: Slot,
129
130 pub context: Option<String>,
132
133 pub system_prompt: Option<String>,
135}
136
137use futures::stream::BoxStream;
138
139#[derive(Debug, Clone)]
141pub struct GenerationResponse {
142 pub code: String,
144
145 pub tokens_used: Option<u32>,
147
148 pub metadata: Option<serde_json::Value>,
150}
151
152#[derive(Debug, Clone)]
154pub struct StreamResponse {
155 pub delta: String,
157
158 pub metadata: Option<serde_json::Value>,
160}
161
162#[async_trait]
166pub trait AiProvider: Send + Sync {
167 fn name(&self) -> &str;
169
170 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
180
181 fn generate_stream(
191 &self,
192 _request: GenerationRequest,
193 ) -> BoxStream<'static, Result<StreamResponse>> {
194 let name = self.name().to_string();
195 Box::pin(async_stream::stream! {
196 yield Err(crate::AetherError::ProviderError(format!(
197 "Streaming not implemented for provider: {}",
198 name
199 )));
200 })
201 }
202
203 async fn generate_batch(
207 &self,
208 requests: Vec<GenerationRequest>,
209 ) -> Result<Vec<GenerationResponse>> {
210 let mut responses = Vec::with_capacity(requests.len());
211 for request in requests {
212 responses.push(self.generate(request).await?);
213 }
214 Ok(responses)
215 }
216
217 async fn health_check(&self) -> Result<bool> {
219 Ok(true)
220 }
221}
222
223#[async_trait]
224impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Arc<T> {
225 fn name(&self) -> &str {
226 (**self).name()
227 }
228
229 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
230 (**self).generate(request).await
231 }
232
233 fn generate_stream(
234 &self,
235 request: GenerationRequest,
236 ) -> BoxStream<'static, Result<StreamResponse>> {
237 (**self).generate_stream(request)
238 }
239}
240
241#[async_trait]
242impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Box<T> {
243 fn name(&self) -> &str {
244 (**self).name()
245 }
246
247 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
248 (**self).generate(request).await
249 }
250
251 fn generate_stream(
252 &self,
253 request: GenerationRequest,
254 ) -> BoxStream<'static, Result<StreamResponse>> {
255 (**self).generate_stream(request)
256 }
257}
258
259#[derive(Debug, Default)]
261pub struct MockProvider {
262 pub responses: std::collections::HashMap<String, String>,
264}
265
266impl MockProvider {
267 pub fn new() -> Self {
269 Self::default()
270 }
271
272 pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
274 self.responses.insert(slot.into(), code.into());
275 self
276 }
277}
278
279#[async_trait]
280impl AiProvider for MockProvider {
281 fn name(&self) -> &str {
282 "mock"
283 }
284
285 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
286 let code = self
287 .responses
288 .get(&request.slot.name)
289 .cloned()
290 .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
291
292 Ok(GenerationResponse {
293 code,
294 tokens_used: Some(10),
295 metadata: None,
296 })
297 }
298
299 fn generate_stream(
300 &self,
301 request: GenerationRequest,
302 ) -> BoxStream<'static, Result<StreamResponse>> {
303 let code = self
304 .responses
305 .get(&request.slot.name)
306 .cloned()
307 .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
308
309 let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
310
311 let stream = async_stream::stream! {
312 for word in words {
313 yield Ok(StreamResponse {
314 delta: word,
315 metadata: None,
316 });
317 }
318 };
319
320 Box::pin(stream)
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[tokio::test]
329 async fn test_mock_provider() {
330 let provider = MockProvider::new()
331 .with_response("button", "<button>Click me</button>");
332
333 let request = GenerationRequest {
334 slot: Slot::new("button", "Create a button"),
335 context: None,
336 system_prompt: None,
337 };
338
339 let response = provider.generate(request).await.unwrap();
340 assert_eq!(response.code, "<button>Click me</button>");
341 }
342}