helios_engine/
llm.rs

1//! # LLM Module
2//!
3//! This module provides the functionality for interacting with Large Language Models (LLMs).
4//! It supports both remote LLMs (like OpenAI) and local LLMs (via `llama.cpp`).
5//! The `LLMClient` provides a unified interface for both types of providers.
6
7use crate::chat::ChatMessage;
8use crate::config::{LLMConfig, LocalConfig};
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use llama_cpp_2::context::params::LlamaContextParams;
14use llama_cpp_2::llama_backend::LlamaBackend;
15use llama_cpp_2::llama_batch::LlamaBatch;
16use llama_cpp_2::model::params::LlamaModelParams;
17use llama_cpp_2::model::{AddBos, LlamaModel, Special};
18use llama_cpp_2::token::LlamaToken;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::os::fd::AsRawFd;
23use std::sync::Arc;
24use tokio::task;
25
26// Add From trait for LLamaCppError to convert to HeliosError
27impl From<llama_cpp_2::LLamaCppError> for HeliosError {
28    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
29        HeliosError::LlamaCppError(format!("{:?}", err))
30    }
31}
32
33/// The type of LLM provider to use.
34#[derive(Clone)]
35pub enum LLMProviderType {
36    /// A remote LLM provider, such as OpenAI.
37    Remote(LLMConfig),
38    /// A local LLM provider, using `llama.cpp`.
39    Local(LocalConfig),
40}
41
42/// A request to an LLM.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct LLMRequest {
45    /// The model to use for the request.
46    pub model: String,
47    /// The messages to send to the model.
48    pub messages: Vec<ChatMessage>,
49    /// The temperature to use for the request.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub temperature: Option<f32>,
52    /// The maximum number of tokens to generate.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub max_tokens: Option<u32>,
55    /// The tools to make available to the model.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub tools: Option<Vec<ToolDefinition>>,
58    /// The tool choice to use for the request.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub tool_choice: Option<String>,
61    /// Whether to stream the response.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub stream: Option<bool>,
64}
65
66/// A chunk of a streamed response.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct StreamChunk {
69    /// The ID of the chunk.
70    pub id: String,
71    /// The object type.
72    pub object: String,
73    /// The creation timestamp.
74    pub created: u64,
75    /// The model that generated the chunk.
76    pub model: String,
77    /// The choices in the chunk.
78    pub choices: Vec<StreamChoice>,
79}
80
81/// A choice in a streamed response.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct StreamChoice {
84    /// The index of the choice.
85    pub index: u32,
86    /// The delta of the choice.
87    pub delta: Delta,
88    /// The reason the stream finished.
89    pub finish_reason: Option<String>,
90}
91
92/// The delta of a streamed choice.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct Delta {
95    /// The role of the delta.
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub role: Option<String>,
98    /// The content of the delta.
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub content: Option<String>,
101}
102
103/// A response from an LLM.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct LLMResponse {
106    /// The ID of the response.
107    pub id: String,
108    /// The object type.
109    pub object: String,
110    /// The creation timestamp.
111    pub created: u64,
112    /// The model that generated the response.
113    pub model: String,
114    /// The choices in the response.
115    pub choices: Vec<Choice>,
116    /// The usage statistics for the response.
117    pub usage: Usage,
118}
119
120/// A choice in an LLM response.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct Choice {
123    /// The index of the choice.
124    pub index: u32,
125    /// The message of the choice.
126    pub message: ChatMessage,
127    /// The reason the generation finished.
128    pub finish_reason: Option<String>,
129}
130
131/// The usage statistics for an LLM response.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Usage {
134    /// The number of tokens in the prompt.
135    pub prompt_tokens: u32,
136    /// The number of tokens in the completion.
137    pub completion_tokens: u32,
138    /// The total number of tokens.
139    pub total_tokens: u32,
140}
141
142/// A trait for LLM providers.
143#[async_trait]
144pub trait LLMProvider: Send + Sync {
145    /// Generates a response from the LLM.
146    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
147    /// Returns the provider as an `Any` type.
148    fn as_any(&self) -> &dyn std::any::Any;
149}
150
151/// A client for interacting with an LLM.
152pub struct LLMClient {
153    provider: Box<dyn LLMProvider + Send + Sync>,
154    provider_type: LLMProviderType,
155}
156
157impl LLMClient {
158    /// Creates a new `LLMClient`.
159    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
160        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
161            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
162            LLMProviderType::Local(config) => {
163                Box::new(LocalLLMProvider::new(config.clone()).await?)
164            }
165        };
166
167        Ok(Self {
168            provider,
169            provider_type,
170        })
171    }
172
173    /// Returns the type of the LLM provider.
174    pub fn provider_type(&self) -> &LLMProviderType {
175        &self.provider_type
176    }
177}
178
179/// A client for interacting with a remote LLM.
180pub struct RemoteLLMClient {
181    config: LLMConfig,
182    client: Client,
183}
184
185impl RemoteLLMClient {
186    /// Creates a new `RemoteLLMClient`.
187    pub fn new(config: LLMConfig) -> Self {
188        Self {
189            config,
190            client: Client::new(),
191        }
192    }
193
194    /// Returns the configuration of the client.
195    pub fn config(&self) -> &LLMConfig {
196        &self.config
197    }
198}
199
200/// Suppresses stdout and stderr.
201fn suppress_output() -> (i32, i32) {
202    // Open /dev/null for writing
203    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
204
205    // Duplicate current stdout and stderr file descriptors
206    let stdout_backup = unsafe { libc::dup(1) };
207    let stderr_backup = unsafe { libc::dup(2) };
208
209    // Redirect stdout and stderr to /dev/null
210    unsafe {
211        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
212        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
213    }
214
215    (stdout_backup, stderr_backup)
216}
217
218/// Restores stdout and stderr.
219fn restore_output(stdout_backup: i32, stderr_backup: i32) {
220    unsafe {
221        libc::dup2(stdout_backup, 1); // restore stdout
222        libc::dup2(stderr_backup, 2); // restore stderr
223        libc::close(stdout_backup);
224        libc::close(stderr_backup);
225    }
226}
227
228/// Suppresses stderr.
229fn suppress_stderr() -> i32 {
230    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
231    let stderr_backup = unsafe { libc::dup(2) };
232    unsafe {
233        libc::dup2(dev_null.as_raw_fd(), 2);
234    }
235    stderr_backup
236}
237
238/// Restores stderr.
239fn restore_stderr(stderr_backup: i32) {
240    unsafe {
241        libc::dup2(stderr_backup, 2);
242        libc::close(stderr_backup);
243    }
244}
245
246/// A provider for a local LLM.
247pub struct LocalLLMProvider {
248    model: Arc<LlamaModel>,
249    backend: Arc<LlamaBackend>,
250}
251
252impl LocalLLMProvider {
253    /// Creates a new `LocalLLMProvider`.
254    pub async fn new(config: LocalConfig) -> Result<Self> {
255        // Suppress verbose output during model loading in offline mode
256        let (stdout_backup, stderr_backup) = suppress_output();
257
258        // Initialize llama backend
259        let backend = LlamaBackend::init().map_err(|e| {
260            restore_output(stdout_backup, stderr_backup);
261            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
262        })?;
263
264        // Download model from HuggingFace if needed
265        let model_path = Self::download_model(&config).await.map_err(|e| {
266            restore_output(stdout_backup, stderr_backup);
267            e
268        })?;
269
270        // Load the model
271        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
272
273        let model =
274            LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
275                restore_output(stdout_backup, stderr_backup);
276                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
277            })?;
278
279        // Restore output
280        restore_output(stdout_backup, stderr_backup);
281
282        Ok(Self {
283            model: Arc::new(model),
284            backend: Arc::new(backend),
285        })
286    }
287
288    /// Downloads a model from Hugging Face.
289    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
290        use std::process::Command;
291
292        // Check if model is already in HuggingFace cache
293        if let Some(cached_path) =
294            Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
295        {
296            // Model found in cache - no output needed in offline mode
297            return Ok(cached_path);
298        }
299
300        // Model not found in cache - suppress download output in offline mode
301
302        // Use huggingface_hub to download the model (suppress output)
303        let output = Command::new("huggingface-cli")
304            .args([
305                "download",
306                &config.huggingface_repo,
307                &config.model_file,
308                "--local-dir",
309                ".cache/models",
310                "--local-dir-use-symlinks",
311                "False",
312            ])
313            .stdout(std::process::Stdio::null()) // Suppress stdout
314            .stderr(std::process::Stdio::null()) // Suppress stderr
315            .output()
316            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
317
318        if !output.status.success() {
319            return Err(HeliosError::LLMError(format!(
320                "Failed to download model: {}",
321                String::from_utf8_lossy(&output.stderr)
322            )));
323        }
324
325        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
326        if !model_path.exists() {
327            return Err(HeliosError::LLMError(format!(
328                "Model file not found after download: {}",
329                model_path.display()
330            )));
331        }
332
333        Ok(model_path)
334    }
335
336    /// Finds a model in the Hugging Face cache.
337    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
338        // Check HuggingFace cache directory
339        let cache_dir = std::env::var("HF_HOME")
340            .map(std::path::PathBuf::from)
341            .unwrap_or_else(|_| {
342                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
343                std::path::PathBuf::from(home)
344                    .join(".cache")
345                    .join("huggingface")
346            });
347
348        let hub_dir = cache_dir.join("hub");
349
350        // Convert repo name to HuggingFace cache format
351        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
352        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
353        let repo_dir = hub_dir.join(&cache_repo_name);
354
355        if !repo_dir.exists() {
356            return None;
357        }
358
359        // Check in snapshots directory (newer cache format)
360        let snapshots_dir = repo_dir.join("snapshots");
361        if snapshots_dir.exists() {
362            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
363                for entry in entries.flatten() {
364                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
365                        if snapshot_path.exists() {
366                            return Some(snapshot_path);
367                        }
368                    }
369                }
370            }
371        }
372
373        // Check in blobs directory (alternative cache format)
374        let blobs_dir = repo_dir.join("blobs");
375        if blobs_dir.exists() {
376            // For blobs, we need to find the blob file by hash
377            // This is more complex, so for now we'll skip this check
378            // The snapshots approach should cover most cases
379        }
380
381        None
382    }
383
384    /// Formats a list of messages into a single string.
385    fn format_messages(&self, messages: &[ChatMessage]) -> String {
386        let mut formatted = String::new();
387
388        // Use Qwen chat template format
389        for message in messages {
390            match message.role {
391                crate::chat::Role::System => {
392                    formatted.push_str("<|im_start|>system\n");
393                    formatted.push_str(&message.content);
394                    formatted.push_str("\n<|im_end|>\n");
395                }
396                crate::chat::Role::User => {
397                    formatted.push_str("<|im_start|>user\n");
398                    formatted.push_str(&message.content);
399                    formatted.push_str("\n<|im_end|>\n");
400                }
401                crate::chat::Role::Assistant => {
402                    formatted.push_str("<|im_start|>assistant\n");
403                    formatted.push_str(&message.content);
404                    formatted.push_str("\n<|im_end|>\n");
405                }
406                crate::chat::Role::Tool => {
407                    // For tool messages, include them as assistant responses
408                    formatted.push_str("<|im_start|>assistant\n");
409                    formatted.push_str(&message.content);
410                    formatted.push_str("\n<|im_end|>\n");
411                }
412            }
413        }
414
415        // Start the assistant's response
416        formatted.push_str("<|im_start|>assistant\n");
417
418        formatted
419    }
420}
421
422#[async_trait]
423impl LLMProvider for RemoteLLMClient {
424    fn as_any(&self) -> &dyn std::any::Any {
425        self
426    }
427
428    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
429        let url = format!("{}/chat/completions", self.config.base_url);
430
431        let response = self
432            .client
433            .post(&url)
434            .header("Authorization", format!("Bearer {}", self.config.api_key))
435            .header("Content-Type", "application/json")
436            .json(&request)
437            .send()
438            .await?;
439
440        if !response.status().is_success() {
441            let status = response.status();
442            let error_text = response
443                .text()
444                .await
445                .unwrap_or_else(|_| "Unknown error".to_string());
446            return Err(HeliosError::LLMError(format!(
447                "LLM API request failed with status {}: {}",
448                status, error_text
449            )));
450        }
451
452        let llm_response: LLMResponse = response.json().await?;
453        Ok(llm_response)
454    }
455}
456
457impl RemoteLLMClient {
458    /// Sends a chat request to the remote LLM.
459    pub async fn chat(
460        &self,
461        messages: Vec<ChatMessage>,
462        tools: Option<Vec<ToolDefinition>>,
463    ) -> Result<ChatMessage> {
464        let request = LLMRequest {
465            model: self.config.model_name.clone(),
466            messages,
467            temperature: Some(self.config.temperature),
468            max_tokens: Some(self.config.max_tokens),
469            tools,
470            tool_choice: None,
471            stream: None,
472        };
473
474        let response = self.generate(request).await?;
475
476        response
477            .choices
478            .into_iter()
479            .next()
480            .map(|choice| choice.message)
481            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
482    }
483
484    /// Sends a streaming chat request to the remote LLM.
485    pub async fn chat_stream<F>(
486        &self,
487        messages: Vec<ChatMessage>,
488        tools: Option<Vec<ToolDefinition>>,
489        mut on_chunk: F,
490    ) -> Result<ChatMessage>
491    where
492        F: FnMut(&str) + Send,
493    {
494        let request = LLMRequest {
495            model: self.config.model_name.clone(),
496            messages,
497            temperature: Some(self.config.temperature),
498            max_tokens: Some(self.config.max_tokens),
499            tools,
500            tool_choice: None,
501            stream: Some(true),
502        };
503
504        let url = format!("{}/chat/completions", self.config.base_url);
505
506        let response = self
507            .client
508            .post(&url)
509            .header("Authorization", format!("Bearer {}", self.config.api_key))
510            .header("Content-Type", "application/json")
511            .json(&request)
512            .send()
513            .await?;
514
515        if !response.status().is_success() {
516            let status = response.status();
517            let error_text = response
518                .text()
519                .await
520                .unwrap_or_else(|_| "Unknown error".to_string());
521            return Err(HeliosError::LLMError(format!(
522                "LLM API request failed with status {}: {}",
523                status, error_text
524            )));
525        }
526
527        let mut stream = response.bytes_stream();
528        let mut full_content = String::new();
529        let mut role = None;
530        let mut buffer = String::new();
531
532        while let Some(chunk_result) = stream.next().await {
533            let chunk = chunk_result?;
534            let chunk_str = String::from_utf8_lossy(&chunk);
535            buffer.push_str(&chunk_str);
536
537            // Process complete lines
538            while let Some(line_end) = buffer.find('\n') {
539                let line = buffer[..line_end].trim().to_string();
540                buffer = buffer[line_end + 1..].to_string();
541
542                if line.is_empty() || line == "data: [DONE]" {
543                    continue;
544                }
545
546                if let Some(data) = line.strip_prefix("data: ") {
547                    match serde_json::from_str::<StreamChunk>(data) {
548                        Ok(stream_chunk) => {
549                            if let Some(choice) = stream_chunk.choices.first() {
550                                if let Some(r) = &choice.delta.role {
551                                    role = Some(r.clone());
552                                }
553                                if let Some(content) = &choice.delta.content {
554                                    full_content.push_str(content);
555                                    on_chunk(content);
556                                }
557                            }
558                        }
559                        Err(e) => {
560                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
561                        }
562                    }
563                }
564            }
565        }
566
567        Ok(ChatMessage {
568            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
569            content: full_content,
570            name: None,
571            tool_calls: None,
572            tool_call_id: None,
573        })
574    }
575}
576
577#[async_trait]
578impl LLMProvider for LocalLLMProvider {
579    fn as_any(&self) -> &dyn std::any::Any {
580        self
581    }
582
583    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
584        let prompt = self.format_messages(&request.messages);
585
586        // Suppress output during inference in offline mode
587        let (stdout_backup, stderr_backup) = suppress_output();
588
589        // Run inference in a blocking task
590        let model = Arc::clone(&self.model);
591        let backend = Arc::clone(&self.backend);
592        let result = task::spawn_blocking(move || {
593            // Create a fresh context per request (model/back-end are reused across calls)
594            use std::num::NonZeroU32;
595            let ctx_params =
596                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
597
598            let mut context = model
599                .new_context(&backend, ctx_params)
600                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
601
602            // Tokenize the prompt
603            let tokens = context
604                .model
605                .str_to_token(&prompt, AddBos::Always)
606                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
607
608            // Create batch for prompt
609            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
610            for (i, &token) in tokens.iter().enumerate() {
611                let compute_logits = true; // Compute logits for all tokens (they accumulate)
612                prompt_batch
613                    .add(token, i as i32, &[0], compute_logits)
614                    .map_err(|e| {
615                        HeliosError::LLMError(format!(
616                            "Failed to add prompt token to batch: {:?}",
617                            e
618                        ))
619                    })?;
620            }
621
622            // Decode the prompt
623            context
624                .decode(&mut prompt_batch)
625                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
626
627            // Generate response tokens
628            let mut generated_text = String::new();
629            let max_new_tokens = 512; // Increased limit for better responses
630            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
631
632            for _ in 0..max_new_tokens {
633                // Get logits from the last decoded position (get_logits returns logits for the last token)
634                let logits = context.get_logits();
635
636                let token_idx = logits
637                    .iter()
638                    .enumerate()
639                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
640                    .map(|(idx, _)| idx)
641                    .unwrap_or_else(|| {
642                        let eos = context.model.token_eos();
643                        eos.0 as usize
644                    });
645                let token = LlamaToken(token_idx as i32);
646
647                // Check for end of sequence
648                if token == context.model.token_eos() {
649                    break;
650                }
651
652                // Convert token back to text
653                match context.model.token_to_str(token, Special::Plaintext) {
654                    Ok(text) => {
655                        generated_text.push_str(&text);
656                    }
657                    Err(_) => continue, // Skip invalid tokens
658                }
659
660                // Create a new batch with just this token
661                let mut gen_batch = LlamaBatch::new(1, 1);
662                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
663                    HeliosError::LLMError(format!(
664                        "Failed to add generated token to batch: {:?}",
665                        e
666                    ))
667                })?;
668
669                // Decode the new token
670                context.decode(&mut gen_batch).map_err(|e| {
671                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
672                })?;
673
674                next_pos += 1;
675            }
676
677            Ok::<String, HeliosError>(generated_text)
678        })
679        .await
680        .map_err(|e| {
681            restore_output(stdout_backup, stderr_backup);
682            HeliosError::LLMError(format!("Task failed: {}", e))
683        })??;
684
685        // Restore output after inference completes
686        restore_output(stdout_backup, stderr_backup);
687
688        let response = LLMResponse {
689            id: format!("local-{}", chrono::Utc::now().timestamp()),
690            object: "chat.completion".to_string(),
691            created: chrono::Utc::now().timestamp() as u64,
692            model: "local-model".to_string(),
693            choices: vec![Choice {
694                index: 0,
695                message: ChatMessage {
696                    role: crate::chat::Role::Assistant,
697                    content: result,
698                    name: None,
699                    tool_calls: None,
700                    tool_call_id: None,
701                },
702                finish_reason: Some("stop".to_string()),
703            }],
704            usage: Usage {
705                prompt_tokens: 0,     // TODO: Calculate actual token count
706                completion_tokens: 0, // TODO: Calculate actual token count
707                total_tokens: 0,      // TODO: Calculate actual token count
708            },
709        };
710
711        Ok(response)
712    }
713}
714
715impl LocalLLMProvider {
716    /// Sends a streaming chat request to the local LLM.
717    async fn chat_stream_local<F>(
718        &self,
719        messages: Vec<ChatMessage>,
720        mut on_chunk: F,
721    ) -> Result<ChatMessage>
722    where
723        F: FnMut(&str) + Send,
724    {
725        let prompt = self.format_messages(&messages);
726
727        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
728        let stderr_backup = suppress_stderr();
729
730        // Create a channel for streaming tokens
731        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
732
733        // Spawn blocking task for generation
734        let model = Arc::clone(&self.model);
735        let backend = Arc::clone(&self.backend);
736        let generation_task = task::spawn_blocking(move || {
737            // Create a fresh context per request (model/back-end are reused across calls)
738            use std::num::NonZeroU32;
739            let ctx_params =
740                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
741
742            let mut context = model
743                .new_context(&backend, ctx_params)
744                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
745
746            // Tokenize the prompt
747            let tokens = context
748                .model
749                .str_to_token(&prompt, AddBos::Always)
750                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
751
752            // Create batch for prompt
753            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
754            for (i, &token) in tokens.iter().enumerate() {
755                let compute_logits = true;
756                prompt_batch
757                    .add(token, i as i32, &[0], compute_logits)
758                    .map_err(|e| {
759                        HeliosError::LLMError(format!(
760                            "Failed to add prompt token to batch: {:?}",
761                            e
762                        ))
763                    })?;
764            }
765
766            // Decode the prompt
767            context
768                .decode(&mut prompt_batch)
769                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
770
771            // Generate response tokens with streaming
772            let mut generated_text = String::new();
773            let max_new_tokens = 512;
774            let mut next_pos = tokens.len() as i32;
775
776            for _ in 0..max_new_tokens {
777                let logits = context.get_logits();
778
779                let token_idx = logits
780                    .iter()
781                    .enumerate()
782                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
783                    .map(|(idx, _)| idx)
784                    .unwrap_or_else(|| {
785                        let eos = context.model.token_eos();
786                        eos.0 as usize
787                    });
788                let token = LlamaToken(token_idx as i32);
789
790                // Check for end of sequence
791                if token == context.model.token_eos() {
792                    break;
793                }
794
795                // Convert token back to text
796                match context.model.token_to_str(token, Special::Plaintext) {
797                    Ok(text) => {
798                        generated_text.push_str(&text);
799                        // Send token through channel; stop if receiver is dropped
800                        if tx.send(text).is_err() {
801                            break;
802                        }
803                    }
804                    Err(_) => continue,
805                }
806
807                // Create a new batch with just this token
808                let mut gen_batch = LlamaBatch::new(1, 1);
809                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
810                    HeliosError::LLMError(format!(
811                        "Failed to add generated token to batch: {:?}",
812                        e
813                    ))
814                })?;
815
816                // Decode the new token
817                context.decode(&mut gen_batch).map_err(|e| {
818                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
819                })?;
820
821                next_pos += 1;
822            }
823
824            Ok::<String, HeliosError>(generated_text)
825        });
826
827        // Receive and process tokens as they arrive
828        while let Some(token) = rx.recv().await {
829            on_chunk(&token);
830        }
831
832        // Wait for generation to complete and get the result
833        let result = match generation_task.await {
834            Ok(Ok(text)) => text,
835            Ok(Err(e)) => {
836                restore_stderr(stderr_backup);
837                return Err(e);
838            }
839            Err(e) => {
840                restore_stderr(stderr_backup);
841                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
842            }
843        };
844
845        // Restore stderr after generation completes
846        restore_stderr(stderr_backup);
847
848        Ok(ChatMessage {
849            role: crate::chat::Role::Assistant,
850            content: result,
851            name: None,
852            tool_calls: None,
853            tool_call_id: None,
854        })
855    }
856}
857
858#[async_trait]
859impl LLMProvider for LLMClient {
860    fn as_any(&self) -> &dyn std::any::Any {
861        self
862    }
863
864    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
865        self.provider.generate(request).await
866    }
867}
868
869impl LLMClient {
870    /// Sends a chat request to the LLM.
871    pub async fn chat(
872        &self,
873        messages: Vec<ChatMessage>,
874        tools: Option<Vec<ToolDefinition>>,
875    ) -> Result<ChatMessage> {
876        let (model_name, temperature, max_tokens) = match &self.provider_type {
877            LLMProviderType::Remote(config) => (
878                config.model_name.clone(),
879                config.temperature,
880                config.max_tokens,
881            ),
882            LLMProviderType::Local(config) => (
883                "local-model".to_string(),
884                config.temperature,
885                config.max_tokens,
886            ),
887        };
888
889        let request = LLMRequest {
890            model: model_name,
891            messages,
892            temperature: Some(temperature),
893            max_tokens: Some(max_tokens),
894            tools,
895            tool_choice: None,
896            stream: None,
897        };
898
899        let response = self.generate(request).await?;
900
901        response
902            .choices
903            .into_iter()
904            .next()
905            .map(|choice| choice.message)
906            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
907    }
908
909    /// Sends a streaming chat request to the LLM.
910    pub async fn chat_stream<F>(
911        &self,
912        messages: Vec<ChatMessage>,
913        tools: Option<Vec<ToolDefinition>>,
914        on_chunk: F,
915    ) -> Result<ChatMessage>
916    where
917        F: FnMut(&str) + Send,
918    {
919        match &self.provider_type {
920            LLMProviderType::Remote(_) => {
921                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
922                    provider.chat_stream(messages, tools, on_chunk).await
923                } else {
924                    Err(HeliosError::AgentError("Provider type mismatch".into()))
925                }
926            }
927            LLMProviderType::Local(_) => {
928                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
929                    provider.chat_stream_local(messages, on_chunk).await
930                } else {
931                    Err(HeliosError::AgentError("Provider type mismatch".into()))
932                }
933            }
934        }
935    }
936}
937
938// Test module added