Skip to main content

ferro_ai/classifier/
anthropic.rs

1use crate::error::Error;
2use async_trait::async_trait;
3
4use super::{ClassificationProvider, ClassifierConfig};
5
6/// Anthropic API-based classification provider.
7///
8/// Uses the Anthropic Messages API with `output_config.format.type = "json_schema"`
9/// for guaranteed schema-compliant JSON output.
10///
11/// # Authentication
12///
13/// Requires an `ANTHROPIC_API_KEY` environment variable or an explicit API key
14/// passed to [`AnthropicProvider::new`].
15pub struct AnthropicProvider {
16    client: reqwest::Client,
17    api_key: String,
18}
19
20impl AnthropicProvider {
21    /// Create a new provider with an explicit API key.
22    ///
23    /// The internal `reqwest::Client` uses a 60-second timeout.
24    pub fn new(api_key: String) -> Self {
25        let client = reqwest::Client::builder()
26            .timeout(std::time::Duration::from_secs(60))
27            .build()
28            .expect("failed to build reqwest client");
29        Self { client, api_key }
30    }
31
32    /// Create a provider reading the API key from `ANTHROPIC_API_KEY`.
33    pub fn from_env() -> Result<Self, Error> {
34        let api_key = std::env::var("ANTHROPIC_API_KEY")
35            .map_err(|_| Error::Config("ANTHROPIC_API_KEY not set".to_string()))?;
36        Ok(Self::new(api_key))
37    }
38
39    /// Build the request body for the Anthropic Messages API.
40    ///
41    /// Uses `output_config.format.type = "json_schema"` for structured output.
42    /// The system prompt is cached with `cache_control.type = "ephemeral"` to
43    /// reduce token costs on repeated calls with the same system prompt.
44    pub(crate) fn build_request_body(
45        system_prompt: &str,
46        user_prompt: &str,
47        schema: &serde_json::Value,
48        config: &ClassifierConfig,
49    ) -> serde_json::Value {
50        serde_json::json!({
51            "model": config.model,
52            "max_tokens": config.max_tokens,
53            "system": [{
54                "type": "text",
55                "text": system_prompt,
56                "cache_control": {"type": "ephemeral"}
57            }],
58            "messages": [{"role": "user", "content": user_prompt}],
59            "output_config": {
60                "format": {
61                    "type": "json_schema",
62                    "schema": schema
63                }
64            }
65        })
66    }
67}
68
69/// Returns `true` for HTTP status codes that indicate a permanent failure.
70///
71/// Permanent errors (400, 401, 403, 404, 422) should not be retried.
72pub(crate) fn is_permanent_error(status: u16) -> bool {
73    matches!(status, 400 | 401 | 403 | 404 | 422)
74}
75
76/// Returns `true` for HTTP status codes that indicate a transient failure.
77///
78/// Transient errors (429, 500, 503, 529) are safe to retry with a delay.
79pub(crate) fn is_transient_error(status: u16) -> bool {
80    matches!(status, 429 | 500 | 503 | 529)
81}
82
83#[async_trait]
84impl ClassificationProvider for AnthropicProvider {
85    async fn classify_raw(
86        &self,
87        system_prompt: &str,
88        user_prompt: &str,
89        schema: &serde_json::Value,
90        config: &ClassifierConfig,
91    ) -> Result<serde_json::Value, Error> {
92        let body = Self::build_request_body(system_prompt, user_prompt, schema, config);
93
94        let response = self
95            .client
96            .post("https://api.anthropic.com/v1/messages")
97            .header("x-api-key", &self.api_key)
98            .header("anthropic-version", "2023-06-01")
99            .header("content-type", "application/json")
100            .json(&body)
101            .send()
102            .await
103            .map_err(|e| {
104                if e.is_timeout() {
105                    Error::Timeout
106                } else {
107                    Error::Provider(format!("request failed: {e}"))
108                }
109            })?;
110
111        let status = response.status().as_u16();
112
113        if is_permanent_error(status) {
114            let text = response.text().await.unwrap_or_default();
115            return Err(Error::Provider(format!("{status} {text}")));
116        }
117
118        if is_transient_error(status) {
119            let text = response.text().await.unwrap_or_default();
120            return Err(Error::Provider(format!("{status} {text}")));
121        }
122
123        if !response.status().is_success() {
124            let text = response.text().await.unwrap_or_default();
125            return Err(Error::Provider(format!("{status} {text}")));
126        }
127
128        let json: serde_json::Value = response
129            .json()
130            .await
131            .map_err(|e| Error::Deserialization(e.to_string()))?;
132
133        // Extract content[0].text from Anthropic response envelope
134        let text = json["content"]
135            .as_array()
136            .and_then(|arr| arr.first())
137            .and_then(|item| item["text"].as_str())
138            .ok_or_else(|| {
139                Error::Deserialization(format!("unexpected response structure: {json}"))
140            })?;
141
142        serde_json::from_str(text).map_err(|e| Error::Deserialization(e.to_string()))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_is_permanent_error() {
152        assert!(is_permanent_error(400));
153        assert!(is_permanent_error(401));
154        assert!(is_permanent_error(403));
155        assert!(is_permanent_error(404));
156        assert!(is_permanent_error(422));
157        assert!(!is_permanent_error(200));
158        assert!(!is_permanent_error(429));
159        assert!(!is_permanent_error(500));
160        assert!(!is_permanent_error(503));
161        assert!(!is_permanent_error(529));
162    }
163
164    #[test]
165    fn test_is_transient_error() {
166        assert!(is_transient_error(429));
167        assert!(is_transient_error(500));
168        assert!(is_transient_error(503));
169        assert!(is_transient_error(529));
170        assert!(!is_transient_error(200));
171        assert!(!is_transient_error(400));
172        assert!(!is_transient_error(401));
173        assert!(!is_transient_error(422));
174    }
175
176    #[test]
177    fn test_build_request_body_contains_output_config() {
178        let config = ClassifierConfig::default();
179        let schema = serde_json::json!({
180            "type": "object",
181            "properties": {
182                "category": {"type": "string"}
183            }
184        });
185
186        let body = AnthropicProvider::build_request_body(
187            "You classify intents.",
188            "Hello world",
189            &schema,
190            &config,
191        );
192
193        // Verify model and max_tokens from config
194        assert_eq!(body["model"], "claude-sonnet-4-6");
195        assert_eq!(body["max_tokens"], 1024);
196
197        // Verify output_config.format.type = "json_schema"
198        assert_eq!(body["output_config"]["format"]["type"], "json_schema");
199        assert_eq!(body["output_config"]["format"]["schema"], schema);
200
201        // Verify system prompt with cache_control
202        let system = &body["system"][0];
203        assert_eq!(system["type"], "text");
204        assert_eq!(system["text"], "You classify intents.");
205        assert_eq!(system["cache_control"]["type"], "ephemeral");
206
207        // Verify user message
208        assert_eq!(body["messages"][0]["role"], "user");
209        assert_eq!(body["messages"][0]["content"], "Hello world");
210    }
211
212    #[test]
213    fn test_build_request_body_uses_config_model() {
214        let config = ClassifierConfig {
215            model: "claude-opus-4-6".to_string(),
216            max_tokens: 2048,
217            ..Default::default()
218        };
219        let body = AnthropicProvider::build_request_body(
220            "system",
221            "user",
222            &serde_json::json!({}),
223            &config,
224        );
225        assert_eq!(body["model"], "claude-opus-4-6");
226        assert_eq!(body["max_tokens"], 2048);
227    }
228}