Skip to main content

llm/providers/anthropic/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_anthropic_stream;
3use super::types::{Request, Thinking};
4use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
5use crate::{Context, LlmError, ReasoningEffort, Result};
6use async_stream;
7use eventsource_stream::Eventsource;
8use futures::StreamExt;
9use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
10use reqwest::{Client, header};
11use std::env;
12use std::time::Duration;
13use tracing::debug;
14
15#[derive(Clone)]
16pub struct AnthropicProvider {
17    client: Client,
18    model: String,
19    base_url: Option<String>,
20    temperature: Option<f32>,
21    max_tokens: u32,
22    api_key: Option<String>,
23}
24
25impl AnthropicProvider {
26    pub fn new(api_key: Option<String>) -> Result<Self> {
27        let client = build_client()?;
28
29        Ok(Self {
30            client,
31            model: "claude-sonnet-4-5-20250929".to_string(),
32            base_url: Some("https://api.anthropic.com".to_string()),
33            temperature: None,
34            max_tokens: 16_384,
35            api_key,
36        })
37    }
38
39    pub fn with_model(mut self, model: &str) -> Self {
40        self.model = model.to_string();
41        self
42    }
43
44    pub fn with_base_url(mut self, base_url: &str) -> Self {
45        self.base_url = Some(base_url.to_string());
46        self
47    }
48
49    pub fn with_temperature(mut self, temperature: f32) -> Self {
50        self.temperature = Some(temperature);
51        self
52    }
53
54    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
55        self.max_tokens = max_tokens;
56        self
57    }
58
59    pub(crate) fn build_request(&self, context: &Context) -> Result<Request> {
60        let (system_prompt, messages) = map_messages(context.messages())?;
61        let tools = if context.tools().is_empty() { None } else { Some(map_tools(context.tools())?) };
62
63        let mut request = Request::new(self.model.clone(), messages)
64            .with_max_tokens(self.max_tokens)
65            .with_stream(true)
66            .with_auto_caching();
67
68        if let Some(temp) = self.temperature {
69            request = request.with_temperature(temp);
70        }
71
72        if let Some(system) = system_prompt {
73            request = request.with_system_cached(system);
74        }
75
76        if let Some(tools) = tools {
77            request = request.with_tools(tools);
78        }
79
80        if let Some(effort) = context.reasoning_effort() {
81            let budget_tokens = effort_to_budget_tokens(effort);
82            request = request.with_thinking(Thinking::new(budget_tokens));
83            // Anthropic requires temperature to be unset when thinking is enabled
84            request.temperature = None;
85            // max_tokens must be > budget_tokens
86            if request.max_tokens <= budget_tokens {
87                request.max_tokens = budget_tokens + 1024;
88            }
89        }
90
91        debug!("Built Anthropic request for model: {}", request.model);
92        Ok(request)
93    }
94
95    fn get_api_key(&self) -> Result<String> {
96        if let Some(key) = &self.api_key {
97            return Ok(key.clone());
98        }
99
100        if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
101            return Ok(api_key);
102        }
103
104        Err(LlmError::MissingApiKey(
105            "No Anthropic credentials found. Set ANTHROPIC_API_KEY environment variable.".to_string(),
106        ))
107    }
108
109    fn build_headers(&self) -> Result<HeaderMap> {
110        let mut headers = HeaderMap::new();
111        headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
112        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
113        let api_key = self.get_api_key()?;
114        headers.insert("x-api-key", HeaderValue::from_str(&api_key)?);
115        Ok(headers)
116    }
117
118    async fn send_request(
119        &self,
120        request: Request,
121        headers: header::HeaderMap,
122    ) -> Result<impl futures::Stream<Item = Result<String>>> {
123        let base_url = self.base_url.as_deref().unwrap_or("https://api.anthropic.com");
124        let url = format!("{base_url}/v1/messages");
125
126        debug!("Sending request to Anthropic API: {url}");
127        debug!(
128            "Anthropic request body: {}",
129            serde_json::to_string(&request).unwrap_or_else(|_| "<failed to serialize>".to_string())
130        );
131
132        debug!("Anthropic request headers: {}", format_headers(&headers));
133        let response = self
134            .client
135            .post(&url)
136            .headers(headers)
137            .json(&request)
138            .send()
139            .await
140            .map_err(|e| LlmError::ApiRequest(e.to_string()))?;
141
142        if !response.status().is_success() {
143            let status = response.status();
144            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
145            return Err(LlmError::ApiError(format!("Anthropic API request failed with status {status}: {error_text}")));
146        }
147
148        let event_stream = response.bytes_stream().eventsource();
149        let processed_stream = event_stream.filter_map(|result| {
150            std::future::ready(match result {
151                Ok(event) => {
152                    let data = event.data;
153                    if data == "[DONE]" { None } else { Some(Ok(data)) }
154                }
155                Err(e) => Some(Err(LlmError::IoError(e.to_string()))),
156            })
157        });
158
159        Ok(processed_stream)
160    }
161}
162
163impl ProviderFactory for AnthropicProvider {
164    async fn from_env() -> Result<Self> {
165        Self::new(None)
166    }
167
168    fn with_model(self, model: &str) -> Self {
169        self.with_model(model)
170    }
171}
172
173impl StreamingModelProvider for AnthropicProvider {
174    fn model(&self) -> Option<crate::LlmModel> {
175        format!("anthropic:{}", self.model).parse().ok()
176    }
177
178    fn context_window(&self) -> Option<u32> {
179        get_context_window("anthropic", &self.model)
180    }
181
182    fn stream_response<'a>(&self, context: &Context) -> LlmResponseStream {
183        let provider = self.clone();
184        let context = context.clone();
185
186        Box::pin(async_stream::stream! {
187            let headers = match provider.build_headers() {
188                Ok(result) => result,
189                Err(e) => {
190                    yield Err(e);
191                    return;
192                }
193            };
194
195            let request = match provider.build_request(&context) {
196                Ok(req) => req,
197                Err(e) => {
198                    yield Err(e);
199                    return;
200                }
201            };
202
203            let stream = match provider.send_request(request, headers).await {
204                Ok(stream) => stream,
205                Err(e) => {
206                    yield Err(e);
207                    return;
208                }
209            };
210
211            let mut anthropic_stream = Box::pin(process_anthropic_stream(stream));
212            while let Some(result) = anthropic_stream.next().await {
213                yield result;
214            }
215        })
216    }
217
218    fn display_name(&self) -> String {
219        format!("Anthropic ({})", self.model)
220    }
221}
222
223fn build_client() -> Result<Client> {
224    Client::builder().timeout(Duration::from_secs(60)).build().map_err(|e| LlmError::HttpClientCreation(e.to_string()))
225}
226
227fn effort_to_budget_tokens(effort: ReasoningEffort) -> u32 {
228    match effort {
229        ReasoningEffort::Low => 1024,
230        ReasoningEffort::Medium => 4096,
231        ReasoningEffort::High | ReasoningEffort::Xhigh => 10240,
232    }
233}
234
235fn should_redact_header(name: &str) -> bool {
236    let lower = name.to_ascii_lowercase();
237    lower == "authorization" || lower == "x-api-key" || lower.contains("secret") || lower.contains("token")
238}
239
240fn format_headers(headers: &header::HeaderMap) -> String {
241    let mut parts = Vec::new();
242    for (name, value) in headers {
243        let name_str = name.as_str();
244        let value_str = if should_redact_header(name_str) {
245            "<redacted>".to_string()
246        } else {
247            value.to_str().unwrap_or("<non-utf8>").to_string()
248        };
249        parts.push(format!("{name_str}={value_str}"));
250    }
251    parts.join(", ")
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::ChatMessage;
258    use crate::ContentBlock;
259    use crate::ToolDefinition;
260    use crate::providers::anthropic::types::{SystemContent, SystemContentBlock};
261    use crate::types::IsoString;
262    use reqwest::header::AUTHORIZATION;
263
264    fn create_test_provider() -> AnthropicProvider {
265        AnthropicProvider::new(Some("test-api-key".to_string()))
266            .unwrap()
267            .with_model("claude-sonnet-4-5-20250929")
268            .with_temperature(0.7)
269            .with_max_tokens(1000)
270    }
271
272    #[test]
273    fn test_provider_creation() {
274        let provider = AnthropicProvider::new(Some("test-api-key".to_string()));
275        assert!(provider.is_ok());
276    }
277
278    #[test]
279    fn build_headers_uses_api_key() {
280        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap();
281        let headers = provider.build_headers().expect("headers");
282        assert_eq!(headers.get("x-api-key").and_then(|value| value.to_str().ok()), Some("test-api-key"));
283        assert!(headers.get(AUTHORIZATION).is_none());
284        assert!(headers.get("anthropic-beta").is_none());
285    }
286
287    #[test]
288    fn test_build_request_simple() {
289        let provider = create_test_provider();
290
291        let context = Context::new(
292            vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }],
293            vec![],
294        );
295
296        let request = provider.build_request(&context).unwrap();
297        assert_eq!(request.model, "claude-sonnet-4-5-20250929");
298        assert_eq!(request.max_tokens, 1000);
299        assert_eq!(request.messages.len(), 1);
300        assert!(request.tools.is_none());
301        assert!(request.stream);
302    }
303
304    #[test]
305    fn test_build_request_with_system_and_tools() {
306        let provider = create_test_provider();
307
308        let context = Context::new(
309            vec![
310                ChatMessage::System { content: "You are helpful".to_string(), timestamp: IsoString::now() },
311                ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() },
312            ],
313            vec![ToolDefinition {
314                name: "search".to_string(),
315                description: "Search for information".to_string(),
316                parameters: r#"{"type": "object", "properties": {"query": {"type": "string"}}}"#.to_string(),
317                server: None,
318            }],
319        );
320
321        let request = provider.build_request(&context).unwrap();
322        if let Some(system) = &request.system {
323            match system {
324                SystemContent::Blocks(blocks) => {
325                    assert_eq!(blocks.len(), 1);
326                    let SystemContentBlock::Text { text, .. } = &blocks[0];
327                    assert_eq!(text, "You are helpful");
328                }
329                SystemContent::Text(_) => panic!("Expected blocks system content"),
330            }
331        } else {
332            panic!("Expected system prompt");
333        }
334        assert_eq!(request.messages.len(), 1);
335        assert!(request.tools.is_some());
336        assert_eq!(request.tools.unwrap().len(), 1);
337    }
338
339    #[test]
340    fn test_build_request_with_caching() {
341        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap(); // Caching is enabled by default
342
343        let context = Context::new(
344            vec![
345                ChatMessage::System { content: "Hello".to_string(), timestamp: IsoString::now() },
346                ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() },
347            ],
348            vec![ToolDefinition {
349                name: "search".to_string(),
350                description: "Search for information".to_string(),
351                parameters: r#"{"type": "object", "properties": {"query": {"type": "string"}}}"#.to_string(),
352                server: None,
353            }],
354        );
355
356        let request = provider.build_request(&context).unwrap();
357
358        // With caching enabled, system prompt should be cached
359        if let Some(system) = &request.system {
360            match system {
361                SystemContent::Blocks(blocks) => {
362                    assert_eq!(blocks.len(), 1);
363                    let SystemContentBlock::Text { text, cache_control } = &blocks[0];
364                    assert_eq!(text, "Hello");
365                    assert!(cache_control.is_some());
366                }
367                SystemContent::Text(_) => panic!("Expected blocks system content for caching"),
368            }
369        } else {
370            panic!("Expected system prompt");
371        }
372
373        assert!(request.tools.is_some());
374
375        // Top-level cache_control enables automatic caching
376        assert!(request.cache_control.is_some());
377    }
378
379    #[test]
380    fn test_build_request_with_reasoning_effort() {
381        let provider = create_test_provider();
382
383        let mut context = Context::new(
384            vec![ChatMessage::User { content: vec![ContentBlock::text("Think hard")], timestamp: IsoString::now() }],
385            vec![],
386        );
387        context.set_reasoning_effort(Some(crate::ReasoningEffort::High));
388
389        let request = provider.build_request(&context).unwrap();
390        let thinking = request.thinking.expect("thinking should be set");
391        assert_eq!(thinking.thinking_type, "enabled");
392        assert_eq!(thinking.budget_tokens, 10240);
393        // Temperature must be unset when thinking is enabled
394        assert!(request.temperature.is_none());
395        // max_tokens must exceed budget_tokens
396        assert!(request.max_tokens > thinking.budget_tokens);
397    }
398
399    #[test]
400    fn test_build_request_without_reasoning_effort_has_no_thinking() {
401        let provider = create_test_provider();
402        let context = Context::new(
403            vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }],
404            vec![],
405        );
406
407        let request = provider.build_request(&context).unwrap();
408        assert!(request.thinking.is_none());
409    }
410
411    #[test]
412    fn test_build_request_thinking_bumps_max_tokens_if_needed() {
413        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap().with_max_tokens(500);
414
415        let mut context = Context::new(
416            vec![ChatMessage::User { content: vec![ContentBlock::text("Hi")], timestamp: IsoString::now() }],
417            vec![],
418        );
419        context.set_reasoning_effort(Some(crate::ReasoningEffort::Low));
420
421        let request = provider.build_request(&context).unwrap();
422        let thinking = request.thinking.as_ref().unwrap();
423        assert!(
424            request.max_tokens > thinking.budget_tokens,
425            "max_tokens ({}) should exceed budget_tokens ({})",
426            request.max_tokens,
427            thinking.budget_tokens
428        );
429    }
430
431    #[test]
432    fn test_anthropic_provider_display_name() {
433        let provider = create_test_provider();
434        assert_eq!(provider.display_name(), "Anthropic (claude-sonnet-4-5-20250929)");
435    }
436
437    #[test]
438    fn test_anthropic_provider_display_name_default() {
439        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap();
440        assert_eq!(provider.display_name(), "Anthropic (claude-sonnet-4-5-20250929)");
441    }
442
443    #[test]
444    fn format_headers_redacts_x_api_key() {
445        let mut headers = HeaderMap::new();
446        headers.insert("x-api-key", HeaderValue::from_static("sk-secret-123"));
447        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
448
449        let formatted = format_headers(&headers);
450        assert!(formatted.contains("x-api-key=<redacted>"));
451        assert!(formatted.contains("content-type=application/json"));
452        assert!(!formatted.contains("sk-secret-123"));
453    }
454
455    #[test]
456    fn format_headers_redacts_authorization() {
457        let mut headers = HeaderMap::new();
458        headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer token123"));
459
460        let formatted = format_headers(&headers);
461        assert!(formatted.contains("authorization=<redacted>"));
462        assert!(!formatted.contains("token123"));
463    }
464
465    #[test]
466    fn format_headers_redacts_secret_and_token_headers() {
467        let mut headers = HeaderMap::new();
468        headers.insert("x-client-secret", HeaderValue::from_static("mysecret"));
469        headers.insert("x-auth-token", HeaderValue::from_static("mytoken"));
470        headers.insert("accept", HeaderValue::from_static("text/plain"));
471
472        let formatted = format_headers(&headers);
473        assert!(formatted.contains("x-client-secret=<redacted>"));
474        assert!(formatted.contains("x-auth-token=<redacted>"));
475        assert!(formatted.contains("accept=text/plain"));
476        assert!(!formatted.contains("mysecret"));
477        assert!(!formatted.contains("mytoken"));
478    }
479}