Skip to main content

evolve_llm/
anthropic.rs

1//! Anthropic Messages API client, Haiku-only.
2
3use crate::client::{CompletionResult, LlmClient, TokenUsage};
4use crate::error::LlmError;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9const DEFAULT_ENDPOINT: &str = "https://api.anthropic.com";
10const DEFAULT_MODEL: &str = "claude-haiku-4-5-20251001";
11const RETRY_DELAY: Duration = Duration::from_millis(500);
12
13/// Minimal client for Anthropic's Messages API, wired for Haiku.
14#[derive(Debug, Clone)]
15pub struct AnthropicHaikuClient {
16    api_key: String,
17    endpoint: String,
18    model: String,
19    http: reqwest::Client,
20}
21
22impl AnthropicHaikuClient {
23    /// Build from the `ANTHROPIC_API_KEY` env var, using the production endpoint.
24    pub fn from_env() -> Result<Self, LlmError> {
25        let key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| LlmError::NoApiKey)?;
26        Ok(Self::with_endpoint(key, DEFAULT_ENDPOINT))
27    }
28
29    /// Construct with a caller-supplied endpoint (used by cassette tests).
30    pub fn with_endpoint(api_key: impl Into<String>, endpoint: impl Into<String>) -> Self {
31        Self {
32            api_key: api_key.into(),
33            endpoint: endpoint.into(),
34            model: DEFAULT_MODEL.to_string(),
35            http: reqwest::Client::new(),
36        }
37    }
38}
39
40#[derive(Serialize)]
41struct MessagesRequest<'a> {
42    model: &'a str,
43    max_tokens: u32,
44    messages: Vec<Message<'a>>,
45}
46
47#[derive(Serialize)]
48struct Message<'a> {
49    role: &'a str,
50    content: &'a str,
51}
52
53#[derive(Deserialize)]
54struct MessagesResponse {
55    content: Vec<ContentBlock>,
56    #[serde(default)]
57    usage: Option<UsageReport>,
58}
59
60#[derive(Deserialize)]
61struct ContentBlock {
62    #[serde(rename = "type")]
63    kind: String,
64    text: Option<String>,
65}
66
67#[derive(Deserialize)]
68struct UsageReport {
69    #[serde(default)]
70    input_tokens: u32,
71    #[serde(default)]
72    output_tokens: u32,
73}
74
75#[async_trait]
76impl LlmClient for AnthropicHaikuClient {
77    async fn complete(&self, prompt: &str, max_tokens: u32) -> Result<CompletionResult, LlmError> {
78        let url = format!("{}/v1/messages", self.endpoint);
79        let body = MessagesRequest {
80            model: &self.model,
81            max_tokens,
82            messages: vec![Message {
83                role: "user",
84                content: prompt,
85            }],
86        };
87
88        for attempt in 0..=1 {
89            let result = self
90                .http
91                .post(&url)
92                .header("x-api-key", &self.api_key)
93                .header("anthropic-version", "2023-06-01")
94                .json(&body)
95                .send()
96                .await;
97
98            match result {
99                Ok(resp) => {
100                    let status = resp.status();
101                    if status.is_success() {
102                        let raw = resp.text().await?;
103                        let parsed: MessagesResponse = serde_json::from_str(&raw)?;
104                        let text = parsed
105                            .content
106                            .into_iter()
107                            .filter(|b| b.kind == "text")
108                            .filter_map(|b| b.text)
109                            .collect::<Vec<_>>()
110                            .join("\n");
111                        let usage = parsed
112                            .usage
113                            .map(|u| TokenUsage {
114                                input: u.input_tokens,
115                                output: u.output_tokens,
116                            })
117                            .unwrap_or_default();
118                        return Ok(CompletionResult { text, usage });
119                    }
120                    let retryable = status.as_u16() == 429 || status.is_server_error();
121                    if retryable && attempt == 0 {
122                        tokio::time::sleep(RETRY_DELAY).await;
123                        continue;
124                    }
125                    let body = resp.text().await.unwrap_or_default();
126                    let snippet = if body.len() > 512 {
127                        &body[..512]
128                    } else {
129                        &body
130                    };
131                    return Err(LlmError::UnexpectedStatus {
132                        status: status.as_u16(),
133                        body: snippet.to_string(),
134                    });
135                }
136                Err(e) if attempt == 0 && e.is_connect() => {
137                    tokio::time::sleep(RETRY_DELAY).await;
138                    continue;
139                }
140                Err(e) => return Err(LlmError::Http(e)),
141            }
142        }
143        unreachable!("retry loop exits via return")
144    }
145
146    fn model_id(&self) -> &str {
147        &self.model
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use wiremock::matchers::{header, method, path};
155    use wiremock::{Mock, MockServer, ResponseTemplate};
156
157    const SAMPLE_RESPONSE: &str = r#"{
158        "id": "msg_01",
159        "type": "message",
160        "role": "assistant",
161        "model": "claude-haiku-4-5-20251001",
162        "content": [{"type":"text","text":"Hello from mock Haiku"}],
163        "usage": {"input_tokens": 12, "output_tokens": 5}
164    }"#;
165
166    #[tokio::test]
167    async fn happy_path_parses_text_and_usage() {
168        let server = MockServer::start().await;
169        Mock::given(method("POST"))
170            .and(path("/v1/messages"))
171            .and(header("x-api-key", "test-key"))
172            .and(header("anthropic-version", "2023-06-01"))
173            .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
174            .mount(&server)
175            .await;
176
177        let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
178        let got = client.complete("hi", 16).await.unwrap();
179        assert_eq!(got.text, "Hello from mock Haiku");
180        assert_eq!(got.usage.input, 12);
181        assert_eq!(got.usage.output, 5);
182    }
183
184    #[tokio::test]
185    async fn retries_once_on_5xx_then_succeeds() {
186        let server = MockServer::start().await;
187        Mock::given(method("POST"))
188            .and(path("/v1/messages"))
189            .respond_with(ResponseTemplate::new(503))
190            .up_to_n_times(1)
191            .mount(&server)
192            .await;
193        Mock::given(method("POST"))
194            .and(path("/v1/messages"))
195            .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
196            .mount(&server)
197            .await;
198
199        let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
200        let got = client.complete("hi", 16).await.unwrap();
201        assert_eq!(got.text, "Hello from mock Haiku");
202    }
203
204    #[tokio::test]
205    async fn gives_up_after_second_5xx() {
206        let server = MockServer::start().await;
207        Mock::given(method("POST"))
208            .and(path("/v1/messages"))
209            .respond_with(ResponseTemplate::new(503))
210            .mount(&server)
211            .await;
212
213        let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
214        let err = client.complete("hi", 16).await.unwrap_err();
215        assert!(matches!(
216            err,
217            LlmError::UnexpectedStatus { status: 503, .. }
218        ));
219    }
220
221    #[tokio::test]
222    async fn does_not_retry_on_4xx_except_429() {
223        let server = MockServer::start().await;
224        Mock::given(method("POST"))
225            .and(path("/v1/messages"))
226            .respond_with(ResponseTemplate::new(400).set_body_string(r#"{"error":"bad"}"#))
227            .expect(1)
228            .mount(&server)
229            .await;
230
231        let client = AnthropicHaikuClient::with_endpoint("test-key", server.uri());
232        let err = client.complete("hi", 16).await.unwrap_err();
233        assert!(matches!(
234            err,
235            LlmError::UnexpectedStatus { status: 400, .. }
236        ));
237    }
238}