1use crate::chat::ChatMessage;
2use crate::config::{LLMConfig, LocalConfig};
3use crate::error::{HeliosError, Result};
4use crate::tools::ToolDefinition;
5use async_trait::async_trait;
6use futures::stream::StreamExt;
7use llama_cpp_2::context::params::LlamaContextParams;
8use llama_cpp_2::llama_backend::LlamaBackend;
9use llama_cpp_2::llama_batch::LlamaBatch;
10use llama_cpp_2::model::params::LlamaModelParams;
11use llama_cpp_2::model::{AddBos, LlamaModel, Special};
12use llama_cpp_2::token::LlamaToken;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::task;
17use std::fs::File;
18use std::os::fd::AsRawFd;
19
20impl From<llama_cpp_2::LLamaCppError> for HeliosError {
22 fn from(err: llama_cpp_2::LLamaCppError) -> Self {
23 HeliosError::LlamaCppError(format!("{:?}", err))
24 }
25}
26
27#[derive(Clone)]
28pub enum LLMProviderType {
29 Remote(LLMConfig),
30 Local(LocalConfig),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LLMRequest {
35 pub model: String,
36 pub messages: Vec<ChatMessage>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub temperature: Option<f32>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub max_tokens: Option<u32>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tools: Option<Vec<ToolDefinition>>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub tool_choice: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub stream: Option<bool>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct StreamChunk {
51 pub id: String,
52 pub object: String,
53 pub created: u64,
54 pub model: String,
55 pub choices: Vec<StreamChoice>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StreamChoice {
60 pub index: u32,
61 pub delta: Delta,
62 pub finish_reason: Option<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Delta {
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub role: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub content: Option<String>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct LLMResponse {
75 pub id: String,
76 pub object: String,
77 pub created: u64,
78 pub model: String,
79 pub choices: Vec<Choice>,
80 pub usage: Usage,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Choice {
85 pub index: u32,
86 pub message: ChatMessage,
87 pub finish_reason: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Usage {
92 pub prompt_tokens: u32,
93 pub completion_tokens: u32,
94 pub total_tokens: u32,
95}
96
97#[async_trait]
98pub trait LLMProvider: Send + Sync {
99 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
100}
101
102pub struct LLMClient {
103 provider: Box<dyn LLMProvider + Send + Sync>,
104 provider_type: LLMProviderType,
105}
106
107impl LLMClient {
108 pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
109 let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
110 LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
111 LLMProviderType::Local(config) => {
112 Box::new(LocalLLMProvider::new(config.clone()).await?)
113 }
114 };
115
116 Ok(Self {
117 provider,
118 provider_type,
119 })
120 }
121
122 pub fn provider_type(&self) -> &LLMProviderType {
123 &self.provider_type
124 }
125}
126
127pub struct RemoteLLMClient {
129 config: LLMConfig,
130 client: Client,
131}
132
133impl RemoteLLMClient {
134 pub fn new(config: LLMConfig) -> Self {
135 Self {
136 config,
137 client: Client::new(),
138 }
139 }
140
141 pub fn config(&self) -> &LLMConfig {
142 &self.config
143 }
144}
145
146fn suppress_output() -> (i32, i32) {
148 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
150
151 let stdout_backup = unsafe { libc::dup(1) };
153 let stderr_backup = unsafe { libc::dup(2) };
154
155 unsafe {
157 libc::dup2(dev_null.as_raw_fd(), 1); libc::dup2(dev_null.as_raw_fd(), 2); }
160
161 (stdout_backup, stderr_backup)
162}
163
164fn restore_output(stdout_backup: i32, stderr_backup: i32) {
166 unsafe {
167 libc::dup2(stdout_backup, 1); libc::dup2(stderr_backup, 2); libc::close(stdout_backup);
170 libc::close(stderr_backup);
171 }
172}
173
174pub struct LocalLLMProvider {
175 model: Arc<LlamaModel>,
176}
177
178impl LocalLLMProvider {
179 pub async fn new(config: LocalConfig) -> Result<Self> {
180 let (stdout_backup, stderr_backup) = suppress_output();
182
183 let backend = LlamaBackend::init().map_err(|e| {
185 restore_output(stdout_backup, stderr_backup);
186 HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
187 })?;
188
189 let model_path = Self::download_model(&config).await.map_err(|e| {
191 restore_output(stdout_backup, stderr_backup);
192 e
193 })?;
194
195 let model_params = LlamaModelParams::default().with_n_gpu_layers(99); let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)
199 .map_err(|e| {
200 restore_output(stdout_backup, stderr_backup);
201 HeliosError::LLMError(format!("Failed to load model: {:?}", e))
202 })?;
203
204 restore_output(stdout_backup, stderr_backup);
206
207 Ok(Self {
208 model: Arc::new(model),
209 })
210 }
211
212 async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
213 use std::process::Command;
214
215 if let Some(cached_path) = Self::find_model_in_cache(&config.huggingface_repo, &config.model_file) {
217 return Ok(cached_path);
219 }
220
221 let output = Command::new("huggingface-cli")
225 .args(&[
226 "download",
227 &config.huggingface_repo,
228 &config.model_file,
229 "--local-dir",
230 ".cache/models",
231 "--local-dir-use-symlinks",
232 "False",
233 ])
234 .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .output()
237 .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
238
239 if !output.status.success() {
240 return Err(HeliosError::LLMError(format!(
241 "Failed to download model: {}",
242 String::from_utf8_lossy(&output.stderr)
243 )));
244 }
245
246 let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
247 if !model_path.exists() {
248 return Err(HeliosError::LLMError(format!(
249 "Model file not found after download: {}",
250 model_path.display()
251 )));
252 }
253
254 Ok(model_path)
255 }
256
257 fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
258 let cache_dir = std::env::var("HF_HOME")
260 .map(std::path::PathBuf::from)
261 .unwrap_or_else(|_| {
262 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
263 std::path::PathBuf::from(home).join(".cache").join("huggingface")
264 });
265
266 let hub_dir = cache_dir.join("hub");
267
268 let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
271 let repo_dir = hub_dir.join(&cache_repo_name);
272
273 if !repo_dir.exists() {
274 return None;
275 }
276
277 let snapshots_dir = repo_dir.join("snapshots");
279 if snapshots_dir.exists() {
280 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
281 for entry in entries.flatten() {
282 if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
283 if snapshot_path.exists() {
284 return Some(snapshot_path);
285 }
286 }
287 }
288 }
289 }
290
291 let blobs_dir = repo_dir.join("blobs");
293 if blobs_dir.exists() {
294 }
298
299 None
300 }
301
302 fn format_messages(&self, messages: &[ChatMessage]) -> String {
303 let mut formatted = String::new();
304
305 for message in messages {
307 match message.role {
308 crate::chat::Role::System => {
309 formatted.push_str("<|im_start|>system\n");
310 formatted.push_str(&message.content);
311 formatted.push_str("\n<|im_end|>\n");
312 }
313 crate::chat::Role::User => {
314 formatted.push_str("<|im_start|>user\n");
315 formatted.push_str(&message.content);
316 formatted.push_str("\n<|im_end|>\n");
317 }
318 crate::chat::Role::Assistant => {
319 formatted.push_str("<|im_start|>assistant\n");
320 formatted.push_str(&message.content);
321 formatted.push_str("\n<|im_end|>\n");
322 }
323 crate::chat::Role::Tool => {
324 formatted.push_str("<|im_start|>assistant\n");
326 formatted.push_str(&message.content);
327 formatted.push_str("\n<|im_end|>\n");
328 }
329 }
330 }
331
332 formatted.push_str("<|im_start|>assistant\n");
334
335 formatted
336 }
337}
338
339#[async_trait]
340impl LLMProvider for RemoteLLMClient {
341 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
342 let url = format!("{}/chat/completions", self.config.base_url);
343
344 let response = self
345 .client
346 .post(&url)
347 .header("Authorization", format!("Bearer {}", self.config.api_key))
348 .header("Content-Type", "application/json")
349 .json(&request)
350 .send()
351 .await?;
352
353 if !response.status().is_success() {
354 let status = response.status();
355 let error_text = response
356 .text()
357 .await
358 .unwrap_or_else(|_| "Unknown error".to_string());
359 return Err(HeliosError::LLMError(format!(
360 "LLM API request failed with status {}: {}",
361 status, error_text
362 )));
363 }
364
365 let llm_response: LLMResponse = response.json().await?;
366 Ok(llm_response)
367 }
368}
369
370impl RemoteLLMClient {
371 pub async fn chat(
372 &self,
373 messages: Vec<ChatMessage>,
374 tools: Option<Vec<ToolDefinition>>,
375 ) -> Result<ChatMessage> {
376 let request = LLMRequest {
377 model: self.config.model_name.clone(),
378 messages,
379 temperature: Some(self.config.temperature),
380 max_tokens: Some(self.config.max_tokens),
381 tools,
382 tool_choice: None,
383 stream: None,
384 };
385
386 let response = self.generate(request).await?;
387
388 response
389 .choices
390 .into_iter()
391 .next()
392 .map(|choice| choice.message)
393 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
394 }
395
396 pub async fn chat_stream<F>(
397 &self,
398 messages: Vec<ChatMessage>,
399 tools: Option<Vec<ToolDefinition>>,
400 mut on_chunk: F,
401 ) -> Result<ChatMessage>
402 where
403 F: FnMut(&str) + Send,
404 {
405 let request = LLMRequest {
406 model: self.config.model_name.clone(),
407 messages,
408 temperature: Some(self.config.temperature),
409 max_tokens: Some(self.config.max_tokens),
410 tools,
411 tool_choice: None,
412 stream: Some(true),
413 };
414
415 let url = format!("{}/chat/completions", self.config.base_url);
416
417 let response = self
418 .client
419 .post(&url)
420 .header("Authorization", format!("Bearer {}", self.config.api_key))
421 .header("Content-Type", "application/json")
422 .json(&request)
423 .send()
424 .await?;
425
426 if !response.status().is_success() {
427 let status = response.status();
428 let error_text = response
429 .text()
430 .await
431 .unwrap_or_else(|_| "Unknown error".to_string());
432 return Err(HeliosError::LLMError(format!(
433 "LLM API request failed with status {}: {}",
434 status, error_text
435 )));
436 }
437
438 let mut stream = response.bytes_stream();
439 let mut full_content = String::new();
440 let mut role = None;
441 let mut buffer = String::new();
442
443 while let Some(chunk_result) = stream.next().await {
444 let chunk = chunk_result?;
445 let chunk_str = String::from_utf8_lossy(&chunk);
446 buffer.push_str(&chunk_str);
447
448 while let Some(line_end) = buffer.find('\n') {
450 let line = buffer[..line_end].trim().to_string();
451 buffer = buffer[line_end + 1..].to_string();
452
453 if line.is_empty() || line == "data: [DONE]" {
454 continue;
455 }
456
457 if let Some(data) = line.strip_prefix("data: ") {
458 match serde_json::from_str::<StreamChunk>(data) {
459 Ok(stream_chunk) => {
460 if let Some(choice) = stream_chunk.choices.first() {
461 if let Some(r) = &choice.delta.role {
462 role = Some(r.clone());
463 }
464 if let Some(content) = &choice.delta.content {
465 full_content.push_str(content);
466 on_chunk(content);
467 }
468 }
469 }
470 Err(e) => {
471 tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
472 }
473 }
474 }
475 }
476 }
477
478 Ok(ChatMessage {
479 role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
480 content: full_content,
481 name: None,
482 tool_calls: None,
483 tool_call_id: None,
484 })
485 }
486}
487
488#[async_trait]
489impl LLMProvider for LocalLLMProvider {
490 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
491 let prompt = self.format_messages(&request.messages);
492 let model = Arc::clone(&self.model);
493
494 let (stdout_backup, stderr_backup) = suppress_output();
496
497 let result = task::spawn_blocking(move || {
499 let backend = LlamaBackend::init().map_err(|e| {
501 HeliosError::LLMError(format!("Failed to initialize backend: {:?}", e))
502 })?;
503
504 use std::num::NonZeroU32;
506 let ctx_params =
507 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
508
509 let mut context = model
510 .new_context(&backend, ctx_params)
511 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
512
513 let tokens = context
515 .model
516 .str_to_token(&prompt, AddBos::Always)
517 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
518
519 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
521 for (i, &token) in tokens.iter().enumerate() {
522 let compute_logits = true; prompt_batch
524 .add(token, i as i32, &[0], compute_logits)
525 .map_err(|e| {
526 HeliosError::LLMError(format!(
527 "Failed to add prompt token to batch: {:?}",
528 e
529 ))
530 })?;
531 }
532
533 context
535 .decode(&mut prompt_batch)
536 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
537
538 let mut generated_text = String::new();
540 let max_new_tokens = 128; let mut next_pos = tokens.len() as i32; for _ in 0..max_new_tokens {
544 let logits = context.get_logits();
546
547 let token_idx = logits
548 .iter()
549 .enumerate()
550 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
551 .map(|(idx, _)| idx)
552 .unwrap_or_else(|| {
553 let eos = context.model.token_eos();
554 eos.0 as usize
555 });
556 let token = LlamaToken(token_idx as i32);
557
558 if token == context.model.token_eos() {
560 break;
561 }
562
563 match context.model.token_to_str(token, Special::Plaintext) {
565 Ok(text) => {
566 generated_text.push_str(&text);
567 },
568 Err(_) => continue, }
570
571 let mut gen_batch = LlamaBatch::new(1, 1);
573 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
574 HeliosError::LLMError(format!(
575 "Failed to add generated token to batch: {:?}",
576 e
577 ))
578 })?;
579
580 context.decode(&mut gen_batch).map_err(|e| {
582 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
583 })?;
584
585 next_pos += 1;
586 }
587
588 Ok::<String, HeliosError>(generated_text)
589 })
590 .await
591 .map_err(|e| {
592 restore_output(stdout_backup, stderr_backup);
593 HeliosError::LLMError(format!("Task failed: {}", e))
594 })??;
595
596 restore_output(stdout_backup, stderr_backup);
598
599 let response = LLMResponse {
600 id: format!("local-{}", chrono::Utc::now().timestamp()),
601 object: "chat.completion".to_string(),
602 created: chrono::Utc::now().timestamp() as u64,
603 model: "local-model".to_string(),
604 choices: vec![Choice {
605 index: 0,
606 message: ChatMessage {
607 role: crate::chat::Role::Assistant,
608 content: result,
609 name: None,
610 tool_calls: None,
611 tool_call_id: None,
612 },
613 finish_reason: Some("stop".to_string()),
614 }],
615 usage: Usage {
616 prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, },
620 };
621
622 Ok(response)
623 }
624}
625
626#[async_trait]
627impl LLMProvider for LLMClient {
628 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
629 self.provider.generate(request).await
630 }
631}
632
633impl LLMClient {
634 pub async fn chat(
635 &self,
636 messages: Vec<ChatMessage>,
637 tools: Option<Vec<ToolDefinition>>,
638 ) -> Result<ChatMessage> {
639 let (model_name, temperature, max_tokens) = match &self.provider_type {
640 LLMProviderType::Remote(config) => (
641 config.model_name.clone(),
642 config.temperature,
643 config.max_tokens,
644 ),
645 LLMProviderType::Local(config) => (
646 "local-model".to_string(),
647 config.temperature,
648 config.max_tokens,
649 ),
650 };
651
652 let request = LLMRequest {
653 model: model_name,
654 messages,
655 temperature: Some(temperature),
656 max_tokens: Some(max_tokens),
657 tools,
658 tool_choice: None,
659 stream: None,
660 };
661
662 let response = self.generate(request).await?;
663
664 response
665 .choices
666 .into_iter()
667 .next()
668 .map(|choice| choice.message)
669 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
670 }
671
672 pub async fn chat_stream<F>(
673 &self,
674 messages: Vec<ChatMessage>,
675 tools: Option<Vec<ToolDefinition>>,
676 mut on_chunk: F,
677 ) -> Result<ChatMessage>
678 where
679 F: FnMut(&str) + Send,
680 {
681 match &self.provider_type {
683 LLMProviderType::Remote(config) => {
684 let remote_client = RemoteLLMClient::new(config.clone());
685 remote_client.chat_stream(messages, tools, on_chunk).await
686 }
687 LLMProviderType::Local(_) => {
688 let response = self.chat(messages, tools).await?;
691 on_chunk(&response.content);
692 Ok(response)
693 }
694 }
695 }
696}
697
698