Skip to main content

llm/providers/anthropic/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_anthropic_stream;
3use super::types::{Request, Thinking};
4use crate::provider::{
5    LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window,
6};
7use crate::{Context, LlmError, ReasoningEffort, Result};
8use async_stream;
9use eventsource_stream::Eventsource;
10use futures::StreamExt;
11use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
12use reqwest::{Client, header};
13use std::env;
14use std::time::Duration;
15use tracing::debug;
16
17#[derive(Clone)]
18pub struct AnthropicProvider {
19    client: Client,
20    model: String,
21    base_url: Option<String>,
22    temperature: Option<f32>,
23    max_tokens: u32,
24    api_key: Option<String>,
25}
26
27impl AnthropicProvider {
28    pub fn new(api_key: Option<String>) -> Result<Self> {
29        let client = build_client()?;
30
31        Ok(Self {
32            client,
33            model: "claude-sonnet-4-5-20250929".to_string(),
34            base_url: Some("https://api.anthropic.com".to_string()),
35            temperature: None,
36            max_tokens: 16_384,
37            api_key,
38        })
39    }
40
41    pub fn with_model(mut self, model: &str) -> Self {
42        self.model = model.to_string();
43        self
44    }
45
46    pub fn with_base_url(mut self, base_url: &str) -> Self {
47        self.base_url = Some(base_url.to_string());
48        self
49    }
50
51    pub fn with_temperature(mut self, temperature: f32) -> Self {
52        self.temperature = Some(temperature);
53        self
54    }
55
56    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
57        self.max_tokens = max_tokens;
58        self
59    }
60
61    pub(crate) fn build_request(&self, context: &Context) -> Result<Request> {
62        let (system_prompt, messages) = map_messages(context.messages())?;
63        let tools = if context.tools().is_empty() {
64            None
65        } else {
66            Some(map_tools(context.tools())?)
67        };
68
69        let mut request = Request::new(self.model.clone(), messages)
70            .with_max_tokens(self.max_tokens)
71            .with_stream(true)
72            .with_auto_caching();
73
74        if let Some(temp) = self.temperature {
75            request = request.with_temperature(temp);
76        }
77
78        if let Some(system) = system_prompt {
79            request = request.with_system_cached(system);
80        }
81
82        if let Some(tools) = tools {
83            request = request.with_tools(tools);
84        }
85
86        if let Some(effort) = context.reasoning_effort() {
87            let budget_tokens = effort_to_budget_tokens(effort);
88            request = request.with_thinking(Thinking::new(budget_tokens));
89            // Anthropic requires temperature to be unset when thinking is enabled
90            request.temperature = None;
91            // max_tokens must be > budget_tokens
92            if request.max_tokens <= budget_tokens {
93                request.max_tokens = budget_tokens + 1024;
94            }
95        }
96
97        debug!("Built Anthropic request for model: {}", request.model);
98        Ok(request)
99    }
100
101    fn get_api_key(&self) -> Result<String> {
102        if let Some(key) = &self.api_key {
103            return Ok(key.clone());
104        }
105
106        if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
107            return Ok(api_key);
108        }
109
110        Err(LlmError::MissingApiKey(
111            "No Anthropic credentials found. Set ANTHROPIC_API_KEY environment variable."
112                .to_string(),
113        ))
114    }
115
116    fn build_headers(&self) -> Result<HeaderMap> {
117        let mut headers = HeaderMap::new();
118        headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
119        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
120        let api_key = self.get_api_key()?;
121        headers.insert("x-api-key", HeaderValue::from_str(&api_key)?);
122        Ok(headers)
123    }
124
125    async fn send_request(
126        &self,
127        request: Request,
128        headers: header::HeaderMap,
129    ) -> Result<impl futures::Stream<Item = Result<String>>> {
130        let base_url = self
131            .base_url
132            .as_deref()
133            .unwrap_or("https://api.anthropic.com");
134        let url = format!("{base_url}/v1/messages");
135
136        debug!("Sending request to Anthropic API: {url}");
137        debug!(
138            "Anthropic request body: {}",
139            serde_json::to_string(&request).unwrap_or_else(|_| "<failed to serialize>".to_string())
140        );
141
142        debug!("Anthropic request headers: {}", format_headers(&headers));
143        let response = self
144            .client
145            .post(&url)
146            .headers(headers)
147            .json(&request)
148            .send()
149            .await
150            .map_err(|e| LlmError::ApiRequest(e.to_string()))?;
151
152        if !response.status().is_success() {
153            let status = response.status();
154            let error_text = response
155                .text()
156                .await
157                .unwrap_or_else(|_| "Unknown error".to_string());
158            return Err(LlmError::ApiError(format!(
159                "Anthropic API request failed with status {status}: {error_text}"
160            )));
161        }
162
163        let event_stream = response.bytes_stream().eventsource();
164        let processed_stream = event_stream.filter_map(|result| {
165            std::future::ready(match result {
166                Ok(event) => {
167                    let data = event.data;
168                    if data == "[DONE]" {
169                        None
170                    } else {
171                        Some(Ok(data))
172                    }
173                }
174                Err(e) => Some(Err(LlmError::IoError(e.to_string()))),
175            })
176        });
177
178        Ok(processed_stream)
179    }
180}
181
182impl ProviderFactory for AnthropicProvider {
183    fn from_env() -> Result<Self> {
184        Self::new(None)
185    }
186
187    fn with_model(self, model: &str) -> Self {
188        self.with_model(model)
189    }
190}
191
192impl StreamingModelProvider for AnthropicProvider {
193    fn model(&self) -> Option<crate::LlmModel> {
194        format!("anthropic:{}", self.model).parse().ok()
195    }
196
197    fn context_window(&self) -> Option<u32> {
198        get_context_window("anthropic", &self.model)
199    }
200
201    fn stream_response<'a>(&self, context: &Context) -> LlmResponseStream {
202        let provider = self.clone();
203        let context = context.clone();
204
205        Box::pin(async_stream::stream! {
206            let headers = match provider.build_headers() {
207                Ok(result) => result,
208                Err(e) => {
209                    yield Err(e);
210                    return;
211                }
212            };
213
214            let request = match provider.build_request(&context) {
215                Ok(req) => req,
216                Err(e) => {
217                    yield Err(e);
218                    return;
219                }
220            };
221
222            let stream = match provider.send_request(request, headers).await {
223                Ok(stream) => stream,
224                Err(e) => {
225                    yield Err(e);
226                    return;
227                }
228            };
229
230            let mut anthropic_stream = Box::pin(process_anthropic_stream(stream));
231            while let Some(result) = anthropic_stream.next().await {
232                yield result;
233            }
234        })
235    }
236
237    fn display_name(&self) -> String {
238        format!("Anthropic ({})", self.model)
239    }
240}
241
242fn build_client() -> Result<Client> {
243    Client::builder()
244        .timeout(Duration::from_secs(60))
245        .build()
246        .map_err(|e| LlmError::HttpClientCreation(e.to_string()))
247}
248
249fn effort_to_budget_tokens(effort: ReasoningEffort) -> u32 {
250    match effort {
251        ReasoningEffort::Low => 1024,
252        ReasoningEffort::Medium => 4096,
253        ReasoningEffort::High | ReasoningEffort::Xhigh => 10240,
254    }
255}
256
257fn should_redact_header(name: &str) -> bool {
258    let lower = name.to_ascii_lowercase();
259    lower == "authorization"
260        || lower == "x-api-key"
261        || lower.contains("secret")
262        || lower.contains("token")
263}
264
265fn format_headers(headers: &header::HeaderMap) -> String {
266    let mut parts = Vec::new();
267    for (name, value) in headers {
268        let name_str = name.as_str();
269        let value_str = if should_redact_header(name_str) {
270            "<redacted>".to_string()
271        } else {
272            value.to_str().unwrap_or("<non-utf8>").to_string()
273        };
274        parts.push(format!("{name_str}={value_str}"));
275    }
276    parts.join(", ")
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::ChatMessage;
283    use crate::ContentBlock;
284    use crate::ToolDefinition;
285    use crate::providers::anthropic::types::{SystemContent, SystemContentBlock};
286    use crate::types::IsoString;
287    use reqwest::header::AUTHORIZATION;
288
289    fn create_test_provider() -> AnthropicProvider {
290        AnthropicProvider::new(Some("test-api-key".to_string()))
291            .unwrap()
292            .with_model("claude-sonnet-4-5-20250929")
293            .with_temperature(0.7)
294            .with_max_tokens(1000)
295    }
296
297    #[test]
298    fn test_provider_creation() {
299        let provider = AnthropicProvider::new(Some("test-api-key".to_string()));
300        assert!(provider.is_ok());
301    }
302
303    #[test]
304    fn build_headers_uses_api_key() {
305        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap();
306        let headers = provider.build_headers().expect("headers");
307        assert_eq!(
308            headers
309                .get("x-api-key")
310                .and_then(|value| value.to_str().ok()),
311            Some("test-api-key")
312        );
313        assert!(headers.get(AUTHORIZATION).is_none());
314        assert!(headers.get("anthropic-beta").is_none());
315    }
316
317    #[test]
318    fn test_build_request_simple() {
319        let provider = create_test_provider();
320
321        let context = Context::new(
322            vec![ChatMessage::User {
323                content: vec![ContentBlock::text("Hello")],
324                timestamp: IsoString::now(),
325            }],
326            vec![],
327        );
328
329        let request = provider.build_request(&context).unwrap();
330        assert_eq!(request.model, "claude-sonnet-4-5-20250929");
331        assert_eq!(request.max_tokens, 1000);
332        assert_eq!(request.messages.len(), 1);
333        assert!(request.tools.is_none());
334        assert!(request.stream);
335    }
336
337    #[test]
338    fn test_build_request_with_system_and_tools() {
339        let provider = create_test_provider();
340
341        let context = Context::new(
342            vec![
343                ChatMessage::System {
344                    content: "You are helpful".to_string(),
345                    timestamp: IsoString::now(),
346                },
347                ChatMessage::User {
348                    content: vec![ContentBlock::text("Hello")],
349                    timestamp: IsoString::now(),
350                },
351            ],
352            vec![ToolDefinition {
353                name: "search".to_string(),
354                description: "Search for information".to_string(),
355                parameters: r#"{"type": "object", "properties": {"query": {"type": "string"}}}"#
356                    .to_string(),
357                server: None,
358            }],
359        );
360
361        let request = provider.build_request(&context).unwrap();
362        if let Some(system) = &request.system {
363            match system {
364                SystemContent::Blocks(blocks) => {
365                    assert_eq!(blocks.len(), 1);
366                    let SystemContentBlock::Text { text, .. } = &blocks[0];
367                    assert_eq!(text, "You are helpful");
368                }
369                SystemContent::Text(_) => panic!("Expected blocks system content"),
370            }
371        } else {
372            panic!("Expected system prompt");
373        }
374        assert_eq!(request.messages.len(), 1);
375        assert!(request.tools.is_some());
376        assert_eq!(request.tools.unwrap().len(), 1);
377    }
378
379    #[test]
380    fn test_build_request_with_caching() {
381        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap(); // Caching is enabled by default
382
383        let context = Context::new(
384            vec![
385                ChatMessage::System {
386                    content: "Hello".to_string(),
387                    timestamp: IsoString::now(),
388                },
389                ChatMessage::User {
390                    content: vec![ContentBlock::text("Hello")],
391                    timestamp: IsoString::now(),
392                },
393            ],
394            vec![ToolDefinition {
395                name: "search".to_string(),
396                description: "Search for information".to_string(),
397                parameters: r#"{"type": "object", "properties": {"query": {"type": "string"}}}"#
398                    .to_string(),
399                server: None,
400            }],
401        );
402
403        let request = provider.build_request(&context).unwrap();
404
405        // With caching enabled, system prompt should be cached
406        if let Some(system) = &request.system {
407            match system {
408                SystemContent::Blocks(blocks) => {
409                    assert_eq!(blocks.len(), 1);
410                    let SystemContentBlock::Text {
411                        text,
412                        cache_control,
413                    } = &blocks[0];
414                    assert_eq!(text, "Hello");
415                    assert!(cache_control.is_some());
416                }
417                SystemContent::Text(_) => panic!("Expected blocks system content for caching"),
418            }
419        } else {
420            panic!("Expected system prompt");
421        }
422
423        assert!(request.tools.is_some());
424
425        // Top-level cache_control enables automatic caching
426        assert!(request.cache_control.is_some());
427    }
428
429    #[test]
430    fn test_build_request_with_reasoning_effort() {
431        let provider = create_test_provider();
432
433        let mut context = Context::new(
434            vec![ChatMessage::User {
435                content: vec![ContentBlock::text("Think hard")],
436                timestamp: IsoString::now(),
437            }],
438            vec![],
439        );
440        context.set_reasoning_effort(Some(crate::ReasoningEffort::High));
441
442        let request = provider.build_request(&context).unwrap();
443        let thinking = request.thinking.expect("thinking should be set");
444        assert_eq!(thinking.thinking_type, "enabled");
445        assert_eq!(thinking.budget_tokens, 10240);
446        // Temperature must be unset when thinking is enabled
447        assert!(request.temperature.is_none());
448        // max_tokens must exceed budget_tokens
449        assert!(request.max_tokens > thinking.budget_tokens);
450    }
451
452    #[test]
453    fn test_build_request_without_reasoning_effort_has_no_thinking() {
454        let provider = create_test_provider();
455        let context = Context::new(
456            vec![ChatMessage::User {
457                content: vec![ContentBlock::text("Hello")],
458                timestamp: IsoString::now(),
459            }],
460            vec![],
461        );
462
463        let request = provider.build_request(&context).unwrap();
464        assert!(request.thinking.is_none());
465    }
466
467    #[test]
468    fn test_build_request_thinking_bumps_max_tokens_if_needed() {
469        let provider = AnthropicProvider::new(Some("test-api-key".to_string()))
470            .unwrap()
471            .with_max_tokens(500);
472
473        let mut context = Context::new(
474            vec![ChatMessage::User {
475                content: vec![ContentBlock::text("Hi")],
476                timestamp: IsoString::now(),
477            }],
478            vec![],
479        );
480        context.set_reasoning_effort(Some(crate::ReasoningEffort::Low));
481
482        let request = provider.build_request(&context).unwrap();
483        let thinking = request.thinking.as_ref().unwrap();
484        assert!(
485            request.max_tokens > thinking.budget_tokens,
486            "max_tokens ({}) should exceed budget_tokens ({})",
487            request.max_tokens,
488            thinking.budget_tokens
489        );
490    }
491
492    #[test]
493    fn test_anthropic_provider_display_name() {
494        let provider = create_test_provider();
495        assert_eq!(
496            provider.display_name(),
497            "Anthropic (claude-sonnet-4-5-20250929)"
498        );
499    }
500
501    #[test]
502    fn test_anthropic_provider_display_name_default() {
503        let provider = AnthropicProvider::new(Some("test-api-key".to_string())).unwrap();
504        assert_eq!(
505            provider.display_name(),
506            "Anthropic (claude-sonnet-4-5-20250929)"
507        );
508    }
509
510    #[test]
511    fn format_headers_redacts_x_api_key() {
512        let mut headers = HeaderMap::new();
513        headers.insert("x-api-key", HeaderValue::from_static("sk-secret-123"));
514        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
515
516        let formatted = format_headers(&headers);
517        assert!(formatted.contains("x-api-key=<redacted>"));
518        assert!(formatted.contains("content-type=application/json"));
519        assert!(!formatted.contains("sk-secret-123"));
520    }
521
522    #[test]
523    fn format_headers_redacts_authorization() {
524        let mut headers = HeaderMap::new();
525        headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer token123"));
526
527        let formatted = format_headers(&headers);
528        assert!(formatted.contains("authorization=<redacted>"));
529        assert!(!formatted.contains("token123"));
530    }
531
532    #[test]
533    fn format_headers_redacts_secret_and_token_headers() {
534        let mut headers = HeaderMap::new();
535        headers.insert("x-client-secret", HeaderValue::from_static("mysecret"));
536        headers.insert("x-auth-token", HeaderValue::from_static("mytoken"));
537        headers.insert("accept", HeaderValue::from_static("text/plain"));
538
539        let formatted = format_headers(&headers);
540        assert!(formatted.contains("x-client-secret=<redacted>"));
541        assert!(formatted.contains("x-auth-token=<redacted>"));
542        assert!(formatted.contains("accept=text/plain"));
543        assert!(!formatted.contains("mysecret"));
544        assert!(!formatted.contains("mytoken"));
545    }
546}