1use anyhow::Context;
6use haki_config::ProviderConfig;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Message {
21 pub role: Role,
22 pub content: String,
23}
24
25impl Message {
26 pub fn user(content: impl Into<String>) -> Self {
27 Self { role: Role::User, content: content.into() }
28 }
29 pub fn assistant(content: impl Into<String>) -> Self {
30 Self { role: Role::Assistant, content: content.into() }
31 }
32 pub fn system(content: impl Into<String>) -> Self {
33 Self { role: Role::System, content: content.into() }
34 }
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct TokenUsage {
39 pub input_tokens: u64,
40 pub output_tokens: u64,
41 pub cache_read_tokens: u64,
42 pub cache_write_tokens: u64,
43}
44
45#[derive(Debug, Clone)]
46pub struct CompletionRequest {
47 pub model: String,
48 pub system: Option<String>,
49 pub messages: Vec<Message>,
50 pub max_tokens: u32,
51}
52
53#[derive(Debug, Clone)]
54pub struct CompletionResponse {
55 pub content: String,
56 pub usage: TokenUsage,
57}
58
59#[derive(Debug, Clone)]
62pub enum LlmProvider {
63 Anthropic { api_key: String, base_url: String },
64 OpenAi { api_key: String, base_url: String },
65 Mock { response: String },
67}
68
69impl LlmProvider {
70 pub fn from_config(cfg: &ProviderConfig) -> anyhow::Result<Self> {
71 let api_key = cfg
72 .api_key
73 .clone()
74 .or_else(|| std::env::var(Self::env_key_name(&cfg.name)).ok())
75 .with_context(|| {
76 format!(
77 "No API key for '{}'. Set {} or haki-config provider.api_key.",
78 cfg.name,
79 Self::env_key_name(&cfg.name)
80 )
81 })?;
82
83 match cfg.name.to_lowercase().as_str() {
84 "anthropic" => Ok(Self::Anthropic {
85 api_key,
86 base_url: cfg
87 .base_url
88 .clone()
89 .unwrap_or_else(|| "https://api.anthropic.com".into()),
90 }),
91 "openai" => Ok(Self::OpenAi {
92 api_key,
93 base_url: cfg
94 .base_url
95 .clone()
96 .unwrap_or_else(|| "https://api.openai.com".into()),
97 }),
98 other => anyhow::bail!(
99 "Unknown provider '{}'. Supported: anthropic, openai",
100 other
101 ),
102 }
103 }
104
105 pub fn provider_name(&self) -> &str {
106 match self {
107 Self::Anthropic { .. } => "anthropic",
108 Self::OpenAi { .. } => "openai",
109 Self::Mock { .. } => "mock",
110 }
111 }
112
113 pub fn default_model(&self) -> &str {
114 match self {
115 Self::Anthropic { .. } => "claude-sonnet-4-5",
116 Self::OpenAi { .. } => "gpt-4o",
117 Self::Mock { .. } => "mock-model",
118 }
119 }
120
121 pub fn mock(response: impl Into<String>) -> Self {
124 Self::Mock { response: response.into() }
125 }
126
127 pub async fn complete(&self, req: CompletionRequest) -> anyhow::Result<CompletionResponse> {
128 match self {
129 Self::Anthropic { api_key, base_url } => {
130 anthropic_complete(api_key, base_url, req).await
131 }
132 Self::OpenAi { api_key, base_url } => openai_complete(api_key, base_url, req).await,
133 Self::Mock { response } => Ok(CompletionResponse {
134 content: response.clone(),
135 usage: TokenUsage::default(),
136 }),
137 }
138 }
139
140 fn env_key_name(provider: &str) -> String {
141 match provider.to_lowercase().as_str() {
142 "anthropic" => "ANTHROPIC_API_KEY".into(),
143 "openai" => "OPENAI_API_KEY".into(),
144 other => format!("{}_API_KEY", other.to_uppercase()),
145 }
146 }
147}
148
149async fn anthropic_complete(
152 api_key: &str,
153 base_url: &str,
154 req: CompletionRequest,
155) -> anyhow::Result<CompletionResponse> {
156 #[derive(Serialize)]
157 struct AnthropicRequest<'a> {
158 model: &'a str,
159 max_tokens: u32,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 system: Option<&'a str>,
162 messages: Vec<serde_json::Value>,
163 }
164
165 let system = req.system.as_deref().or_else(|| {
166 req.messages.iter().find(|m| m.role == Role::System).map(|m| m.content.as_str())
167 });
168
169 let messages: Vec<serde_json::Value> = req
170 .messages
171 .iter()
172 .filter(|m| m.role != Role::System)
173 .map(|m| {
174 serde_json::json!({
175 "role": match m.role { Role::User => "user", _ => "assistant" },
176 "content": m.content,
177 })
178 })
179 .collect();
180
181 let body = AnthropicRequest {
182 model: &req.model,
183 max_tokens: req.max_tokens,
184 system,
185 messages,
186 };
187
188 let client = reqwest::Client::new();
189 let resp = client
190 .post(format!("{}/v1/messages", base_url.trim_end_matches('/')))
191 .header("x-api-key", api_key)
192 .header("anthropic-version", "2023-06-01")
193 .header("content-type", "application/json")
194 .json(&body)
195 .send()
196 .await
197 .context("Anthropic HTTP request failed")?;
198
199 let status = resp.status();
200 let text = resp.text().await?;
201
202 if !status.is_success() {
203 anyhow::bail!("Anthropic API error {}: {}", status, text);
204 }
205
206 let val: serde_json::Value = serde_json::from_str(&text)?;
207 let content = val["content"]
208 .as_array()
209 .and_then(|blocks| blocks.iter().find(|b| b["type"] == "text"))
210 .and_then(|b| b["text"].as_str())
211 .unwrap_or("")
212 .to_string();
213
214 let usage = TokenUsage {
215 input_tokens: val["usage"]["input_tokens"].as_u64().unwrap_or(0),
216 output_tokens: val["usage"]["output_tokens"].as_u64().unwrap_or(0),
217 cache_read_tokens: val["usage"]["cache_read_input_tokens"].as_u64().unwrap_or(0),
218 cache_write_tokens: val["usage"]["cache_creation_input_tokens"].as_u64().unwrap_or(0),
219 };
220
221 Ok(CompletionResponse { content, usage })
222}
223
224async fn openai_complete(
227 api_key: &str,
228 base_url: &str,
229 req: CompletionRequest,
230) -> anyhow::Result<CompletionResponse> {
231 let mut messages: Vec<serde_json::Value> = Vec::new();
232
233 if let Some(sys) = &req.system {
235 messages.push(serde_json::json!({ "role": "system", "content": sys }));
236 }
237
238 for m in &req.messages {
239 let role = match m.role {
240 Role::System => continue,
241 Role::User => "user",
242 Role::Assistant => "assistant",
243 };
244 messages.push(serde_json::json!({ "role": role, "content": m.content }));
245 }
246
247 let body = serde_json::json!({
248 "model": req.model,
249 "max_tokens": req.max_tokens,
250 "messages": messages,
251 });
252
253 let client = reqwest::Client::new();
254 let resp = client
255 .post(format!("{}/v1/chat/completions", base_url.trim_end_matches('/')))
256 .bearer_auth(api_key)
257 .header("content-type", "application/json")
258 .json(&body)
259 .send()
260 .await
261 .context("OpenAI HTTP request failed")?;
262
263 let status = resp.status();
264 let text = resp.text().await?;
265
266 if !status.is_success() {
267 anyhow::bail!("OpenAI API error {}: {}", status, text);
268 }
269
270 let val: serde_json::Value = serde_json::from_str(&text)?;
271 let content = val["choices"][0]["message"]["content"]
272 .as_str()
273 .unwrap_or("")
274 .to_string();
275
276 let usage = TokenUsage {
277 input_tokens: val["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
278 output_tokens: val["usage"]["completion_tokens"].as_u64().unwrap_or(0),
279 cache_read_tokens: 0,
280 cache_write_tokens: 0,
281 };
282
283 Ok(CompletionResponse { content, usage })
284}
285
286#[cfg(test)]
289mod tests {
290 use super::*;
291 use haki_config::ProviderConfig;
292
293 fn anthropic_cfg(key: Option<&str>) -> ProviderConfig {
294 ProviderConfig {
295 name: "anthropic".into(),
296 api_key: key.map(Into::into),
297 base_url: None,
298 }
299 }
300
301 #[test]
302 fn from_config_uses_explicit_key() {
303 let p = LlmProvider::from_config(&anthropic_cfg(Some("sk-test"))).unwrap();
304 assert_eq!(p.provider_name(), "anthropic");
305 assert_eq!(p.default_model(), "claude-sonnet-4-5");
306 }
307
308 #[test]
309 fn from_config_reads_env_key() {
310 std::env::set_var("ANTHROPIC_API_KEY", "env-key");
311 let p = LlmProvider::from_config(&anthropic_cfg(None)).unwrap();
312 assert_eq!(p.provider_name(), "anthropic");
313 std::env::remove_var("ANTHROPIC_API_KEY");
314 }
315
316 #[test]
317 fn from_config_missing_key_is_err() {
318 std::env::remove_var("ANTHROPIC_API_KEY");
320 std::env::remove_var("HAKI_PROVIDER__API_KEY");
321 let cfg = ProviderConfig { name: "anthropic".into(), api_key: None, base_url: None };
323 assert!(LlmProvider::from_config(&cfg).is_err());
324 }
325
326 #[test]
327 fn from_config_unknown_provider_is_err() {
328 let cfg =
329 ProviderConfig { name: "groq".into(), api_key: Some("k".into()), base_url: None };
330 assert!(LlmProvider::from_config(&cfg).is_err());
331 }
332
333 #[test]
334 fn message_constructors() {
335 let m = Message::user("hello");
336 assert_eq!(m.role, Role::User);
337 assert_eq!(m.content, "hello");
338
339 let m = Message::system("be helpful");
340 assert_eq!(m.role, Role::System);
341 }
342}
343