Skip to main content

brainos_cortex/
llm.rs

1//! LLM client — hybrid provider with trait-based adapter.
2//!
3//! `LlmProvider` trait with multiple implementations:
4//! - `OllamaProvider` — local Ollama server
5//! - `OpenAiProvider` — OpenAI compatible APIs
6
7use std::pin::Pin;
8
9use futures::Stream;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13// ─── Errors ─────────────────────────────────────────────────────────────────
14
15/// Errors from the LLM layer.
16#[derive(Debug, Error)]
17pub enum LlmError {
18    #[error("HTTP request failed: {0}")]
19    Http(#[from] reqwest::Error),
20
21    #[error("API error: {status} - {message}")]
22    Api { status: u16, message: String },
23
24    #[error("Stream error: {0}")]
25    Stream(String),
26
27    #[error("Invalid response format: {0}")]
28    InvalidFormat(String),
29
30    #[error("Provider not available: {0}")]
31    ProviderUnavailable(String),
32
33    #[error("Rate limited")]
34    RateLimited,
35
36    #[error("Timeout")]
37    Timeout,
38}
39
40// ─── Types ──────────────────────────────────────────────────────────────────
41
42/// A message in the conversation.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Message {
45    pub role: Role,
46    pub content: String,
47}
48
49/// Message roles.
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "lowercase")]
52pub enum Role {
53    System,
54    User,
55    Assistant,
56}
57
58/// LLM response chunk (for streaming).
59#[derive(Debug, Clone)]
60pub struct ResponseChunk {
61    pub content: String,
62    pub is_done: bool,
63}
64
65/// Complete LLM response.
66#[derive(Debug, Clone)]
67pub struct Response {
68    pub content: String,
69    pub usage: Option<Usage>,
70}
71
72/// Token usage statistics.
73#[derive(Debug, Clone)]
74pub struct Usage {
75    pub prompt_tokens: u32,
76    pub completion_tokens: u32,
77    pub total_tokens: u32,
78}
79
80// ─── Provider Trait ─────────────────────────────────────────────────────────
81
82/// Trait for LLM providers.
83#[async_trait::async_trait]
84pub trait LlmProvider: Send + Sync {
85    /// Generate a complete response (non-streaming).
86    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError>;
87
88    /// Generate a streaming response.
89    async fn generate_stream(
90        &self,
91        messages: &[Message],
92    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError>;
93
94    /// Check if the provider is available.
95    async fn health_check(&self) -> bool;
96
97    /// Get the provider name.
98    fn name(&self) -> &str;
99}
100
101// ─── Ollama Provider ────────────────────────────────────────────────────────
102
103/// Ollama API request body.
104#[derive(Serialize)]
105struct OllamaRequest {
106    model: String,
107    messages: Vec<OllamaMessage>,
108    stream: bool,
109    options: Option<OllamaOptions>,
110}
111
112#[derive(Serialize, Deserialize)]
113struct OllamaMessage {
114    role: String,
115    content: String,
116}
117
118#[derive(Serialize)]
119#[serde(rename_all = "camelCase")]
120struct OllamaOptions {
121    temperature: f64,
122    num_predict: i32,
123}
124
125/// Ollama API response (works for both streaming and non-streaming).
126#[derive(Deserialize)]
127struct OllamaResponse {
128    message: Option<OllamaMessage>,
129    done: bool,
130    #[serde(default)]
131    prompt_eval_count: Option<u32>,
132    #[serde(default)]
133    eval_count: Option<u32>,
134}
135
136/// Ollama LLM provider.
137pub struct OllamaProvider {
138    client: reqwest::Client,
139    base_url: String,
140    model: String,
141    temperature: f64,
142    max_tokens: i32,
143}
144
145impl OllamaProvider {
146    /// Create a new Ollama provider.
147    pub fn new(
148        base_url: &str,
149        model: &str,
150        temperature: f64,
151        max_tokens: i32,
152    ) -> Result<Self, LlmError> {
153        // Ollama may need to load a large model on first call — allow up to 5 min
154        let client = reqwest::Client::builder()
155            .timeout(std::time::Duration::from_secs(300))
156            .build()
157            .map_err(|e| {
158                LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
159            })?;
160
161        Ok(Self {
162            client,
163            base_url: base_url.trim_end_matches('/').to_string(),
164            model: model.to_string(),
165            temperature,
166            max_tokens,
167        })
168    }
169
170    /// Create with default config. Panics only if TLS initialisation fails (extremely rare).
171    pub fn default_config() -> Self {
172        Self::new("http://localhost:11434", "qwen2.5-coder:7b", 0.7, 4096)
173            .expect("Failed to initialise default Ollama HTTP client")
174    }
175
176    fn convert_messages(messages: &[Message]) -> Vec<OllamaMessage> {
177        messages
178            .iter()
179            .map(|m| OllamaMessage {
180                role: match m.role {
181                    Role::System => "system".to_string(),
182                    Role::User => "user".to_string(),
183                    Role::Assistant => "assistant".to_string(),
184                },
185                content: m.content.clone(),
186            })
187            .collect()
188    }
189}
190
191#[async_trait::async_trait]
192impl LlmProvider for OllamaProvider {
193    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
194        let url = format!("{}/api/chat", self.base_url);
195        let request = OllamaRequest {
196            model: self.model.clone(),
197            messages: Self::convert_messages(messages),
198            stream: false,
199            options: Some(OllamaOptions {
200                temperature: self.temperature,
201                num_predict: self.max_tokens,
202            }),
203        };
204
205        let resp = self.client.post(&url).json(&request).send().await?;
206
207        if !resp.status().is_success() {
208            let status = resp.status();
209            let body = resp.text().await.unwrap_or_default();
210            return Err(LlmError::Api {
211                status: status.as_u16(),
212                message: body,
213            });
214        }
215
216        let data: OllamaResponse = resp.json().await?;
217
218        let content = data.message.map(|m| m.content).unwrap_or_default();
219
220        Ok(Response {
221            content,
222            usage: Some(Usage {
223                prompt_tokens: data.prompt_eval_count.unwrap_or(0),
224                completion_tokens: data.eval_count.unwrap_or(0),
225                total_tokens: data.prompt_eval_count.unwrap_or(0) + data.eval_count.unwrap_or(0),
226            }),
227        })
228    }
229
230    async fn generate_stream(
231        &self,
232        messages: &[Message],
233    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
234        use futures::stream::try_unfold;
235
236        let url = format!("{}/api/chat", self.base_url);
237        let request = OllamaRequest {
238            model: self.model.clone(),
239            messages: Self::convert_messages(messages),
240            stream: true,
241            options: Some(OllamaOptions {
242                temperature: self.temperature,
243                num_predict: self.max_tokens,
244            }),
245        };
246
247        let resp = self.client.post(&url).json(&request).send().await?;
248
249        if !resp.status().is_success() {
250            let status = resp.status();
251            let body = resp.text().await.unwrap_or_default();
252            return Err(LlmError::Api {
253                status: status.as_u16(),
254                message: body,
255            });
256        }
257
258        let byte_stream = resp.bytes_stream();
259
260        // State: (byte_stream, leftover buffer for incomplete lines)
261        let stream = try_unfold(
262            (Box::pin(byte_stream), String::new()),
263            |(mut byte_stream, mut buf)| async move {
264                use futures::TryStreamExt;
265
266                loop {
267                    // Try to extract a complete line from the buffer
268                    if let Some(newline_pos) = buf.find('\n') {
269                        let line: String = buf[..newline_pos].to_string();
270                        buf = buf[newline_pos + 1..].to_string();
271
272                        let line = line.trim();
273                        if line.is_empty() {
274                            continue;
275                        }
276
277                        match serde_json::from_str::<OllamaResponse>(line) {
278                            Ok(data) => {
279                                let content = data.message.map(|m| m.content).unwrap_or_default();
280                                let chunk = ResponseChunk {
281                                    content,
282                                    is_done: data.done,
283                                };
284                                if data.done {
285                                    return Ok(Some((chunk, (byte_stream, buf))));
286                                }
287                                return Ok(Some((chunk, (byte_stream, buf))));
288                            }
289                            Err(e) => {
290                                return Err(LlmError::InvalidFormat(format!(
291                                    "Failed to parse streaming response: {e}"
292                                )));
293                            }
294                        }
295                    }
296
297                    // Need more data from the network
298                    match byte_stream.try_next().await {
299                        Ok(Some(bytes)) => {
300                            buf.push_str(&String::from_utf8_lossy(&bytes));
301                        }
302                        Ok(None) => {
303                            // Stream ended — parse any remaining data in buffer
304                            let remaining = buf.trim();
305                            if !remaining.is_empty() {
306                                if let Ok(data) = serde_json::from_str::<OllamaResponse>(remaining)
307                                {
308                                    let content =
309                                        data.message.map(|m| m.content).unwrap_or_default();
310                                    return Ok(Some((
311                                        ResponseChunk {
312                                            content,
313                                            is_done: true,
314                                        },
315                                        (byte_stream, String::new()),
316                                    )));
317                                }
318                            }
319                            return Ok(None);
320                        }
321                        Err(e) => return Err(LlmError::Http(e)),
322                    }
323                }
324            },
325        );
326
327        Ok(Box::pin(stream))
328    }
329
330    async fn health_check(&self) -> bool {
331        let url = format!("{}/api/tags", self.base_url);
332        match self.client.get(&url).send().await {
333            Ok(resp) => resp.status().is_success(),
334            Err(_) => false,
335        }
336    }
337
338    fn name(&self) -> &str {
339        "ollama"
340    }
341}
342
343// ─── OpenAI-Compatible Provider ─────────────────────────────────────────────
344
345/// OpenAI API request body.
346#[derive(Serialize)]
347struct OpenAiRequest {
348    model: String,
349    messages: Vec<OpenAiMessage>,
350    temperature: f64,
351    max_tokens: Option<i32>,
352    stream: bool,
353}
354
355#[derive(Serialize, Deserialize)]
356struct OpenAiMessage {
357    role: String,
358    content: String,
359}
360
361/// OpenAI API response.
362#[derive(Deserialize)]
363struct OpenAiResponse {
364    choices: Vec<OpenAiChoice>,
365    usage: Option<OpenAiUsage>,
366}
367
368#[derive(Deserialize)]
369struct OpenAiChoice {
370    message: OpenAiMessage,
371    #[allow(dead_code)]
372    finish_reason: Option<String>,
373}
374
375/// Streaming chunk from OpenAI SSE (delta instead of message).
376#[derive(Deserialize)]
377struct OpenAiStreamResponse {
378    choices: Vec<OpenAiStreamChoice>,
379}
380
381#[derive(Deserialize)]
382struct OpenAiStreamChoice {
383    delta: OpenAiDelta,
384    finish_reason: Option<String>,
385}
386
387#[derive(Deserialize)]
388struct OpenAiDelta {
389    #[serde(default)]
390    content: Option<String>,
391}
392
393#[derive(Deserialize)]
394struct OpenAiUsage {
395    prompt_tokens: u32,
396    completion_tokens: u32,
397    total_tokens: u32,
398}
399
400/// OpenAI-compatible provider (works with OpenAI, OpenRouter, etc.)
401pub struct OpenAiProvider {
402    client: reqwest::Client,
403    base_url: String,
404    api_key: Option<String>,
405    model: String,
406    temperature: f64,
407    max_tokens: Option<i32>,
408}
409
410impl OpenAiProvider {
411    /// Create a new OpenAI-compatible provider.
412    pub fn new(
413        base_url: &str,
414        api_key: Option<&str>,
415        model: &str,
416        temperature: f64,
417        max_tokens: Option<i32>,
418    ) -> Result<Self, LlmError> {
419        let client = reqwest::Client::builder()
420            .timeout(std::time::Duration::from_secs(300))
421            .build()
422            .map_err(|e| {
423                LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
424            })?;
425
426        Ok(Self {
427            client,
428            base_url: base_url.trim_end_matches('/').to_string(),
429            api_key: api_key.map(|s| s.to_string()),
430            model: model.to_string(),
431            temperature,
432            max_tokens,
433        })
434    }
435
436    /// Create for OpenAI API.
437    pub fn openai(api_key: &str, model: &str) -> Self {
438        Self::new(
439            "https://api.openai.com/v1",
440            Some(api_key),
441            model,
442            0.7,
443            Some(4096),
444        )
445        .expect("Failed to initialise OpenAI HTTP client")
446    }
447
448    /// Create for OpenRouter.
449    pub fn openrouter(api_key: &str, model: &str) -> Self {
450        Self::new(
451            "https://openrouter.ai/api/v1",
452            Some(api_key),
453            model,
454            0.7,
455            Some(4096),
456        )
457        .expect("Failed to initialise OpenRouter HTTP client")
458    }
459
460    fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
461        messages
462            .iter()
463            .map(|m| OpenAiMessage {
464                role: match m.role {
465                    Role::System => "system".to_string(),
466                    Role::User => "user".to_string(),
467                    Role::Assistant => "assistant".to_string(),
468                },
469                content: m.content.clone(),
470            })
471            .collect()
472    }
473
474    fn build_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
475        let mut builder = builder;
476        if let Some(key) = &self.api_key {
477            builder = builder.header("Authorization", format!("Bearer {}", key));
478        }
479        builder
480    }
481}
482
483#[async_trait::async_trait]
484impl LlmProvider for OpenAiProvider {
485    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
486        let url = format!("{}/chat/completions", self.base_url);
487        let request = OpenAiRequest {
488            model: self.model.clone(),
489            messages: Self::convert_messages(messages),
490            temperature: self.temperature,
491            max_tokens: self.max_tokens,
492            stream: false,
493        };
494
495        let resp = self
496            .build_request(self.client.post(&url))
497            .json(&request)
498            .send()
499            .await?;
500
501        if !resp.status().is_success() {
502            let status = resp.status();
503            let body = resp.text().await.unwrap_or_default();
504            return Err(LlmError::Api {
505                status: status.as_u16(),
506                message: body,
507            });
508        }
509
510        let data: OpenAiResponse = resp.json().await?;
511        let content = data
512            .choices
513            .first()
514            .map(|c| c.message.content.clone())
515            .unwrap_or_default();
516
517        Ok(Response {
518            content,
519            usage: data.usage.map(|u| Usage {
520                prompt_tokens: u.prompt_tokens,
521                completion_tokens: u.completion_tokens,
522                total_tokens: u.total_tokens,
523            }),
524        })
525    }
526
527    async fn generate_stream(
528        &self,
529        messages: &[Message],
530    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
531        use futures::stream::try_unfold;
532
533        let url = format!("{}/chat/completions", self.base_url);
534        let request = OpenAiRequest {
535            model: self.model.clone(),
536            messages: Self::convert_messages(messages),
537            temperature: self.temperature,
538            max_tokens: self.max_tokens,
539            stream: true,
540        };
541
542        let resp = self
543            .build_request(self.client.post(&url))
544            .json(&request)
545            .send()
546            .await?;
547
548        if !resp.status().is_success() {
549            let status = resp.status();
550            let body = resp.text().await.unwrap_or_default();
551            return Err(LlmError::Api {
552                status: status.as_u16(),
553                message: body,
554            });
555        }
556
557        let byte_stream = resp.bytes_stream();
558
559        // State: (byte_stream, leftover buffer for incomplete SSE lines)
560        let stream = try_unfold(
561            (Box::pin(byte_stream), String::new()),
562            |(mut byte_stream, mut buf)| async move {
563                use futures::TryStreamExt;
564
565                loop {
566                    // Try to extract a complete line from the buffer
567                    if let Some(newline_pos) = buf.find('\n') {
568                        let line: String = buf[..newline_pos].to_string();
569                        buf = buf[newline_pos + 1..].to_string();
570
571                        let line = line.trim();
572                        if line.is_empty() {
573                            continue;
574                        }
575
576                        // SSE format: "data: {...}" or "data: [DONE]"
577                        if let Some(data) = line.strip_prefix("data: ") {
578                            let data = data.trim();
579                            if data == "[DONE]" {
580                                return Ok(None);
581                            }
582
583                            match serde_json::from_str::<OpenAiStreamResponse>(data) {
584                                Ok(resp) => {
585                                    if let Some(choice) = resp.choices.first() {
586                                        let content =
587                                            choice.delta.content.clone().unwrap_or_default();
588                                        let is_done = choice.finish_reason.is_some();
589                                        let chunk = ResponseChunk { content, is_done };
590                                        return Ok(Some((chunk, (byte_stream, buf))));
591                                    }
592                                    // Choice with no delta content (e.g. role-only chunk) — skip
593                                    continue;
594                                }
595                                Err(e) => {
596                                    return Err(LlmError::InvalidFormat(format!(
597                                        "Failed to parse streaming response: {e}"
598                                    )));
599                                }
600                            }
601                        }
602                        // Skip non-data SSE lines (e.g. "event:", comments)
603                        continue;
604                    }
605
606                    // Need more data from the network
607                    match byte_stream.try_next().await {
608                        Ok(Some(bytes)) => {
609                            buf.push_str(&String::from_utf8_lossy(&bytes));
610                        }
611                        Ok(None) => return Ok(None),
612                        Err(e) => return Err(LlmError::Http(e)),
613                    }
614                }
615            },
616        );
617
618        Ok(Box::pin(stream))
619    }
620
621    async fn health_check(&self) -> bool {
622        let url = format!("{}/models", self.base_url);
623        match self.build_request(self.client.get(&url)).send().await {
624            Ok(resp) => resp.status().is_success(),
625            Err(_) => false,
626        }
627    }
628
629    fn name(&self) -> &str {
630        "openai"
631    }
632}
633
634// ─── Provider Factory ───────────────────────────────────────────────────────
635
636/// Configuration for LLM provider selection.
637#[derive(Debug, Clone)]
638pub struct ProviderConfig {
639    pub provider: String,
640    pub base_url: String,
641    pub api_key: Option<String>,
642    pub model: String,
643    pub temperature: f64,
644    pub max_tokens: i32,
645}
646
647impl Default for ProviderConfig {
648    fn default() -> Self {
649        Self {
650            provider: "ollama".to_string(),
651            base_url: "http://localhost:11434".to_string(),
652            api_key: None,
653            model: "qwen2.5-coder:7b".to_string(),
654            temperature: 0.7,
655            max_tokens: 4096,
656        }
657    }
658}
659
660/// Create an LLM provider from configuration.
661pub fn create_provider(config: &ProviderConfig) -> Box<dyn LlmProvider> {
662    match config.provider.as_str() {
663        "ollama" => Box::new(
664            OllamaProvider::new(
665                &config.base_url,
666                &config.model,
667                config.temperature,
668                config.max_tokens,
669            )
670            .unwrap_or_else(|e| {
671                tracing::error!(error = %e, "Failed to create Ollama provider, falling back to default");
672                OllamaProvider::default_config()
673            }),
674        ),
675        "openai" => Box::new(
676            OpenAiProvider::new(
677                &config.base_url,
678                config.api_key.as_deref(),
679                &config.model,
680                config.temperature,
681                Some(config.max_tokens),
682            )
683            // TLS initialisation failure is unrecoverable — surface clearly.
684            .expect("Failed to initialise OpenAI HTTP client"),
685        ),
686        _ => Box::new(OllamaProvider::default_config()),
687    }
688}
689
690// ─── Tests ──────────────────────────────────────────────────────────────────
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_provider_config_default() {
698        let config = ProviderConfig::default();
699        assert_eq!(config.provider, "ollama");
700        assert_eq!(config.model, "qwen2.5-coder:7b");
701    }
702
703    #[test]
704    fn test_ollama_provider_creation() {
705        let provider = OllamaProvider::new("http://localhost:11434", "llama3:8b", 0.5, 2048)
706            .expect("OllamaProvider::new should not fail in test");
707        assert_eq!(provider.name(), "ollama");
708    }
709
710    #[test]
711    fn test_openai_provider_creation() {
712        let provider = OpenAiProvider::openai("test-key", "gpt-4");
713        assert_eq!(provider.name(), "openai");
714    }
715
716    #[test]
717    fn test_openrouter_provider_creation() {
718        let provider = OpenAiProvider::openrouter("test-key", "anthropic/claude-3-opus");
719        assert_eq!(provider.name(), "openai");
720    }
721}