1use crate::client::{CompletionResult, LlmClient, TokenUsage};
4use crate::error::LlmError;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9const DEFAULT_ENDPOINT: &str = "https://api.anthropic.com";
10const DEFAULT_MODEL: &str = "claude-haiku-4-5-20251001";
11const RETRY_DELAY: Duration = Duration::from_millis(500);
12
13#[derive(Debug, Clone)]
15pub struct AnthropicHaikuClient {
16 api_key: String,
17 endpoint: String,
18 model: String,
19 http: reqwest::Client,
20}
21
22impl AnthropicHaikuClient {
23 pub fn from_env() -> Result<Self, LlmError> {
25 let key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| LlmError::NoApiKey)?;
26 Ok(Self::with_endpoint(key, DEFAULT_ENDPOINT))
27 }
28
29 pub fn with_endpoint(api_key: impl Into<String>, endpoint: impl Into<String>) -> Self {
31 Self {
32 api_key: api_key.into(),
33 endpoint: endpoint.into(),
34 model: DEFAULT_MODEL.to_string(),
35 http: reqwest::Client::new(),
36 }
37 }
38}
39
40#[derive(Serialize)]
41struct MessagesRequest<'a> {
42 model: &'a str,
43 max_tokens: u32,
44 messages: Vec<Message<'a>>,
45}
46
47#[derive(Serialize)]
48struct Message<'a> {
49 role: &'a str,
50 content: &'a str,
51}
52
53#[derive(Deserialize)]
54struct MessagesResponse {
55 content: Vec<ContentBlock>,
56 #[serde(default)]
57 usage: Option<UsageReport>,
58}
59
60#[derive(Deserialize)]
61struct ContentBlock {
62 #[serde(rename = "type")]
63 kind: String,
64 text: Option<String>,
65}
66
67#[derive(Deserialize)]
68struct UsageReport {
69 #[serde(default)]
70 input_tokens: u32,
71 #[serde(default)]
72 output_tokens: u32,
73}
74
75#[async_trait]
76impl LlmClient for AnthropicHaikuClient {
77 async fn complete(&self, prompt: &str, max_tokens: u32) -> Result<CompletionResult, LlmError> {
78 let url = format!("{}/v1/messages", self.endpoint);
79 let body = MessagesRequest {
80 model: &self.model,
81 max_tokens,
82 messages: vec![Message {
83 role: "user",
84 content: prompt,
85 }],
86 };
87
88 for attempt in 0..=1 {
89 let result = self
90 .http
91 .post(&url)
92 .header("x-api-key", &self.api_key)
93 .header("anthropic-version", "2023-06-01")
94 .json(&body)
95 .send()
96 .await;
97
98 match result {
99 Ok(resp) => {
100 let status = resp.status();
101 if status.is_success() {
102 let raw = resp.text().await?;
103 let parsed: MessagesResponse = serde_json::from_str(&raw)?;
104 let text = parsed
105 .content
106 .into_iter()
107 .filter(|b| b.kind == "text")
108 .filter_map(|b| b.text)
109 .collect::<Vec<_>>()
110 .join("\n");
111 let usage = parsed
112 .usage
113 .map(|u| TokenUsage {
114 input: u.input_tokens,
115 output: u.output_tokens,
116 })
117 .unwrap_or_default();
118 return Ok(CompletionResult { text, usage });
119 }
120 let retryable = status.as_u16() == 429 || status.is_server_error();
121 if retryable && attempt == 0 {
122 tokio::time::sleep(RETRY_DELAY).await;
123 continue;
124 }
125 let body = resp.text().await.unwrap_or_default();
126 let snippet = if body.len() > 512 {
127 &body[..512]
128 } else {
129 &body
130 };
131 return Err(LlmError::UnexpectedStatus {
132 status: status.as_u16(),
133 body: snippet.to_string(),
134 });
135 }
136 Err(e) if attempt == 0 && e.is_connect() => {
137 tokio::time::sleep(RETRY_DELAY).await;
138 continue;
139 }
140 Err(e) => return Err(LlmError::Http(e)),
141 }
142 }
143 unreachable!("retry loop exits via return")
144 }
145
146 fn model_id(&self) -> &str {
147 &self.model
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use wiremock::matchers::{header, method, path};
155 use wiremock::{Mock, MockServer, ResponseTemplate};
156
157 const SAMPLE_RESPONSE: &str = r#"{
158 "id": "msg_01",
159 "type": "message",
160 "role": "assistant",
161 "model": "claude-haiku-4-5-20251001",
162 "content": [{"type":"text","text":"Hello from mock Haiku"}],
163 "usage": {"input_tokens": 12, "output_tokens": 5}
164 }"#;
165
166 #[tokio::test]
167 async fn happy_path_parses_text_and_usage() {
168 let server = MockServer::start().await;
169 Mock::given(method("POST"))
170 .and(path("/v1/messages"))
171 .and(header("x-api-key", "test-key"))
172 .and(header("anthropic-version", "2023-06-01"))
173 .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
174 .mount(&server)
175 .await;
176
177 let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
178 let got = client.complete("hi", 16).await.unwrap();
179 assert_eq!(got.text, "Hello from mock Haiku");
180 assert_eq!(got.usage.input, 12);
181 assert_eq!(got.usage.output, 5);
182 }
183
184 #[tokio::test]
185 async fn retries_once_on_5xx_then_succeeds() {
186 let server = MockServer::start().await;
187 Mock::given(method("POST"))
188 .and(path("/v1/messages"))
189 .respond_with(ResponseTemplate::new(503))
190 .up_to_n_times(1)
191 .mount(&server)
192 .await;
193 Mock::given(method("POST"))
194 .and(path("/v1/messages"))
195 .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
196 .mount(&server)
197 .await;
198
199 let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
200 let got = client.complete("hi", 16).await.unwrap();
201 assert_eq!(got.text, "Hello from mock Haiku");
202 }
203
204 #[tokio::test]
205 async fn gives_up_after_second_5xx() {
206 let server = MockServer::start().await;
207 Mock::given(method("POST"))
208 .and(path("/v1/messages"))
209 .respond_with(ResponseTemplate::new(503))
210 .mount(&server)
211 .await;
212
213 let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
214 let err = client.complete("hi", 16).await.unwrap_err();
215 assert!(matches!(
216 err,
217 LlmError::UnexpectedStatus { status: 503, .. }
218 ));
219 }
220
221 #[tokio::test]
222 async fn does_not_retry_on_4xx_except_429() {
223 let server = MockServer::start().await;
224 Mock::given(method("POST"))
225 .and(path("/v1/messages"))
226 .respond_with(ResponseTemplate::new(400).set_body_string(r#"{"error":"bad"}"#))
227 .expect(1)
228 .mount(&server)
229 .await;
230
231 let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
232 let err = client.complete("hi", 16).await.unwrap_err();
233 assert!(matches!(
234 err,
235 LlmError::UnexpectedStatus { status: 400, .. }
236 ));
237 }
238}