helios_engine/
llm.rs

1//! # LLM Module
2//!
3//! This module provides the functionality for interacting with Large Language Models (LLMs).
4//! It supports both remote LLMs (like OpenAI) and local LLMs (via `llama.cpp`).
5//! The `LLMClient` provides a unified interface for both types of providers.
6
7use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18    crate::config::LocalConfig,
19    llama_cpp_2::{
20        context::params::LlamaContextParams,
21        llama_backend::LlamaBackend,
22        llama_batch::LlamaBatch,
23        model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24        token::LlamaToken,
25    },
26    std::{fs::File, os::fd::AsRawFd, sync::Arc},
27    tokio::task,
28};
29
30// Add From trait for LLamaCppError to convert to HeliosError
31#[cfg(feature = "local")]
32impl From<llama_cpp_2::LLamaCppError> for HeliosError {
33    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
34        HeliosError::LlamaCppError(format!("{:?}", err))
35    }
36}
37
38/// The type of LLM provider to use.
39#[derive(Clone)]
40pub enum LLMProviderType {
41    /// A remote LLM provider, such as OpenAI.
42    Remote(LLMConfig),
43    /// A local LLM provider, using `llama.cpp`.
44    #[cfg(feature = "local")]
45    Local(LocalConfig),
46}
47
48/// A request to an LLM.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LLMRequest {
51    /// The model to use for the request.
52    pub model: String,
53    /// The messages to send to the model.
54    pub messages: Vec<ChatMessage>,
55    /// The temperature to use for the request.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub temperature: Option<f32>,
58    /// The maximum number of tokens to generate.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub max_tokens: Option<u32>,
61    /// The tools to make available to the model.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub tools: Option<Vec<ToolDefinition>>,
64    /// The tool choice to use for the request.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub tool_choice: Option<String>,
67    /// Whether to stream the response.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub stream: Option<bool>,
70    /// Stop sequences for the request.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub stop: Option<Vec<String>>,
73}
74
75/// A chunk of a streamed response.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct StreamChunk {
78    /// The ID of the chunk.
79    pub id: String,
80    /// The object type.
81    pub object: String,
82    /// The creation timestamp.
83    pub created: u64,
84    /// The model that generated the chunk.
85    pub model: String,
86    /// The choices in the chunk.
87    pub choices: Vec<StreamChoice>,
88}
89
90/// A choice in a streamed response.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct StreamChoice {
93    /// The index of the choice.
94    pub index: u32,
95    /// The delta of the choice.
96    pub delta: Delta,
97    /// The reason the stream finished.
98    pub finish_reason: Option<String>,
99}
100
101/// A tool call in a streamed delta.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct DeltaToolCall {
104    /// The index of the tool call.
105    pub index: u32,
106    /// The ID of the tool call.
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub id: Option<String>,
109    /// The function call information.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub function: Option<DeltaFunctionCall>,
112}
113
114/// A function call in a streamed delta.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct DeltaFunctionCall {
117    /// The name of the function.
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub name: Option<String>,
120    /// The arguments for the function.
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub arguments: Option<String>,
123}
124
125/// The delta of a streamed choice.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct Delta {
128    /// The role of the delta.
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub role: Option<String>,
131    /// The content of the delta.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub content: Option<String>,
134    /// The tool calls in the delta.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub tool_calls: Option<Vec<DeltaToolCall>>,
137}
138
139/// A response from an LLM.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct LLMResponse {
142    /// The ID of the response.
143    pub id: String,
144    /// The object type.
145    pub object: String,
146    /// The creation timestamp.
147    pub created: u64,
148    /// The model that generated the response.
149    pub model: String,
150    /// The choices in the response.
151    pub choices: Vec<Choice>,
152    /// The usage statistics for the response.
153    pub usage: Usage,
154}
155
156/// A choice in an LLM response.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct Choice {
159    /// The index of the choice.
160    pub index: u32,
161    /// The message of the choice.
162    pub message: ChatMessage,
163    /// The reason the generation finished.
164    pub finish_reason: Option<String>,
165}
166
167/// The usage statistics for an LLM response.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct Usage {
170    /// The number of tokens in the prompt.
171    pub prompt_tokens: u32,
172    /// The number of tokens in the completion.
173    pub completion_tokens: u32,
174    /// The total number of tokens.
175    pub total_tokens: u32,
176}
177
178/// A trait for LLM providers.
179#[async_trait]
180pub trait LLMProvider: Send + Sync {
181    /// Generates a response from the LLM.
182    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
183    /// Returns the provider as an `Any` type.
184    fn as_any(&self) -> &dyn std::any::Any;
185}
186
187/// A client for interacting with an LLM.
188pub struct LLMClient {
189    provider: Box<dyn LLMProvider + Send + Sync>,
190    provider_type: LLMProviderType,
191}
192
193impl LLMClient {
194    /// Creates a new `LLMClient`.
195    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
196        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
197            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
198            #[cfg(feature = "local")]
199            LLMProviderType::Local(config) => {
200                Box::new(LocalLLMProvider::new(config.clone()).await?)
201            }
202        };
203
204        Ok(Self {
205            provider,
206            provider_type,
207        })
208    }
209
210    /// Returns the type of the LLM provider.
211    pub fn provider_type(&self) -> &LLMProviderType {
212        &self.provider_type
213    }
214}
215
216/// A client for interacting with a remote LLM.
217pub struct RemoteLLMClient {
218    config: LLMConfig,
219    client: Client,
220}
221
222impl RemoteLLMClient {
223    /// Creates a new `RemoteLLMClient`.
224    pub fn new(config: LLMConfig) -> Self {
225        Self {
226            config,
227            client: Client::new(),
228        }
229    }
230
231    /// Returns the configuration of the client.
232    pub fn config(&self) -> &LLMConfig {
233        &self.config
234    }
235}
236
237/// Suppresses stdout and stderr.
238#[cfg(feature = "local")]
239fn suppress_output() -> (i32, i32) {
240    // Open /dev/null for writing
241    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
242
243    // Duplicate current stdout and stderr file descriptors
244    let stdout_backup = unsafe { libc::dup(1) };
245    let stderr_backup = unsafe { libc::dup(2) };
246
247    // Redirect stdout and stderr to /dev/null
248    unsafe {
249        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
250        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
251    }
252
253    (stdout_backup, stderr_backup)
254}
255
256/// Restores stdout and stderr.
257#[cfg(feature = "local")]
258fn restore_output(stdout_backup: i32, stderr_backup: i32) {
259    unsafe {
260        libc::dup2(stdout_backup, 1); // restore stdout
261        libc::dup2(stderr_backup, 2); // restore stderr
262        libc::close(stdout_backup);
263        libc::close(stderr_backup);
264    }
265}
266
267/// Suppresses stderr.
268#[cfg(feature = "local")]
269fn suppress_stderr() -> i32 {
270    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
271    let stderr_backup = unsafe { libc::dup(2) };
272    unsafe {
273        libc::dup2(dev_null.as_raw_fd(), 2);
274    }
275    stderr_backup
276}
277
278/// Restores stderr.
279#[cfg(feature = "local")]
280fn restore_stderr(stderr_backup: i32) {
281    unsafe {
282        libc::dup2(stderr_backup, 2);
283        libc::close(stderr_backup);
284    }
285}
286
287/// A provider for a local LLM.
288#[cfg(feature = "local")]
289pub struct LocalLLMProvider {
290    model: Arc<LlamaModel>,
291    backend: Arc<LlamaBackend>,
292}
293
294#[cfg(feature = "local")]
295impl LocalLLMProvider {
296    /// Creates a new `LocalLLMProvider`.
297    pub async fn new(config: LocalConfig) -> Result<Self> {
298        // Suppress verbose output during model loading in offline mode
299        let (stdout_backup, stderr_backup) = suppress_output();
300
301        // Initialize llama backend
302        let backend = LlamaBackend::init().map_err(|e| {
303            restore_output(stdout_backup, stderr_backup);
304            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
305        })?;
306
307        // Download model from HuggingFace if needed
308        let model_path = Self::download_model(&config).await.map_err(|e| {
309            restore_output(stdout_backup, stderr_backup);
310            e
311        })?;
312
313        // Load the model
314        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
315
316        let model =
317            LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
318                restore_output(stdout_backup, stderr_backup);
319                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
320            })?;
321
322        // Restore output
323        restore_output(stdout_backup, stderr_backup);
324
325        Ok(Self {
326            model: Arc::new(model),
327            backend: Arc::new(backend),
328        })
329    }
330
331    /// Downloads a model from Hugging Face.
332    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
333        use std::process::Command;
334
335        // Check if model is already in HuggingFace cache
336        if let Some(cached_path) =
337            Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
338        {
339            // Model found in cache - no output needed in offline mode
340            return Ok(cached_path);
341        }
342
343        // Model not found in cache - suppress download output in offline mode
344
345        // Use huggingface_hub to download the model (suppress output)
346        let output = Command::new("huggingface-cli")
347            .args([
348                "download",
349                &config.huggingface_repo,
350                &config.model_file,
351                "--local-dir",
352                ".cache/models",
353                "--local-dir-use-symlinks",
354                "False",
355            ])
356            .stdout(std::process::Stdio::null()) // Suppress stdout
357            .stderr(std::process::Stdio::null()) // Suppress stderr
358            .output()
359            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
360
361        if !output.status.success() {
362            return Err(HeliosError::LLMError(format!(
363                "Failed to download model: {}",
364                String::from_utf8_lossy(&output.stderr)
365            )));
366        }
367
368        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
369        if !model_path.exists() {
370            return Err(HeliosError::LLMError(format!(
371                "Model file not found after download: {}",
372                model_path.display()
373            )));
374        }
375
376        Ok(model_path)
377    }
378
379    /// Finds a model in the Hugging Face cache.
380    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
381        // Check HuggingFace cache directory
382        let cache_dir = std::env::var("HF_HOME")
383            .map(std::path::PathBuf::from)
384            .unwrap_or_else(|_| {
385                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
386                std::path::PathBuf::from(home)
387                    .join(".cache")
388                    .join("huggingface")
389            });
390
391        let hub_dir = cache_dir.join("hub");
392
393        // Convert repo name to HuggingFace cache format
394        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
395        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
396        let repo_dir = hub_dir.join(&cache_repo_name);
397
398        if !repo_dir.exists() {
399            return None;
400        }
401
402        // Check in snapshots directory (newer cache format)
403        let snapshots_dir = repo_dir.join("snapshots");
404        if snapshots_dir.exists() {
405            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
406                for entry in entries.flatten() {
407                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
408                        if snapshot_path.exists() {
409                            return Some(snapshot_path);
410                        }
411                    }
412                }
413            }
414        }
415
416        // Check in blobs directory (alternative cache format)
417        let blobs_dir = repo_dir.join("blobs");
418        if blobs_dir.exists() {
419            // For blobs, we need to find the blob file by hash
420            // This is more complex, so for now we'll skip this check
421            // The snapshots approach should cover most cases
422        }
423
424        None
425    }
426
427    /// Formats a list of messages into a single string.
428    fn format_messages(&self, messages: &[ChatMessage]) -> String {
429        let mut formatted = String::new();
430
431        // Use Qwen chat template format
432        for message in messages {
433            match message.role {
434                crate::chat::Role::System => {
435                    formatted.push_str("<|im_start|>system\n");
436                    formatted.push_str(&message.content);
437                    formatted.push_str("\n<|im_end|>\n");
438                }
439                crate::chat::Role::User => {
440                    formatted.push_str("<|im_start|>user\n");
441                    formatted.push_str(&message.content);
442                    formatted.push_str("\n<|im_end|>\n");
443                }
444                crate::chat::Role::Assistant => {
445                    formatted.push_str("<|im_start|>assistant\n");
446                    formatted.push_str(&message.content);
447                    formatted.push_str("\n<|im_end|>\n");
448                }
449                crate::chat::Role::Tool => {
450                    // For tool messages, include them as assistant responses
451                    formatted.push_str("<|im_start|>assistant\n");
452                    formatted.push_str(&message.content);
453                    formatted.push_str("\n<|im_end|>\n");
454                }
455            }
456        }
457
458        // Start the assistant's response
459        formatted.push_str("<|im_start|>assistant\n");
460
461        formatted
462    }
463}
464
465#[async_trait]
466impl LLMProvider for RemoteLLMClient {
467    fn as_any(&self) -> &dyn std::any::Any {
468        self
469    }
470
471    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
472        let url = format!("{}/chat/completions", self.config.base_url);
473
474        let response = self
475            .client
476            .post(&url)
477            .header("Authorization", format!("Bearer {}", self.config.api_key))
478            .header("Content-Type", "application/json")
479            .json(&request)
480            .send()
481            .await?;
482
483        if !response.status().is_success() {
484            let status = response.status();
485            let error_text = response
486                .text()
487                .await
488                .unwrap_or_else(|_| "Unknown error".to_string());
489            return Err(HeliosError::LLMError(format!(
490                "LLM API request failed with status {}: {}",
491                status, error_text
492            )));
493        }
494
495        let llm_response: LLMResponse = response.json().await?;
496        Ok(llm_response)
497    }
498}
499
500impl RemoteLLMClient {
501    /// Sends a chat request to the remote LLM.
502    pub async fn chat(
503        &self,
504        messages: Vec<ChatMessage>,
505        tools: Option<Vec<ToolDefinition>>,
506        temperature: Option<f32>,
507        max_tokens: Option<u32>,
508        stop: Option<Vec<String>>,
509    ) -> Result<ChatMessage> {
510        let request = LLMRequest {
511            model: self.config.model_name.clone(),
512            messages,
513            temperature: temperature.or(Some(self.config.temperature)),
514            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
515            tools,
516            tool_choice: None,
517            stream: None,
518            stop,
519        };
520
521        let response = self.generate(request).await?;
522
523        response
524            .choices
525            .into_iter()
526            .next()
527            .map(|choice| choice.message)
528            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
529    }
530
531    /// Sends a streaming chat request to the remote LLM.
532    pub async fn chat_stream<F>(
533        &self,
534        messages: Vec<ChatMessage>,
535        tools: Option<Vec<ToolDefinition>>,
536        temperature: Option<f32>,
537        max_tokens: Option<u32>,
538        stop: Option<Vec<String>>,
539        mut on_chunk: F,
540    ) -> Result<ChatMessage>
541    where
542        F: FnMut(&str) + Send,
543    {
544        let request = LLMRequest {
545            model: self.config.model_name.clone(),
546            messages,
547            temperature: temperature.or(Some(self.config.temperature)),
548            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
549            tools,
550            tool_choice: None,
551            stream: Some(true),
552            stop,
553        };
554
555        let url = format!("{}/chat/completions", self.config.base_url);
556
557        let response = self
558            .client
559            .post(&url)
560            .header("Authorization", format!("Bearer {}", self.config.api_key))
561            .header("Content-Type", "application/json")
562            .json(&request)
563            .send()
564            .await?;
565
566        if !response.status().is_success() {
567            let status = response.status();
568            let error_text = response
569                .text()
570                .await
571                .unwrap_or_else(|_| "Unknown error".to_string());
572            return Err(HeliosError::LLMError(format!(
573                "LLM API request failed with status {}: {}",
574                status, error_text
575            )));
576        }
577
578        let mut stream = response.bytes_stream();
579        let mut full_content = String::new();
580        let mut role = None;
581        let mut tool_calls = Vec::new();
582        let mut buffer = String::new();
583
584        while let Some(chunk_result) = stream.next().await {
585            let chunk = chunk_result?;
586            let chunk_str = String::from_utf8_lossy(&chunk);
587            buffer.push_str(&chunk_str);
588
589            // Process complete lines
590            while let Some(line_end) = buffer.find('\n') {
591                let line = buffer[..line_end].trim().to_string();
592                buffer = buffer[line_end + 1..].to_string();
593
594                if line.is_empty() || line == "data: [DONE]" {
595                    continue;
596                }
597
598                if let Some(data) = line.strip_prefix("data: ") {
599                    match serde_json::from_str::<StreamChunk>(data) {
600                        Ok(stream_chunk) => {
601                            if let Some(choice) = stream_chunk.choices.first() {
602                                if let Some(r) = &choice.delta.role {
603                                    role = Some(r.clone());
604                                }
605                                if let Some(content) = &choice.delta.content {
606                                    full_content.push_str(content);
607                                    on_chunk(content);
608                                }
609                                if let Some(delta_tool_calls) = &choice.delta.tool_calls {
610                                    for delta_tool_call in delta_tool_calls {
611                                        // Find or create the tool call at this index
612                                        while tool_calls.len() <= delta_tool_call.index as usize {
613                                            tool_calls.push(None);
614                                        }
615                                        let tool_call_slot =
616                                            &mut tool_calls[delta_tool_call.index as usize];
617
618                                        if tool_call_slot.is_none() {
619                                            *tool_call_slot = Some(crate::chat::ToolCall {
620                                                id: String::new(),
621                                                call_type: "function".to_string(),
622                                                function: crate::chat::FunctionCall {
623                                                    name: String::new(),
624                                                    arguments: String::new(),
625                                                },
626                                            });
627                                        }
628
629                                        if let Some(tool_call) = tool_call_slot.as_mut() {
630                                            if let Some(id) = &delta_tool_call.id {
631                                                tool_call.id = id.clone();
632                                            }
633                                            if let Some(function) = &delta_tool_call.function {
634                                                if let Some(name) = &function.name {
635                                                    tool_call.function.name = name.clone();
636                                                }
637                                                if let Some(args) = &function.arguments {
638                                                    tool_call.function.arguments = args.clone();
639                                                }
640                                            }
641                                        }
642                                    }
643                                }
644                            }
645                        }
646                        Err(e) => {
647                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
648                        }
649                    }
650                }
651            }
652        }
653
654        let final_tool_calls = tool_calls.into_iter().flatten().collect::<Vec<_>>();
655        let tool_calls_option = if final_tool_calls.is_empty() {
656            None
657        } else {
658            Some(final_tool_calls)
659        };
660
661        Ok(ChatMessage {
662            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
663            content: full_content,
664            name: None,
665            tool_calls: tool_calls_option,
666            tool_call_id: None,
667        })
668    }
669}
670
671#[cfg(feature = "local")]
672#[async_trait]
673impl LLMProvider for LocalLLMProvider {
674    fn as_any(&self) -> &dyn std::any::Any {
675        self
676    }
677
678    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
679        let prompt = self.format_messages(&request.messages);
680
681        // Suppress output during inference in offline mode
682        let (stdout_backup, stderr_backup) = suppress_output();
683
684        // Run inference in a blocking task
685        let model = Arc::clone(&self.model);
686        let backend = Arc::clone(&self.backend);
687        let result = task::spawn_blocking(move || {
688            // Create a fresh context per request (model/back-end are reused across calls)
689            use std::num::NonZeroU32;
690            let ctx_params =
691                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
692
693            let mut context = model
694                .new_context(&backend, ctx_params)
695                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
696
697            // Tokenize the prompt
698            let tokens = context
699                .model
700                .str_to_token(&prompt, AddBos::Always)
701                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
702
703            // Create batch for prompt
704            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
705            for (i, &token) in tokens.iter().enumerate() {
706                let compute_logits = true; // Compute logits for all tokens (they accumulate)
707                prompt_batch
708                    .add(token, i as i32, &[0], compute_logits)
709                    .map_err(|e| {
710                        HeliosError::LLMError(format!(
711                            "Failed to add prompt token to batch: {:?}",
712                            e
713                        ))
714                    })?;
715            }
716
717            // Decode the prompt
718            context
719                .decode(&mut prompt_batch)
720                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
721
722            // Generate response tokens
723            let mut generated_text = String::new();
724            let max_new_tokens = 512; // Increased limit for better responses
725            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
726
727            for _ in 0..max_new_tokens {
728                // Get logits from the last decoded position (get_logits returns logits for the last token)
729                let logits = context.get_logits();
730
731                let token_idx = logits
732                    .iter()
733                    .enumerate()
734                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
735                    .map(|(idx, _)| idx)
736                    .unwrap_or_else(|| {
737                        let eos = context.model.token_eos();
738                        eos.0 as usize
739                    });
740                let token = LlamaToken(token_idx as i32);
741
742                // Check for end of sequence
743                if token == context.model.token_eos() {
744                    break;
745                }
746
747                // Convert token back to text
748                match context.model.token_to_str(token, Special::Plaintext) {
749                    Ok(text) => {
750                        generated_text.push_str(&text);
751                    }
752                    Err(_) => continue, // Skip invalid tokens
753                }
754
755                // Create a new batch with just this token
756                let mut gen_batch = LlamaBatch::new(1, 1);
757                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
758                    HeliosError::LLMError(format!(
759                        "Failed to add generated token to batch: {:?}",
760                        e
761                    ))
762                })?;
763
764                // Decode the new token
765                context.decode(&mut gen_batch).map_err(|e| {
766                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
767                })?;
768
769                next_pos += 1;
770            }
771
772            Ok::<String, HeliosError>(generated_text)
773        })
774        .await
775        .map_err(|e| {
776            restore_output(stdout_backup, stderr_backup);
777            HeliosError::LLMError(format!("Task failed: {}", e))
778        })??;
779
780        // Restore output after inference completes
781        restore_output(stdout_backup, stderr_backup);
782
783        let response = LLMResponse {
784            id: format!("local-{}", chrono::Utc::now().timestamp()),
785            object: "chat.completion".to_string(),
786            created: chrono::Utc::now().timestamp() as u64,
787            model: "local-model".to_string(),
788            choices: vec![Choice {
789                index: 0,
790                message: ChatMessage {
791                    role: crate::chat::Role::Assistant,
792                    content: result,
793                    name: None,
794                    tool_calls: None,
795                    tool_call_id: None,
796                },
797                finish_reason: Some("stop".to_string()),
798            }],
799            usage: Usage {
800                prompt_tokens: 0,     // TODO: Calculate actual token count
801                completion_tokens: 0, // TODO: Calculate actual token count
802                total_tokens: 0,      // TODO: Calculate actual token count
803            },
804        };
805
806        Ok(response)
807    }
808}
809
810#[cfg(feature = "local")]
811impl LocalLLMProvider {
812    /// Sends a streaming chat request to the local LLM.
813    async fn chat_stream_local<F>(
814        &self,
815        messages: Vec<ChatMessage>,
816        _temperature: Option<f32>,
817        _max_tokens: Option<u32>,
818        _stop: Option<Vec<String>>,
819        mut on_chunk: F,
820    ) -> Result<ChatMessage>
821    where
822        F: FnMut(&str) + Send,
823    {
824        let prompt = self.format_messages(&messages);
825
826        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
827        let stderr_backup = suppress_stderr();
828
829        // Create a channel for streaming tokens
830        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
831
832        // Spawn blocking task for generation
833        let model = Arc::clone(&self.model);
834        let backend = Arc::clone(&self.backend);
835        let generation_task = task::spawn_blocking(move || {
836            // Create a fresh context per request (model/back-end are reused across calls)
837            use std::num::NonZeroU32;
838            let ctx_params =
839                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
840
841            let mut context = model
842                .new_context(&backend, ctx_params)
843                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
844
845            // Tokenize the prompt
846            let tokens = context
847                .model
848                .str_to_token(&prompt, AddBos::Always)
849                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
850
851            // Create batch for prompt
852            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
853            for (i, &token) in tokens.iter().enumerate() {
854                let compute_logits = true;
855                prompt_batch
856                    .add(token, i as i32, &[0], compute_logits)
857                    .map_err(|e| {
858                        HeliosError::LLMError(format!(
859                            "Failed to add prompt token to batch: {:?}",
860                            e
861                        ))
862                    })?;
863            }
864
865            // Decode the prompt
866            context
867                .decode(&mut prompt_batch)
868                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
869
870            // Generate response tokens with streaming
871            let mut generated_text = String::new();
872            let max_new_tokens = 512;
873            let mut next_pos = tokens.len() as i32;
874
875            for _ in 0..max_new_tokens {
876                let logits = context.get_logits();
877
878                let token_idx = logits
879                    .iter()
880                    .enumerate()
881                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
882                    .map(|(idx, _)| idx)
883                    .unwrap_or_else(|| {
884                        let eos = context.model.token_eos();
885                        eos.0 as usize
886                    });
887                let token = LlamaToken(token_idx as i32);
888
889                // Check for end of sequence
890                if token == context.model.token_eos() {
891                    break;
892                }
893
894                // Convert token back to text
895                match context.model.token_to_str(token, Special::Plaintext) {
896                    Ok(text) => {
897                        generated_text.push_str(&text);
898                        // Send token through channel; stop if receiver is dropped
899                        if tx.send(text).is_err() {
900                            break;
901                        }
902                    }
903                    Err(_) => continue,
904                }
905
906                // Create a new batch with just this token
907                let mut gen_batch = LlamaBatch::new(1, 1);
908                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
909                    HeliosError::LLMError(format!(
910                        "Failed to add generated token to batch: {:?}",
911                        e
912                    ))
913                })?;
914
915                // Decode the new token
916                context.decode(&mut gen_batch).map_err(|e| {
917                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
918                })?;
919
920                next_pos += 1;
921            }
922
923            Ok::<String, HeliosError>(generated_text)
924        });
925
926        // Receive and process tokens as they arrive
927        while let Some(token) = rx.recv().await {
928            on_chunk(&token);
929        }
930
931        // Wait for generation to complete and get the result
932        let result = match generation_task.await {
933            Ok(Ok(text)) => text,
934            Ok(Err(e)) => {
935                restore_stderr(stderr_backup);
936                return Err(e);
937            }
938            Err(e) => {
939                restore_stderr(stderr_backup);
940                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
941            }
942        };
943
944        // Restore stderr after generation completes
945        restore_stderr(stderr_backup);
946
947        Ok(ChatMessage {
948            role: crate::chat::Role::Assistant,
949            content: result,
950            name: None,
951            tool_calls: None,
952            tool_call_id: None,
953        })
954    }
955}
956
957#[async_trait]
958impl LLMProvider for LLMClient {
959    fn as_any(&self) -> &dyn std::any::Any {
960        self
961    }
962
963    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
964        self.provider.generate(request).await
965    }
966}
967
968impl LLMClient {
969    /// Sends a chat request to the LLM.
970    pub async fn chat(
971        &self,
972        messages: Vec<ChatMessage>,
973        tools: Option<Vec<ToolDefinition>>,
974        temperature: Option<f32>,
975        max_tokens: Option<u32>,
976        stop: Option<Vec<String>>,
977    ) -> Result<ChatMessage> {
978        let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
979            LLMProviderType::Remote(config) => (
980                config.model_name.clone(),
981                config.temperature,
982                config.max_tokens,
983            ),
984            #[cfg(feature = "local")]
985            LLMProviderType::Local(config) => (
986                "local-model".to_string(),
987                config.temperature,
988                config.max_tokens,
989            ),
990        };
991
992        let request = LLMRequest {
993            model: model_name,
994            messages,
995            temperature: temperature.or(Some(default_temperature)),
996            max_tokens: max_tokens.or(Some(default_max_tokens)),
997            tools,
998            tool_choice: None,
999            stream: None,
1000            stop,
1001        };
1002
1003        let response = self.generate(request).await?;
1004
1005        response
1006            .choices
1007            .into_iter()
1008            .next()
1009            .map(|choice| choice.message)
1010            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1011    }
1012
1013    /// Sends a streaming chat request to the LLM.
1014    pub async fn chat_stream<F>(
1015        &self,
1016        messages: Vec<ChatMessage>,
1017        tools: Option<Vec<ToolDefinition>>,
1018        temperature: Option<f32>,
1019        max_tokens: Option<u32>,
1020        stop: Option<Vec<String>>,
1021        on_chunk: F,
1022    ) -> Result<ChatMessage>
1023    where
1024        F: FnMut(&str) + Send,
1025    {
1026        match &self.provider_type {
1027            LLMProviderType::Remote(_) => {
1028                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
1029                    provider
1030                        .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
1031                        .await
1032                } else {
1033                    Err(HeliosError::AgentError("Provider type mismatch".into()))
1034                }
1035            }
1036            #[cfg(feature = "local")]
1037            LLMProviderType::Local(_) => {
1038                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
1039                    provider
1040                        .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
1041                        .await
1042                } else {
1043                    Err(HeliosError::AgentError("Provider type mismatch".into()))
1044                }
1045            }
1046        }
1047    }
1048}
1049
1050// Test module added