mur_core/model/
provider.rs1use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ModelResponse {
9 pub content: String,
10 pub model: String,
11 pub input_tokens: u32,
12 pub output_tokens: u32,
13 pub cost: f64,
14}
15
16pub struct AnthropicProvider {
18 api_key: String,
19 client: reqwest::Client,
20}
21
22#[derive(Serialize)]
23struct AnthropicRequest {
24 model: String,
25 max_tokens: u32,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 temperature: Option<f64>,
28 messages: Vec<Message>,
29}
30
31#[derive(Serialize)]
32struct Message {
33 role: String,
34 content: String,
35}
36
37#[derive(Deserialize)]
38struct AnthropicResponse {
39 content: Vec<ContentBlock>,
40 model: String,
41 usage: Usage,
42}
43
44#[derive(Deserialize)]
45struct AnthropicErrorResponse {
46 error: Option<AnthropicErrorDetail>,
47}
48
49#[derive(Deserialize)]
50struct AnthropicErrorDetail {
51 #[serde(rename = "type")]
52 error_type: Option<String>,
53 message: Option<String>,
54}
55
56#[derive(Deserialize)]
57struct ContentBlock {
58 text: Option<String>,
59}
60
61#[derive(Deserialize)]
62struct Usage {
63 input_tokens: u32,
64 output_tokens: u32,
65}
66
67impl AnthropicProvider {
68 pub fn new(api_key: Option<String>) -> Result<Self> {
70 let key = api_key
71 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
72 .context("ANTHROPIC_API_KEY not set")?;
73
74 Ok(Self {
75 api_key: key,
76 client: reqwest::Client::new(),
77 })
78 }
79
80 pub async fn complete(
82 &self,
83 prompt: &str,
84 model: &str,
85 temperature: f64,
86 max_tokens: u32,
87 ) -> Result<ModelResponse> {
88 let request = AnthropicRequest {
89 model: model.to_string(),
90 max_tokens: if max_tokens > 0 { max_tokens } else { 4096 },
91 temperature: Some(temperature),
92 messages: vec![Message {
93 role: "user".into(),
94 content: prompt.to_string(),
95 }],
96 };
97
98 let response = self
99 .client
100 .post("https://api.anthropic.com/v1/messages")
101 .header("x-api-key", &self.api_key)
102 .header("anthropic-version", "2023-06-01")
103 .header("content-type", "application/json")
104 .json(&request)
105 .send()
106 .await
107 .context("Failed to connect to Anthropic API")?;
108
109 if !response.status().is_success() {
110 let status = response.status();
111 let body = response.text().await.unwrap_or_default();
112
113 if let Ok(err) = serde_json::from_str::<AnthropicErrorResponse>(&body) {
115 if let Some(detail) = err.error {
116 let error_type = detail.error_type.as_deref().unwrap_or("unknown");
117 let message = detail.message.as_deref().unwrap_or("Unknown error");
118 match error_type {
119 "authentication_error" => {
120 anyhow::bail!(
121 "Anthropic authentication failed: {}. Check your ANTHROPIC_API_KEY",
122 message
123 );
124 }
125 "rate_limit_error" => {
126 anyhow::bail!(
127 "Anthropic rate limit exceeded: {}. Retry after a moment",
128 message
129 );
130 }
131 "overloaded_error" => {
132 anyhow::bail!(
133 "Anthropic API overloaded: {}. Retry after a moment",
134 message
135 );
136 }
137 _ => {
138 anyhow::bail!(
139 "Anthropic API error {} ({}): {}",
140 status,
141 error_type,
142 message
143 );
144 }
145 }
146 }
147 }
148
149 anyhow::bail!("Anthropic API error {}: {}", status, body);
150 }
151
152 let body: AnthropicResponse = response
153 .json()
154 .await
155 .context("Parsing Anthropic response")?;
156
157 let content = body
158 .content
159 .into_iter()
160 .filter_map(|b| b.text)
161 .collect::<Vec<_>>()
162 .join("");
163
164 let cost = estimate_cost(&body.model, body.usage.input_tokens, body.usage.output_tokens);
165
166 Ok(ModelResponse {
167 content,
168 model: body.model,
169 input_tokens: body.usage.input_tokens,
170 output_tokens: body.usage.output_tokens,
171 cost,
172 })
173 }
174}
175
176fn estimate_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
178 let (input_price, output_price) = match model {
179 m if m.contains("opus") => (15.0, 75.0), m if m.contains("sonnet") => (3.0, 15.0),
181 m if m.contains("haiku") => (0.25, 1.25),
182 _ => (3.0, 15.0), };
184
185 (input_tokens as f64 / 1_000_000.0 * input_price)
186 + (output_tokens as f64 / 1_000_000.0 * output_price)
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_cost_estimation() {
195 let cost = estimate_cost("claude-sonnet-4-20250514", 1000, 500);
197 assert!(cost > 0.0);
198 assert!(cost < 0.1); let opus_cost = estimate_cost("claude-opus-4-20250514", 1000, 500);
202 assert!(opus_cost > cost);
203 }
204}