Skip to main content

haki_llm/
provider.rs

1//! LLM provider abstraction — direct HTTP clients for full control.
2//! Uses reqwest rather than rig trait objects to avoid generic type maze.
3//! rig-core is used for RAG/vector/tool features in later phases.
4
5use anyhow::Context;
6use haki_config::ProviderConfig;
7use serde::{Deserialize, Serialize};
8
9// ─── Public types ─────────────────────────────────────────────────────────────
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14    User,
15    Assistant,
16    System,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Message {
21    pub role: Role,
22    pub content: String,
23}
24
25impl Message {
26    pub fn user(content: impl Into<String>) -> Self {
27        Self { role: Role::User, content: content.into() }
28    }
29    pub fn assistant(content: impl Into<String>) -> Self {
30        Self { role: Role::Assistant, content: content.into() }
31    }
32    pub fn system(content: impl Into<String>) -> Self {
33        Self { role: Role::System, content: content.into() }
34    }
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct TokenUsage {
39    pub input_tokens: u64,
40    pub output_tokens: u64,
41    pub cache_read_tokens: u64,
42    pub cache_write_tokens: u64,
43}
44
45#[derive(Debug, Clone)]
46pub struct CompletionRequest {
47    pub model: String,
48    pub system: Option<String>,
49    pub messages: Vec<Message>,
50    pub max_tokens: u32,
51}
52
53#[derive(Debug, Clone)]
54pub struct CompletionResponse {
55    pub content: String,
56    pub usage: TokenUsage,
57}
58
59// ─── Provider enum ────────────────────────────────────────────────────────────
60
61#[derive(Debug, Clone)]
62pub enum LlmProvider {
63    Anthropic { api_key: String, base_url: String },
64    OpenAi { api_key: String, base_url: String },
65    /// Fixed-response mock for tests — never makes network calls.
66    Mock { response: String },
67}
68
69impl LlmProvider {
70    pub fn from_config(cfg: &ProviderConfig) -> anyhow::Result<Self> {
71        let api_key = cfg
72            .api_key
73            .clone()
74            .or_else(|| std::env::var(Self::env_key_name(&cfg.name)).ok())
75            .with_context(|| {
76                format!(
77                    "No API key for '{}'. Set {} or haki-config provider.api_key.",
78                    cfg.name,
79                    Self::env_key_name(&cfg.name)
80                )
81            })?;
82
83        match cfg.name.to_lowercase().as_str() {
84            "anthropic" => Ok(Self::Anthropic {
85                api_key,
86                base_url: cfg
87                    .base_url
88                    .clone()
89                    .unwrap_or_else(|| "https://api.anthropic.com".into()),
90            }),
91            "openai" => Ok(Self::OpenAi {
92                api_key,
93                base_url: cfg
94                    .base_url
95                    .clone()
96                    .unwrap_or_else(|| "https://api.openai.com".into()),
97            }),
98            other => anyhow::bail!(
99                "Unknown provider '{}'. Supported: anthropic, openai",
100                other
101            ),
102        }
103    }
104
105    pub fn provider_name(&self) -> &str {
106        match self {
107            Self::Anthropic { .. } => "anthropic",
108            Self::OpenAi { .. } => "openai",
109            Self::Mock { .. } => "mock",
110        }
111    }
112
113    pub fn default_model(&self) -> &str {
114        match self {
115            Self::Anthropic { .. } => "claude-sonnet-4-5",
116            Self::OpenAi { .. } => "gpt-4o",
117            Self::Mock { .. } => "mock-model",
118        }
119    }
120
121    /// Create a mock provider that always returns `response` without making
122    /// any network calls. Use in tests only.
123    pub fn mock(response: impl Into<String>) -> Self {
124        Self::Mock { response: response.into() }
125    }
126
127    pub async fn complete(&self, req: CompletionRequest) -> anyhow::Result<CompletionResponse> {
128        match self {
129            Self::Anthropic { api_key, base_url } => {
130                anthropic_complete(api_key, base_url, req).await
131            }
132            Self::OpenAi { api_key, base_url } => openai_complete(api_key, base_url, req).await,
133            Self::Mock { response } => Ok(CompletionResponse {
134                content: response.clone(),
135                usage: TokenUsage::default(),
136            }),
137        }
138    }
139
140    fn env_key_name(provider: &str) -> String {
141        match provider.to_lowercase().as_str() {
142            "anthropic" => "ANTHROPIC_API_KEY".into(),
143            "openai" => "OPENAI_API_KEY".into(),
144            other => format!("{}_API_KEY", other.to_uppercase()),
145        }
146    }
147}
148
149// ─── Anthropic HTTP client ─────────────────────────────────────────────────────
150
151async fn anthropic_complete(
152    api_key: &str,
153    base_url: &str,
154    req: CompletionRequest,
155) -> anyhow::Result<CompletionResponse> {
156    #[derive(Serialize)]
157    struct AnthropicRequest<'a> {
158        model: &'a str,
159        max_tokens: u32,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        system: Option<&'a str>,
162        messages: Vec<serde_json::Value>,
163    }
164
165    let system = req.system.as_deref().or_else(|| {
166        req.messages.iter().find(|m| m.role == Role::System).map(|m| m.content.as_str())
167    });
168
169    let messages: Vec<serde_json::Value> = req
170        .messages
171        .iter()
172        .filter(|m| m.role != Role::System)
173        .map(|m| {
174            serde_json::json!({
175                "role": match m.role { Role::User => "user", _ => "assistant" },
176                "content": m.content,
177            })
178        })
179        .collect();
180
181    let body = AnthropicRequest {
182        model: &req.model,
183        max_tokens: req.max_tokens,
184        system,
185        messages,
186    };
187
188    let client = reqwest::Client::new();
189    let resp = client
190        .post(format!("{}/v1/messages", base_url.trim_end_matches('/')))
191        .header("x-api-key", api_key)
192        .header("anthropic-version", "2023-06-01")
193        .header("content-type", "application/json")
194        .json(&body)
195        .send()
196        .await
197        .context("Anthropic HTTP request failed")?;
198
199    let status = resp.status();
200    let text = resp.text().await?;
201
202    if !status.is_success() {
203        anyhow::bail!("Anthropic API error {}: {}", status, text);
204    }
205
206    let val: serde_json::Value = serde_json::from_str(&text)?;
207    let content = val["content"]
208        .as_array()
209        .and_then(|blocks| blocks.iter().find(|b| b["type"] == "text"))
210        .and_then(|b| b["text"].as_str())
211        .unwrap_or("")
212        .to_string();
213
214    let usage = TokenUsage {
215        input_tokens: val["usage"]["input_tokens"].as_u64().unwrap_or(0),
216        output_tokens: val["usage"]["output_tokens"].as_u64().unwrap_or(0),
217        cache_read_tokens: val["usage"]["cache_read_input_tokens"].as_u64().unwrap_or(0),
218        cache_write_tokens: val["usage"]["cache_creation_input_tokens"].as_u64().unwrap_or(0),
219    };
220
221    Ok(CompletionResponse { content, usage })
222}
223
224// ─── OpenAI HTTP client ────────────────────────────────────────────────────────
225
226async fn openai_complete(
227    api_key: &str,
228    base_url: &str,
229    req: CompletionRequest,
230) -> anyhow::Result<CompletionResponse> {
231    let mut messages: Vec<serde_json::Value> = Vec::new();
232
233    // Prepend system message if present
234    if let Some(sys) = &req.system {
235        messages.push(serde_json::json!({ "role": "system", "content": sys }));
236    }
237
238    for m in &req.messages {
239        let role = match m.role {
240            Role::System => continue,
241            Role::User => "user",
242            Role::Assistant => "assistant",
243        };
244        messages.push(serde_json::json!({ "role": role, "content": m.content }));
245    }
246
247    let body = serde_json::json!({
248        "model": req.model,
249        "max_tokens": req.max_tokens,
250        "messages": messages,
251    });
252
253    let client = reqwest::Client::new();
254    let resp = client
255        .post(format!("{}/v1/chat/completions", base_url.trim_end_matches('/')))
256        .bearer_auth(api_key)
257        .header("content-type", "application/json")
258        .json(&body)
259        .send()
260        .await
261        .context("OpenAI HTTP request failed")?;
262
263    let status = resp.status();
264    let text = resp.text().await?;
265
266    if !status.is_success() {
267        anyhow::bail!("OpenAI API error {}: {}", status, text);
268    }
269
270    let val: serde_json::Value = serde_json::from_str(&text)?;
271    let content = val["choices"][0]["message"]["content"]
272        .as_str()
273        .unwrap_or("")
274        .to_string();
275
276    let usage = TokenUsage {
277        input_tokens: val["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
278        output_tokens: val["usage"]["completion_tokens"].as_u64().unwrap_or(0),
279        cache_read_tokens: 0,
280        cache_write_tokens: 0,
281    };
282
283    Ok(CompletionResponse { content, usage })
284}
285
286// ─── Tests ────────────────────────────────────────────────────────────────────
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use haki_config::ProviderConfig;
292
293    fn anthropic_cfg(key: Option<&str>) -> ProviderConfig {
294        ProviderConfig {
295            name: "anthropic".into(),
296            api_key: key.map(Into::into),
297            base_url: None,
298        }
299    }
300
301    #[test]
302    fn from_config_uses_explicit_key() {
303        let p = LlmProvider::from_config(&anthropic_cfg(Some("sk-test"))).unwrap();
304        assert_eq!(p.provider_name(), "anthropic");
305        assert_eq!(p.default_model(), "claude-sonnet-4-5");
306    }
307
308    #[test]
309    fn from_config_reads_env_key() {
310        std::env::set_var("ANTHROPIC_API_KEY", "env-key");
311        let p = LlmProvider::from_config(&anthropic_cfg(None)).unwrap();
312        assert_eq!(p.provider_name(), "anthropic");
313        std::env::remove_var("ANTHROPIC_API_KEY");
314    }
315
316    #[test]
317    fn from_config_missing_key_is_err() {
318        // Remove both possible sources so parallel tests don't bleed env state.
319        std::env::remove_var("ANTHROPIC_API_KEY");
320        std::env::remove_var("HAKI_PROVIDER__API_KEY");
321        // Only reliable when the key truly isn't set; mark serial if flaky in CI.
322        let cfg = ProviderConfig { name: "anthropic".into(), api_key: None, base_url: None };
323        assert!(LlmProvider::from_config(&cfg).is_err());
324    }
325
326    #[test]
327    fn from_config_unknown_provider_is_err() {
328        let cfg =
329            ProviderConfig { name: "groq".into(), api_key: Some("k".into()), base_url: None };
330        assert!(LlmProvider::from_config(&cfg).is_err());
331    }
332
333    #[test]
334    fn message_constructors() {
335        let m = Message::user("hello");
336        assert_eq!(m.role, Role::User);
337        assert_eq!(m.content, "hello");
338
339        let m = Message::system("be helpful");
340        assert_eq!(m.role, Role::System);
341    }
342}
343