use crate::{Result, Slot};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub api_key: String,
pub model: String,
pub base_url: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub timeout_seconds: Option<u64>,
pub api_key_url: Option<String>,
}
impl ProviderConfig {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
base_url: None,
max_tokens: None,
temperature: None,
timeout_seconds: None,
api_key_url: None,
}
}
pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
self.api_key_url = Some(url.into());
self
}
pub async fn resolve_api_key(&self) -> Result<String> {
if let Some(ref url) = self.api_key_url {
let resp = reqwest::get(url)
.await
.map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
let key = resp
.text()
.await
.map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
Ok(key.trim().to_string())
} else {
Ok(self.api_key.clone())
}
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn with_max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp.clamp(0.0, 2.0));
self
}
pub fn with_timeout(mut self, seconds: u64) -> Self {
self.timeout_seconds = Some(seconds);
self
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("AETHER_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.map_err(|_| {
crate::AetherError::ConfigError(
"AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
)
})?;
let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
let mut config = Self::new(api_key, model);
if let Ok(url) = std::env::var("AETHER_BASE_URL") {
config = config.with_base_url(url);
}
Ok(config)
}
}
#[derive(Debug, Clone)]
pub struct GenerationRequest {
pub slot: Slot,
pub context: Option<String>,
pub system_prompt: Option<String>,
pub model: Option<String>,
pub max_tokens: Option<u32>,
}
use futures::stream::BoxStream;
#[derive(Debug, Clone)]
pub struct GenerationResponse {
pub code: String,
pub tokens_used: Option<u32>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct StreamResponse {
pub delta: String,
pub metadata: Option<serde_json::Value>,
}
#[async_trait]
pub trait AiProvider: Send + Sync {
fn name(&self) -> &str;
async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
fn generate_stream(
&self,
_request: GenerationRequest,
) -> BoxStream<'static, Result<StreamResponse>> {
let name = self.name().to_string();
Box::pin(async_stream::stream! {
yield Err(crate::AetherError::ProviderError(format!(
"Streaming not implemented for provider: {}",
name
)));
})
}
async fn generate_batch(
&self,
requests: Vec<GenerationRequest>,
) -> Result<Vec<GenerationResponse>> {
let mut responses = Vec::with_capacity(requests.len());
for request in requests {
responses.push(self.generate(request).await?);
}
Ok(responses)
}
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
#[async_trait]
impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Arc<T> {
fn name(&self) -> &str {
(**self).name()
}
async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
(**self).generate(request).await
}
fn generate_stream(
&self,
request: GenerationRequest,
) -> BoxStream<'static, Result<StreamResponse>> {
(**self).generate_stream(request)
}
}
#[async_trait]
impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Box<T> {
fn name(&self) -> &str {
(**self).name()
}
async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
(**self).generate(request).await
}
fn generate_stream(
&self,
request: GenerationRequest,
) -> BoxStream<'static, Result<StreamResponse>> {
(**self).generate_stream(request)
}
}
#[derive(Debug, Default)]
pub struct MockProvider {
pub responses: std::collections::HashMap<String, String>,
}
impl MockProvider {
pub fn new() -> Self {
Self::default()
}
pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
self.responses.insert(slot.into(), code.into());
self
}
}
#[async_trait]
impl AiProvider for MockProvider {
fn name(&self) -> &str {
"mock"
}
async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
let code = self
.responses
.get(&request.slot.name)
.cloned()
.unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
Ok(GenerationResponse {
code,
tokens_used: Some(10),
metadata: None,
})
}
fn generate_stream(
&self,
request: GenerationRequest,
) -> BoxStream<'static, Result<StreamResponse>> {
let code = self
.responses
.get(&request.slot.name)
.cloned()
.unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
let stream = async_stream::stream! {
for word in words {
yield Ok(StreamResponse {
delta: word,
metadata: None,
});
}
};
Box::pin(stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_provider() {
let provider = MockProvider::new()
.with_response("button", "<button>Click me</button>");
let request = GenerationRequest {
slot: Slot::new("button", "Create a button"),
context: None,
system_prompt: None,
model: None,
max_tokens: None,
};
let response = provider.generate(request).await.unwrap();
assert_eq!(response.code, "<button>Click me</button>");
}
}