Skip to main content

synth_claw/providers/
anthropic.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4
5use super::{GenerationRequest, GenerationResponse, LLMProvider, ModelPricing};
6use crate::{Error, Result};
7
8const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
9const ANTHROPIC_VERSION: &str = "2023-06-01";
10
11pub struct AnthropicProvider {
12    client: Client,
13    api_key: String,
14    model: String,
15    default_temperature: Option<f32>,
16    default_max_tokens: Option<u32>,
17    pricing: ModelPricing,
18}
19
20#[derive(Serialize)]
21struct AnthropicRequest {
22    model: String,
23    max_tokens: u32,
24    messages: Vec<Message>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    system: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    temperature: Option<f32>,
29}
30
31#[derive(Serialize)]
32struct Message {
33    role: String,
34    content: String,
35}
36
37#[derive(Deserialize)]
38struct AnthropicResponse {
39    content: Vec<ContentBlock>,
40    usage: Usage,
41}
42
43#[derive(Deserialize)]
44struct ContentBlock {
45    text: String,
46}
47
48#[derive(Deserialize)]
49struct Usage {
50    input_tokens: u32,
51    output_tokens: u32,
52}
53
54#[derive(Deserialize)]
55struct AnthropicError {
56    error: ErrorDetail,
57}
58
59#[derive(Deserialize)]
60struct ErrorDetail {
61    message: String,
62}
63
64impl AnthropicProvider {
65    pub fn new(
66        model: String,
67        api_key: Option<String>,
68        temperature: Option<f32>,
69        max_tokens: Option<u32>,
70    ) -> Result<Self> {
71        let api_key = api_key
72            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
73            .ok_or_else(|| {
74                Error::Provider("ANTHROPIC_API_KEY not set and no api_key provided".into())
75            })?;
76
77        let client = Client::new();
78        let pricing = get_model_pricing(&model);
79
80        Ok(Self {
81            client,
82            api_key,
83            model,
84            default_temperature: temperature,
85            default_max_tokens: max_tokens,
86            pricing,
87        })
88    }
89}
90
91#[async_trait]
92impl LLMProvider for AnthropicProvider {
93    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
94        let temperature = request.temperature.or(self.default_temperature);
95        let max_tokens = request.max_tokens.or(self.default_max_tokens).unwrap_or(4096);
96
97        let anthropic_request = AnthropicRequest {
98            model: self.model.clone(),
99            max_tokens,
100            messages: vec![Message {
101                role: "user".to_string(),
102                content: request.prompt,
103            }],
104            system: request.system_prompt,
105            temperature,
106        };
107
108        let response = self
109            .client
110            .post(ANTHROPIC_API_URL)
111            .header("x-api-key", &self.api_key)
112            .header("anthropic-version", ANTHROPIC_VERSION)
113            .header("content-type", "application/json")
114            .json(&anthropic_request)
115            .send()
116            .await
117            .map_err(|e| Error::Provider(format!("Request failed: {}", e)))?;
118
119        let status = response.status();
120        let body = response
121            .text()
122            .await
123            .map_err(|e| Error::Provider(format!("Failed to read response: {}", e)))?;
124
125        if !status.is_success() {
126            let error: AnthropicError = serde_json::from_str(&body)
127                .map_err(|_| Error::Provider(format!("API error ({}): {}", status, body)))?;
128            return Err(Error::Provider(format!(
129                "Anthropic API error: {}",
130                error.error.message
131            )));
132        }
133
134        let response: AnthropicResponse = serde_json::from_str(&body)
135            .map_err(|e| Error::Provider(format!("Failed to parse response: {}", e)))?;
136
137        let content = response
138            .content
139            .first()
140            .map(|c| c.text.clone())
141            .ok_or_else(|| Error::Provider("No response content".into()))?;
142
143        Ok(GenerationResponse {
144            content,
145            input_tokens: response.usage.input_tokens,
146            output_tokens: response.usage.output_tokens,
147        })
148    }
149
150    fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
151        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.pricing.input_per_million;
152        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.pricing.output_per_million;
153        input_cost + output_cost
154    }
155
156    fn name(&self) -> &str {
157        "anthropic"
158    }
159
160    fn model(&self) -> &str {
161        &self.model
162    }
163}
164
165fn get_model_pricing(model: &str) -> ModelPricing {
166    match model {
167        m if m.contains("claude-haiku-4-5-20251001") || m.contains("claude-4.6-sonnet") => ModelPricing {
168            input_per_million: 3.00,
169            output_per_million: 15.00,
170        },
171        m if m.contains("claude-3-5-haiku") || m.contains("claude-3.5-haiku") => ModelPricing {
172            input_per_million: 0.80,
173            output_per_million: 4.00,
174        },
175        m if m.contains("claude-3-opus") => ModelPricing {
176            input_per_million: 15.00,
177            output_per_million: 75.00,
178        },
179        m if m.contains("claude-haiku-4-5-20251001") => ModelPricing {
180            input_per_million: 3.00,
181            output_per_million: 15.00,
182        },
183        m if m.contains("claude-3-haiku") => ModelPricing {
184            input_per_million: 0.25,
185            output_per_million: 1.25,
186        },
187        _ => ModelPricing {
188            input_per_million: 3.00,
189            output_per_million: 15.00,
190        },
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_model_pricing_sonnet() {
200        let pricing = get_model_pricing("claude-haiku-4-5-20251001");
201        assert_eq!(pricing.input_per_million, 3.00);
202        assert_eq!(pricing.output_per_million, 15.00);
203    }
204
205    #[test]
206    fn test_model_pricing_haiku() {
207        let pricing = get_model_pricing("claude-3-5-haiku-20241022");
208        assert_eq!(pricing.input_per_million, 0.80);
209        assert_eq!(pricing.output_per_million, 4.00);
210    }
211
212    #[test]
213    fn test_model_pricing_opus() {
214        let pricing = get_model_pricing("claude-3-opus-20240229");
215        assert_eq!(pricing.input_per_million, 15.00);
216        assert_eq!(pricing.output_per_million, 75.00);
217    }
218
219    #[test]
220    fn test_cost_estimation() {
221        let provider = AnthropicProvider::new(
222            "claude-haiku-4-5-20251001".to_string(),
223            Some("test-key".to_string()),
224            None,
225            None,
226        )
227        .unwrap();
228
229        // 1M input + 1M output tokens
230        let cost = provider.estimate_cost(1_000_000, 1_000_000);
231        assert!((cost - 18.00).abs() < 0.001); // $3 + $15 = $18
232    }
233
234    #[test]
235    fn test_provider_name_and_model() {
236        let provider = AnthropicProvider::new(
237            "claude-haiku-4-5-20251001".to_string(),
238            Some("test-key".to_string()),
239            Some(0.7),
240            Some(2000),
241        )
242        .unwrap();
243
244        assert_eq!(provider.name(), "anthropic");
245        assert_eq!(provider.model(), "claude-haiku-4-5-20251001");
246    }
247
248    #[test]
249    fn test_missing_api_key_error() {
250        unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
251        let result = AnthropicProvider::new(
252            "claude-haiku-4-5-20251001".to_string(),
253            None,
254            None,
255            None,
256        );
257        assert!(result.is_err());
258    }
259}