Skip to main content

aether_ai/
anthropic.rs

1//! Anthropic Claude provider implementation.
2
3use aether_core::{
4    AetherError, AiProvider, ProviderConfig, Result,
5    provider::{GenerationRequest, GenerationResponse},
6    SlotKind,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use tracing::{debug, instrument};
12
13const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
14const ANTHROPIC_VERSION: &str = "2023-06-01";
15
16/// Anthropic Claude provider for code generation.
17#[derive(Debug, Clone)]
18pub struct AnthropicProvider {
19    client: Client,
20    config: ProviderConfig,
21}
22
23/// Anthropic message request.
24#[derive(Debug, Serialize)]
25struct MessageRequest {
26    model: String,
27    max_tokens: u32,
28    system: Option<String>,
29    messages: Vec<Message>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    temperature: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    stream: Option<bool>,
34}
35
36/// Anthropic streaming response event (minimal)
37#[derive(Debug, Deserialize)]
38#[serde(tag = "type")]
39enum StreamEvent {
40    #[serde(rename = "content_block_delta")]
41    ContentBlockDelta {
42        delta: TextDelta,
43    },
44    #[serde(other)]
45    Unknown,
46}
47
48#[derive(Debug, Deserialize)]
49struct TextDelta {
50    text: String,
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54struct Message {
55    role: String,
56    content: String,
57}
58
59/// Anthropic message response.
60#[derive(Debug, Deserialize)]
61struct MessageResponse {
62    content: Vec<ContentBlock>,
63    usage: Usage,
64}
65
66#[derive(Debug, Deserialize)]
67struct ContentBlock {
68    text: String,
69}
70
71#[derive(Debug, Deserialize)]
72struct Usage {
73    input_tokens: u32,
74    output_tokens: u32,
75}
76
77impl AnthropicProvider {
78    /// Create a new Anthropic provider.
79    pub fn new(config: ProviderConfig) -> Result<Self> {
80        let timeout = config.timeout_seconds.unwrap_or(60);
81        let client = Client::builder()
82            .timeout(std::time::Duration::from_secs(timeout))
83            .build()
84            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
85
86        Ok(Self { client, config })
87    }
88
89    /// Create a provider from environment variables.
90    ///
91    /// Reads `ANTHROPIC_API_KEY`.
92    pub fn from_env() -> Result<Self> {
93        let api_key = std::env::var("ANTHROPIC_API_KEY")
94            .map_err(|_| AetherError::ConfigError("ANTHROPIC_API_KEY not set".to_string()))?;
95
96        let model = std::env::var("ANTHROPIC_MODEL")
97            .unwrap_or_else(|_| "claude-opus-4-5".to_string());
98
99        let config = ProviderConfig::new(api_key, model);
100        Self::new(config)
101    }
102
103    /// Create a provider from environment with a specific model.
104    pub fn from_env_with_model(model: &str) -> Result<Self> {
105        let api_key = std::env::var("ANTHROPIC_API_KEY")
106            .map_err(|_| AetherError::ConfigError("ANTHROPIC_API_KEY not set".to_string()))?;
107
108        let config = ProviderConfig::new(api_key, model);
109        Self::new(config)
110    }
111
112    /// Build the system prompt for code generation.
113    fn build_system_prompt(&self, kind: &SlotKind, context: Option<&str>) -> String {
114        let base = "You are a code generation assistant. Generate only the requested code without explanations or markdown code blocks. Output raw code only.";
115
116        let kind_specific = match kind {
117            SlotKind::Html => "\nGenerate valid HTML5 markup.",
118            SlotKind::Css => "\nGenerate valid CSS styles.",
119            SlotKind::JavaScript => "\nGenerate valid JavaScript code.",
120            SlotKind::Function => "\nGenerate a complete function definition.",
121            SlotKind::Class => "\nGenerate a complete class/struct definition.",
122            SlotKind::Component => "\nGenerate a complete component with HTML, CSS, and JavaScript as needed.",
123            _ => "",
124        };
125
126        let context_part = context
127            .filter(|c| !c.is_empty())
128            .map(|c| format!("\n\nContext:\n{}", c))
129            .unwrap_or_default();
130
131        format!("{}{}{}", base, kind_specific, context_part)
132    }
133}
134
135use aether_core::provider::StreamResponse;
136use futures::stream::{BoxStream, StreamExt};
137
138#[async_trait]
139impl AiProvider for AnthropicProvider {
140    fn name(&self) -> &str {
141        "anthropic"
142    }
143
144    #[instrument(skip(self, request), fields(slot = %request.slot.name))]
145    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
146        debug!("Generating code with Anthropic for slot: {}", request.slot.name);
147
148        let api_key = self.config.resolve_api_key().await?;
149
150        let system = Some(request.system_prompt.unwrap_or_else(|| {
151            self.build_system_prompt(&request.slot.kind, request.context.as_deref())
152        }));
153
154        let messages = vec![Message {
155            role: "user".to_string(),
156            content: request.slot.prompt.clone(),
157        }];
158
159        let temperature = request.slot.temperature.or(self.config.temperature);
160        let api_request = MessageRequest {
161            model: request.model.clone().unwrap_or_else(|| self.config.model.clone()),
162            max_tokens: request.max_tokens.or(self.config.max_tokens).unwrap_or(4096),
163            system,
164            messages,
165            temperature,
166            stream: None,
167        };
168
169        let url = self.config.base_url.as_deref().unwrap_or(ANTHROPIC_API_URL);
170
171        let response = self
172            .client
173            .post(url)
174            .header("x-api-key", &api_key)
175            .header("anthropic-version", ANTHROPIC_VERSION)
176            .header("Content-Type", "application/json")
177            .json(&api_request)
178            .send()
179            .await
180            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
181
182        if !response.status().is_success() {
183            let status = response.status();
184            let body = response.text().await.unwrap_or_default();
185            return Err(AetherError::ProviderError(format!(
186                "API error {}: {}",
187                status, body
188            )));
189        }
190
191        let msg_response: MessageResponse = response
192            .json()
193            .await
194            .map_err(|e| AetherError::ProviderError(e.to_string()))?;
195
196        let code = msg_response
197            .content
198            .first()
199            .map(|c| c.text.clone())
200            .unwrap_or_default();
201
202        // Strip markdown code blocks if present
203        let code = strip_code_blocks(&code);
204
205        Ok(GenerationResponse {
206            code,
207            tokens_used: Some(msg_response.usage.input_tokens + msg_response.usage.output_tokens),
208            metadata: None,
209        })
210    }
211
212    fn generate_stream(
213        &self,
214        request: GenerationRequest,
215    ) -> BoxStream<'static, Result<StreamResponse>> {
216        let client = self.client.clone();
217        let config = self.config.clone();
218        let system = Some(request.system_prompt.unwrap_or_else(|| {
219            self.build_system_prompt(&request.slot.kind, request.context.as_deref())
220        }));
221        let user_prompt = request.slot.prompt.clone();
222        let url = config.base_url.as_deref().unwrap_or(ANTHROPIC_API_URL).to_string();
223
224        let temperature = request.slot.temperature.or(config.temperature);
225        let api_request = MessageRequest {
226            model: request.model.clone().unwrap_or_else(|| config.model.clone()),
227            max_tokens: request.max_tokens.or(config.max_tokens).unwrap_or(4096),
228            system,
229            messages: vec![Message {
230                role: "user".to_string(),
231                content: user_prompt,
232            }],
233            temperature,
234            stream: Some(true),
235        };
236
237        let stream = async_stream::stream! {
238            let api_key = match config.resolve_api_key().await {
239                Ok(k) => k,
240                Err(e) => {
241                    yield Err(e);
242                    return;
243                }
244            };
245
246            let response = client
247                .post(&url)
248                .header("x-api-key", &api_key)
249                .header("anthropic-version", ANTHROPIC_VERSION)
250                .header("Content-Type", "application/json")
251                .json(&api_request)
252                .send()
253                .await
254                .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
255
256            let response = match response {
257                Ok(r) => r,
258                Err(e) => {
259                    yield Err(e);
260                    return;
261                }
262            };
263
264            if !response.status().is_success() {
265                let status = response.status();
266                let body = response.text().await.unwrap_or_default();
267                yield Err(aether_core::AetherError::ProviderError(format!(
268                    "API error {}: {}",
269                    status, body
270                )));
271                return;
272            }
273
274            let mut stream = response.bytes_stream();
275            
276            while let Some(chunk_result) = stream.next().await {
277                let chunk = match chunk_result {
278                    Ok(c) => c,
279                    Err(e) => {
280                        yield Err(aether_core::AetherError::NetworkError(e.to_string()));
281                        break;
282                    }
283                };
284
285                let text = String::from_utf8_lossy(&chunk);
286                for line in text.lines() {
287                    let line = line.trim();
288                    if line.is_empty() { continue; }
289                    
290                    if let Some(event_data) = line.strip_prefix("data: ") {
291                        if let Ok(event) = serde_json::from_str::<StreamEvent>(event_data) {
292                            if let StreamEvent::ContentBlockDelta { delta } = event {
293                                yield Ok(StreamResponse {
294                                    delta: delta.text,
295                                    metadata: None,
296                                });
297                            }
298                        }
299                    }
300                }
301            }
302        };
303
304        Box::pin(stream)
305    }
306}
307
308/// Strip markdown code blocks from generated code.
309fn strip_code_blocks(code: &str) -> String {
310    let code = code.trim();
311
312    if code.starts_with("```") && code.ends_with("```") {
313        let lines: Vec<&str> = code.lines().collect();
314        if lines.len() >= 2 {
315            return lines[1..lines.len() - 1].join("\n");
316        }
317    }
318
319    code.to_string()
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_system_prompt() {
328        let config = ProviderConfig::new("test-key", "claude-3-sonnet-20240229");
329        let provider = AnthropicProvider::new(config).unwrap();
330
331        let prompt = provider.build_system_prompt(&SlotKind::Html, None);
332        assert!(prompt.contains("HTML5"));
333    }
334}