helios_engine/
llm.rs

1use crate::chat::ChatMessage;
2use crate::config::{LLMConfig, LocalConfig};
3use crate::error::{HeliosError, Result};
4use crate::tools::ToolDefinition;
5use async_trait::async_trait;
6use futures::stream::StreamExt;
7use llama_cpp_2::context::params::LlamaContextParams;
8use llama_cpp_2::llama_backend::LlamaBackend;
9use llama_cpp_2::llama_batch::LlamaBatch;
10use llama_cpp_2::model::params::LlamaModelParams;
11use llama_cpp_2::model::{AddBos, LlamaModel, Special};
12use llama_cpp_2::token::LlamaToken;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::task;
17use std::fs::File;
18use std::os::fd::AsRawFd;
19
20// Add From trait for LLamaCppError to convert to HeliosError
21impl From<llama_cpp_2::LLamaCppError> for HeliosError {
22    fn from(err: llama_cpp_2::LLamaCppError) -> Self {
23        HeliosError::LlamaCppError(format!("{:?}", err))
24    }
25}
26
27#[derive(Clone)]
28pub enum LLMProviderType {
29    Remote(LLMConfig),
30    Local(LocalConfig),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LLMRequest {
35    pub model: String,
36    pub messages: Vec<ChatMessage>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub temperature: Option<f32>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub max_tokens: Option<u32>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tools: Option<Vec<ToolDefinition>>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub tool_choice: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub stream: Option<bool>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct StreamChunk {
51    pub id: String,
52    pub object: String,
53    pub created: u64,
54    pub model: String,
55    pub choices: Vec<StreamChoice>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StreamChoice {
60    pub index: u32,
61    pub delta: Delta,
62    pub finish_reason: Option<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Delta {
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub role: Option<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub content: Option<String>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct LLMResponse {
75    pub id: String,
76    pub object: String,
77    pub created: u64,
78    pub model: String,
79    pub choices: Vec<Choice>,
80    pub usage: Usage,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Choice {
85    pub index: u32,
86    pub message: ChatMessage,
87    pub finish_reason: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Usage {
92    pub prompt_tokens: u32,
93    pub completion_tokens: u32,
94    pub total_tokens: u32,
95}
96
97#[async_trait]
98pub trait LLMProvider: Send + Sync {
99    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
100}
101
102pub struct LLMClient {
103    provider: Box<dyn LLMProvider + Send + Sync>,
104    provider_type: LLMProviderType,
105}
106
107impl LLMClient {
108    pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
109        let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
110            LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
111            LLMProviderType::Local(config) => {
112                Box::new(LocalLLMProvider::new(config.clone()).await?)
113            }
114        };
115
116        Ok(Self {
117            provider,
118            provider_type,
119        })
120    }
121
122    pub fn provider_type(&self) -> &LLMProviderType {
123        &self.provider_type
124    }
125}
126
127// Rename the old LLMClient to RemoteLLMClient
128pub struct RemoteLLMClient {
129    config: LLMConfig,
130    client: Client,
131}
132
133impl RemoteLLMClient {
134    pub fn new(config: LLMConfig) -> Self {
135        Self {
136            config,
137            client: Client::new(),
138        }
139    }
140
141    pub fn config(&self) -> &LLMConfig {
142        &self.config
143    }
144}
145
146/// Helper function to suppress stdout and stderr during model loading
147fn suppress_output() -> (i32, i32) {
148    // Open /dev/null for writing
149    let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
150
151    // Duplicate current stdout and stderr file descriptors
152    let stdout_backup = unsafe { libc::dup(1) };
153    let stderr_backup = unsafe { libc::dup(2) };
154
155    // Redirect stdout and stderr to /dev/null
156    unsafe {
157        libc::dup2(dev_null.as_raw_fd(), 1); // stdout
158        libc::dup2(dev_null.as_raw_fd(), 2); // stderr
159    }
160
161    (stdout_backup, stderr_backup)
162}
163
164/// Helper function to restore stdout and stderr
165fn restore_output(stdout_backup: i32, stderr_backup: i32) {
166    unsafe {
167        libc::dup2(stdout_backup, 1); // restore stdout
168        libc::dup2(stderr_backup, 2); // restore stderr
169        libc::close(stdout_backup);
170        libc::close(stderr_backup);
171    }
172}
173
174pub struct LocalLLMProvider {
175    model: Arc<LlamaModel>,
176}
177
178impl LocalLLMProvider {
179    pub async fn new(config: LocalConfig) -> Result<Self> {
180        // Suppress verbose output during model loading in offline mode
181        let (stdout_backup, stderr_backup) = suppress_output();
182
183        // Initialize llama backend
184        let backend = LlamaBackend::init().map_err(|e| {
185            restore_output(stdout_backup, stderr_backup);
186            HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
187        })?;
188
189        // Download model from HuggingFace if needed
190        let model_path = Self::download_model(&config).await.map_err(|e| {
191            restore_output(stdout_backup, stderr_backup);
192            e
193        })?;
194
195        // Load the model
196        let model_params = LlamaModelParams::default().with_n_gpu_layers(99); // Use GPU if available
197
198        let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)
199            .map_err(|e| {
200                restore_output(stdout_backup, stderr_backup);
201                HeliosError::LLMError(format!("Failed to load model: {:?}", e))
202            })?;
203
204        // Restore output
205        restore_output(stdout_backup, stderr_backup);
206
207        Ok(Self {
208            model: Arc::new(model),
209        })
210    }
211
212    async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
213        use std::process::Command;
214
215        // Check if model is already in HuggingFace cache
216        if let Some(cached_path) = Self::find_model_in_cache(&config.huggingface_repo, &config.model_file) {
217            // Model found in cache - no output needed in offline mode
218            return Ok(cached_path);
219        }
220
221        // Model not found in cache - suppress download output in offline mode
222
223        // Use huggingface_hub to download the model (suppress output)
224        let output = Command::new("huggingface-cli")
225            .args(&[
226                "download",
227                &config.huggingface_repo,
228                &config.model_file,
229                "--local-dir",
230                ".cache/models",
231                "--local-dir-use-symlinks",
232                "False",
233            ])
234            .stdout(std::process::Stdio::null()) // Suppress stdout
235            .stderr(std::process::Stdio::null()) // Suppress stderr
236            .output()
237            .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
238
239        if !output.status.success() {
240            return Err(HeliosError::LLMError(format!(
241                "Failed to download model: {}",
242                String::from_utf8_lossy(&output.stderr)
243            )));
244        }
245
246        let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
247        if !model_path.exists() {
248            return Err(HeliosError::LLMError(format!(
249                "Model file not found after download: {}",
250                model_path.display()
251            )));
252        }
253
254        Ok(model_path)
255    }
256
257    fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
258        // Check HuggingFace cache directory
259        let cache_dir = std::env::var("HF_HOME")
260            .map(std::path::PathBuf::from)
261            .unwrap_or_else(|_| {
262                let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
263                std::path::PathBuf::from(home).join(".cache").join("huggingface")
264            });
265
266        let hub_dir = cache_dir.join("hub");
267
268        // Convert repo name to HuggingFace cache format
269        // e.g., "unsloth/Qwen3-0.6B-GGUF" -> "models--unsloth--Qwen3-0.6B-GGUF"
270        let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
271        let repo_dir = hub_dir.join(&cache_repo_name);
272
273        if !repo_dir.exists() {
274            return None;
275        }
276
277        // Check in snapshots directory (newer cache format)
278        let snapshots_dir = repo_dir.join("snapshots");
279        if snapshots_dir.exists() {
280            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
281                for entry in entries.flatten() {
282                    if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
283                        if snapshot_path.exists() {
284                            return Some(snapshot_path);
285                        }
286                    }
287                }
288            }
289        }
290
291        // Check in blobs directory (alternative cache format)
292        let blobs_dir = repo_dir.join("blobs");
293        if blobs_dir.exists() {
294            // For blobs, we need to find the blob file by hash
295            // This is more complex, so for now we'll skip this check
296            // The snapshots approach should cover most cases
297        }
298
299        None
300    }
301
302    fn format_messages(&self, messages: &[ChatMessage]) -> String {
303        let mut formatted = String::new();
304
305        // Use Qwen chat template format
306        for message in messages {
307            match message.role {
308                crate::chat::Role::System => {
309                    formatted.push_str("<|im_start|>system\n");
310                    formatted.push_str(&message.content);
311                    formatted.push_str("\n<|im_end|>\n");
312                }
313                crate::chat::Role::User => {
314                    formatted.push_str("<|im_start|>user\n");
315                    formatted.push_str(&message.content);
316                    formatted.push_str("\n<|im_end|>\n");
317                }
318                crate::chat::Role::Assistant => {
319                    formatted.push_str("<|im_start|>assistant\n");
320                    formatted.push_str(&message.content);
321                    formatted.push_str("\n<|im_end|>\n");
322                }
323                crate::chat::Role::Tool => {
324                    // For tool messages, include them as assistant responses
325                    formatted.push_str("<|im_start|>assistant\n");
326                    formatted.push_str(&message.content);
327                    formatted.push_str("\n<|im_end|>\n");
328                }
329            }
330        }
331
332        // Start the assistant's response
333        formatted.push_str("<|im_start|>assistant\n");
334
335        formatted
336    }
337}
338
339#[async_trait]
340impl LLMProvider for RemoteLLMClient {
341    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
342        let url = format!("{}/chat/completions", self.config.base_url);
343
344        let response = self
345            .client
346            .post(&url)
347            .header("Authorization", format!("Bearer {}", self.config.api_key))
348            .header("Content-Type", "application/json")
349            .json(&request)
350            .send()
351            .await?;
352
353        if !response.status().is_success() {
354            let status = response.status();
355            let error_text = response
356                .text()
357                .await
358                .unwrap_or_else(|_| "Unknown error".to_string());
359            return Err(HeliosError::LLMError(format!(
360                "LLM API request failed with status {}: {}",
361                status, error_text
362            )));
363        }
364
365        let llm_response: LLMResponse = response.json().await?;
366        Ok(llm_response)
367    }
368}
369
370impl RemoteLLMClient {
371    pub async fn chat(
372        &self,
373        messages: Vec<ChatMessage>,
374        tools: Option<Vec<ToolDefinition>>,
375    ) -> Result<ChatMessage> {
376        let request = LLMRequest {
377            model: self.config.model_name.clone(),
378            messages,
379            temperature: Some(self.config.temperature),
380            max_tokens: Some(self.config.max_tokens),
381            tools,
382            tool_choice: None,
383            stream: None,
384        };
385
386        let response = self.generate(request).await?;
387
388        response
389            .choices
390            .into_iter()
391            .next()
392            .map(|choice| choice.message)
393            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
394    }
395
396    pub async fn chat_stream<F>(
397        &self,
398        messages: Vec<ChatMessage>,
399        tools: Option<Vec<ToolDefinition>>,
400        mut on_chunk: F,
401    ) -> Result<ChatMessage>
402    where
403        F: FnMut(&str) + Send,
404    {
405        let request = LLMRequest {
406            model: self.config.model_name.clone(),
407            messages,
408            temperature: Some(self.config.temperature),
409            max_tokens: Some(self.config.max_tokens),
410            tools,
411            tool_choice: None,
412            stream: Some(true),
413        };
414
415        let url = format!("{}/chat/completions", self.config.base_url);
416
417        let response = self
418            .client
419            .post(&url)
420            .header("Authorization", format!("Bearer {}", self.config.api_key))
421            .header("Content-Type", "application/json")
422            .json(&request)
423            .send()
424            .await?;
425
426        if !response.status().is_success() {
427            let status = response.status();
428            let error_text = response
429                .text()
430                .await
431                .unwrap_or_else(|_| "Unknown error".to_string());
432            return Err(HeliosError::LLMError(format!(
433                "LLM API request failed with status {}: {}",
434                status, error_text
435            )));
436        }
437
438        let mut stream = response.bytes_stream();
439        let mut full_content = String::new();
440        let mut role = None;
441        let mut buffer = String::new();
442
443        while let Some(chunk_result) = stream.next().await {
444            let chunk = chunk_result?;
445            let chunk_str = String::from_utf8_lossy(&chunk);
446            buffer.push_str(&chunk_str);
447
448            // Process complete lines
449            while let Some(line_end) = buffer.find('\n') {
450                let line = buffer[..line_end].trim().to_string();
451                buffer = buffer[line_end + 1..].to_string();
452
453                if line.is_empty() || line == "data: [DONE]" {
454                    continue;
455                }
456
457                if let Some(data) = line.strip_prefix("data: ") {
458                    match serde_json::from_str::<StreamChunk>(data) {
459                        Ok(stream_chunk) => {
460                            if let Some(choice) = stream_chunk.choices.first() {
461                                if let Some(r) = &choice.delta.role {
462                                    role = Some(r.clone());
463                                }
464                                if let Some(content) = &choice.delta.content {
465                                    full_content.push_str(content);
466                                    on_chunk(content);
467                                }
468                            }
469                        }
470                        Err(e) => {
471                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
472                        }
473                    }
474                }
475            }
476        }
477
478        Ok(ChatMessage {
479            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
480            content: full_content,
481            name: None,
482            tool_calls: None,
483            tool_call_id: None,
484        })
485    }
486}
487
488#[async_trait]
489impl LLMProvider for LocalLLMProvider {
490    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
491        let prompt = self.format_messages(&request.messages);
492        let model = Arc::clone(&self.model);
493
494        // Suppress output during inference in offline mode
495        let (stdout_backup, stderr_backup) = suppress_output();
496
497        // Run inference in a blocking task
498        let result = task::spawn_blocking(move || {
499            // Initialize backend
500            let backend = LlamaBackend::init().map_err(|e| {
501                HeliosError::LLMError(format!("Failed to initialize backend: {:?}", e))
502            })?;
503
504            // Create context
505            use std::num::NonZeroU32;
506            let ctx_params =
507                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
508
509            let mut context = model
510                .new_context(&backend, ctx_params)
511                .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
512
513            // Tokenize the prompt
514            let tokens = context
515                .model
516                .str_to_token(&prompt, AddBos::Always)
517                .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
518
519            // Create batch for prompt
520            let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
521            for (i, &token) in tokens.iter().enumerate() {
522                let compute_logits = true; // Compute logits for all tokens (they accumulate)
523                prompt_batch
524                    .add(token, i as i32, &[0], compute_logits)
525                    .map_err(|e| {
526                        HeliosError::LLMError(format!(
527                            "Failed to add prompt token to batch: {:?}",
528                            e
529                        ))
530                    })?;
531            }
532
533            // Decode the prompt
534            context
535                .decode(&mut prompt_batch)
536                .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
537
538            // Generate response tokens
539            let mut generated_text = String::new();
540            let max_new_tokens = 128; // Limit response length
541            let mut next_pos = tokens.len() as i32; // Start after the prompt tokens
542
543            for _ in 0..max_new_tokens {
544                // Get logits from the last decoded position (get_logits returns logits for the last token)
545                let logits = context.get_logits();
546
547                let token_idx = logits
548                    .iter()
549                    .enumerate()
550                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
551                    .map(|(idx, _)| idx)
552                    .unwrap_or_else(|| {
553                        let eos = context.model.token_eos();
554                        eos.0 as usize
555                    });
556                let token = LlamaToken(token_idx as i32);
557
558                // Check for end of sequence
559                if token == context.model.token_eos() {
560                    break;
561                }
562
563                // Convert token back to text
564                match context.model.token_to_str(token, Special::Plaintext) {
565                    Ok(text) => {
566                        generated_text.push_str(&text);
567                    },
568                    Err(_) => continue, // Skip invalid tokens
569                }
570
571                // Create a new batch with just this token
572                let mut gen_batch = LlamaBatch::new(1, 1);
573                gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
574                    HeliosError::LLMError(format!(
575                        "Failed to add generated token to batch: {:?}",
576                        e
577                    ))
578                })?;
579
580                // Decode the new token
581                context.decode(&mut gen_batch).map_err(|e| {
582                    HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
583                })?;
584
585                next_pos += 1;
586            }
587
588            Ok::<String, HeliosError>(generated_text)
589        })
590        .await
591        .map_err(|e| {
592            restore_output(stdout_backup, stderr_backup);
593            HeliosError::LLMError(format!("Task failed: {}", e))
594        })??;
595
596        // Restore output after inference completes
597        restore_output(stdout_backup, stderr_backup);
598
599        let response = LLMResponse {
600            id: format!("local-{}", chrono::Utc::now().timestamp()),
601            object: "chat.completion".to_string(),
602            created: chrono::Utc::now().timestamp() as u64,
603            model: "local-model".to_string(),
604            choices: vec![Choice {
605                index: 0,
606                message: ChatMessage {
607                    role: crate::chat::Role::Assistant,
608                    content: result,
609                    name: None,
610                    tool_calls: None,
611                    tool_call_id: None,
612                },
613                finish_reason: Some("stop".to_string()),
614            }],
615            usage: Usage {
616                prompt_tokens: 0,     // TODO: Calculate actual token count
617                completion_tokens: 0, // TODO: Calculate actual token count
618                total_tokens: 0,      // TODO: Calculate actual token count
619            },
620        };
621
622        Ok(response)
623    }
624}
625
626#[async_trait]
627impl LLMProvider for LLMClient {
628    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
629        self.provider.generate(request).await
630    }
631}
632
633impl LLMClient {
634    pub async fn chat(
635        &self,
636        messages: Vec<ChatMessage>,
637        tools: Option<Vec<ToolDefinition>>,
638    ) -> Result<ChatMessage> {
639        let (model_name, temperature, max_tokens) = match &self.provider_type {
640            LLMProviderType::Remote(config) => (
641                config.model_name.clone(),
642                config.temperature,
643                config.max_tokens,
644            ),
645            LLMProviderType::Local(config) => (
646                "local-model".to_string(),
647                config.temperature,
648                config.max_tokens,
649            ),
650        };
651
652        let request = LLMRequest {
653            model: model_name,
654            messages,
655            temperature: Some(temperature),
656            max_tokens: Some(max_tokens),
657            tools,
658            tool_choice: None,
659            stream: None,
660        };
661
662        let response = self.generate(request).await?;
663
664        response
665            .choices
666            .into_iter()
667            .next()
668            .map(|choice| choice.message)
669            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
670    }
671
672    pub async fn chat_stream<F>(
673        &self,
674        messages: Vec<ChatMessage>,
675        tools: Option<Vec<ToolDefinition>>,
676        mut on_chunk: F,
677    ) -> Result<ChatMessage>
678    where
679        F: FnMut(&str) + Send,
680    {
681        // For local models, streaming is not yet implemented, so fall back to regular chat
682        match &self.provider_type {
683            LLMProviderType::Remote(config) => {
684                let remote_client = RemoteLLMClient::new(config.clone());
685                remote_client.chat_stream(messages, tools, on_chunk).await
686            }
687            LLMProviderType::Local(_) => {
688                // For now, local models don't support streaming, so we call the callback
689                // with the full response content to maintain compatibility
690                let response = self.chat(messages, tools).await?;
691                on_chunk(&response.content);
692                Ok(response)
693            }
694        }
695    }
696}
697
698// Test module added