Skip to main content

mur_core/model/
provider.rs

1//! Model provider — Anthropic API client for AI-powered workflow steps.
2
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5
6/// Response from a model completion.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ModelResponse {
9    pub content: String,
10    pub model: String,
11    pub input_tokens: u32,
12    pub output_tokens: u32,
13    pub cost: f64,
14}
15
16/// Anthropic Messages API provider.
17pub struct AnthropicProvider {
18    api_key: String,
19    client: reqwest::Client,
20}
21
22#[derive(Serialize)]
23struct AnthropicRequest {
24    model: String,
25    max_tokens: u32,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    temperature: Option<f64>,
28    messages: Vec<Message>,
29}
30
31#[derive(Serialize)]
32struct Message {
33    role: String,
34    content: String,
35}
36
37#[derive(Deserialize)]
38struct AnthropicResponse {
39    content: Vec<ContentBlock>,
40    model: String,
41    usage: Usage,
42}
43
44#[derive(Deserialize)]
45struct AnthropicErrorResponse {
46    error: Option<AnthropicErrorDetail>,
47}
48
49#[derive(Deserialize)]
50struct AnthropicErrorDetail {
51    #[serde(rename = "type")]
52    error_type: Option<String>,
53    message: Option<String>,
54}
55
56#[derive(Deserialize)]
57struct ContentBlock {
58    text: Option<String>,
59}
60
61#[derive(Deserialize)]
62struct Usage {
63    input_tokens: u32,
64    output_tokens: u32,
65}
66
67impl AnthropicProvider {
68    /// Create a new provider. Reads API key from env if not provided.
69    pub fn new(api_key: Option<String>) -> Result<Self> {
70        let key = api_key
71            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
72            .context("ANTHROPIC_API_KEY not set")?;
73
74        Ok(Self {
75            api_key: key,
76            client: reqwest::Client::new(),
77        })
78    }
79
80    /// Send a completion request with temperature and max_tokens.
81    pub async fn complete(
82        &self,
83        prompt: &str,
84        model: &str,
85        temperature: f64,
86        max_tokens: u32,
87    ) -> Result<ModelResponse> {
88        let request = AnthropicRequest {
89            model: model.to_string(),
90            max_tokens: if max_tokens > 0 { max_tokens } else { 4096 },
91            temperature: Some(temperature),
92            messages: vec![Message {
93                role: "user".into(),
94                content: prompt.to_string(),
95            }],
96        };
97
98        let response = self
99            .client
100            .post("https://api.anthropic.com/v1/messages")
101            .header("x-api-key", &self.api_key)
102            .header("anthropic-version", "2023-06-01")
103            .header("content-type", "application/json")
104            .json(&request)
105            .send()
106            .await
107            .context("Failed to connect to Anthropic API")?;
108
109        if !response.status().is_success() {
110            let status = response.status();
111            let body = response.text().await.unwrap_or_default();
112
113            // Parse structured error for better messages
114            if let Ok(err) = serde_json::from_str::<AnthropicErrorResponse>(&body) {
115                if let Some(detail) = err.error {
116                    let error_type = detail.error_type.as_deref().unwrap_or("unknown");
117                    let message = detail.message.as_deref().unwrap_or("Unknown error");
118                    match error_type {
119                        "authentication_error" => {
120                            anyhow::bail!(
121                                "Anthropic authentication failed: {}. Check your ANTHROPIC_API_KEY",
122                                message
123                            );
124                        }
125                        "rate_limit_error" => {
126                            anyhow::bail!(
127                                "Anthropic rate limit exceeded: {}. Retry after a moment",
128                                message
129                            );
130                        }
131                        "overloaded_error" => {
132                            anyhow::bail!(
133                                "Anthropic API overloaded: {}. Retry after a moment",
134                                message
135                            );
136                        }
137                        _ => {
138                            anyhow::bail!(
139                                "Anthropic API error {} ({}): {}",
140                                status,
141                                error_type,
142                                message
143                            );
144                        }
145                    }
146                }
147            }
148
149            anyhow::bail!("Anthropic API error {}: {}", status, body);
150        }
151
152        let body: AnthropicResponse = response
153            .json()
154            .await
155            .context("Parsing Anthropic response")?;
156
157        let content = body
158            .content
159            .into_iter()
160            .filter_map(|b| b.text)
161            .collect::<Vec<_>>()
162            .join("");
163
164        let cost = estimate_cost(&body.model, body.usage.input_tokens, body.usage.output_tokens);
165
166        Ok(ModelResponse {
167            content,
168            model: body.model,
169            input_tokens: body.usage.input_tokens,
170            output_tokens: body.usage.output_tokens,
171            cost,
172        })
173    }
174}
175
176/// Estimate cost based on model and token counts.
177fn estimate_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
178    let (input_price, output_price) = match model {
179        m if m.contains("opus") => (15.0, 75.0), // per 1M tokens
180        m if m.contains("sonnet") => (3.0, 15.0),
181        m if m.contains("haiku") => (0.25, 1.25),
182        _ => (3.0, 15.0), // default to sonnet pricing
183    };
184
185    (input_tokens as f64 / 1_000_000.0 * input_price)
186        + (output_tokens as f64 / 1_000_000.0 * output_price)
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_cost_estimation() {
195        // 1000 input + 500 output on sonnet
196        let cost = estimate_cost("claude-sonnet-4-20250514", 1000, 500);
197        assert!(cost > 0.0);
198        assert!(cost < 0.1); // Should be small for 1500 tokens
199
200        // Opus should be more expensive
201        let opus_cost = estimate_cost("claude-opus-4-20250514", 1000, 500);
202        assert!(opus_cost > cost);
203    }
204}