helios_engine/
llm.rs

1//! # LLM Module
2//!
3//! This module provides the functionality for interacting with Large Language Models (LLMs).
4//! It supports both remote LLMs (like OpenAI) and local LLMs (via `llama.cpp`).
5//! The `LLMClient` provides a unified interface for both types of providers.
6
7use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18    crate::config::LocalConfig,
19    llama_cpp_2::{
20        context::params::LlamaContextParams,
21        llama_backend::LlamaBackend,
22        llama_batch::LlamaBatch,
23        model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24        token::LlamaToken,
25    },
26    std::{fs::File, os::fd::AsRawFd, sync::Arc},
27    tokio::task,
28};
29
30// Add From trait for LLamaCppError to convert to HeliosError
31#[cfg(feature = "local")]
32impl From<llama_cpp_2::LLamaCppError> for HeliosError {
33    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
34        HeliosError::LlamaCppError(format!("{:?}", err))
35    }
36}
37
38/// The type of LLM provider to use.
39#[derive(Clone)]
40pub enum LLMProviderType {
41    /// A remote LLM provider, such as OpenAI.
42    Remote(LLMConfig),
43    /// A local LLM provider, using `llama.cpp`.
44    #[cfg(feature = "local")]
45    Local(LocalConfig),
46}
47
48/// A request to an LLM.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LLMRequest {
51    /// The model to use for the request.
52    pub model: String,
53    /// The messages to send to the model.
54    pub messages: Vec<ChatMessage>,
55    /// The temperature to use for the request.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub temperature: Option<f32>,
58    /// The maximum number of tokens to generate.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub max_tokens: Option<u32>,
61    /// The tools to make available to the model.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub tools: Option<Vec<ToolDefinition>>,
64    /// The tool choice to use for the request.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub tool_choice: Option<String>,
67    /// Whether to stream the response.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub stream: Option<bool>,
70    /// Stop sequences for the request.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub stop: Option<Vec<String>>,
73}
74
75/// A chunk of a streamed response.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct StreamChunk {
78    /// The ID of the chunk.
79    pub id: String,
80    /// The object type.
81    pub object: String,
82    /// The creation timestamp.
83    pub created: u64,
84    /// The model that generated the chunk.
85    pub model: String,
86    /// The choices in the chunk.
87    pub choices: Vec<StreamChoice>,
88}
89
90/// A choice in a streamed response.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct StreamChoice {
93    /// The index of the choice.
94    pub index: u32,
95    /// The delta of the choice.
96    pub delta: Delta,
97    /// The reason the stream finished.
98    pub finish_reason: Option<String>,
99}
100
101/// A tool call in a streamed delta.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct DeltaToolCall {
104    /// The index of the tool call.
105    pub index: u32,
106    /// The ID of the tool call.
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub id: Option<String>,
109    /// The function call information.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub function: Option<DeltaFunctionCall>,
112}
113
114/// A function call in a streamed delta.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct DeltaFunctionCall {
117    /// The name of the function.
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub name: Option<String>,
120    /// The arguments for the function.
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub arguments: Option<String>,
123}
124
125/// The delta of a streamed choice.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct Delta {
128    /// The role of the delta.
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub role: Option<String>,
131    /// The content of the delta.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub content: Option<String>,
134    /// The tool calls in the delta.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub tool_calls: Option<Vec<DeltaToolCall>>,
137}
138
139/// A response from an LLM.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct LLMResponse {
142    /// The ID of the response.
143    pub id: String,
144    /// The object type.
145    pub object: String,
146    /// The creation timestamp.
147    pub created: u64,
148    /// The model that generated the response.
149    pub model: String,
150    /// The choices in the response.
151    pub choices: Vec<Choice>,
152    /// The usage statistics for the response.
153    pub usage: Usage,
154}
155
156/// A choice in an LLM response.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct Choice {
159    /// The index of the choice.
160    pub index: u32,
161    /// The message of the choice.
162    pub message: ChatMessage,
163    /// The reason the generation finished.
164    pub finish_reason: Option<String>,
165}
166
167/// The usage statistics for an LLM response.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct Usage {
170    /// The number of tokens in the prompt.
171    pub prompt_tokens: u32,
172    /// The number of tokens in the completion.
173    pub completion_tokens: u32,
174    /// The total number of tokens.
175    pub total_tokens: u32,
176}
177
178/// A trait for LLM providers.
179#[async_trait]
180pub trait LLMProvider: Send + Sync {
181    /// Generates a response from the LLM.
182    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
183    /// Returns the provider as an `Any` type.
184    fn as_any(&self) -> &dyn std::any::Any;
185}
186
187/// A client for interacting with an LLM.
188pub struct LLMClient {
189    provider: Box<dyn LLMProvider + Send + Sync>,
190    provider_type: LLMProviderType,
191}
192
193impl LLMClient {
194    /// Creates a new `LLMClient`.
195    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
196        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
197            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
198            #[cfg(feature = "local")]
199            LLMProviderType::Local(config) => {
200                Box::new(LocalLLMProvider::new(config.clone()).await?)
201            }
202        };
203
204        Ok(Self {
205            provider,
206            provider_type,
207        })
208    }
209
210    /// Returns the type of the LLM provider.
211    pub fn provider_type(&self) -> &LLMProviderType {
212        &self.provider_type
213    }
214}
215
216/// A client for interacting with a remote LLM.
217pub struct RemoteLLMClient {
218    config: LLMConfig,
219    client: Client,
220}
221
222impl RemoteLLMClient {
223    /// Creates a new `RemoteLLMClient`.
224    pub fn new(config: LLMConfig) -> Self {
225        Self {
226            config,
227            client: Client::new(),
228        }
229    }
230
231    /// Returns the configuration of the client.
232    pub fn config(&self) -> &LLMConfig {
233        &self.config
234    }
235}
236
237/// Suppresses stdout and stderr.
238#[cfg(feature = "local")]
239fn suppress_output() -> (i32, i32) {
240    // Open /dev/null for writing
241    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
242
243    // Duplicate current stdout and stderr file descriptors
244    let stdout_backup = unsafe { libc::dup(1) };
245    let stderr_backup = unsafe { libc::dup(2) };
246
247    // Redirect stdout and stderr to /dev/null
248    unsafe {
249        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
250        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
251    }
252
253    (stdout_backup, stderr_backup)
254}
255
256/// Restores stdout and stderr.
257#[cfg(feature = "local")]
258fn restore_output(stdout_backup: i32, stderr_backup: i32) {
259    unsafe {
260        libc::dup2(stdout_backup, 1); // restore stdout
261        libc::dup2(stderr_backup, 2); // restore stderr
262        libc::close(stdout_backup);
263        libc::close(stderr_backup);
264    }
265}
266
267/// Suppresses stderr.
268#[cfg(feature = "local")]
269fn suppress_stderr() -> i32 {
270    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
271    let stderr_backup = unsafe { libc::dup(2) };
272    unsafe {
273        libc::dup2(dev_null.as_raw_fd(), 2);
274    }
275    stderr_backup
276}
277
278/// Restores stderr.
279#[cfg(feature = "local")]
280fn restore_stderr(stderr_backup: i32) {
281    unsafe {
282        libc::dup2(stderr_backup, 2);
283        libc::close(stderr_backup);
284    }
285}
286
287/// A provider for a local LLM.
288#[cfg(feature = "local")]
289pub struct LocalLLMProvider {
290    model: Arc<LlamaModel>,
291    backend: Arc<LlamaBackend>,
292}
293
294#[cfg(feature = "local")]
295impl LocalLLMProvider {
296    /// Creates a new `LocalLLMProvider`.
297    pub async fn new(config: LocalConfig) -> Result<Self> {
298        // Suppress verbose output during model loading in offline mode
299        let (stdout_backup, stderr_backup) = suppress_output();
300
301        // Initialize llama backend
302        let backend = LlamaBackend::init().map_err(|e| {
303            restore_output(stdout_backup, stderr_backup);
304            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
305        })?;
306
307        // Download model from HuggingFace if needed
308        let model_path = Self::download_model(&config).await.map_err(|e| {
309            restore_output(stdout_backup, stderr_backup);
310            e
311        })?;
312
313        // Load the model
314        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
315
316        let model =
317            LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
318                restore_output(stdout_backup, stderr_backup);
319                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
320            })?;
321
322        // Restore output
323        restore_output(stdout_backup, stderr_backup);
324
325        Ok(Self {
326            model: Arc::new(model),
327            backend: Arc::new(backend),
328        })
329    }
330
331    /// Downloads a model from Hugging Face.
332    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
333        use std::process::Command;
334
335        // Check if model is already in HuggingFace cache
336        if let Some(cached_path) =
337            Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
338        {
339            // Model found in cache - no output needed in offline mode
340            return Ok(cached_path);
341        }
342
343        // Model not found in cache - suppress download output in offline mode
344
345        // Use huggingface_hub to download the model (suppress output)
346        let output = Command::new("huggingface-cli")
347            .args([
348                "download",
349                &config.huggingface_repo,
350                &config.model_file,
351                "--local-dir",
352                ".cache/models",
353                "--local-dir-use-symlinks",
354                "False",
355            ])
356            .stdout(std::process::Stdio::null()) // Suppress stdout
357            .stderr(std::process::Stdio::null()) // Suppress stderr
358            .output()
359            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
360
361        if !output.status.success() {
362            return Err(HeliosError::LLMError(format!(
363                "Failed to download model: {}",
364                String::from_utf8_lossy(&output.stderr)
365            )));
366        }
367
368        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
369        if !model_path.exists() {
370            return Err(HeliosError::LLMError(format!(
371                "Model file not found after download: {}",
372                model_path.display()
373            )));
374        }
375
376        Ok(model_path)
377    }
378
379    /// Finds a model in the Hugging Face cache.
380    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
381        // Check HuggingFace cache directory
382        let cache_dir = std::env::var("HF_HOME")
383            .map(std::path::PathBuf::from)
384            .unwrap_or_else(|_| {
385                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
386                std::path::PathBuf::from(home)
387                    .join(".cache")
388                    .join("huggingface")
389            });
390
391        let hub_dir = cache_dir.join("hub");
392
393        // Convert repo name to HuggingFace cache format
394        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
395        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
396        let repo_dir = hub_dir.join(&cache_repo_name);
397
398        if !repo_dir.exists() {
399            return None;
400        }
401
402        // Check in snapshots directory (newer cache format)
403        let snapshots_dir = repo_dir.join("snapshots");
404        if snapshots_dir.exists() {
405            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
406                for entry in entries.flatten() {
407                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
408                        if snapshot_path.exists() {
409                            return Some(snapshot_path);
410                        }
411                    }
412                }
413            }
414        }
415
416        // Check in blobs directory (alternative cache format)
417        let blobs_dir = repo_dir.join("blobs");
418        if blobs_dir.exists() {
419            // For blobs, we need to find the blob file by hash
420            // This is more complex, so for now we'll skip this check
421            // The snapshots approach should cover most cases
422        }
423
424        None
425    }
426
427    /// Formats a list of messages into a single string.
428    fn format_messages(&self, messages: &[ChatMessage]) -> String {
429        let mut formatted = String::new();
430
431        // Use Qwen chat template format
432        for message in messages {
433            match message.role {
434                crate::chat::Role::System => {
435                    formatted.push_str("<|im_start|>system\n");
436                    formatted.push_str(&message.content);
437                    formatted.push_str("\n<|im_end|>\n");
438                }
439                crate::chat::Role::User => {
440                    formatted.push_str("<|im_start|>user\n");
441                    formatted.push_str(&message.content);
442                    formatted.push_str("\n<|im_end|>\n");
443                }
444                crate::chat::Role::Assistant => {
445                    formatted.push_str("<|im_start|>assistant\n");
446                    formatted.push_str(&message.content);
447                    formatted.push_str("\n<|im_end|>\n");
448                }
449                crate::chat::Role::Tool => {
450                    // For tool messages, include them as assistant responses
451                    formatted.push_str("<|im_start|>assistant\n");
452                    formatted.push_str(&message.content);
453                    formatted.push_str("\n<|im_end|>\n");
454                }
455            }
456        }
457
458        // Start the assistant's response
459        formatted.push_str("<|im_start|>assistant\n");
460
461        formatted
462    }
463}
464
465#[async_trait]
466impl LLMProvider for RemoteLLMClient {
467    fn as_any(&self) -> &dyn std::any::Any {
468        self
469    }
470
471    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
472        let url = format!("{}/chat/completions", self.config.base_url);
473
474        let mut request_builder = self
475            .client
476            .post(&url)
477            .header("Content-Type", "application/json");
478
479        // Only add authorization header if not a local vLLM instance
480        if !self.config.base_url.contains("10.")
481            && !self.config.base_url.contains("localhost")
482            && !self.config.base_url.contains("127.0.0.1")
483        {
484            request_builder =
485                request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
486        }
487
488        let response = request_builder.json(&request).send().await?;
489
490        if !response.status().is_success() {
491            let status = response.status();
492            let error_text = response
493                .text()
494                .await
495                .unwrap_or_else(|_| "Unknown error".to_string());
496            return Err(HeliosError::LLMError(format!(
497                "LLM API request failed with status {}: {}",
498                status, error_text
499            )));
500        }
501
502        let llm_response: LLMResponse = response.json().await?;
503        Ok(llm_response)
504    }
505}
506
507impl RemoteLLMClient {
508    /// Sends a chat request to the remote LLM.
509    pub async fn chat(
510        &self,
511        messages: Vec<ChatMessage>,
512        tools: Option<Vec<ToolDefinition>>,
513        temperature: Option<f32>,
514        max_tokens: Option<u32>,
515        stop: Option<Vec<String>>,
516    ) -> Result<ChatMessage> {
517        let request = LLMRequest {
518            model: self.config.model_name.clone(),
519            messages,
520            temperature: temperature.or(Some(self.config.temperature)),
521            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
522            tools: tools.clone(),
523            tool_choice: if tools.is_some() {
524                Some("auto".to_string())
525            } else {
526                None
527            },
528            stream: None,
529            stop,
530        };
531
532        let response = self.generate(request).await?;
533
534        response
535            .choices
536            .into_iter()
537            .next()
538            .map(|choice| choice.message)
539            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
540    }
541
542    /// Sends a streaming chat request to the remote LLM.
543    pub async fn chat_stream<F>(
544        &self,
545        messages: Vec<ChatMessage>,
546        tools: Option<Vec<ToolDefinition>>,
547        temperature: Option<f32>,
548        max_tokens: Option<u32>,
549        stop: Option<Vec<String>>,
550        mut on_chunk: F,
551    ) -> Result<ChatMessage>
552    where
553        F: FnMut(&str) + Send,
554    {
555        let request = LLMRequest {
556            model: self.config.model_name.clone(),
557            messages,
558            temperature: temperature.or(Some(self.config.temperature)),
559            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
560            tools: tools.clone(),
561            tool_choice: if tools.is_some() {
562                Some("auto".to_string())
563            } else {
564                None
565            },
566            stream: Some(true),
567            stop,
568        };
569
570        let url = format!("{}/chat/completions", self.config.base_url);
571
572        let mut request_builder = self
573            .client
574            .post(&url)
575            .header("Content-Type", "application/json");
576
577        // Only add authorization header if not a local vLLM instance
578        if !self.config.base_url.contains("10.")
579            && !self.config.base_url.contains("localhost")
580            && !self.config.base_url.contains("127.0.0.1")
581        {
582            request_builder =
583                request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
584        }
585
586        let response = request_builder.json(&request).send().await?;
587
588        if !response.status().is_success() {
589            let status = response.status();
590            let error_text = response
591                .text()
592                .await
593                .unwrap_or_else(|_| "Unknown error".to_string());
594            return Err(HeliosError::LLMError(format!(
595                "LLM API request failed with status {}: {}",
596                status, error_text
597            )));
598        }
599
600        let mut stream = response.bytes_stream();
601        let mut full_content = String::new();
602        let mut role = None;
603        let mut tool_calls = Vec::new();
604        let mut buffer = String::new();
605
606        while let Some(chunk_result) = stream.next().await {
607            let chunk = chunk_result?;
608            let chunk_str = String::from_utf8_lossy(&chunk);
609            buffer.push_str(&chunk_str);
610
611            // Process complete lines
612            while let Some(line_end) = buffer.find('\n') {
613                let line = buffer[..line_end].trim().to_string();
614                buffer = buffer[line_end + 1..].to_string();
615
616                if line.is_empty() || line == "data: [DONE]" {
617                    continue;
618                }
619
620                if let Some(data) = line.strip_prefix("data: ") {
621                    match serde_json::from_str::<StreamChunk>(data) {
622                        Ok(stream_chunk) => {
623                            if let Some(choice) = stream_chunk.choices.first() {
624                                if let Some(r) = &choice.delta.role {
625                                    role = Some(r.clone());
626                                }
627                                if let Some(content) = &choice.delta.content {
628                                    full_content.push_str(content);
629                                    on_chunk(content);
630                                }
631                                if let Some(delta_tool_calls) = &choice.delta.tool_calls {
632                                    for delta_tool_call in delta_tool_calls {
633                                        // Find or create the tool call at this index
634                                        while tool_calls.len() <= delta_tool_call.index as usize {
635                                            tool_calls.push(None);
636                                        }
637                                        let tool_call_slot =
638                                            &mut tool_calls[delta_tool_call.index as usize];
639
640                                        if tool_call_slot.is_none() {
641                                            *tool_call_slot = Some(crate::chat::ToolCall {
642                                                id: String::new(),
643                                                call_type: "function".to_string(),
644                                                function: crate::chat::FunctionCall {
645                                                    name: String::new(),
646                                                    arguments: String::new(),
647                                                },
648                                            });
649                                        }
650
651                                        if let Some(tool_call) = tool_call_slot.as_mut() {
652                                            if let Some(id) = &delta_tool_call.id {
653                                                tool_call.id = id.clone();
654                                            }
655                                            if let Some(function) = &delta_tool_call.function {
656                                                if let Some(name) = &function.name {
657                                                    tool_call.function.name = name.clone();
658                                                }
659                                                if let Some(args) = &function.arguments {
660                                                    tool_call.function.arguments.push_str(args);
661                                                }
662                                            }
663                                        }
664                                    }
665                                }
666                            }
667                        }
668                        Err(e) => {
669                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
670                        }
671                    }
672                }
673            }
674        }
675
676        let final_tool_calls = tool_calls.into_iter().flatten().collect::<Vec<_>>();
677        let tool_calls_option = if final_tool_calls.is_empty() {
678            None
679        } else {
680            Some(final_tool_calls)
681        };
682
683        Ok(ChatMessage {
684            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
685            content: full_content,
686            name: None,
687            tool_calls: tool_calls_option,
688            tool_call_id: None,
689        })
690    }
691}
692
693#[cfg(feature = "local")]
694#[async_trait]
695impl LLMProvider for LocalLLMProvider {
696    fn as_any(&self) -> &dyn std::any::Any {
697        self
698    }
699
700    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
701        let prompt = self.format_messages(&request.messages);
702
703        // Suppress output during inference in offline mode
704        let (stdout_backup, stderr_backup) = suppress_output();
705
706        // Run inference in a blocking task
707        let model = Arc::clone(&self.model);
708        let backend = Arc::clone(&self.backend);
709        let result = task::spawn_blocking(move || {
710            // Create a fresh context per request (model/back-end are reused across calls)
711            use std::num::NonZeroU32;
712            let ctx_params =
713                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
714
715            let mut context = model
716                .new_context(&backend, ctx_params)
717                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
718
719            // Tokenize the prompt
720            let tokens = context
721                .model
722                .str_to_token(&prompt, AddBos::Always)
723                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
724
725            // Create batch for prompt
726            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
727            for (i, &token) in tokens.iter().enumerate() {
728                let compute_logits = true; // Compute logits for all tokens (they accumulate)
729                prompt_batch
730                    .add(token, i as i32, &[0], compute_logits)
731                    .map_err(|e| {
732                        HeliosError::LLMError(format!(
733                            "Failed to add prompt token to batch: {:?}",
734                            e
735                        ))
736                    })?;
737            }
738
739            // Decode the prompt
740            context
741                .decode(&mut prompt_batch)
742                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
743
744            // Generate response tokens
745            let mut generated_text = String::new();
746            let max_new_tokens = 512; // Increased limit for better responses
747            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
748
749            for _ in 0..max_new_tokens {
750                // Get logits from the last decoded position (get_logits returns logits for the last token)
751                let logits = context.get_logits();
752
753                let token_idx = logits
754                    .iter()
755                    .enumerate()
756                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
757                    .map(|(idx, _)| idx)
758                    .unwrap_or_else(|| {
759                        let eos = context.model.token_eos();
760                        eos.0 as usize
761                    });
762                let token = LlamaToken(token_idx as i32);
763
764                // Check for end of sequence
765                if token == context.model.token_eos() {
766                    break;
767                }
768
769                // Convert token back to text
770                match context.model.token_to_str(token, Special::Plaintext) {
771                    Ok(text) => {
772                        generated_text.push_str(&text);
773                    }
774                    Err(_) => continue, // Skip invalid tokens
775                }
776
777                // Create a new batch with just this token
778                let mut gen_batch = LlamaBatch::new(1, 1);
779                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
780                    HeliosError::LLMError(format!(
781                        "Failed to add generated token to batch: {:?}",
782                        e
783                    ))
784                })?;
785
786                // Decode the new token
787                context.decode(&mut gen_batch).map_err(|e| {
788                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
789                })?;
790
791                next_pos += 1;
792            }
793
794            Ok::<String, HeliosError>(generated_text)
795        })
796        .await
797        .map_err(|e| {
798            restore_output(stdout_backup, stderr_backup);
799            HeliosError::LLMError(format!("Task failed: {}", e))
800        })??;
801
802        // Restore output after inference completes
803        restore_output(stdout_backup, stderr_backup);
804
805        let response = LLMResponse {
806            id: format!("local-{}", chrono::Utc::now().timestamp()),
807            object: "chat.completion".to_string(),
808            created: chrono::Utc::now().timestamp() as u64,
809            model: "local-model".to_string(),
810            choices: vec![Choice {
811                index: 0,
812                message: ChatMessage {
813                    role: crate::chat::Role::Assistant,
814                    content: result,
815                    name: None,
816                    tool_calls: None,
817                    tool_call_id: None,
818                },
819                finish_reason: Some("stop".to_string()),
820            }],
821            usage: Usage {
822                prompt_tokens: 0,     // TODO: Calculate actual token count
823                completion_tokens: 0, // TODO: Calculate actual token count
824                total_tokens: 0,      // TODO: Calculate actual token count
825            },
826        };
827
828        Ok(response)
829    }
830}
831
832#[cfg(feature = "local")]
833impl LocalLLMProvider {
834    /// Sends a streaming chat request to the local LLM.
835    async fn chat_stream_local<F>(
836        &self,
837        messages: Vec<ChatMessage>,
838        _temperature: Option<f32>,
839        _max_tokens: Option<u32>,
840        _stop: Option<Vec<String>>,
841        mut on_chunk: F,
842    ) -> Result<ChatMessage>
843    where
844        F: FnMut(&str) + Send,
845    {
846        let prompt = self.format_messages(&messages);
847
848        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
849        let stderr_backup = suppress_stderr();
850
851        // Create a channel for streaming tokens
852        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
853
854        // Spawn blocking task for generation
855        let model = Arc::clone(&self.model);
856        let backend = Arc::clone(&self.backend);
857        let generation_task = task::spawn_blocking(move || {
858            // Create a fresh context per request (model/back-end are reused across calls)
859            use std::num::NonZeroU32;
860            let ctx_params =
861                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
862
863            let mut context = model
864                .new_context(&backend, ctx_params)
865                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
866
867            // Tokenize the prompt
868            let tokens = context
869                .model
870                .str_to_token(&prompt, AddBos::Always)
871                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
872
873            // Create batch for prompt
874            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
875            for (i, &token) in tokens.iter().enumerate() {
876                let compute_logits = true;
877                prompt_batch
878                    .add(token, i as i32, &[0], compute_logits)
879                    .map_err(|e| {
880                        HeliosError::LLMError(format!(
881                            "Failed to add prompt token to batch: {:?}",
882                            e
883                        ))
884                    })?;
885            }
886
887            // Decode the prompt
888            context
889                .decode(&mut prompt_batch)
890                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
891
892            // Generate response tokens with streaming
893            let mut generated_text = String::new();
894            let max_new_tokens = 512;
895            let mut next_pos = tokens.len() as i32;
896
897            for _ in 0..max_new_tokens {
898                let logits = context.get_logits();
899
900                let token_idx = logits
901                    .iter()
902                    .enumerate()
903                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
904                    .map(|(idx, _)| idx)
905                    .unwrap_or_else(|| {
906                        let eos = context.model.token_eos();
907                        eos.0 as usize
908                    });
909                let token = LlamaToken(token_idx as i32);
910
911                // Check for end of sequence
912                if token == context.model.token_eos() {
913                    break;
914                }
915
916                // Convert token back to text
917                match context.model.token_to_str(token, Special::Plaintext) {
918                    Ok(text) => {
919                        generated_text.push_str(&text);
920                        // Send token through channel; stop if receiver is dropped
921                        if tx.send(text).is_err() {
922                            break;
923                        }
924                    }
925                    Err(_) => continue,
926                }
927
928                // Create a new batch with just this token
929                let mut gen_batch = LlamaBatch::new(1, 1);
930                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
931                    HeliosError::LLMError(format!(
932                        "Failed to add generated token to batch: {:?}",
933                        e
934                    ))
935                })?;
936
937                // Decode the new token
938                context.decode(&mut gen_batch).map_err(|e| {
939                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
940                })?;
941
942                next_pos += 1;
943            }
944
945            Ok::<String, HeliosError>(generated_text)
946        });
947
948        // Receive and process tokens as they arrive
949        while let Some(token) = rx.recv().await {
950            on_chunk(&token);
951        }
952
953        // Wait for generation to complete and get the result
954        let result = match generation_task.await {
955            Ok(Ok(text)) => text,
956            Ok(Err(e)) => {
957                restore_stderr(stderr_backup);
958                return Err(e);
959            }
960            Err(e) => {
961                restore_stderr(stderr_backup);
962                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
963            }
964        };
965
966        // Restore stderr after generation completes
967        restore_stderr(stderr_backup);
968
969        Ok(ChatMessage {
970            role: crate::chat::Role::Assistant,
971            content: result,
972            name: None,
973            tool_calls: None,
974            tool_call_id: None,
975        })
976    }
977}
978
979#[async_trait]
980impl LLMProvider for LLMClient {
981    fn as_any(&self) -> &dyn std::any::Any {
982        self
983    }
984
985    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
986        self.provider.generate(request).await
987    }
988}
989
990impl LLMClient {
991    /// Sends a chat request to the LLM.
992    pub async fn chat(
993        &self,
994        messages: Vec<ChatMessage>,
995        tools: Option<Vec<ToolDefinition>>,
996        temperature: Option<f32>,
997        max_tokens: Option<u32>,
998        stop: Option<Vec<String>>,
999    ) -> Result<ChatMessage> {
1000        let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
1001            LLMProviderType::Remote(config) => (
1002                config.model_name.clone(),
1003                config.temperature,
1004                config.max_tokens,
1005            ),
1006            #[cfg(feature = "local")]
1007            LLMProviderType::Local(config) => (
1008                "local-model".to_string(),
1009                config.temperature,
1010                config.max_tokens,
1011            ),
1012        };
1013
1014        let request = LLMRequest {
1015            model: model_name,
1016            messages,
1017            temperature: temperature.or(Some(default_temperature)),
1018            max_tokens: max_tokens.or(Some(default_max_tokens)),
1019            tools: tools.clone(),
1020            tool_choice: if tools.is_some() {
1021                Some("auto".to_string())
1022            } else {
1023                None
1024            },
1025            stream: None,
1026            stop,
1027        };
1028
1029        let response = self.generate(request).await?;
1030
1031        response
1032            .choices
1033            .into_iter()
1034            .next()
1035            .map(|choice| choice.message)
1036            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1037    }
1038
1039    /// Sends a streaming chat request to the LLM.
1040    pub async fn chat_stream<F>(
1041        &self,
1042        messages: Vec<ChatMessage>,
1043        tools: Option<Vec<ToolDefinition>>,
1044        temperature: Option<f32>,
1045        max_tokens: Option<u32>,
1046        stop: Option<Vec<String>>,
1047        on_chunk: F,
1048    ) -> Result<ChatMessage>
1049    where
1050        F: FnMut(&str) + Send,
1051    {
1052        match &self.provider_type {
1053            LLMProviderType::Remote(_) => {
1054                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
1055                    provider
1056                        .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
1057                        .await
1058                } else {
1059                    Err(HeliosError::AgentError("Provider type mismatch".into()))
1060                }
1061            }
1062            #[cfg(feature = "local")]
1063            LLMProviderType::Local(_) => {
1064                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
1065                    provider
1066                        .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
1067                        .await
1068                } else {
1069                    Err(HeliosError::AgentError("Provider type mismatch".into()))
1070                }
1071            }
1072        }
1073    }
1074}
1075
1076// Test module added