helios_engine/
llm.rs

1//! # LLM Module
2//!
3//! This module provides the functionality for interacting with Large Language Models (LLMs).
4//! It supports both remote LLMs (like OpenAI) and local LLMs (via `llama.cpp`).
5//! The `LLMClient` provides a unified interface for both types of providers.
6
7use crate::chat::ChatMessage;
8use crate::config::{LLMConfig, LocalConfig};
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use llama_cpp_2::context::params::LlamaContextParams;
14use llama_cpp_2::llama_backend::LlamaBackend;
15use llama_cpp_2::llama_batch::LlamaBatch;
16use llama_cpp_2::model::params::LlamaModelParams;
17use llama_cpp_2::model::{AddBos, LlamaModel, Special};
18use llama_cpp_2::token::LlamaToken;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::os::fd::AsRawFd;
23use std::sync::Arc;
24use tokio::task;
25
26// Add From trait for LLamaCppError to convert to HeliosError
27impl From<llama_cpp_2::LLamaCppError> for HeliosError {
28    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
29        HeliosError::LlamaCppError(format!("{:?}", err))
30    }
31}
32
33/// The type of LLM provider to use.
34#[derive(Clone)]
35pub enum LLMProviderType {
36    /// A remote LLM provider, such as OpenAI.
37    Remote(LLMConfig),
38    /// A local LLM provider, using `llama.cpp`.
39    Local(LocalConfig),
40}
41
42/// A request to an LLM.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct LLMRequest {
45    /// The model to use for the request.
46    pub model: String,
47    /// The messages to send to the model.
48    pub messages: Vec<ChatMessage>,
49    /// The temperature to use for the request.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub temperature: Option<f32>,
52    /// The maximum number of tokens to generate.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub max_tokens: Option<u32>,
55    /// The tools to make available to the model.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub tools: Option<Vec<ToolDefinition>>,
58    /// The tool choice to use for the request.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub tool_choice: Option<String>,
61    /// Whether to stream the response.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub stream: Option<bool>,
64    /// Stop sequences for the request.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub stop: Option<Vec<String>>,
67}
68
69/// A chunk of a streamed response.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct StreamChunk {
72    /// The ID of the chunk.
73    pub id: String,
74    /// The object type.
75    pub object: String,
76    /// The creation timestamp.
77    pub created: u64,
78    /// The model that generated the chunk.
79    pub model: String,
80    /// The choices in the chunk.
81    pub choices: Vec<StreamChoice>,
82}
83
84/// A choice in a streamed response.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct StreamChoice {
87    /// The index of the choice.
88    pub index: u32,
89    /// The delta of the choice.
90    pub delta: Delta,
91    /// The reason the stream finished.
92    pub finish_reason: Option<String>,
93}
94
95/// The delta of a streamed choice.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct Delta {
98    /// The role of the delta.
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub role: Option<String>,
101    /// The content of the delta.
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub content: Option<String>,
104}
105
106/// A response from an LLM.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct LLMResponse {
109    /// The ID of the response.
110    pub id: String,
111    /// The object type.
112    pub object: String,
113    /// The creation timestamp.
114    pub created: u64,
115    /// The model that generated the response.
116    pub model: String,
117    /// The choices in the response.
118    pub choices: Vec<Choice>,
119    /// The usage statistics for the response.
120    pub usage: Usage,
121}
122
123/// A choice in an LLM response.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct Choice {
126    /// The index of the choice.
127    pub index: u32,
128    /// The message of the choice.
129    pub message: ChatMessage,
130    /// The reason the generation finished.
131    pub finish_reason: Option<String>,
132}
133
134/// The usage statistics for an LLM response.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct Usage {
137    /// The number of tokens in the prompt.
138    pub prompt_tokens: u32,
139    /// The number of tokens in the completion.
140    pub completion_tokens: u32,
141    /// The total number of tokens.
142    pub total_tokens: u32,
143}
144
145/// A trait for LLM providers.
146#[async_trait]
147pub trait LLMProvider: Send + Sync {
148    /// Generates a response from the LLM.
149    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
150    /// Returns the provider as an `Any` type.
151    fn as_any(&self) -> &dyn std::any::Any;
152}
153
154/// A client for interacting with an LLM.
155pub struct LLMClient {
156    provider: Box<dyn LLMProvider + Send + Sync>,
157    provider_type: LLMProviderType,
158}
159
160impl LLMClient {
161    /// Creates a new `LLMClient`.
162    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
163        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
164            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
165            LLMProviderType::Local(config) => {
166                Box::new(LocalLLMProvider::new(config.clone()).await?)
167            }
168        };
169
170        Ok(Self {
171            provider,
172            provider_type,
173        })
174    }
175
176    /// Returns the type of the LLM provider.
177    pub fn provider_type(&self) -> &LLMProviderType {
178        &self.provider_type
179    }
180}
181
182/// A client for interacting with a remote LLM.
183pub struct RemoteLLMClient {
184    config: LLMConfig,
185    client: Client,
186}
187
188impl RemoteLLMClient {
189    /// Creates a new `RemoteLLMClient`.
190    pub fn new(config: LLMConfig) -> Self {
191        Self {
192            config,
193            client: Client::new(),
194        }
195    }
196
197    /// Returns the configuration of the client.
198    pub fn config(&self) -> &LLMConfig {
199        &self.config
200    }
201}
202
203/// Suppresses stdout and stderr.
204fn suppress_output() -> (i32, i32) {
205    // Open /dev/null for writing
206    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
207
208    // Duplicate current stdout and stderr file descriptors
209    let stdout_backup = unsafe { libc::dup(1) };
210    let stderr_backup = unsafe { libc::dup(2) };
211
212    // Redirect stdout and stderr to /dev/null
213    unsafe {
214        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
215        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
216    }
217
218    (stdout_backup, stderr_backup)
219}
220
221/// Restores stdout and stderr.
222fn restore_output(stdout_backup: i32, stderr_backup: i32) {
223    unsafe {
224        libc::dup2(stdout_backup, 1); // restore stdout
225        libc::dup2(stderr_backup, 2); // restore stderr
226        libc::close(stdout_backup);
227        libc::close(stderr_backup);
228    }
229}
230
231/// Suppresses stderr.
232fn suppress_stderr() -> i32 {
233    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
234    let stderr_backup = unsafe { libc::dup(2) };
235    unsafe {
236        libc::dup2(dev_null.as_raw_fd(), 2);
237    }
238    stderr_backup
239}
240
241/// Restores stderr.
242fn restore_stderr(stderr_backup: i32) {
243    unsafe {
244        libc::dup2(stderr_backup, 2);
245        libc::close(stderr_backup);
246    }
247}
248
249/// A provider for a local LLM.
250pub struct LocalLLMProvider {
251    model: Arc<LlamaModel>,
252    backend: Arc<LlamaBackend>,
253}
254
255impl LocalLLMProvider {
256    /// Creates a new `LocalLLMProvider`.
257    pub async fn new(config: LocalConfig) -> Result<Self> {
258        // Suppress verbose output during model loading in offline mode
259        let (stdout_backup, stderr_backup) = suppress_output();
260
261        // Initialize llama backend
262        let backend = LlamaBackend::init().map_err(|e| {
263            restore_output(stdout_backup, stderr_backup);
264            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
265        })?;
266
267        // Download model from HuggingFace if needed
268        let model_path = Self::download_model(&config).await.map_err(|e| {
269            restore_output(stdout_backup, stderr_backup);
270            e
271        })?;
272
273        // Load the model
274        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
275
276        let model =
277            LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
278                restore_output(stdout_backup, stderr_backup);
279                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
280            })?;
281
282        // Restore output
283        restore_output(stdout_backup, stderr_backup);
284
285        Ok(Self {
286            model: Arc::new(model),
287            backend: Arc::new(backend),
288        })
289    }
290
291    /// Downloads a model from Hugging Face.
292    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
293        use std::process::Command;
294
295        // Check if model is already in HuggingFace cache
296        if let Some(cached_path) =
297            Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
298        {
299            // Model found in cache - no output needed in offline mode
300            return Ok(cached_path);
301        }
302
303        // Model not found in cache - suppress download output in offline mode
304
305        // Use huggingface_hub to download the model (suppress output)
306        let output = Command::new("huggingface-cli")
307            .args([
308                "download",
309                &config.huggingface_repo,
310                &config.model_file,
311                "--local-dir",
312                ".cache/models",
313                "--local-dir-use-symlinks",
314                "False",
315            ])
316            .stdout(std::process::Stdio::null()) // Suppress stdout
317            .stderr(std::process::Stdio::null()) // Suppress stderr
318            .output()
319            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
320
321        if !output.status.success() {
322            return Err(HeliosError::LLMError(format!(
323                "Failed to download model: {}",
324                String::from_utf8_lossy(&output.stderr)
325            )));
326        }
327
328        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
329        if !model_path.exists() {
330            return Err(HeliosError::LLMError(format!(
331                "Model file not found after download: {}",
332                model_path.display()
333            )));
334        }
335
336        Ok(model_path)
337    }
338
339    /// Finds a model in the Hugging Face cache.
340    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
341        // Check HuggingFace cache directory
342        let cache_dir = std::env::var("HF_HOME")
343            .map(std::path::PathBuf::from)
344            .unwrap_or_else(|_| {
345                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
346                std::path::PathBuf::from(home)
347                    .join(".cache")
348                    .join("huggingface")
349            });
350
351        let hub_dir = cache_dir.join("hub");
352
353        // Convert repo name to HuggingFace cache format
354        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
355        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
356        let repo_dir = hub_dir.join(&cache_repo_name);
357
358        if !repo_dir.exists() {
359            return None;
360        }
361
362        // Check in snapshots directory (newer cache format)
363        let snapshots_dir = repo_dir.join("snapshots");
364        if snapshots_dir.exists() {
365            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
366                for entry in entries.flatten() {
367                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
368                        if snapshot_path.exists() {
369                            return Some(snapshot_path);
370                        }
371                    }
372                }
373            }
374        }
375
376        // Check in blobs directory (alternative cache format)
377        let blobs_dir = repo_dir.join("blobs");
378        if blobs_dir.exists() {
379            // For blobs, we need to find the blob file by hash
380            // This is more complex, so for now we'll skip this check
381            // The snapshots approach should cover most cases
382        }
383
384        None
385    }
386
387    /// Formats a list of messages into a single string.
388    fn format_messages(&self, messages: &[ChatMessage]) -> String {
389        let mut formatted = String::new();
390
391        // Use Qwen chat template format
392        for message in messages {
393            match message.role {
394                crate::chat::Role::System => {
395                    formatted.push_str("<|im_start|>system\n");
396                    formatted.push_str(&message.content);
397                    formatted.push_str("\n<|im_end|>\n");
398                }
399                crate::chat::Role::User => {
400                    formatted.push_str("<|im_start|>user\n");
401                    formatted.push_str(&message.content);
402                    formatted.push_str("\n<|im_end|>\n");
403                }
404                crate::chat::Role::Assistant => {
405                    formatted.push_str("<|im_start|>assistant\n");
406                    formatted.push_str(&message.content);
407                    formatted.push_str("\n<|im_end|>\n");
408                }
409                crate::chat::Role::Tool => {
410                    // For tool messages, include them as assistant responses
411                    formatted.push_str("<|im_start|>assistant\n");
412                    formatted.push_str(&message.content);
413                    formatted.push_str("\n<|im_end|>\n");
414                }
415            }
416        }
417
418        // Start the assistant's response
419        formatted.push_str("<|im_start|>assistant\n");
420
421        formatted
422    }
423}
424
425#[async_trait]
426impl LLMProvider for RemoteLLMClient {
427    fn as_any(&self) -> &dyn std::any::Any {
428        self
429    }
430
431    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
432        let url = format!("{}/chat/completions", self.config.base_url);
433
434        let response = self
435            .client
436            .post(&url)
437            .header("Authorization", format!("Bearer {}", self.config.api_key))
438            .header("Content-Type", "application/json")
439            .json(&request)
440            .send()
441            .await?;
442
443        if !response.status().is_success() {
444            let status = response.status();
445            let error_text = response
446                .text()
447                .await
448                .unwrap_or_else(|_| "Unknown error".to_string());
449            return Err(HeliosError::LLMError(format!(
450                "LLM API request failed with status {}: {}",
451                status, error_text
452            )));
453        }
454
455        let llm_response: LLMResponse = response.json().await?;
456        Ok(llm_response)
457    }
458}
459
460impl RemoteLLMClient {
461    /// Sends a chat request to the remote LLM.
462    pub async fn chat(
463        &self,
464        messages: Vec<ChatMessage>,
465        tools: Option<Vec<ToolDefinition>>,
466        temperature: Option<f32>,
467        max_tokens: Option<u32>,
468        stop: Option<Vec<String>>,
469    ) -> Result<ChatMessage> {
470        let request = LLMRequest {
471            model: self.config.model_name.clone(),
472            messages,
473            temperature: temperature.or(Some(self.config.temperature)),
474            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
475            tools,
476            tool_choice: None,
477            stream: None,
478            stop,
479        };
480
481        let response = self.generate(request).await?;
482
483        response
484            .choices
485            .into_iter()
486            .next()
487            .map(|choice| choice.message)
488            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
489    }
490
491    /// Sends a streaming chat request to the remote LLM.
492    pub async fn chat_stream<F>(
493        &self,
494        messages: Vec<ChatMessage>,
495        tools: Option<Vec<ToolDefinition>>,
496        temperature: Option<f32>,
497        max_tokens: Option<u32>,
498        stop: Option<Vec<String>>,
499        mut on_chunk: F,
500    ) -> Result<ChatMessage>
501    where
502        F: FnMut(&str) + Send,
503    {
504        let request = LLMRequest {
505            model: self.config.model_name.clone(),
506            messages,
507            temperature: temperature.or(Some(self.config.temperature)),
508            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
509            tools,
510            tool_choice: None,
511            stream: Some(true),
512            stop,
513        };
514
515        let url = format!("{}/chat/completions", self.config.base_url);
516
517        let response = self
518            .client
519            .post(&url)
520            .header("Authorization", format!("Bearer {}", self.config.api_key))
521            .header("Content-Type", "application/json")
522            .json(&request)
523            .send()
524            .await?;
525
526        if !response.status().is_success() {
527            let status = response.status();
528            let error_text = response
529                .text()
530                .await
531                .unwrap_or_else(|_| "Unknown error".to_string());
532            return Err(HeliosError::LLMError(format!(
533                "LLM API request failed with status {}: {}",
534                status, error_text
535            )));
536        }
537
538        let mut stream = response.bytes_stream();
539        let mut full_content = String::new();
540        let mut role = None;
541        let mut buffer = String::new();
542
543        while let Some(chunk_result) = stream.next().await {
544            let chunk = chunk_result?;
545            let chunk_str = String::from_utf8_lossy(&chunk);
546            buffer.push_str(&chunk_str);
547
548            // Process complete lines
549            while let Some(line_end) = buffer.find('\n') {
550                let line = buffer[..line_end].trim().to_string();
551                buffer = buffer[line_end + 1..].to_string();
552
553                if line.is_empty() || line == "data: [DONE]" {
554                    continue;
555                }
556
557                if let Some(data) = line.strip_prefix("data: ") {
558                    match serde_json::from_str::<StreamChunk>(data) {
559                        Ok(stream_chunk) => {
560                            if let Some(choice) = stream_chunk.choices.first() {
561                                if let Some(r) = &choice.delta.role {
562                                    role = Some(r.clone());
563                                }
564                                if let Some(content) = &choice.delta.content {
565                                    full_content.push_str(content);
566                                    on_chunk(content);
567                                }
568                            }
569                        }
570                        Err(e) => {
571                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
572                        }
573                    }
574                }
575            }
576        }
577
578        Ok(ChatMessage {
579            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
580            content: full_content,
581            name: None,
582            tool_calls: None,
583            tool_call_id: None,
584        })
585    }
586}
587
588#[async_trait]
589impl LLMProvider for LocalLLMProvider {
590    fn as_any(&self) -> &dyn std::any::Any {
591        self
592    }
593
594    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
595        let prompt = self.format_messages(&request.messages);
596
597        // Suppress output during inference in offline mode
598        let (stdout_backup, stderr_backup) = suppress_output();
599
600        // Run inference in a blocking task
601        let model = Arc::clone(&self.model);
602        let backend = Arc::clone(&self.backend);
603        let result = task::spawn_blocking(move || {
604            // Create a fresh context per request (model/back-end are reused across calls)
605            use std::num::NonZeroU32;
606            let ctx_params =
607                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
608
609            let mut context = model
610                .new_context(&backend, ctx_params)
611                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
612
613            // Tokenize the prompt
614            let tokens = context
615                .model
616                .str_to_token(&prompt, AddBos::Always)
617                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
618
619            // Create batch for prompt
620            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
621            for (i, &token) in tokens.iter().enumerate() {
622                let compute_logits = true; // Compute logits for all tokens (they accumulate)
623                prompt_batch
624                    .add(token, i as i32, &[0], compute_logits)
625                    .map_err(|e| {
626                        HeliosError::LLMError(format!(
627                            "Failed to add prompt token to batch: {:?}",
628                            e
629                        ))
630                    })?;
631            }
632
633            // Decode the prompt
634            context
635                .decode(&mut prompt_batch)
636                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
637
638            // Generate response tokens
639            let mut generated_text = String::new();
640            let max_new_tokens = 512; // Increased limit for better responses
641            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
642
643            for _ in 0..max_new_tokens {
644                // Get logits from the last decoded position (get_logits returns logits for the last token)
645                let logits = context.get_logits();
646
647                let token_idx = logits
648                    .iter()
649                    .enumerate()
650                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
651                    .map(|(idx, _)| idx)
652                    .unwrap_or_else(|| {
653                        let eos = context.model.token_eos();
654                        eos.0 as usize
655                    });
656                let token = LlamaToken(token_idx as i32);
657
658                // Check for end of sequence
659                if token == context.model.token_eos() {
660                    break;
661                }
662
663                // Convert token back to text
664                match context.model.token_to_str(token, Special::Plaintext) {
665                    Ok(text) => {
666                        generated_text.push_str(&text);
667                    }
668                    Err(_) => continue, // Skip invalid tokens
669                }
670
671                // Create a new batch with just this token
672                let mut gen_batch = LlamaBatch::new(1, 1);
673                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
674                    HeliosError::LLMError(format!(
675                        "Failed to add generated token to batch: {:?}",
676                        e
677                    ))
678                })?;
679
680                // Decode the new token
681                context.decode(&mut gen_batch).map_err(|e| {
682                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
683                })?;
684
685                next_pos += 1;
686            }
687
688            Ok::<String, HeliosError>(generated_text)
689        })
690        .await
691        .map_err(|e| {
692            restore_output(stdout_backup, stderr_backup);
693            HeliosError::LLMError(format!("Task failed: {}", e))
694        })??;
695
696        // Restore output after inference completes
697        restore_output(stdout_backup, stderr_backup);
698
699        let response = LLMResponse {
700            id: format!("local-{}", chrono::Utc::now().timestamp()),
701            object: "chat.completion".to_string(),
702            created: chrono::Utc::now().timestamp() as u64,
703            model: "local-model".to_string(),
704            choices: vec![Choice {
705                index: 0,
706                message: ChatMessage {
707                    role: crate::chat::Role::Assistant,
708                    content: result,
709                    name: None,
710                    tool_calls: None,
711                    tool_call_id: None,
712                },
713                finish_reason: Some("stop".to_string()),
714            }],
715            usage: Usage {
716                prompt_tokens: 0,     // TODO: Calculate actual token count
717                completion_tokens: 0, // TODO: Calculate actual token count
718                total_tokens: 0,      // TODO: Calculate actual token count
719            },
720        };
721
722        Ok(response)
723    }
724}
725
726impl LocalLLMProvider {
727    /// Sends a streaming chat request to the local LLM.
728    async fn chat_stream_local<F>(
729        &self,
730        messages: Vec<ChatMessage>,
731        _temperature: Option<f32>,
732        _max_tokens: Option<u32>,
733        _stop: Option<Vec<String>>,
734        mut on_chunk: F,
735    ) -> Result<ChatMessage>
736    where
737        F: FnMut(&str) + Send,
738    {
739        let prompt = self.format_messages(&messages);
740
741        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
742        let stderr_backup = suppress_stderr();
743
744        // Create a channel for streaming tokens
745        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
746
747        // Spawn blocking task for generation
748        let model = Arc::clone(&self.model);
749        let backend = Arc::clone(&self.backend);
750        let generation_task = task::spawn_blocking(move || {
751            // Create a fresh context per request (model/back-end are reused across calls)
752            use std::num::NonZeroU32;
753            let ctx_params =
754                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
755
756            let mut context = model
757                .new_context(&backend, ctx_params)
758                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
759
760            // Tokenize the prompt
761            let tokens = context
762                .model
763                .str_to_token(&prompt, AddBos::Always)
764                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
765
766            // Create batch for prompt
767            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
768            for (i, &token) in tokens.iter().enumerate() {
769                let compute_logits = true;
770                prompt_batch
771                    .add(token, i as i32, &[0], compute_logits)
772                    .map_err(|e| {
773                        HeliosError::LLMError(format!(
774                            "Failed to add prompt token to batch: {:?}",
775                            e
776                        ))
777                    })?;
778            }
779
780            // Decode the prompt
781            context
782                .decode(&mut prompt_batch)
783                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
784
785            // Generate response tokens with streaming
786            let mut generated_text = String::new();
787            let max_new_tokens = 512;
788            let mut next_pos = tokens.len() as i32;
789
790            for _ in 0..max_new_tokens {
791                let logits = context.get_logits();
792
793                let token_idx = logits
794                    .iter()
795                    .enumerate()
796                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
797                    .map(|(idx, _)| idx)
798                    .unwrap_or_else(|| {
799                        let eos = context.model.token_eos();
800                        eos.0 as usize
801                    });
802                let token = LlamaToken(token_idx as i32);
803
804                // Check for end of sequence
805                if token == context.model.token_eos() {
806                    break;
807                }
808
809                // Convert token back to text
810                match context.model.token_to_str(token, Special::Plaintext) {
811                    Ok(text) => {
812                        generated_text.push_str(&text);
813                        // Send token through channel; stop if receiver is dropped
814                        if tx.send(text).is_err() {
815                            break;
816                        }
817                    }
818                    Err(_) => continue,
819                }
820
821                // Create a new batch with just this token
822                let mut gen_batch = LlamaBatch::new(1, 1);
823                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
824                    HeliosError::LLMError(format!(
825                        "Failed to add generated token to batch: {:?}",
826                        e
827                    ))
828                })?;
829
830                // Decode the new token
831                context.decode(&mut gen_batch).map_err(|e| {
832                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
833                })?;
834
835                next_pos += 1;
836            }
837
838            Ok::<String, HeliosError>(generated_text)
839        });
840
841        // Receive and process tokens as they arrive
842        while let Some(token) = rx.recv().await {
843            on_chunk(&token);
844        }
845
846        // Wait for generation to complete and get the result
847        let result = match generation_task.await {
848            Ok(Ok(text)) => text,
849            Ok(Err(e)) => {
850                restore_stderr(stderr_backup);
851                return Err(e);
852            }
853            Err(e) => {
854                restore_stderr(stderr_backup);
855                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
856            }
857        };
858
859        // Restore stderr after generation completes
860        restore_stderr(stderr_backup);
861
862        Ok(ChatMessage {
863            role: crate::chat::Role::Assistant,
864            content: result,
865            name: None,
866            tool_calls: None,
867            tool_call_id: None,
868        })
869    }
870}
871
872#[async_trait]
873impl LLMProvider for LLMClient {
874    fn as_any(&self) -> &dyn std::any::Any {
875        self
876    }
877
878    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
879        self.provider.generate(request).await
880    }
881}
882
883impl LLMClient {
884    /// Sends a chat request to the LLM.
885    pub async fn chat(
886        &self,
887        messages: Vec<ChatMessage>,
888        tools: Option<Vec<ToolDefinition>>,
889        temperature: Option<f32>,
890        max_tokens: Option<u32>,
891        stop: Option<Vec<String>>,
892    ) -> Result<ChatMessage> {
893        let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
894            LLMProviderType::Remote(config) => (
895                config.model_name.clone(),
896                config.temperature,
897                config.max_tokens,
898            ),
899            LLMProviderType::Local(config) => (
900                "local-model".to_string(),
901                config.temperature,
902                config.max_tokens,
903            ),
904        };
905
906        let request = LLMRequest {
907            model: model_name,
908            messages,
909            temperature: temperature.or(Some(default_temperature)),
910            max_tokens: max_tokens.or(Some(default_max_tokens)),
911            tools,
912            tool_choice: None,
913            stream: None,
914            stop,
915        };
916
917        let response = self.generate(request).await?;
918
919        response
920            .choices
921            .into_iter()
922            .next()
923            .map(|choice| choice.message)
924            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
925    }
926
927    /// Sends a streaming chat request to the LLM.
928    pub async fn chat_stream<F>(
929        &self,
930        messages: Vec<ChatMessage>,
931        tools: Option<Vec<ToolDefinition>>,
932        temperature: Option<f32>,
933        max_tokens: Option<u32>,
934        stop: Option<Vec<String>>,
935        on_chunk: F,
936    ) -> Result<ChatMessage>
937    where
938        F: FnMut(&str) + Send,
939    {
940        match &self.provider_type {
941            LLMProviderType::Remote(_) => {
942                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
943                    provider
944                        .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
945                        .await
946                } else {
947                    Err(HeliosError::AgentError("Provider type mismatch".into()))
948                }
949            }
950            LLMProviderType::Local(_) => {
951                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
952                    provider
953                        .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
954                        .await
955                } else {
956                    Err(HeliosError::AgentError("Provider type mismatch".into()))
957                }
958            }
959        }
960    }
961}
962
963// Test module added