ferro_ai/classifier/
anthropic.rs1use crate::error::Error;
2use async_trait::async_trait;
3
4use super::{ClassificationProvider, ClassifierConfig};
5
6pub struct AnthropicProvider {
16 client: reqwest::Client,
17 api_key: String,
18}
19
20impl AnthropicProvider {
21 pub fn new(api_key: String) -> Self {
25 let client = reqwest::Client::builder()
26 .timeout(std::time::Duration::from_secs(60))
27 .build()
28 .expect("failed to build reqwest client");
29 Self { client, api_key }
30 }
31
32 pub fn from_env() -> Result<Self, Error> {
34 let api_key = std::env::var("ANTHROPIC_API_KEY")
35 .map_err(|_| Error::Config("ANTHROPIC_API_KEY not set".to_string()))?;
36 Ok(Self::new(api_key))
37 }
38
39 pub(crate) fn build_request_body(
45 system_prompt: &str,
46 user_prompt: &str,
47 schema: &serde_json::Value,
48 config: &ClassifierConfig,
49 ) -> serde_json::Value {
50 serde_json::json!({
51 "model": config.model,
52 "max_tokens": config.max_tokens,
53 "system": [{
54 "type": "text",
55 "text": system_prompt,
56 "cache_control": {"type": "ephemeral"}
57 }],
58 "messages": [{"role": "user", "content": user_prompt}],
59 "output_config": {
60 "format": {
61 "type": "json_schema",
62 "schema": schema
63 }
64 }
65 })
66 }
67}
68
69pub(crate) fn is_permanent_error(status: u16) -> bool {
73 matches!(status, 400 | 401 | 403 | 404 | 422)
74}
75
76pub(crate) fn is_transient_error(status: u16) -> bool {
80 matches!(status, 429 | 500 | 503 | 529)
81}
82
83#[async_trait]
84impl ClassificationProvider for AnthropicProvider {
85 async fn classify_raw(
86 &self,
87 system_prompt: &str,
88 user_prompt: &str,
89 schema: &serde_json::Value,
90 config: &ClassifierConfig,
91 ) -> Result<serde_json::Value, Error> {
92 let body = Self::build_request_body(system_prompt, user_prompt, schema, config);
93
94 let response = self
95 .client
96 .post("https://api.anthropic.com/v1/messages")
97 .header("x-api-key", &self.api_key)
98 .header("anthropic-version", "2023-06-01")
99 .header("content-type", "application/json")
100 .json(&body)
101 .send()
102 .await
103 .map_err(|e| {
104 if e.is_timeout() {
105 Error::Timeout
106 } else {
107 Error::Provider(format!("request failed: {e}"))
108 }
109 })?;
110
111 let status = response.status().as_u16();
112
113 if is_permanent_error(status) {
114 let text = response.text().await.unwrap_or_default();
115 return Err(Error::Provider(format!("{status} {text}")));
116 }
117
118 if is_transient_error(status) {
119 let text = response.text().await.unwrap_or_default();
120 return Err(Error::Provider(format!("{status} {text}")));
121 }
122
123 if !response.status().is_success() {
124 let text = response.text().await.unwrap_or_default();
125 return Err(Error::Provider(format!("{status} {text}")));
126 }
127
128 let json: serde_json::Value = response
129 .json()
130 .await
131 .map_err(|e| Error::Deserialization(e.to_string()))?;
132
133 let text = json["content"]
135 .as_array()
136 .and_then(|arr| arr.first())
137 .and_then(|item| item["text"].as_str())
138 .ok_or_else(|| {
139 Error::Deserialization(format!("unexpected response structure: {json}"))
140 })?;
141
142 serde_json::from_str(text).map_err(|e| Error::Deserialization(e.to_string()))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_is_permanent_error() {
152 assert!(is_permanent_error(400));
153 assert!(is_permanent_error(401));
154 assert!(is_permanent_error(403));
155 assert!(is_permanent_error(404));
156 assert!(is_permanent_error(422));
157 assert!(!is_permanent_error(200));
158 assert!(!is_permanent_error(429));
159 assert!(!is_permanent_error(500));
160 assert!(!is_permanent_error(503));
161 assert!(!is_permanent_error(529));
162 }
163
164 #[test]
165 fn test_is_transient_error() {
166 assert!(is_transient_error(429));
167 assert!(is_transient_error(500));
168 assert!(is_transient_error(503));
169 assert!(is_transient_error(529));
170 assert!(!is_transient_error(200));
171 assert!(!is_transient_error(400));
172 assert!(!is_transient_error(401));
173 assert!(!is_transient_error(422));
174 }
175
176 #[test]
177 fn test_build_request_body_contains_output_config() {
178 let config = ClassifierConfig::default();
179 let schema = serde_json::json!({
180 "type": "object",
181 "properties": {
182 "category": {"type": "string"}
183 }
184 });
185
186 let body = AnthropicProvider::build_request_body(
187 "You classify intents.",
188 "Hello world",
189 &schema,
190 &config,
191 );
192
193 assert_eq!(body["model"], "claude-sonnet-4-6");
195 assert_eq!(body["max_tokens"], 1024);
196
197 assert_eq!(body["output_config"]["format"]["type"], "json_schema");
199 assert_eq!(body["output_config"]["format"]["schema"], schema);
200
201 let system = &body["system"][0];
203 assert_eq!(system["type"], "text");
204 assert_eq!(system["text"], "You classify intents.");
205 assert_eq!(system["cache_control"]["type"], "ephemeral");
206
207 assert_eq!(body["messages"][0]["role"], "user");
209 assert_eq!(body["messages"][0]["content"], "Hello world");
210 }
211
212 #[test]
213 fn test_build_request_body_uses_config_model() {
214 let config = ClassifierConfig {
215 model: "claude-opus-4-6".to_string(),
216 max_tokens: 2048,
217 ..Default::default()
218 };
219 let body = AnthropicProvider::build_request_body(
220 "system",
221 "user",
222 &serde_json::json!({}),
223 &config,
224 );
225 assert_eq!(body["model"], "claude-opus-4-6");
226 assert_eq!(body["max_tokens"], 2048);
227 }
228}