1use crate::{Result, Slot};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ProviderConfig {
12 pub api_key: String,
14
15 pub model: String,
17
18 pub base_url: Option<String>,
20
21 pub max_tokens: Option<u32>,
23
24 pub temperature: Option<f32>,
26
27 pub timeout_seconds: Option<u64>,
29
30 pub api_key_url: Option<String>,
32}
33
34impl ProviderConfig {
35 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
37 Self {
38 api_key: api_key.into(),
39 model: model.into(),
40 base_url: None,
41 max_tokens: None,
42 temperature: None,
43 timeout_seconds: None,
44 api_key_url: None,
45 }
46 }
47
48 pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
50 self.api_key_url = Some(url.into());
51 self
52 }
53
54 pub async fn resolve_api_key(&self) -> Result<String> {
56 if let Some(ref url) = self.api_key_url {
57 let resp = reqwest::get(url)
58 .await
59 .map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
60
61 let key = resp
62 .text()
63 .await
64 .map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
65
66 Ok(key.trim().to_string())
67 } else {
68 Ok(self.api_key.clone())
69 }
70 }
71
72 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
74 self.base_url = Some(url.into());
75 self
76 }
77
78 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
80 self.max_tokens = Some(tokens);
81 self
82 }
83
84 pub fn with_temperature(mut self, temp: f32) -> Self {
86 self.temperature = Some(temp.clamp(0.0, 2.0));
87 self
88 }
89
90 pub fn with_timeout(mut self, seconds: u64) -> Self {
92 self.timeout_seconds = Some(seconds);
93 self
94 }
95
96 pub fn from_env() -> Result<Self> {
103 let api_key = std::env::var("AETHER_API_KEY")
104 .or_else(|_| std::env::var("OPENAI_API_KEY"))
105 .map_err(|_| {
106 crate::AetherError::ConfigError(
107 "AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
108 )
109 })?;
110
111 let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
112
113 let mut config = Self::new(api_key, model);
114
115 if let Ok(url) = std::env::var("AETHER_BASE_URL") {
116 config = config.with_base_url(url);
117 }
118
119 Ok(config)
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct GenerationRequest {
126 pub slot: Slot,
128
129 pub context: Option<String>,
131
132 pub system_prompt: Option<String>,
134}
135
136use futures::stream::BoxStream;
137
138#[derive(Debug, Clone)]
140pub struct GenerationResponse {
141 pub code: String,
143
144 pub tokens_used: Option<u32>,
146
147 pub metadata: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone)]
153pub struct StreamResponse {
154 pub delta: String,
156
157 pub metadata: Option<serde_json::Value>,
159}
160
161#[async_trait]
165pub trait AiProvider: Send + Sync {
166 fn name(&self) -> &str;
168
169 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
179
180 fn generate_stream(
190 &self,
191 _request: GenerationRequest,
192 ) -> BoxStream<'static, Result<StreamResponse>> {
193 let name = self.name().to_string();
194 Box::pin(async_stream::stream! {
195 yield Err(crate::AetherError::ProviderError(format!(
196 "Streaming not implemented for provider: {}",
197 name
198 )));
199 })
200 }
201
202 async fn generate_batch(
206 &self,
207 requests: Vec<GenerationRequest>,
208 ) -> Result<Vec<GenerationResponse>> {
209 let mut responses = Vec::with_capacity(requests.len());
210 for request in requests {
211 responses.push(self.generate(request).await?);
212 }
213 Ok(responses)
214 }
215
216 async fn health_check(&self) -> Result<bool> {
218 Ok(true)
219 }
220}
221
222#[derive(Debug, Default)]
224pub struct MockProvider {
225 pub responses: std::collections::HashMap<String, String>,
227}
228
229impl MockProvider {
230 pub fn new() -> Self {
232 Self::default()
233 }
234
235 pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
237 self.responses.insert(slot.into(), code.into());
238 self
239 }
240}
241
242#[async_trait]
243impl AiProvider for MockProvider {
244 fn name(&self) -> &str {
245 "mock"
246 }
247
248 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
249 let code = self
250 .responses
251 .get(&request.slot.name)
252 .cloned()
253 .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
254
255 Ok(GenerationResponse {
256 code,
257 tokens_used: Some(10),
258 metadata: None,
259 })
260 }
261
262 fn generate_stream(
263 &self,
264 request: GenerationRequest,
265 ) -> BoxStream<'static, Result<StreamResponse>> {
266 let code = self
267 .responses
268 .get(&request.slot.name)
269 .cloned()
270 .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
271
272 use futures::StreamExt;
273 let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
274
275 let stream = async_stream::stream! {
276 for word in words {
277 yield Ok(StreamResponse {
278 delta: word,
279 metadata: None,
280 });
281 }
282 };
283
284 Box::pin(stream)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[tokio::test]
293 async fn test_mock_provider() {
294 let provider = MockProvider::new()
295 .with_response("button", "<button>Click me</button>");
296
297 let request = GenerationRequest {
298 slot: Slot::new("button", "Create a button"),
299 context: None,
300 system_prompt: None,
301 };
302
303 let response = provider.generate(request).await.unwrap();
304 assert_eq!(response.code, "<button>Click me</button>");
305 }
306}