helios_engine/
llm.rs

1//! # LLM Module
2//!
3//! This module provides the functionality for interacting with Large Language Models (LLMs).
4//! It supports both remote LLMs (like OpenAI) and local LLMs (via `llama.cpp`).
5//! The `LLMClient` provides a unified interface for both types of providers.
6
7use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18    crate::config::LocalConfig,
19    llama_cpp_2::{
20        context::params::LlamaContextParams,
21        llama_backend::LlamaBackend,
22        llama_batch::LlamaBatch,
23        model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24        token::LlamaToken,
25    },
26    std::{fs::File, os::fd::AsRawFd, sync::Arc},
27    tokio::task,
28};
29
30#[cfg(feature = "candle")]
31use crate::candle_provider::CandleLLMProvider;
32
33// Add From trait for LLamaCppError to convert to HeliosError
34#[cfg(feature = "local")]
35impl From<llama_cpp_2::LLamaCppError> for HeliosError {
36    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
37        HeliosError::LlamaCppError(format!("{:?}", err))
38    }
39}
40
41/// The type of LLM provider to use.
42#[derive(Clone)]
43pub enum LLMProviderType {
44    /// A remote LLM provider, such as OpenAI.
45    Remote(LLMConfig),
46    /// A local LLM provider, using `llama.cpp`.
47    #[cfg(feature = "local")]
48    Local(LocalConfig),
49    /// A local LLM provider, using Candle.
50    #[cfg(feature = "candle")]
51    Candle(crate::config::CandleConfig),
52}
53
54/// A request to an LLM.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct LLMRequest {
57    /// The model to use for the request.
58    pub model: String,
59    /// The messages to send to the model.
60    pub messages: Vec<ChatMessage>,
61    /// The temperature to use for the request.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub temperature: Option<f32>,
64    /// The maximum number of tokens to generate.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub max_tokens: Option<u32>,
67    /// The tools to make available to the model.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub tools: Option<Vec<ToolDefinition>>,
70    /// The tool choice to use for the request.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub tool_choice: Option<String>,
73    /// Whether to stream the response.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub stream: Option<bool>,
76    /// Stop sequences for the request.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub stop: Option<Vec<String>>,
79}
80
81/// A chunk of a streamed response.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct StreamChunk {
84    /// The ID of the chunk.
85    pub id: String,
86    /// The object type.
87    pub object: String,
88    /// The creation timestamp.
89    pub created: u64,
90    /// The model that generated the chunk.
91    pub model: String,
92    /// The choices in the chunk.
93    pub choices: Vec<StreamChoice>,
94}
95
96/// A choice in a streamed response.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct StreamChoice {
99    /// The index of the choice.
100    pub index: u32,
101    /// The delta of the choice.
102    pub delta: Delta,
103    /// The reason the stream finished.
104    pub finish_reason: Option<String>,
105}
106
107/// A tool call in a streamed delta.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct DeltaToolCall {
110    /// The index of the tool call.
111    pub index: u32,
112    /// The ID of the tool call.
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub id: Option<String>,
115    /// The function call information.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub function: Option<DeltaFunctionCall>,
118}
119
120/// A function call in a streamed delta.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct DeltaFunctionCall {
123    /// The name of the function.
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub name: Option<String>,
126    /// The arguments for the function.
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub arguments: Option<String>,
129}
130
131/// The delta of a streamed choice.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Delta {
134    /// The role of the delta.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub role: Option<String>,
137    /// The content of the delta.
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub content: Option<String>,
140    /// The tool calls in the delta.
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub tool_calls: Option<Vec<DeltaToolCall>>,
143}
144
145/// A response from an LLM.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct LLMResponse {
148    /// The ID of the response.
149    pub id: String,
150    /// The object type.
151    pub object: String,
152    /// The creation timestamp.
153    pub created: u64,
154    /// The model that generated the response.
155    pub model: String,
156    /// The choices in the response.
157    pub choices: Vec<Choice>,
158    /// The usage statistics for the response.
159    pub usage: Usage,
160}
161
162/// A choice in an LLM response.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct Choice {
165    /// The index of the choice.
166    pub index: u32,
167    /// The message of the choice.
168    pub message: ChatMessage,
169    /// The reason the generation finished.
170    pub finish_reason: Option<String>,
171}
172
173/// The usage statistics for an LLM response.
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct Usage {
176    /// The number of tokens in the prompt.
177    pub prompt_tokens: u32,
178    /// The number of tokens in the completion.
179    pub completion_tokens: u32,
180    /// The total number of tokens.
181    pub total_tokens: u32,
182}
183
184/// A trait for LLM providers.
185#[async_trait]
186pub trait LLMProvider: Send + Sync {
187    /// Generates a response from the LLM.
188    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
189    /// Returns the provider as an `Any` type.
190    fn as_any(&self) -> &dyn std::any::Any;
191}
192
193/// A client for interacting with an LLM.
194pub struct LLMClient {
195    provider: Box<dyn LLMProvider + Send + Sync>,
196    provider_type: LLMProviderType,
197}
198
199impl LLMClient {
200    /// Creates a new `LLMClient`.
201    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
202        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
203            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
204            #[cfg(feature = "local")]
205            LLMProviderType::Local(config) => {
206                Box::new(LocalLLMProvider::new(config.clone()).await?)
207            }
208            #[cfg(feature = "candle")]
209            LLMProviderType::Candle(config) => {
210                Box::new(CandleLLMProvider::new(config.clone()).await?)
211            }
212        };
213
214        Ok(Self {
215            provider,
216            provider_type,
217        })
218    }
219
220    /// Returns the type of the LLM provider.
221    pub fn provider_type(&self) -> &LLMProviderType {
222        &self.provider_type
223    }
224}
225
226/// A client for interacting with a remote LLM.
227pub struct RemoteLLMClient {
228    config: LLMConfig,
229    client: Client,
230}
231
232impl RemoteLLMClient {
233    /// Creates a new `RemoteLLMClient`.
234    pub fn new(config: LLMConfig) -> Self {
235        Self {
236            config,
237            client: Client::new(),
238        }
239    }
240
241    /// Returns the configuration of the client.
242    pub fn config(&self) -> &LLMConfig {
243        &self.config
244    }
245}
246
247/// Suppresses stdout and stderr.
248#[cfg(feature = "local")]
249fn suppress_output() -> (i32, i32) {
250    // Open /dev/null for writing
251    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
252
253    // Duplicate current stdout and stderr file descriptors
254    let stdout_backup = unsafe { libc::dup(1) };
255    let stderr_backup = unsafe { libc::dup(2) };
256
257    // Redirect stdout and stderr to /dev/null
258    unsafe {
259        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
260        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
261    }
262
263    (stdout_backup, stderr_backup)
264}
265
266/// Restores stdout and stderr.
267#[cfg(feature = "local")]
268fn restore_output(stdout_backup: i32, stderr_backup: i32) {
269    unsafe {
270        libc::dup2(stdout_backup, 1); // restore stdout
271        libc::dup2(stderr_backup, 2); // restore stderr
272        libc::close(stdout_backup);
273        libc::close(stderr_backup);
274    }
275}
276
277/// Suppresses stderr.
278#[cfg(feature = "local")]
279fn suppress_stderr() -> i32 {
280    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
281    let stderr_backup = unsafe { libc::dup(2) };
282    unsafe {
283        libc::dup2(dev_null.as_raw_fd(), 2);
284    }
285    stderr_backup
286}
287
288/// Restores stderr.
289#[cfg(feature = "local")]
290fn restore_stderr(stderr_backup: i32) {
291    unsafe {
292        libc::dup2(stderr_backup, 2);
293        libc::close(stderr_backup);
294    }
295}
296
297/// A provider for a local LLM.
298#[cfg(feature = "local")]
299pub struct LocalLLMProvider {
300    model: Arc<LlamaModel>,
301    backend: Arc<LlamaBackend>,
302}
303
304#[cfg(feature = "local")]
305impl LocalLLMProvider {
306    /// Creates a new `LocalLLMProvider`.
307    pub async fn new(config: LocalConfig) -> Result<Self> {
308        // Suppress verbose output during model loading in offline mode
309        let (stdout_backup, stderr_backup) = suppress_output();
310
311        // Initialize llama backend
312        let backend = LlamaBackend::init().map_err(|e| {
313            restore_output(stdout_backup, stderr_backup);
314            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
315        })?;
316
317        // Download model from HuggingFace if needed
318        let model_path = Self::download_model(&config).await.map_err(|e| {
319            restore_output(stdout_backup, stderr_backup);
320            e
321        })?;
322
323        // Load the model
324        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
325
326        let model =
327            LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
328                restore_output(stdout_backup, stderr_backup);
329                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
330            })?;
331
332        // Restore output
333        restore_output(stdout_backup, stderr_backup);
334
335        Ok(Self {
336            model: Arc::new(model),
337            backend: Arc::new(backend),
338        })
339    }
340
341    /// Downloads a model from Hugging Face.
342    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
343        use std::process::Command;
344
345        // Check if model is already in HuggingFace cache
346        if let Some(cached_path) =
347            Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
348        {
349            // Model found in cache - no output needed in offline mode
350            return Ok(cached_path);
351        }
352
353        // Model not found in cache - suppress download output in offline mode
354
355        // Use huggingface_hub to download the model (suppress output)
356        let output = Command::new("huggingface-cli")
357            .args([
358                "download",
359                &config.huggingface_repo,
360                &config.model_file,
361                "--local-dir",
362                ".cache/models",
363                "--local-dir-use-symlinks",
364                "False",
365            ])
366            .stdout(std::process::Stdio::null()) // Suppress stdout
367            .stderr(std::process::Stdio::null()) // Suppress stderr
368            .output()
369            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
370
371        if !output.status.success() {
372            return Err(HeliosError::LLMError(format!(
373                "Failed to download model: {}",
374                String::from_utf8_lossy(&output.stderr)
375            )));
376        }
377
378        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
379        if !model_path.exists() {
380            return Err(HeliosError::LLMError(format!(
381                "Model file not found after download: {}",
382                model_path.display()
383            )));
384        }
385
386        Ok(model_path)
387    }
388
389    /// Finds a model in the Hugging Face cache.
390    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
391        // Check HuggingFace cache directory
392        let cache_dir = std::env::var("HF_HOME")
393            .map(std::path::PathBuf::from)
394            .unwrap_or_else(|_| {
395                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
396                std::path::PathBuf::from(home)
397                    .join(".cache")
398                    .join("huggingface")
399            });
400
401        let hub_dir = cache_dir.join("hub");
402
403        // Convert repo name to HuggingFace cache format
404        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
405        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
406        let repo_dir = hub_dir.join(&cache_repo_name);
407
408        if !repo_dir.exists() {
409            return None;
410        }
411
412        // Check in snapshots directory (newer cache format)
413        let snapshots_dir = repo_dir.join("snapshots");
414        if snapshots_dir.exists() {
415            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
416                for entry in entries.flatten() {
417                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
418                        if snapshot_path.exists() {
419                            return Some(snapshot_path);
420                        }
421                    }
422                }
423            }
424        }
425
426        // Check in blobs directory (alternative cache format)
427        let blobs_dir = repo_dir.join("blobs");
428        if blobs_dir.exists() {
429            // For blobs, we need to find the blob file by hash
430            // This is more complex, so for now we'll skip this check
431            // The snapshots approach should cover most cases
432        }
433
434        None
435    }
436
437    /// Formats a list of messages into a single string.
438    fn format_messages(&self, messages: &[ChatMessage]) -> String {
439        let mut formatted = String::new();
440
441        // Use Qwen chat template format
442        for message in messages {
443            match message.role {
444                crate::chat::Role::System => {
445                    formatted.push_str("<|im_start|>system\n");
446                    formatted.push_str(&message.content);
447                    formatted.push_str("\n<|im_end|>\n");
448                }
449                crate::chat::Role::User => {
450                    formatted.push_str("<|im_start|>user\n");
451                    formatted.push_str(&message.content);
452                    formatted.push_str("\n<|im_end|>\n");
453                }
454                crate::chat::Role::Assistant => {
455                    formatted.push_str("<|im_start|>assistant\n");
456                    formatted.push_str(&message.content);
457                    formatted.push_str("\n<|im_end|>\n");
458                }
459                crate::chat::Role::Tool => {
460                    // For tool messages, include them as assistant responses
461                    formatted.push_str("<|im_start|>assistant\n");
462                    formatted.push_str(&message.content);
463                    formatted.push_str("\n<|im_end|>\n");
464                }
465            }
466        }
467
468        // Start the assistant's response
469        formatted.push_str("<|im_start|>assistant\n");
470
471        formatted
472    }
473}
474
475#[async_trait]
476impl LLMProvider for RemoteLLMClient {
477    fn as_any(&self) -> &dyn std::any::Any {
478        self
479    }
480
481    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
482        let url = format!("{}/chat/completions", self.config.base_url);
483
484        let mut request_builder = self
485            .client
486            .post(&url)
487            .header("Content-Type", "application/json");
488
489        // Only add authorization header if not a local vLLM instance
490        if !self.config.base_url.contains("10.")
491            && !self.config.base_url.contains("localhost")
492            && !self.config.base_url.contains("127.0.0.1")
493        {
494            request_builder =
495                request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
496        }
497
498        let response = request_builder.json(&request).send().await?;
499
500        if !response.status().is_success() {
501            let status = response.status();
502            let error_text = response
503                .text()
504                .await
505                .unwrap_or_else(|_| "Unknown error".to_string());
506            return Err(HeliosError::LLMError(format!(
507                "LLM API request failed with status {}: {}",
508                status, error_text
509            )));
510        }
511
512        let llm_response: LLMResponse = response.json().await?;
513        Ok(llm_response)
514    }
515}
516
517impl RemoteLLMClient {
518    /// Sends a chat request to the remote LLM.
519    pub async fn chat(
520        &self,
521        messages: Vec<ChatMessage>,
522        tools: Option<Vec<ToolDefinition>>,
523        temperature: Option<f32>,
524        max_tokens: Option<u32>,
525        stop: Option<Vec<String>>,
526    ) -> Result<ChatMessage> {
527        let request = LLMRequest {
528            model: self.config.model_name.clone(),
529            messages,
530            temperature: temperature.or(Some(self.config.temperature)),
531            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
532            tools: tools.clone(),
533            tool_choice: if tools.is_some() {
534                Some("auto".to_string())
535            } else {
536                None
537            },
538            stream: None,
539            stop,
540        };
541
542        let response = self.generate(request).await?;
543
544        response
545            .choices
546            .into_iter()
547            .next()
548            .map(|choice| choice.message)
549            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
550    }
551
552    /// Sends a streaming chat request to the remote LLM.
553    pub async fn chat_stream<F>(
554        &self,
555        messages: Vec<ChatMessage>,
556        tools: Option<Vec<ToolDefinition>>,
557        temperature: Option<f32>,
558        max_tokens: Option<u32>,
559        stop: Option<Vec<String>>,
560        mut on_chunk: F,
561    ) -> Result<ChatMessage>
562    where
563        F: FnMut(&str) + Send,
564    {
565        let request = LLMRequest {
566            model: self.config.model_name.clone(),
567            messages,
568            temperature: temperature.or(Some(self.config.temperature)),
569            max_tokens: max_tokens.or(Some(self.config.max_tokens)),
570            tools: tools.clone(),
571            tool_choice: if tools.is_some() {
572                Some("auto".to_string())
573            } else {
574                None
575            },
576            stream: Some(true),
577            stop,
578        };
579
580        let url = format!("{}/chat/completions", self.config.base_url);
581
582        let mut request_builder = self
583            .client
584            .post(&url)
585            .header("Content-Type", "application/json");
586
587        // Only add authorization header if not a local vLLM instance
588        if !self.config.base_url.contains("10.")
589            && !self.config.base_url.contains("localhost")
590            && !self.config.base_url.contains("127.0.0.1")
591        {
592            request_builder =
593                request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
594        }
595
596        let response = request_builder.json(&request).send().await?;
597
598        if !response.status().is_success() {
599            let status = response.status();
600            let error_text = response
601                .text()
602                .await
603                .unwrap_or_else(|_| "Unknown error".to_string());
604            return Err(HeliosError::LLMError(format!(
605                "LLM API request failed with status {}: {}",
606                status, error_text
607            )));
608        }
609
610        let mut stream = response.bytes_stream();
611        let mut full_content = String::new();
612        let mut role = None;
613        let mut tool_calls = Vec::new();
614        let mut buffer = String::new();
615
616        while let Some(chunk_result) = stream.next().await {
617            let chunk = chunk_result?;
618            let chunk_str = String::from_utf8_lossy(&chunk);
619            buffer.push_str(&chunk_str);
620
621            // Process complete lines
622            while let Some(line_end) = buffer.find('\n') {
623                let line = buffer[..line_end].trim().to_string();
624                buffer = buffer[line_end + 1..].to_string();
625
626                if line.is_empty() || line == "data: [DONE]" {
627                    continue;
628                }
629
630                if let Some(data) = line.strip_prefix("data: ") {
631                    match serde_json::from_str::<StreamChunk>(data) {
632                        Ok(stream_chunk) => {
633                            if let Some(choice) = stream_chunk.choices.first() {
634                                if let Some(r) = &choice.delta.role {
635                                    role = Some(r.clone());
636                                }
637                                if let Some(content) = &choice.delta.content {
638                                    full_content.push_str(content);
639                                    on_chunk(content);
640                                }
641                                if let Some(delta_tool_calls) = &choice.delta.tool_calls {
642                                    for delta_tool_call in delta_tool_calls {
643                                        // Find or create the tool call at this index
644                                        while tool_calls.len() <= delta_tool_call.index as usize {
645                                            tool_calls.push(None);
646                                        }
647                                        let tool_call_slot =
648                                            &mut tool_calls[delta_tool_call.index as usize];
649
650                                        if tool_call_slot.is_none() {
651                                            *tool_call_slot = Some(crate::chat::ToolCall {
652                                                id: String::new(),
653                                                call_type: "function".to_string(),
654                                                function: crate::chat::FunctionCall {
655                                                    name: String::new(),
656                                                    arguments: String::new(),
657                                                },
658                                            });
659                                        }
660
661                                        if let Some(tool_call) = tool_call_slot.as_mut() {
662                                            if let Some(id) = &delta_tool_call.id {
663                                                tool_call.id = id.clone();
664                                            }
665                                            if let Some(function) = &delta_tool_call.function {
666                                                if let Some(name) = &function.name {
667                                                    tool_call.function.name = name.clone();
668                                                }
669                                                if let Some(args) = &function.arguments {
670                                                    tool_call.function.arguments.push_str(args);
671                                                }
672                                            }
673                                        }
674                                    }
675                                }
676                            }
677                        }
678                        Err(e) => {
679                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
680                        }
681                    }
682                }
683            }
684        }
685
686        let final_tool_calls = tool_calls.into_iter().flatten().collect::<Vec<_>>();
687        let tool_calls_option = if final_tool_calls.is_empty() {
688            None
689        } else {
690            Some(final_tool_calls)
691        };
692
693        Ok(ChatMessage {
694            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
695            content: full_content,
696            name: None,
697            tool_calls: tool_calls_option,
698            tool_call_id: None,
699        })
700    }
701}
702
703#[cfg(feature = "local")]
704#[async_trait]
705impl LLMProvider for LocalLLMProvider {
706    fn as_any(&self) -> &dyn std::any::Any {
707        self
708    }
709
710    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
711        let prompt = self.format_messages(&request.messages);
712
713        // Suppress output during inference in offline mode
714        let (stdout_backup, stderr_backup) = suppress_output();
715
716        // Run inference in a blocking task
717        let model = Arc::clone(&self.model);
718        let backend = Arc::clone(&self.backend);
719        let result = task::spawn_blocking(move || {
720            // Create a fresh context per request (model/back-end are reused across calls)
721            use std::num::NonZeroU32;
722            let ctx_params =
723                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
724
725            let mut context = model
726                .new_context(&backend, ctx_params)
727                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
728
729            // Tokenize the prompt
730            let tokens = context
731                .model
732                .str_to_token(&prompt, AddBos::Always)
733                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
734
735            // Create batch for prompt
736            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
737            for (i, &token) in tokens.iter().enumerate() {
738                let compute_logits = true; // Compute logits for all tokens (they accumulate)
739                prompt_batch
740                    .add(token, i as i32, &[0], compute_logits)
741                    .map_err(|e| {
742                        HeliosError::LLMError(format!(
743                            "Failed to add prompt token to batch: {:?}",
744                            e
745                        ))
746                    })?;
747            }
748
749            // Decode the prompt
750            context
751                .decode(&mut prompt_batch)
752                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
753
754            // Generate response tokens
755            let mut generated_text = String::new();
756            let max_new_tokens = 512; // Increased limit for better responses
757            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
758
759            for _ in 0..max_new_tokens {
760                // Get logits from the last decoded position (get_logits returns logits for the last token)
761                let logits = context.get_logits();
762
763                let token_idx = logits
764                    .iter()
765                    .enumerate()
766                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
767                    .map(|(idx, _)| idx)
768                    .unwrap_or_else(|| {
769                        let eos = context.model.token_eos();
770                        eos.0 as usize
771                    });
772                let token = LlamaToken(token_idx as i32);
773
774                // Check for end of sequence
775                if token == context.model.token_eos() {
776                    break;
777                }
778
779                // Convert token back to text
780                match context.model.token_to_str(token, Special::Plaintext) {
781                    Ok(text) => {
782                        generated_text.push_str(&text);
783                    }
784                    Err(_) => continue, // Skip invalid tokens
785                }
786
787                // Create a new batch with just this token
788                let mut gen_batch = LlamaBatch::new(1, 1);
789                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
790                    HeliosError::LLMError(format!(
791                        "Failed to add generated token to batch: {:?}",
792                        e
793                    ))
794                })?;
795
796                // Decode the new token
797                context.decode(&mut gen_batch).map_err(|e| {
798                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
799                })?;
800
801                next_pos += 1;
802            }
803
804            Ok::<String, HeliosError>(generated_text)
805        })
806        .await
807        .map_err(|e| {
808            restore_output(stdout_backup, stderr_backup);
809            HeliosError::LLMError(format!("Task failed: {}", e))
810        })??;
811
812        // Restore output after inference completes
813        restore_output(stdout_backup, stderr_backup);
814
815        let response = LLMResponse {
816            id: format!("local-{}", chrono::Utc::now().timestamp()),
817            object: "chat.completion".to_string(),
818            created: chrono::Utc::now().timestamp() as u64,
819            model: "local-model".to_string(),
820            choices: vec![Choice {
821                index: 0,
822                message: ChatMessage {
823                    role: crate::chat::Role::Assistant,
824                    content: result,
825                    name: None,
826                    tool_calls: None,
827                    tool_call_id: None,
828                },
829                finish_reason: Some("stop".to_string()),
830            }],
831            usage: Usage {
832                prompt_tokens: 0,     // TODO: Calculate actual token count
833                completion_tokens: 0, // TODO: Calculate actual token count
834                total_tokens: 0,      // TODO: Calculate actual token count
835            },
836        };
837
838        Ok(response)
839    }
840}
841
842#[cfg(feature = "local")]
843impl LocalLLMProvider {
844    /// Sends a streaming chat request to the local LLM.
845    async fn chat_stream_local<F>(
846        &self,
847        messages: Vec<ChatMessage>,
848        _temperature: Option<f32>,
849        _max_tokens: Option<u32>,
850        _stop: Option<Vec<String>>,
851        mut on_chunk: F,
852    ) -> Result<ChatMessage>
853    where
854        F: FnMut(&str) + Send,
855    {
856        let prompt = self.format_messages(&messages);
857
858        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
859        let stderr_backup = suppress_stderr();
860
861        // Create a channel for streaming tokens
862        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
863
864        // Spawn blocking task for generation
865        let model = Arc::clone(&self.model);
866        let backend = Arc::clone(&self.backend);
867        let generation_task = task::spawn_blocking(move || {
868            // Create a fresh context per request (model/back-end are reused across calls)
869            use std::num::NonZeroU32;
870            let ctx_params =
871                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
872
873            let mut context = model
874                .new_context(&backend, ctx_params)
875                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
876
877            // Tokenize the prompt
878            let tokens = context
879                .model
880                .str_to_token(&prompt, AddBos::Always)
881                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
882
883            // Create batch for prompt
884            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
885            for (i, &token) in tokens.iter().enumerate() {
886                let compute_logits = true;
887                prompt_batch
888                    .add(token, i as i32, &[0], compute_logits)
889                    .map_err(|e| {
890                        HeliosError::LLMError(format!(
891                            "Failed to add prompt token to batch: {:?}",
892                            e
893                        ))
894                    })?;
895            }
896
897            // Decode the prompt
898            context
899                .decode(&mut prompt_batch)
900                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
901
902            // Generate response tokens with streaming
903            let mut generated_text = String::new();
904            let max_new_tokens = 512;
905            let mut next_pos = tokens.len() as i32;
906
907            for _ in 0..max_new_tokens {
908                let logits = context.get_logits();
909
910                let token_idx = logits
911                    .iter()
912                    .enumerate()
913                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
914                    .map(|(idx, _)| idx)
915                    .unwrap_or_else(|| {
916                        let eos = context.model.token_eos();
917                        eos.0 as usize
918                    });
919                let token = LlamaToken(token_idx as i32);
920
921                // Check for end of sequence
922                if token == context.model.token_eos() {
923                    break;
924                }
925
926                // Convert token back to text
927                match context.model.token_to_str(token, Special::Plaintext) {
928                    Ok(text) => {
929                        generated_text.push_str(&text);
930                        // Send token through channel; stop if receiver is dropped
931                        if tx.send(text).is_err() {
932                            break;
933                        }
934                    }
935                    Err(_) => continue,
936                }
937
938                // Create a new batch with just this token
939                let mut gen_batch = LlamaBatch::new(1, 1);
940                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
941                    HeliosError::LLMError(format!(
942                        "Failed to add generated token to batch: {:?}",
943                        e
944                    ))
945                })?;
946
947                // Decode the new token
948                context.decode(&mut gen_batch).map_err(|e| {
949                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
950                })?;
951
952                next_pos += 1;
953            }
954
955            Ok::<String, HeliosError>(generated_text)
956        });
957
958        // Receive and process tokens as they arrive
959        while let Some(token) = rx.recv().await {
960            on_chunk(&token);
961        }
962
963        // Wait for generation to complete and get the result
964        let result = match generation_task.await {
965            Ok(Ok(text)) => text,
966            Ok(Err(e)) => {
967                restore_stderr(stderr_backup);
968                return Err(e);
969            }
970            Err(e) => {
971                restore_stderr(stderr_backup);
972                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
973            }
974        };
975
976        // Restore stderr after generation completes
977        restore_stderr(stderr_backup);
978
979        Ok(ChatMessage {
980            role: crate::chat::Role::Assistant,
981            content: result,
982            name: None,
983            tool_calls: None,
984            tool_call_id: None,
985        })
986    }
987}
988
989#[async_trait]
990impl LLMProvider for LLMClient {
991    fn as_any(&self) -> &dyn std::any::Any {
992        self
993    }
994
995    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
996        self.provider.generate(request).await
997    }
998}
999
1000impl LLMClient {
1001    /// Sends a chat request to the LLM.
1002    pub async fn chat(
1003        &self,
1004        messages: Vec<ChatMessage>,
1005        tools: Option<Vec<ToolDefinition>>,
1006        temperature: Option<f32>,
1007        max_tokens: Option<u32>,
1008        stop: Option<Vec<String>>,
1009    ) -> Result<ChatMessage> {
1010        let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
1011            LLMProviderType::Remote(config) => (
1012                config.model_name.clone(),
1013                config.temperature,
1014                config.max_tokens,
1015            ),
1016            #[cfg(feature = "local")]
1017            LLMProviderType::Local(config) => (
1018                "local-model".to_string(),
1019                config.temperature,
1020                config.max_tokens,
1021            ),
1022            #[cfg(feature = "candle")]
1023            LLMProviderType::Candle(config) => (
1024                config.huggingface_repo.clone(),
1025                config.temperature,
1026                config.max_tokens,
1027            ),
1028        };
1029
1030        let request = LLMRequest {
1031            model: model_name,
1032            messages,
1033            temperature: temperature.or(Some(default_temperature)),
1034            max_tokens: max_tokens.or(Some(default_max_tokens)),
1035            tools: tools.clone(),
1036            tool_choice: if tools.is_some() {
1037                Some("auto".to_string())
1038            } else {
1039                None
1040            },
1041            stream: None,
1042            stop,
1043        };
1044
1045        let response = self.generate(request).await?;
1046
1047        response
1048            .choices
1049            .into_iter()
1050            .next()
1051            .map(|choice| choice.message)
1052            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1053    }
1054
1055    /// Sends a streaming chat request to the LLM.
1056    pub async fn chat_stream<F>(
1057        &self,
1058        messages: Vec<ChatMessage>,
1059        tools: Option<Vec<ToolDefinition>>,
1060        temperature: Option<f32>,
1061        max_tokens: Option<u32>,
1062        stop: Option<Vec<String>>,
1063        on_chunk: F,
1064    ) -> Result<ChatMessage>
1065    where
1066        F: FnMut(&str) + Send,
1067    {
1068        match &self.provider_type {
1069            LLMProviderType::Remote(_) => {
1070                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
1071                    provider
1072                        .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
1073                        .await
1074                } else {
1075                    Err(HeliosError::AgentError("Provider type mismatch".into()))
1076                }
1077            }
1078            #[cfg(feature = "local")]
1079            LLMProviderType::Local(_) => {
1080                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
1081                    provider
1082                        .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
1083                        .await
1084                } else {
1085                    Err(HeliosError::AgentError("Provider type mismatch".into()))
1086                }
1087            }
1088            #[cfg(feature = "candle")]
1089            LLMProviderType::Candle(config) => {
1090                // For Candle, use non-streaming generate and call on_chunk with full response
1091                let (model_name, default_temperature, default_max_tokens) = (
1092                    config.huggingface_repo.clone(),
1093                    config.temperature,
1094                    config.max_tokens,
1095                );
1096
1097                let request = LLMRequest {
1098                    model: model_name,
1099                    messages,
1100                    temperature: temperature.or(Some(default_temperature)),
1101                    max_tokens: max_tokens.or(Some(default_max_tokens)),
1102                    tools: tools.clone(),
1103                    tool_choice: if tools.is_some() {
1104                        Some("auto".to_string())
1105                    } else {
1106                        None
1107                    },
1108                    stream: None,
1109                    stop,
1110                };
1111
1112                let response = self.provider.generate(request).await?;
1113                if let Some(choice) = response.choices.first() {
1114                    on_chunk(&choice.message.content);
1115                }
1116                response
1117                    .choices
1118                    .into_iter()
1119                    .next()
1120                    .map(|choice| choice.message)
1121                    .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1122            }
1123        }
1124    }
1125}
1126
1127// Test module added