1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4
5use super::{GenerationRequest, GenerationResponse, LLMProvider, ModelPricing};
6use crate::{Error, Result};
7
8const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
9const ANTHROPIC_VERSION: &str = "2023-06-01";
10
11pub struct AnthropicProvider {
12 client: Client,
13 api_key: String,
14 model: String,
15 default_temperature: Option<f32>,
16 default_max_tokens: Option<u32>,
17 pricing: ModelPricing,
18}
19
20#[derive(Serialize)]
21struct AnthropicRequest {
22 model: String,
23 max_tokens: u32,
24 messages: Vec<Message>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 system: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 temperature: Option<f32>,
29}
30
31#[derive(Serialize)]
32struct Message {
33 role: String,
34 content: String,
35}
36
37#[derive(Deserialize)]
38struct AnthropicResponse {
39 content: Vec<ContentBlock>,
40 usage: Usage,
41}
42
43#[derive(Deserialize)]
44struct ContentBlock {
45 text: String,
46}
47
48#[derive(Deserialize)]
49struct Usage {
50 input_tokens: u32,
51 output_tokens: u32,
52}
53
54#[derive(Deserialize)]
55struct AnthropicError {
56 error: ErrorDetail,
57}
58
59#[derive(Deserialize)]
60struct ErrorDetail {
61 message: String,
62}
63
64impl AnthropicProvider {
65 pub fn new(
66 model: String,
67 api_key: Option<String>,
68 temperature: Option<f32>,
69 max_tokens: Option<u32>,
70 ) -> Result<Self> {
71 let api_key = api_key
72 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
73 .ok_or_else(|| {
74 Error::Provider("ANTHROPIC_API_KEY not set and no api_key provided".into())
75 })?;
76
77 let client = Client::new();
78 let pricing = get_model_pricing(&model);
79
80 Ok(Self {
81 client,
82 api_key,
83 model,
84 default_temperature: temperature,
85 default_max_tokens: max_tokens,
86 pricing,
87 })
88 }
89}
90
91#[async_trait]
92impl LLMProvider for AnthropicProvider {
93 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
94 let temperature = request.temperature.or(self.default_temperature);
95 let max_tokens = request.max_tokens.or(self.default_max_tokens).unwrap_or(4096);
96
97 let anthropic_request = AnthropicRequest {
98 model: self.model.clone(),
99 max_tokens,
100 messages: vec![Message {
101 role: "user".to_string(),
102 content: request.prompt,
103 }],
104 system: request.system_prompt,
105 temperature,
106 };
107
108 let response = self
109 .client
110 .post(ANTHROPIC_API_URL)
111 .header("x-api-key", &self.api_key)
112 .header("anthropic-version", ANTHROPIC_VERSION)
113 .header("content-type", "application/json")
114 .json(&anthropic_request)
115 .send()
116 .await
117 .map_err(|e| Error::Provider(format!("Request failed: {}", e)))?;
118
119 let status = response.status();
120 let body = response
121 .text()
122 .await
123 .map_err(|e| Error::Provider(format!("Failed to read response: {}", e)))?;
124
125 if !status.is_success() {
126 let error: AnthropicError = serde_json::from_str(&body)
127 .map_err(|_| Error::Provider(format!("API error ({}): {}", status, body)))?;
128 return Err(Error::Provider(format!(
129 "Anthropic API error: {}",
130 error.error.message
131 )));
132 }
133
134 let response: AnthropicResponse = serde_json::from_str(&body)
135 .map_err(|e| Error::Provider(format!("Failed to parse response: {}", e)))?;
136
137 let content = response
138 .content
139 .first()
140 .map(|c| c.text.clone())
141 .ok_or_else(|| Error::Provider("No response content".into()))?;
142
143 Ok(GenerationResponse {
144 content,
145 input_tokens: response.usage.input_tokens,
146 output_tokens: response.usage.output_tokens,
147 })
148 }
149
150 fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
151 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.pricing.input_per_million;
152 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.pricing.output_per_million;
153 input_cost + output_cost
154 }
155
156 fn name(&self) -> &str {
157 "anthropic"
158 }
159
160 fn model(&self) -> &str {
161 &self.model
162 }
163}
164
165fn get_model_pricing(model: &str) -> ModelPricing {
166 match model {
167 m if m.contains("claude-haiku-4-5-20251001") || m.contains("claude-4.6-sonnet") => ModelPricing {
168 input_per_million: 3.00,
169 output_per_million: 15.00,
170 },
171 m if m.contains("claude-3-5-haiku") || m.contains("claude-3.5-haiku") => ModelPricing {
172 input_per_million: 0.80,
173 output_per_million: 4.00,
174 },
175 m if m.contains("claude-3-opus") => ModelPricing {
176 input_per_million: 15.00,
177 output_per_million: 75.00,
178 },
179 m if m.contains("claude-haiku-4-5-20251001") => ModelPricing {
180 input_per_million: 3.00,
181 output_per_million: 15.00,
182 },
183 m if m.contains("claude-3-haiku") => ModelPricing {
184 input_per_million: 0.25,
185 output_per_million: 1.25,
186 },
187 _ => ModelPricing {
188 input_per_million: 3.00,
189 output_per_million: 15.00,
190 },
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_model_pricing_sonnet() {
200 let pricing = get_model_pricing("claude-haiku-4-5-20251001");
201 assert_eq!(pricing.input_per_million, 3.00);
202 assert_eq!(pricing.output_per_million, 15.00);
203 }
204
205 #[test]
206 fn test_model_pricing_haiku() {
207 let pricing = get_model_pricing("claude-3-5-haiku-20241022");
208 assert_eq!(pricing.input_per_million, 0.80);
209 assert_eq!(pricing.output_per_million, 4.00);
210 }
211
212 #[test]
213 fn test_model_pricing_opus() {
214 let pricing = get_model_pricing("claude-3-opus-20240229");
215 assert_eq!(pricing.input_per_million, 15.00);
216 assert_eq!(pricing.output_per_million, 75.00);
217 }
218
219 #[test]
220 fn test_cost_estimation() {
221 let provider = AnthropicProvider::new(
222 "claude-haiku-4-5-20251001".to_string(),
223 Some("test-key".to_string()),
224 None,
225 None,
226 )
227 .unwrap();
228
229 let cost = provider.estimate_cost(1_000_000, 1_000_000);
231 assert!((cost - 18.00).abs() < 0.001); }
233
234 #[test]
235 fn test_provider_name_and_model() {
236 let provider = AnthropicProvider::new(
237 "claude-haiku-4-5-20251001".to_string(),
238 Some("test-key".to_string()),
239 Some(0.7),
240 Some(2000),
241 )
242 .unwrap();
243
244 assert_eq!(provider.name(), "anthropic");
245 assert_eq!(provider.model(), "claude-haiku-4-5-20251001");
246 }
247
248 #[test]
249 fn test_missing_api_key_error() {
250 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
251 let result = AnthropicProvider::new(
252 "claude-haiku-4-5-20251001".to_string(),
253 None,
254 None,
255 None,
256 );
257 assert!(result.is_err());
258 }
259}