helios_engine/
llm.rs

1use crate::chat::ChatMessage;
2use crate::config::{LLMConfig, LocalConfig};
3use crate::error::{HeliosError, Result};
4use crate::tools::ToolDefinition;
5use async_trait::async_trait;
6use futures::stream::StreamExt;
7use llama_cpp_2::context::params::LlamaContextParams;
8use llama_cpp_2::llama_backend::LlamaBackend;
9use llama_cpp_2::llama_batch::LlamaBatch;
10use llama_cpp_2::model::params::LlamaModelParams;
11use llama_cpp_2::model::{AddBos, LlamaModel, Special};
12use llama_cpp_2::token::LlamaToken;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::task;
17use std::fs::File;
18use std::os::fd::AsRawFd;
19
20// Add From trait for LLamaCppError to convert to HeliosError
21impl From<llama_cpp_2::LLamaCppError> for HeliosError {
22    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
23        HeliosError::LlamaCppError(format!("{:?}", err))
24    }
25}
26
27#[derive(Clone)]
28pub enum LLMProviderType {
29    Remote(LLMConfig),
30    Local(LocalConfig),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LLMRequest {
35    pub model: String,
36    pub messages: Vec<ChatMessage>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub temperature: Option<f32>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub max_tokens: Option<u32>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tools: Option<Vec<ToolDefinition>>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub tool_choice: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub stream: Option<bool>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct StreamChunk {
51    pub id: String,
52    pub object: String,
53    pub created: u64,
54    pub model: String,
55    pub choices: Vec<StreamChoice>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StreamChoice {
60    pub index: u32,
61    pub delta: Delta,
62    pub finish_reason: Option<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Delta {
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub role: Option<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub content: Option<String>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct LLMResponse {
75    pub id: String,
76    pub object: String,
77    pub created: u64,
78    pub model: String,
79    pub choices: Vec<Choice>,
80    pub usage: Usage,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Choice {
85    pub index: u32,
86    pub message: ChatMessage,
87    pub finish_reason: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Usage {
92    pub prompt_tokens: u32,
93    pub completion_tokens: u32,
94    pub total_tokens: u32,
95}
96
97#[async_trait]
98pub trait LLMProvider: Send + Sync {
99    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
100    fn as_any(&self) -> &dyn std::any::Any;
101}
102
103pub struct LLMClient {
104    provider: Box<dyn LLMProvider + Send + Sync>,
105    provider_type: LLMProviderType,
106}
107
108impl LLMClient {
109    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
110        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
111            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
112            LLMProviderType::Local(config) => {
113                Box::new(LocalLLMProvider::new(config.clone()).await?)
114            }
115        };
116
117        Ok(Self {
118            provider,
119            provider_type,
120        })
121    }
122
123    pub fn provider_type(&self) -> &LLMProviderType {
124        &self.provider_type
125    }
126}
127
128// Rename the old LLMClient to RemoteLLMClient
129pub struct RemoteLLMClient {
130    config: LLMConfig,
131    client: Client,
132}
133
134impl RemoteLLMClient {
135    pub fn new(config: LLMConfig) -> Self {
136        Self {
137            config,
138            client: Client::new(),
139        }
140    }
141
142    pub fn config(&self) -> &LLMConfig {
143        &self.config
144    }
145}
146
147/// Helper function to suppress stdout and stderr during model loading
148fn suppress_output() -> (i32, i32) {
149    // Open /dev/null for writing
150    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
151
152    // Duplicate current stdout and stderr file descriptors
153    let stdout_backup = unsafe { libc::dup(1) };
154    let stderr_backup = unsafe { libc::dup(2) };
155
156    // Redirect stdout and stderr to /dev/null
157    unsafe {
158        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
159        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
160    }
161
162    (stdout_backup, stderr_backup)
163}
164
165/// Helper function to restore stdout and stderr
166fn restore_output(stdout_backup: i32, stderr_backup: i32) {
167    unsafe {
168        libc::dup2(stdout_backup, 1); // restore stdout
169        libc::dup2(stderr_backup, 2); // restore stderr
170        libc::close(stdout_backup);
171        libc::close(stderr_backup);
172    }
173}
174
175/// Helper function to suppress only stderr (used to hide llama.cpp context logs while preserving stdout streaming)
176fn suppress_stderr() -> i32 {
177    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
178    let stderr_backup = unsafe { libc::dup(2) };
179    unsafe {
180        libc::dup2(dev_null.as_raw_fd(), 2);
181    }
182    stderr_backup
183}
184
185/// Helper function to restore only stderr
186fn restore_stderr(stderr_backup: i32) {
187    unsafe {
188        libc::dup2(stderr_backup, 2);
189        libc::close(stderr_backup);
190    }
191}
192
193pub struct LocalLLMProvider {
194    model: Arc<LlamaModel>,
195    backend: Arc<LlamaBackend>,
196}
197
198impl LocalLLMProvider {
199    pub async fn new(config: LocalConfig) -> Result<Self> {
200        // Suppress verbose output during model loading in offline mode
201        let (stdout_backup, stderr_backup) = suppress_output();
202
203        // Initialize llama backend
204        let backend = LlamaBackend::init().map_err(|e| {
205            restore_output(stdout_backup, stderr_backup);
206            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
207        })?;
208
209        // Download model from HuggingFace if needed
210        let model_path = Self::download_model(&config).await.map_err(|e| {
211            restore_output(stdout_backup, stderr_backup);
212            e
213        })?;
214
215        // Load the model
216        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
217
218        let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)
219            .map_err(|e| {
220                restore_output(stdout_backup, stderr_backup);
221                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
222            })?;
223
224        // Restore output
225        restore_output(stdout_backup, stderr_backup);
226
227        Ok(Self {
228            model: Arc::new(model),
229            backend: Arc::new(backend),
230        })
231    }
232
233    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
234        use std::process::Command;
235
236        // Check if model is already in HuggingFace cache
237        if let Some(cached_path) = Self::find_model_in_cache(&config.huggingface_repo, &config.model_file) {
238            // Model found in cache - no output needed in offline mode
239            return Ok(cached_path);
240        }
241
242        // Model not found in cache - suppress download output in offline mode
243
244        // Use huggingface_hub to download the model (suppress output)
245        let output = Command::new("huggingface-cli")
246            .args(&[
247                "download",
248                &config.huggingface_repo,
249                &config.model_file,
250                "--local-dir",
251                ".cache/models",
252                "--local-dir-use-symlinks",
253                "False",
254            ])
255            .stdout(std::process::Stdio::null()) // Suppress stdout
256            .stderr(std::process::Stdio::null()) // Suppress stderr
257            .output()
258            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
259
260        if !output.status.success() {
261            return Err(HeliosError::LLMError(format!(
262                "Failed to download model: {}",
263                String::from_utf8_lossy(&output.stderr)
264            )));
265        }
266
267        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
268        if !model_path.exists() {
269            return Err(HeliosError::LLMError(format!(
270                "Model file not found after download: {}",
271                model_path.display()
272            )));
273        }
274
275        Ok(model_path)
276    }
277
278    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
279        // Check HuggingFace cache directory
280        let cache_dir = std::env::var("HF_HOME")
281            .map(std::path::PathBuf::from)
282            .unwrap_or_else(|_| {
283                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
284                std::path::PathBuf::from(home).join(".cache").join("huggingface")
285            });
286
287        let hub_dir = cache_dir.join("hub");
288
289        // Convert repo name to HuggingFace cache format
290        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
291        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
292        let repo_dir = hub_dir.join(&cache_repo_name);
293
294        if !repo_dir.exists() {
295            return None;
296        }
297
298        // Check in snapshots directory (newer cache format)
299        let snapshots_dir = repo_dir.join("snapshots");
300        if snapshots_dir.exists() {
301            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
302                for entry in entries.flatten() {
303                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
304                        if snapshot_path.exists() {
305                            return Some(snapshot_path);
306                        }
307                    }
308                }
309            }
310        }
311
312        // Check in blobs directory (alternative cache format)
313        let blobs_dir = repo_dir.join("blobs");
314        if blobs_dir.exists() {
315            // For blobs, we need to find the blob file by hash
316            // This is more complex, so for now we'll skip this check
317            // The snapshots approach should cover most cases
318        }
319
320        None
321    }
322
323    fn format_messages(&self, messages: &[ChatMessage]) -> String {
324        let mut formatted = String::new();
325
326        // Use Qwen chat template format
327        for message in messages {
328            match message.role {
329                crate::chat::Role::System => {
330                    formatted.push_str("<|im_start|>system\n");
331                    formatted.push_str(&message.content);
332                    formatted.push_str("\n<|im_end|>\n");
333                }
334                crate::chat::Role::User => {
335                    formatted.push_str("<|im_start|>user\n");
336                    formatted.push_str(&message.content);
337                    formatted.push_str("\n<|im_end|>\n");
338                }
339                crate::chat::Role::Assistant => {
340                    formatted.push_str("<|im_start|>assistant\n");
341                    formatted.push_str(&message.content);
342                    formatted.push_str("\n<|im_end|>\n");
343                }
344                crate::chat::Role::Tool => {
345                    // For tool messages, include them as assistant responses
346                    formatted.push_str("<|im_start|>assistant\n");
347                    formatted.push_str(&message.content);
348                    formatted.push_str("\n<|im_end|>\n");
349                }
350            }
351        }
352
353        // Start the assistant's response
354        formatted.push_str("<|im_start|>assistant\n");
355
356        formatted
357    }
358}
359
360#[async_trait]
361impl LLMProvider for RemoteLLMClient {
362    fn as_any(&self) -> &dyn std::any::Any {
363        self
364    }
365
366    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
367        let url = format!("{}/chat/completions", self.config.base_url);
368
369        let response = self
370            .client
371            .post(&url)
372            .header("Authorization", format!("Bearer {}", self.config.api_key))
373            .header("Content-Type", "application/json")
374            .json(&request)
375            .send()
376            .await?;
377
378        if !response.status().is_success() {
379            let status = response.status();
380            let error_text = response
381                .text()
382                .await
383                .unwrap_or_else(|_| "Unknown error".to_string());
384            return Err(HeliosError::LLMError(format!(
385                "LLM API request failed with status {}: {}",
386                status, error_text
387            )));
388        }
389
390        let llm_response: LLMResponse = response.json().await?;
391        Ok(llm_response)
392    }
393}
394
395impl RemoteLLMClient {
396    pub async fn chat(
397        &self,
398        messages: Vec<ChatMessage>,
399        tools: Option<Vec<ToolDefinition>>,
400    ) -> Result<ChatMessage> {
401        let request = LLMRequest {
402            model: self.config.model_name.clone(),
403            messages,
404            temperature: Some(self.config.temperature),
405            max_tokens: Some(self.config.max_tokens),
406            tools,
407            tool_choice: None,
408            stream: None,
409        };
410
411        let response = self.generate(request).await?;
412
413        response
414            .choices
415            .into_iter()
416            .next()
417            .map(|choice| choice.message)
418            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
419    }
420
421    pub async fn chat_stream<F>(
422        &self,
423        messages: Vec<ChatMessage>,
424        tools: Option<Vec<ToolDefinition>>,
425        mut on_chunk: F,
426    ) -> Result<ChatMessage>
427    where
428        F: FnMut(&str) + Send,
429    {
430        let request = LLMRequest {
431            model: self.config.model_name.clone(),
432            messages,
433            temperature: Some(self.config.temperature),
434            max_tokens: Some(self.config.max_tokens),
435            tools,
436            tool_choice: None,
437            stream: Some(true),
438        };
439
440        let url = format!("{}/chat/completions", self.config.base_url);
441
442        let response = self
443            .client
444            .post(&url)
445            .header("Authorization", format!("Bearer {}", self.config.api_key))
446            .header("Content-Type", "application/json")
447            .json(&request)
448            .send()
449            .await?;
450
451        if !response.status().is_success() {
452            let status = response.status();
453            let error_text = response
454                .text()
455                .await
456                .unwrap_or_else(|_| "Unknown error".to_string());
457            return Err(HeliosError::LLMError(format!(
458                "LLM API request failed with status {}: {}",
459                status, error_text
460            )));
461        }
462
463        let mut stream = response.bytes_stream();
464        let mut full_content = String::new();
465        let mut role = None;
466        let mut buffer = String::new();
467
468        while let Some(chunk_result) = stream.next().await {
469            let chunk = chunk_result?;
470            let chunk_str = String::from_utf8_lossy(&chunk);
471            buffer.push_str(&chunk_str);
472
473            // Process complete lines
474            while let Some(line_end) = buffer.find('\n') {
475                let line = buffer[..line_end].trim().to_string();
476                buffer = buffer[line_end + 1..].to_string();
477
478                if line.is_empty() || line == "data: [DONE]" {
479                    continue;
480                }
481
482                if let Some(data) = line.strip_prefix("data: ") {
483                    match serde_json::from_str::<StreamChunk>(data) {
484                        Ok(stream_chunk) => {
485                            if let Some(choice) = stream_chunk.choices.first() {
486                                if let Some(r) = &choice.delta.role {
487                                    role = Some(r.clone());
488                                }
489                                if let Some(content) = &choice.delta.content {
490                                    full_content.push_str(content);
491                                    on_chunk(content);
492                                }
493                            }
494                        }
495                        Err(e) => {
496                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
497                        }
498                    }
499                }
500            }
501        }
502
503        Ok(ChatMessage {
504            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
505            content: full_content,
506            name: None,
507            tool_calls: None,
508            tool_call_id: None,
509        })
510    }
511}
512
513#[async_trait]
514impl LLMProvider for LocalLLMProvider {
515    fn as_any(&self) -> &dyn std::any::Any {
516        self
517    }
518
519    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
520        let prompt = self.format_messages(&request.messages);
521
522        // Suppress output during inference in offline mode
523        let (stdout_backup, stderr_backup) = suppress_output();
524
525        // Run inference in a blocking task
526        let model = Arc::clone(&self.model);
527        let backend = Arc::clone(&self.backend);
528        let result = task::spawn_blocking(move || {
529            // Create a fresh context per request (model/back-end are reused across calls)
530            use std::num::NonZeroU32;
531            let ctx_params =
532                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
533
534            let mut context = model
535                .new_context(&backend, ctx_params)
536                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
537
538            // Tokenize the prompt
539            let tokens = context
540                .model
541                .str_to_token(&prompt, AddBos::Always)
542                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
543
544            // Create batch for prompt
545            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
546            for (i, &token) in tokens.iter().enumerate() {
547                let compute_logits = true; // Compute logits for all tokens (they accumulate)
548                prompt_batch
549                    .add(token, i as i32, &[0], compute_logits)
550                    .map_err(|e| {
551                        HeliosError::LLMError(format!(
552                            "Failed to add prompt token to batch: {:?}",
553                            e
554                        ))
555                    })?;
556            }
557
558            // Decode the prompt
559            context
560                .decode(&mut prompt_batch)
561                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
562
563            // Generate response tokens
564            let mut generated_text = String::new();
565            let max_new_tokens = 512; // Increased limit for better responses
566            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
567
568            for _ in 0..max_new_tokens {
569                // Get logits from the last decoded position (get_logits returns logits for the last token)
570                let logits = context.get_logits();
571
572                let token_idx = logits
573                    .iter()
574                    .enumerate()
575                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
576                    .map(|(idx, _)| idx)
577                    .unwrap_or_else(|| {
578                        let eos = context.model.token_eos();
579                        eos.0 as usize
580                    });
581                let token = LlamaToken(token_idx as i32);
582
583                // Check for end of sequence
584                if token == context.model.token_eos() {
585                    break;
586                }
587
588                // Convert token back to text
589                match context.model.token_to_str(token, Special::Plaintext) {
590                    Ok(text) => {
591                        generated_text.push_str(&text);
592                    },
593                    Err(_) => continue, // Skip invalid tokens
594                }
595
596                // Create a new batch with just this token
597                let mut gen_batch = LlamaBatch::new(1, 1);
598                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
599                    HeliosError::LLMError(format!(
600                        "Failed to add generated token to batch: {:?}",
601                        e
602                    ))
603                })?;
604
605                // Decode the new token
606                context.decode(&mut gen_batch).map_err(|e| {
607                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
608                })?;
609
610                next_pos += 1;
611            }
612
613            Ok::<String, HeliosError>(generated_text)
614        })
615        .await
616        .map_err(|e| {
617            restore_output(stdout_backup, stderr_backup);
618            HeliosError::LLMError(format!("Task failed: {}", e))
619        })??;
620
621        // Restore output after inference completes
622        restore_output(stdout_backup, stderr_backup);
623
624        let response = LLMResponse {
625            id: format!("local-{}", chrono::Utc::now().timestamp()),
626            object: "chat.completion".to_string(),
627            created: chrono::Utc::now().timestamp() as u64,
628            model: "local-model".to_string(),
629            choices: vec![Choice {
630                index: 0,
631                message: ChatMessage {
632                    role: crate::chat::Role::Assistant,
633                    content: result,
634                    name: None,
635                    tool_calls: None,
636                    tool_call_id: None,
637                },
638                finish_reason: Some("stop".to_string()),
639            }],
640            usage: Usage {
641                prompt_tokens: 0,     // TODO: Calculate actual token count
642                completion_tokens: 0, // TODO: Calculate actual token count
643                total_tokens: 0,      // TODO: Calculate actual token count
644            },
645        };
646
647        Ok(response)
648    }
649}
650
651impl LocalLLMProvider {
652    // Add streaming support for local models
653    async fn chat_stream_local<F>(
654        &self,
655        messages: Vec<ChatMessage>,
656        mut on_chunk: F,
657    ) -> Result<ChatMessage>
658    where
659        F: FnMut(&str) + Send,
660    {
661        let prompt = self.format_messages(&messages);
662
663        // Suppress only stderr so llama.cpp context logs are hidden but stdout streaming remains visible
664        let stderr_backup = suppress_stderr();
665
666        // Create a channel for streaming tokens
667        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
668
669        // Spawn blocking task for generation
670        let model = Arc::clone(&self.model);
671        let backend = Arc::clone(&self.backend);
672        let generation_task = task::spawn_blocking(move || {
673            // Create a fresh context per request (model/back-end are reused across calls)
674            use std::num::NonZeroU32;
675            let ctx_params =
676                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
677
678            let mut context = model
679                .new_context(&backend, ctx_params)
680                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
681
682            // Tokenize the prompt
683            let tokens = context
684                .model
685                .str_to_token(&prompt, AddBos::Always)
686                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
687
688            // Create batch for prompt
689            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
690            for (i, &token) in tokens.iter().enumerate() {
691                let compute_logits = true;
692                prompt_batch
693                    .add(token, i as i32, &[0], compute_logits)
694                    .map_err(|e| {
695                        HeliosError::LLMError(format!(
696                            "Failed to add prompt token to batch: {:?}",
697                            e
698                        ))
699                    })?;
700            }
701
702            // Decode the prompt
703            context
704                .decode(&mut prompt_batch)
705                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
706
707            // Generate response tokens with streaming
708            let mut generated_text = String::new();
709            let max_new_tokens = 512;
710            let mut next_pos = tokens.len() as i32;
711
712            for _ in 0..max_new_tokens {
713                let logits = context.get_logits();
714
715                let token_idx = logits
716                    .iter()
717                    .enumerate()
718                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
719                    .map(|(idx, _)| idx)
720                    .unwrap_or_else(|| {
721                        let eos = context.model.token_eos();
722                        eos.0 as usize
723                    });
724                let token = LlamaToken(token_idx as i32);
725
726                // Check for end of sequence
727                if token == context.model.token_eos() {
728                    break;
729                }
730
731                // Convert token back to text
732                match context.model.token_to_str(token, Special::Plaintext) {
733                    Ok(text) => {
734                        generated_text.push_str(&text);
735                        // Send token through channel
736                        let _ = tx.send(text);
737                    },
738                    Err(_) => continue,
739                }
740
741                // Create a new batch with just this token
742                let mut gen_batch = LlamaBatch::new(1, 1);
743                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
744                    HeliosError::LLMError(format!(
745                        "Failed to add generated token to batch: {:?}",
746                        e
747                    ))
748                })?;
749
750                // Decode the new token
751                context.decode(&mut gen_batch).map_err(|e| {
752                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
753                })?;
754
755                next_pos += 1;
756            }
757
758            Ok::<String, HeliosError>(generated_text)
759        });
760
761        // Receive and process tokens as they arrive
762        while let Some(token) = rx.recv().await {
763            on_chunk(&token);
764        }
765
766        // Wait for generation to complete and get the result
767        let result = match generation_task.await {
768            Ok(Ok(text)) => text,
769            Ok(Err(e)) => {
770                restore_stderr(stderr_backup);
771                return Err(e);
772            }
773            Err(e) => {
774                restore_stderr(stderr_backup);
775                return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
776            }
777        };
778
779        // Restore stderr after generation completes
780        restore_stderr(stderr_backup);
781
782        Ok(ChatMessage {
783            role: crate::chat::Role::Assistant,
784            content: result,
785            name: None,
786            tool_calls: None,
787            tool_call_id: None,
788        })
789    }
790}
791
792#[async_trait]
793impl LLMProvider for LLMClient {
794    fn as_any(&self) -> &dyn std::any::Any {
795        self
796    }
797
798    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
799        self.provider.generate(request).await
800    }
801}
802
803impl LLMClient {
804    pub async fn chat(
805        &self,
806        messages: Vec<ChatMessage>,
807        tools: Option<Vec<ToolDefinition>>,
808    ) -> Result<ChatMessage> {
809        let (model_name, temperature, max_tokens) = match &self.provider_type {
810            LLMProviderType::Remote(config) => (
811                config.model_name.clone(),
812                config.temperature,
813                config.max_tokens,
814            ),
815            LLMProviderType::Local(config) => (
816                "local-model".to_string(),
817                config.temperature,
818                config.max_tokens,
819            ),
820        };
821
822        let request = LLMRequest {
823            model: model_name,
824            messages,
825            temperature: Some(temperature),
826            max_tokens: Some(max_tokens),
827            tools,
828            tool_choice: None,
829            stream: None,
830        };
831
832        let response = self.generate(request).await?;
833
834        response
835            .choices
836            .into_iter()
837            .next()
838            .map(|choice| choice.message)
839            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
840    }
841
842    pub async fn chat_stream<F>(
843        &self,
844        messages: Vec<ChatMessage>,
845        tools: Option<Vec<ToolDefinition>>,
846        on_chunk: F,
847    ) -> Result<ChatMessage>
848    where
849        F: FnMut(&str) + Send,
850    {
851        match &self.provider_type {
852            LLMProviderType::Remote(_) => {
853                if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
854                    provider.chat_stream(messages, tools, on_chunk).await
855                } else {
856                    Err(HeliosError::AgentError("Provider type mismatch".into()))
857                }
858            }
859            LLMProviderType::Local(_) => {
860                if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
861                    provider.chat_stream_local(messages, on_chunk).await
862                } else {
863                    Err(HeliosError::AgentError("Provider type mismatch".into()))
864                }
865            }
866        }
867    }
868}
869
870// Test module added