helios_engine/
llm.rs

1//! # LLM Module
2//!
3//! This module provides the functionality for interacting with Large Language Models (LLMs).
4//! It supports both remote LLMs (like OpenAI) and local LLMs (via `llama.cpp`).
5//! The `LLMClient` provides a unified interface for both types of providers.
6
7use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18    crate::config::LocalConfig,
19    llama_cpp_2::{
20        context::params::LlamaContextParams,
21        llama_backend::LlamaBackend,
22        llama_batch::LlamaBatch,
23        model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24        token::LlamaToken,
25    },
26    std::{fs::File, os::fd::AsRawFd, sync::Arc},
27    tokio::task,
28};
29
30// Add From trait for LLamaCppError to convert to HeliosError
31#[cfg(feature = "local")]
32impl From<llama_cpp_2::LLamaCppError> for HeliosError {
33    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
34        HeliosError::LlamaCppError(format!("{:?}", err))
35    }
36}
37
38/// The type of LLM provider to use.
39#[derive(Clone)]
40pub enum LLMProviderType {
41    /// A remote LLM provider, such as OpenAI.
42    Remote(LLMConfig),
43    /// A local LLM provider, using `llama.cpp`.
44    #[cfg(feature = "local")]
45    Local(LocalConfig),
46}
47
48/// A request to an LLM.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LLMRequest {
51    /// The model to use for the request.
52    pub model: String,
53    /// The messages to send to the model.
54    pub messages: Vec<ChatMessage>,
55    /// The temperature to use for the request.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub temperature: Option<f32>,
58    /// The maximum number of tokens to generate.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub max_tokens: Option<u32>,
61    /// The tools to make available to the model.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub tools: Option<Vec<ToolDefinition>>,
64    /// The tool choice to use for the request.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub tool_choice: Option<String>,
67    /// Whether to stream the response.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub stream: Option<bool>,
70    /// Stop sequences for the request.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub stop: Option<Vec<String>>,
73}
74
75/// A chunk of a streamed response.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct StreamChunk {
78    /// The ID of the chunk.
79    pub id: String,
80    /// The object type.
81    pub object: String,
82    /// The creation timestamp.
83    pub created: u64,
84    /// The model that generated the chunk.
85    pub model: String,
86    /// The choices in the chunk.
87    pub choices: Vec<StreamChoice>,
88}
89
90/// A choice in a streamed response.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct StreamChoice {
93    /// The index of the choice.
94    pub index: u32,
95    /// The delta of the choice.
96    pub delta: Delta,
97    /// The reason the stream finished.
98    pub finish_reason: Option<String>,
99}
100
101/// The delta of a streamed choice.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct Delta {
104    /// The role of the delta.
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub role: Option<String>,
107    /// The content of the delta.
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub content: Option<String>,
110}
111
112/// A response from an LLM.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct LLMResponse {
115    /// The ID of the response.
116    pub id: String,
117    /// The object type.
118    pub object: String,
119    /// The creation timestamp.
120    pub created: u64,
121    /// The model that generated the response.
122    pub model: String,
123    /// The choices in the response.
124    pub choices: Vec<Choice>,
125    /// The usage statistics for the response.
126    pub usage: Usage,
127}
128
129/// A choice in an LLM response.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct Choice {
132    /// The index of the choice.
133    pub index: u32,
134    /// The message of the choice.
135    pub message: ChatMessage,
136    /// The reason the generation finished.
137    pub finish_reason: Option<String>,
138}
139
140/// The usage statistics for an LLM response.
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct Usage {
143    /// The number of tokens in the prompt.
144    pub prompt_tokens: u32,
145    /// The number of tokens in the completion.
146    pub completion_tokens: u32,
147    /// The total number of tokens.
148    pub total_tokens: u32,
149}
150
151/// A trait for LLM providers.
152#[async_trait]
153pub trait LLMProvider: Send + Sync {
154    /// Generates a response from the LLM.
155    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
156    /// Returns the provider as an `Any` type.
157    fn as_any(&self) -> &dyn std::any::Any;
158}
159
160/// A client for interacting with an LLM.
161pub struct LLMClient {
162    provider: Box<dyn LLMProvider + Send + Sync>,
163    provider_type: LLMProviderType,
164}
165
166impl LLMClient {
167    /// Creates a new `LLMClient`.
168    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
169        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
170            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
171            #[cfg(feature = "local")]
172            LLMProviderType::Local(config) => {
173                Box::new(LocalLLMProvider::new(config.clone()).await?)
174            }
175        };
176
177        Ok(Self {
178            provider,
179            provider_type,
180        })
181    }
182
183    /// Returns the type of the LLM provider.
184    pub fn provider_type(&self) -> &LLMProviderType {
185        &self.provider_type
186    }
187}
188
189/// A client for interacting with a remote LLM.
190pub struct RemoteLLMClient {
191    config: LLMConfig,
192    client: Client,
193}
194
195impl RemoteLLMClient {
196    /// Creates a new `RemoteLLMClient`.
197    pub fn new(config: LLMConfig) -> Self {
198        Self {
199            config,
200            client: Client::new(),
201        }
202    }
203
204    /// Returns the configuration of the client.
205    pub fn config(&self) -> &LLMConfig {
206        &self.config
207    }
208}
209
210/// Suppresses stdout and stderr.
211#[cfg(feature = "local")]
212fn suppress_output() -> (i32, i32) {
213    // Open /dev/null for writing
214    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
215
216    // Duplicate current stdout and stderr file descriptors
217    let stdout_backup = unsafe { libc::dup(1) };
218    let stderr_backup = unsafe { libc::dup(2) };
219
220    // Redirect stdout and stderr to /dev/null
221    unsafe {
222        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
223        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
224    }
225
226    (stdout_backup, stderr_backup)
227}
228
229/// Restores stdout and stderr.
230#[cfg(feature = "local")]
231fn restore_output(stdout_backup: i32, stderr_backup: i32) {
232    unsafe {
233        libc::dup2(stdout_backup, 1); // restore stdout
234        libc::dup2(stderr_backup, 2); // restore stderr
235        libc::close(stdout_backup);
236        libc::close(stderr_backup);
237    }
238}
239
240/// Suppresses stderr.
241#[cfg(feature = "local")]
242fn suppress_stderr() -> i32 {
243    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
244    let stderr_backup = unsafe { libc::dup(2) };
245    unsafe {
246        libc::dup2(dev_null.as_raw_fd(), 2);
247    }
248    stderr_backup
249}
250
251/// Restores stderr.
252#[cfg(feature = "local")]
253fn restore_stderr(stderr_backup: i32) {
254    unsafe {
255        libc::dup2(stderr_backup, 2);
256        libc::close(stderr_backup);
257    }
258}
259
260/// A provider for a local LLM.
261#[cfg(feature = "local")]
262pub struct LocalLLMProvider {
263    model: Arc<LlamaModel>,
264    backend: Arc<LlamaBackend>,
265}
266
267#[cfg(feature = "local")]
268impl LocalLLMProvider {
269    /// Creates a new `LocalLLMProvider`.
270    pub async fn new(config: LocalConfig) -> Result<Self> {
271        // Suppress verbose output during model loading in offline mode
272        let (stdout_backup, stderr_backup) = suppress_output();
273
274        // Initialize llama backend
275        let backend = LlamaBackend::init().map_err(|e| {
276            restore_output(stdout_backup, stderr_backup);
277            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
278        })?;
279
280        // Download model from HuggingFace if needed
281        let model_path = Self::download_model(&config).await.map_err(|e| {
282            restore_output(stdout_backup, stderr_backup);
283            e
284        })?;
285
286        // Load the model
287        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
288
289        let model =
290            LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
291                restore_output(stdout_backup, stderr_backup);
292                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
293            })?;
294
295        // Restore output
296        restore_output(stdout_backup, stderr_backup);
297
298        Ok(Self {
299            model: Arc::new(model),
300            backend: Arc::new(backend),
301        })
302    }
303
304    /// Downloads a model from Hugging Face.
305    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
306        use std::process::Command;
307
308        // Check if model is already in HuggingFace cache
309        if let Some(cached_path) =
310            Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
311        {
312            // Model found in cache - no output needed in offline mode
313            return Ok(cached_path);
314        }
315
316        // Model not found in cache - suppress download output in offline mode
317
318        // Use huggingface_hub to download the model (suppress output)
319        let output = Command::new("huggingface-cli")
320            .args([
321                "download",
322                &config.huggingface_repo,
323                &config.model_file,
324                "--local-dir",
325                ".cache/models",
326                "--local-dir-use-symlinks",
327                "False",
328            ])
329            .stdout(std::process::Stdio::null()) // Suppress stdout
330            .stderr(std::process::Stdio::null()) // Suppress stderr
331            .output()
332            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
333
334        if !output.status.success() {
335            return Err(HeliosError::LLMError(format!(
336                "Failed to download model: {}",
337                String::from_utf8_lossy(&output.stderr)
338            )));
339        }
340
341        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
342        if !model_path.exists() {
343            return Err(HeliosError::LLMError(format!(
344                "Model file not found after download: {}",
345                model_path.display()
346            )));
347        }
348
349        Ok(model_path)
350    }
351
352    /// Finds a model in the Hugging Face cache.
353    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
354        // Check HuggingFace cache directory
355        let cache_dir = std::env::var("HF_HOME")
356            .map(std::path::PathBuf::from)
357            .unwrap_or_else(|_| {
358                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
359                std::path::PathBuf::from(home)
360                    .join(".cache")
361                    .join("huggingface")
362            });
363
364        let hub_dir = cache_dir.join("hub");
365
366        // Convert repo name to HuggingFace cache format
367        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
368        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
369        let repo_dir = hub_dir.join(&cache_repo_name);
370
371        if !repo_dir.exists() {
372            return None;
373        }
374
375        // Check in snapshots directory (newer cache format)
376        let snapshots_dir = repo_dir.join("snapshots");
377        if snapshots_dir.exists() {
378            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
379                for entry in entries.flatten() {
380                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
381                        if snapshot_path.exists() {
382                            return Some(snapshot_path);
383                        }
384                    }
385                }
386            }
387        }
388
389        // Check in blobs directory (alternative cache format)
390        let blobs_dir = repo_dir.join("blobs");
391        if blobs_dir.exists() {
392            // For blobs, we need to find the blob file by hash
393            // This is more complex, so for now we'll skip this check
394            // The snapshots approach should cover most cases
395        }
396
397        None
398    }
399
400    /// Formats a list of messages into a single string.
401    fn format_messages(&self, messages: &[ChatMessage]) -> String {
402        let mut formatted = String::new();
403
404        // Use Qwen chat template format
405        for message in messages {
406            match message.role {
407                crate::chat::Role::System => {
408                    formatted.push_str("<|im_start|>system\n");
409                    formatted.push_str(&message.content);
410                    formatted.push_str("\n<|im_end|>\n");
411                }
412                crate::chat::Role::User => {
413                    formatted.push_str("<|im_start|>user\n");
414                    formatted.push_str(&message.content);
415                    formatted.push_str("\n<|im_end|>\n");
416                }
417                crate::chat::Role::Assistant => {
418                    formatted.push_str("<|im_start|>assistant\n");
419                    formatted.push_str(&message.content);
420                    formatted.push_str("\n<|im_end|>\n");
421                }
422                crate::chat::Role::Tool => {
423                    // For tool messages, include them as assistant responses
424                    formatted.push_str("<|im_start|>assistant\n");
425                    formatted.push_str(&message.content);
426                    formatted.push_str("\n<|im_end|>\n");
427                }
428            }
429        }
430
431        // Start the assistant's response
432        formatted.push_str("<|im_start|>assistant\n");
433
434        formatted
435    }
436}
437
438#[async_trait]
439impl LLMProvider for RemoteLLMClient {
440    fn as_any(&self) -> &dyn std::any::Any {
441        self
442    }
443
444    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
445        let url = format!("{}/chat/completions", self.config.base_url);
446
447        let response = self
448            .client
449            .post(&url)
450            .header("Authorization", format!("Bearer {}", self.config.api_key))
451            .header("Content-Type", "application/json")
452            .json(&request)
453            .send()
454            .await?;
455
456        if !response.status().is_success() {
457            let status = response.status();
458            let error_text = response
459                .text()
460                .await
461                .unwrap_or_else(|_| "Unknown error".to_string());
462            return Err(HeliosError::LLMError(format!(
463                "LLM API request failed with status {}: {}",
464                status, error_text
465            )));
466        }
467
468        let llm_response: LLMResponse = response.json().await?;
469        Ok(llm_response)
470    }
471}
472
473impl RemoteLLMClient {
474    /// Sends a chat request to the remote LLM.
475    pub async fn chat(
476        &self,
477        messages: Vec<ChatMessage>,
478        tools: Option<Vec<ToolDefinition>>,
479        temperature: Option<f32>,
480        max_tokens: Option<u32>,
481        stop: Option<Vec<String>>,
482    ) -> Result<ChatMessage> {
483        let request = LLMRequest {
484            model: self.config.model_name.clone(),
485            messages,
486            temperature: temperature.or(Some(self.config.temperature)),
487            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
488            tools,
489            tool_choice: None,
490            stream: None,
491            stop,
492        };
493
494        let response = self.generate(request).await?;
495
496        response
497            .choices
498            .into_iter()
499            .next()
500            .map(|choice| choice.message)
501            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
502    }
503
504    /// Sends a streaming chat request to the remote LLM.
505    pub async fn chat_stream<F>(
506        &self,
507        messages: Vec<ChatMessage>,
508        tools: Option<Vec<ToolDefinition>>,
509        temperature: Option<f32>,
510        max_tokens: Option<u32>,
511        stop: Option<Vec<String>>,
512        mut on_chunk: F,
513    ) -> Result<ChatMessage>
514    where
515        F: FnMut(&str) + Send,
516    {
517        let request = LLMRequest {
518            model: self.config.model_name.clone(),
519            messages,
520            temperature: temperature.or(Some(self.config.temperature)),
521            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
522            tools,
523            tool_choice: None,
524            stream: Some(true),
525            stop,
526        };
527
528        let url = format!("{}/chat/completions", self.config.base_url);
529
530        let response = self
531            .client
532            .post(&url)
533            .header("Authorization", format!("Bearer {}", self.config.api_key))
534            .header("Content-Type", "application/json")
535            .json(&request)
536            .send()
537            .await?;
538
539        if !response.status().is_success() {
540            let status = response.status();
541            let error_text = response
542                .text()
543                .await
544                .unwrap_or_else(|_| "Unknown error".to_string());
545            return Err(HeliosError::LLMError(format!(
546                "LLM API request failed with status {}: {}",
547                status, error_text
548            )));
549        }
550
551        let mut stream = response.bytes_stream();
552        let mut full_content = String::new();
553        let mut role = None;
554        let mut buffer = String::new();
555
556        while let Some(chunk_result) = stream.next().await {
557            let chunk = chunk_result?;
558            let chunk_str = String::from_utf8_lossy(&chunk);
559            buffer.push_str(&chunk_str);
560
561            // Process complete lines
562            while let Some(line_end) = buffer.find('\n') {
563                let line = buffer[..line_end].trim().to_string();
564                buffer = buffer[line_end + 1..].to_string();
565
566                if line.is_empty() || line == "data: [DONE]" {
567                    continue;
568                }
569
570                if let Some(data) = line.strip_prefix("data: ") {
571                    match serde_json::from_str::<StreamChunk>(data) {
572                        Ok(stream_chunk) => {
573                            if let Some(choice) = stream_chunk.choices.first() {
574                                if let Some(r) = &choice.delta.role {
575                                    role = Some(r.clone());
576                                }
577                                if let Some(content) = &choice.delta.content {
578                                    full_content.push_str(content);
579                                    on_chunk(content);
580                                }
581                            }
582                        }
583                        Err(e) => {
584                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
585                        }
586                    }
587                }
588            }
589        }
590
591        Ok(ChatMessage {
592            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
593            content: full_content,
594            name: None,
595            tool_calls: None,
596            tool_call_id: None,
597        })
598    }
599}
600
601#[cfg(feature = "local")]
602#[async_trait]
603impl LLMProvider for LocalLLMProvider {
604    fn as_any(&self) -> &dyn std::any::Any {
605        self
606    }
607
608    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
609        let prompt = self.format_messages(&request.messages);
610
611        // Suppress output during inference in offline mode
612        let (stdout_backup, stderr_backup) = suppress_output();
613
614        // Run inference in a blocking task
615        let model = Arc::clone(&self.model);
616        let backend = Arc::clone(&self.backend);
617        let result = task::spawn_blocking(move || {
618            // Create a fresh context per request (model/back-end are reused across calls)
619            use std::num::NonZeroU32;
620            let ctx_params =
621                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
622
623            let mut context = model
624                .new_context(&backend, ctx_params)
625                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
626
627            // Tokenize the prompt
628            let tokens = context
629                .model
630                .str_to_token(&prompt, AddBos::Always)
631                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
632
633            // Create batch for prompt
634            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
635            for (i, &token) in tokens.iter().enumerate() {
636                let compute_logits = true; // Compute logits for all tokens (they accumulate)
637                prompt_batch
638                    .add(token, i as i32, &[0], compute_logits)
639                    .map_err(|e| {
640                        HeliosError::LLMError(format!(
641                            "Failed to add prompt token to batch: {:?}",
642                            e
643                        ))
644                    })?;
645            }
646
647            // Decode the prompt
648            context
649                .decode(&mut prompt_batch)
650                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
651
652            // Generate response tokens
653            let mut generated_text = String::new();
654            let max_new_tokens = 512; // Increased limit for better responses
655            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
656
657            for _ in 0..max_new_tokens {
658                // Get logits from the last decoded position (get_logits returns logits for the last token)
659                let logits = context.get_logits();
660
661                let token_idx = logits
662                    .iter()
663                    .enumerate()
664                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
665                    .map(|(idx, _)| idx)
666                    .unwrap_or_else(|| {
667                        let eos = context.model.token_eos();
668                        eos.0 as usize
669                    });
670                let token = LlamaToken(token_idx as i32);
671
672                // Check for end of sequence
673                if token == context.model.token_eos() {
674                    break;
675                }
676
677                // Convert token back to text
678                match context.model.token_to_str(token, Special::Plaintext) {
679                    Ok(text) => {
680                        generated_text.push_str(&text);
681                    }
682                    Err(_) => continue, // Skip invalid tokens
683                }
684
685                // Create a new batch with just this token
686                let mut gen_batch = LlamaBatch::new(1, 1);
687                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
688                    HeliosError::LLMError(format!(
689                        "Failed to add generated token to batch: {:?}",
690                        e
691                    ))
692                })?;
693
694                // Decode the new token
695                context.decode(&mut gen_batch).map_err(|e| {
696                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
697                })?;
698
699                next_pos += 1;
700            }
701
702            Ok::<String, HeliosError>(generated_text)
703        })
704        .await
705        .map_err(|e| {
706            restore_output(stdout_backup, stderr_backup);
707            HeliosError::LLMError(format!("Task failed: {}", e))
708        })??;
709
710        // Restore output after inference completes
711        restore_output(stdout_backup, stderr_backup);
712
713        let response = LLMResponse {
714            id: format!("local-{}", chrono::Utc::now().timestamp()),
715            object: "chat.completion".to_string(),
716            created: chrono::Utc::now().timestamp() as u64,
717            model: "local-model".to_string(),
718            choices: vec![Choice {
719                index: 0,
720                message: ChatMessage {
721                    role: crate::chat::Role::Assistant,
722                    content: result,
723                    name: None,
724                    tool_calls: None,
725                    tool_call_id: None,
726                },
727                finish_reason: Some("stop".to_string()),
728            }],
729            usage: Usage {
730                prompt_tokens: 0,     // TODO: Calculate actual token count
731                completion_tokens: 0, // TODO: Calculate actual token count
732                total_tokens: 0,      // TODO: Calculate actual token count
733            },
734        };
735
736        Ok(response)
737    }
738}
739
740#[cfg(feature = "local")]
741impl LocalLLMProvider {
742    /// Sends a streaming chat request to the local LLM.
743    async fn chat_stream_local<F>(
744        &self,
745        messages: Vec<ChatMessage>,
746        _temperature: Option<f32>,
747        _max_tokens: Option<u32>,
748        _stop: Option<Vec<String>>,
749        mut on_chunk: F,
750    ) -> Result<ChatMessage>
751    where
752        F: FnMut(&str) + Send,
753    {
754        let prompt = self.format_messages(&messages);
755
756        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
757        let stderr_backup = suppress_stderr();
758
759        // Create a channel for streaming tokens
760        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
761
762        // Spawn blocking task for generation
763        let model = Arc::clone(&self.model);
764        let backend = Arc::clone(&self.backend);
765        let generation_task = task::spawn_blocking(move || {
766            // Create a fresh context per request (model/back-end are reused across calls)
767            use std::num::NonZeroU32;
768            let ctx_params =
769                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
770
771            let mut context = model
772                .new_context(&backend, ctx_params)
773                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
774
775            // Tokenize the prompt
776            let tokens = context
777                .model
778                .str_to_token(&prompt, AddBos::Always)
779                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
780
781            // Create batch for prompt
782            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
783            for (i, &token) in tokens.iter().enumerate() {
784                let compute_logits = true;
785                prompt_batch
786                    .add(token, i as i32, &[0], compute_logits)
787                    .map_err(|e| {
788                        HeliosError::LLMError(format!(
789                            "Failed to add prompt token to batch: {:?}",
790                            e
791                        ))
792                    })?;
793            }
794
795            // Decode the prompt
796            context
797                .decode(&mut prompt_batch)
798                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
799
800            // Generate response tokens with streaming
801            let mut generated_text = String::new();
802            let max_new_tokens = 512;
803            let mut next_pos = tokens.len() as i32;
804
805            for _ in 0..max_new_tokens {
806                let logits = context.get_logits();
807
808                let token_idx = logits
809                    .iter()
810                    .enumerate()
811                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
812                    .map(|(idx, _)| idx)
813                    .unwrap_or_else(|| {
814                        let eos = context.model.token_eos();
815                        eos.0 as usize
816                    });
817                let token = LlamaToken(token_idx as i32);
818
819                // Check for end of sequence
820                if token == context.model.token_eos() {
821                    break;
822                }
823
824                // Convert token back to text
825                match context.model.token_to_str(token, Special::Plaintext) {
826                    Ok(text) => {
827                        generated_text.push_str(&text);
828                        // Send token through channel; stop if receiver is dropped
829                        if tx.send(text).is_err() {
830                            break;
831                        }
832                    }
833                    Err(_) => continue,
834                }
835
836                // Create a new batch with just this token
837                let mut gen_batch = LlamaBatch::new(1, 1);
838                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
839                    HeliosError::LLMError(format!(
840                        "Failed to add generated token to batch: {:?}",
841                        e
842                    ))
843                })?;
844
845                // Decode the new token
846                context.decode(&mut gen_batch).map_err(|e| {
847                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
848                })?;
849
850                next_pos += 1;
851            }
852
853            Ok::<String, HeliosError>(generated_text)
854        });
855
856        // Receive and process tokens as they arrive
857        while let Some(token) = rx.recv().await {
858            on_chunk(&token);
859        }
860
861        // Wait for generation to complete and get the result
862        let result = match generation_task.await {
863            Ok(Ok(text)) => text,
864            Ok(Err(e)) => {
865                restore_stderr(stderr_backup);
866                return Err(e);
867            }
868            Err(e) => {
869                restore_stderr(stderr_backup);
870                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
871            }
872        };
873
874        // Restore stderr after generation completes
875        restore_stderr(stderr_backup);
876
877        Ok(ChatMessage {
878            role: crate::chat::Role::Assistant,
879            content: result,
880            name: None,
881            tool_calls: None,
882            tool_call_id: None,
883        })
884    }
885}
886
887#[async_trait]
888impl LLMProvider for LLMClient {
889    fn as_any(&self) -> &dyn std::any::Any {
890        self
891    }
892
893    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
894        self.provider.generate(request).await
895    }
896}
897
898impl LLMClient {
899    /// Sends a chat request to the LLM.
900    pub async fn chat(
901        &self,
902        messages: Vec<ChatMessage>,
903        tools: Option<Vec<ToolDefinition>>,
904        temperature: Option<f32>,
905        max_tokens: Option<u32>,
906        stop: Option<Vec<String>>,
907    ) -> Result<ChatMessage> {
908        let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
909            LLMProviderType::Remote(config) => (
910                config.model_name.clone(),
911                config.temperature,
912                config.max_tokens,
913            ),
914            #[cfg(feature = "local")]
915            LLMProviderType::Local(config) => (
916                "local-model".to_string(),
917                config.temperature,
918                config.max_tokens,
919            ),
920        };
921
922        let request = LLMRequest {
923            model: model_name,
924            messages,
925            temperature: temperature.or(Some(default_temperature)),
926            max_tokens: max_tokens.or(Some(default_max_tokens)),
927            tools,
928            tool_choice: None,
929            stream: None,
930            stop,
931        };
932
933        let response = self.generate(request).await?;
934
935        response
936            .choices
937            .into_iter()
938            .next()
939            .map(|choice| choice.message)
940            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
941    }
942
943    /// Sends a streaming chat request to the LLM.
944    pub async fn chat_stream<F>(
945        &self,
946        messages: Vec<ChatMessage>,
947        tools: Option<Vec<ToolDefinition>>,
948        temperature: Option<f32>,
949        max_tokens: Option<u32>,
950        stop: Option<Vec<String>>,
951        on_chunk: F,
952    ) -> Result<ChatMessage>
953    where
954        F: FnMut(&str) + Send,
955    {
956        match &self.provider_type {
957            LLMProviderType::Remote(_) => {
958                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
959                    provider
960                        .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
961                        .await
962                } else {
963                    Err(HeliosError::AgentError("Provider type mismatch".into()))
964                }
965            }
966            #[cfg(feature = "local")]
967            LLMProviderType::Local(_) => {
968                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
969                    provider
970                        .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
971                        .await
972                } else {
973                    Err(HeliosError::AgentError("Provider type mismatch".into()))
974                }
975            }
976        }
977    }
978}
979
980// Test module added