1use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18 crate::config::LocalConfig,
19 llama_cpp_2::{
20 context::params::LlamaContextParams,
21 llama_backend::LlamaBackend,
22 llama_batch::LlamaBatch,
23 model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24 token::LlamaToken,
25 },
26 std::{fs::File, os::fd::AsRawFd, sync::Arc},
27 tokio::task,
28};
29
30#[cfg(feature = "local")]
32impl From<llama_cpp_2::LLamaCppError> for HeliosError {
33 fn from(err: llama_cpp_2::LLamaCppError) -> Self {
34 HeliosError::LlamaCppError(format!("{:?}", err))
35 }
36}
37
38#[derive(Clone)]
40pub enum LLMProviderType {
41 Remote(LLMConfig),
43 #[cfg(feature = "local")]
45 Local(LocalConfig),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LLMRequest {
51 pub model: String,
53 pub messages: Vec<ChatMessage>,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub temperature: Option<f32>,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub max_tokens: Option<u32>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub tools: Option<Vec<ToolDefinition>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub tool_choice: Option<String>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub stream: Option<bool>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub stop: Option<Vec<String>>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct StreamChunk {
78 pub id: String,
80 pub object: String,
82 pub created: u64,
84 pub model: String,
86 pub choices: Vec<StreamChoice>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct StreamChoice {
93 pub index: u32,
95 pub delta: Delta,
97 pub finish_reason: Option<String>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct DeltaToolCall {
104 pub index: u32,
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub id: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub function: Option<DeltaFunctionCall>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct DeltaFunctionCall {
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub name: Option<String>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub arguments: Option<String>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct Delta {
128 #[serde(skip_serializing_if = "Option::is_none")]
130 pub role: Option<String>,
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub content: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub tool_calls: Option<Vec<DeltaToolCall>>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct LLMResponse {
142 pub id: String,
144 pub object: String,
146 pub created: u64,
148 pub model: String,
150 pub choices: Vec<Choice>,
152 pub usage: Usage,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct Choice {
159 pub index: u32,
161 pub message: ChatMessage,
163 pub finish_reason: Option<String>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct Usage {
170 pub prompt_tokens: u32,
172 pub completion_tokens: u32,
174 pub total_tokens: u32,
176}
177
178#[async_trait]
180pub trait LLMProvider: Send + Sync {
181 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
183 fn as_any(&self) -> &dyn std::any::Any;
185}
186
187pub struct LLMClient {
189 provider: Box<dyn LLMProvider + Send + Sync>,
190 provider_type: LLMProviderType,
191}
192
193impl LLMClient {
194 pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
196 let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
197 LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
198 #[cfg(feature = "local")]
199 LLMProviderType::Local(config) => {
200 Box::new(LocalLLMProvider::new(config.clone()).await?)
201 }
202 };
203
204 Ok(Self {
205 provider,
206 provider_type,
207 })
208 }
209
210 pub fn provider_type(&self) -> &LLMProviderType {
212 &self.provider_type
213 }
214}
215
216pub struct RemoteLLMClient {
218 config: LLMConfig,
219 client: Client,
220}
221
222impl RemoteLLMClient {
223 pub fn new(config: LLMConfig) -> Self {
225 Self {
226 config,
227 client: Client::new(),
228 }
229 }
230
231 pub fn config(&self) -> &LLMConfig {
233 &self.config
234 }
235}
236
237#[cfg(feature = "local")]
239fn suppress_output() -> (i32, i32) {
240 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
242
243 let stdout_backup = unsafe { libc::dup(1) };
245 let stderr_backup = unsafe { libc::dup(2) };
246
247 unsafe {
249 libc::dup2(dev_null.as_raw_fd(), 1); libc::dup2(dev_null.as_raw_fd(), 2); }
252
253 (stdout_backup, stderr_backup)
254}
255
256#[cfg(feature = "local")]
258fn restore_output(stdout_backup: i32, stderr_backup: i32) {
259 unsafe {
260 libc::dup2(stdout_backup, 1); libc::dup2(stderr_backup, 2); libc::close(stdout_backup);
263 libc::close(stderr_backup);
264 }
265}
266
267#[cfg(feature = "local")]
269fn suppress_stderr() -> i32 {
270 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
271 let stderr_backup = unsafe { libc::dup(2) };
272 unsafe {
273 libc::dup2(dev_null.as_raw_fd(), 2);
274 }
275 stderr_backup
276}
277
278#[cfg(feature = "local")]
280fn restore_stderr(stderr_backup: i32) {
281 unsafe {
282 libc::dup2(stderr_backup, 2);
283 libc::close(stderr_backup);
284 }
285}
286
287#[cfg(feature = "local")]
289pub struct LocalLLMProvider {
290 model: Arc<LlamaModel>,
291 backend: Arc<LlamaBackend>,
292}
293
294#[cfg(feature = "local")]
295impl LocalLLMProvider {
296 pub async fn new(config: LocalConfig) -> Result<Self> {
298 let (stdout_backup, stderr_backup) = suppress_output();
300
301 let backend = LlamaBackend::init().map_err(|e| {
303 restore_output(stdout_backup, stderr_backup);
304 HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
305 })?;
306
307 let model_path = Self::download_model(&config).await.map_err(|e| {
309 restore_output(stdout_backup, stderr_backup);
310 e
311 })?;
312
313 let model_params = LlamaModelParams::default().with_n_gpu_layers(99); let model =
317 LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
318 restore_output(stdout_backup, stderr_backup);
319 HeliosError::LLMError(format!("Failed to load model: {:?}", e))
320 })?;
321
322 restore_output(stdout_backup, stderr_backup);
324
325 Ok(Self {
326 model: Arc::new(model),
327 backend: Arc::new(backend),
328 })
329 }
330
331 async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
333 use std::process::Command;
334
335 if let Some(cached_path) =
337 Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
338 {
339 return Ok(cached_path);
341 }
342
343 let output = Command::new("huggingface-cli")
347 .args([
348 "download",
349 &config.huggingface_repo,
350 &config.model_file,
351 "--local-dir",
352 ".cache/models",
353 "--local-dir-use-symlinks",
354 "False",
355 ])
356 .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .output()
359 .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
360
361 if !output.status.success() {
362 return Err(HeliosError::LLMError(format!(
363 "Failed to download model: {}",
364 String::from_utf8_lossy(&output.stderr)
365 )));
366 }
367
368 let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
369 if !model_path.exists() {
370 return Err(HeliosError::LLMError(format!(
371 "Model file not found after download: {}",
372 model_path.display()
373 )));
374 }
375
376 Ok(model_path)
377 }
378
379 fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
381 let cache_dir = std::env::var("HF_HOME")
383 .map(std::path::PathBuf::from)
384 .unwrap_or_else(|_| {
385 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
386 std::path::PathBuf::from(home)
387 .join(".cache")
388 .join("huggingface")
389 });
390
391 let hub_dir = cache_dir.join("hub");
392
393 let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
396 let repo_dir = hub_dir.join(&cache_repo_name);
397
398 if !repo_dir.exists() {
399 return None;
400 }
401
402 let snapshots_dir = repo_dir.join("snapshots");
404 if snapshots_dir.exists() {
405 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
406 for entry in entries.flatten() {
407 if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
408 if snapshot_path.exists() {
409 return Some(snapshot_path);
410 }
411 }
412 }
413 }
414 }
415
416 let blobs_dir = repo_dir.join("blobs");
418 if blobs_dir.exists() {
419 }
423
424 None
425 }
426
427 fn format_messages(&self, messages: &[ChatMessage]) -> String {
429 let mut formatted = String::new();
430
431 for message in messages {
433 match message.role {
434 crate::chat::Role::System => {
435 formatted.push_str("<|im_start|>system\n");
436 formatted.push_str(&message.content);
437 formatted.push_str("\n<|im_end|>\n");
438 }
439 crate::chat::Role::User => {
440 formatted.push_str("<|im_start|>user\n");
441 formatted.push_str(&message.content);
442 formatted.push_str("\n<|im_end|>\n");
443 }
444 crate::chat::Role::Assistant => {
445 formatted.push_str("<|im_start|>assistant\n");
446 formatted.push_str(&message.content);
447 formatted.push_str("\n<|im_end|>\n");
448 }
449 crate::chat::Role::Tool => {
450 formatted.push_str("<|im_start|>assistant\n");
452 formatted.push_str(&message.content);
453 formatted.push_str("\n<|im_end|>\n");
454 }
455 }
456 }
457
458 formatted.push_str("<|im_start|>assistant\n");
460
461 formatted
462 }
463}
464
465#[async_trait]
466impl LLMProvider for RemoteLLMClient {
467 fn as_any(&self) -> &dyn std::any::Any {
468 self
469 }
470
471 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
472 let url = format!("{}/chat/completions", self.config.base_url);
473
474 let mut request_builder = self
475 .client
476 .post(&url)
477 .header("Content-Type", "application/json");
478
479 if !self.config.base_url.contains("10.")
481 && !self.config.base_url.contains("localhost")
482 && !self.config.base_url.contains("127.0.0.1")
483 {
484 request_builder =
485 request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
486 }
487
488 let response = request_builder.json(&request).send().await?;
489
490 if !response.status().is_success() {
491 let status = response.status();
492 let error_text = response
493 .text()
494 .await
495 .unwrap_or_else(|_| "Unknown error".to_string());
496 return Err(HeliosError::LLMError(format!(
497 "LLM API request failed with status {}: {}",
498 status, error_text
499 )));
500 }
501
502 let llm_response: LLMResponse = response.json().await?;
503 Ok(llm_response)
504 }
505}
506
507impl RemoteLLMClient {
508 pub async fn chat(
510 &self,
511 messages: Vec<ChatMessage>,
512 tools: Option<Vec<ToolDefinition>>,
513 temperature: Option<f32>,
514 max_tokens: Option<u32>,
515 stop: Option<Vec<String>>,
516 ) -> Result<ChatMessage> {
517 let request = LLMRequest {
518 model: self.config.model_name.clone(),
519 messages,
520 temperature: temperature.or(Some(self.config.temperature)),
521 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
522 tools: tools.clone(),
523 tool_choice: if tools.is_some() {
524 Some("auto".to_string())
525 } else {
526 None
527 },
528 stream: None,
529 stop,
530 };
531
532 let response = self.generate(request).await?;
533
534 response
535 .choices
536 .into_iter()
537 .next()
538 .map(|choice| choice.message)
539 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
540 }
541
542 pub async fn chat_stream<F>(
544 &self,
545 messages: Vec<ChatMessage>,
546 tools: Option<Vec<ToolDefinition>>,
547 temperature: Option<f32>,
548 max_tokens: Option<u32>,
549 stop: Option<Vec<String>>,
550 mut on_chunk: F,
551 ) -> Result<ChatMessage>
552 where
553 F: FnMut(&str) + Send,
554 {
555 let request = LLMRequest {
556 model: self.config.model_name.clone(),
557 messages,
558 temperature: temperature.or(Some(self.config.temperature)),
559 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
560 tools: tools.clone(),
561 tool_choice: if tools.is_some() {
562 Some("auto".to_string())
563 } else {
564 None
565 },
566 stream: Some(true),
567 stop,
568 };
569
570 let url = format!("{}/chat/completions", self.config.base_url);
571
572 let mut request_builder = self
573 .client
574 .post(&url)
575 .header("Content-Type", "application/json");
576
577 if !self.config.base_url.contains("10.")
579 && !self.config.base_url.contains("localhost")
580 && !self.config.base_url.contains("127.0.0.1")
581 {
582 request_builder =
583 request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
584 }
585
586 let response = request_builder.json(&request).send().await?;
587
588 if !response.status().is_success() {
589 let status = response.status();
590 let error_text = response
591 .text()
592 .await
593 .unwrap_or_else(|_| "Unknown error".to_string());
594 return Err(HeliosError::LLMError(format!(
595 "LLM API request failed with status {}: {}",
596 status, error_text
597 )));
598 }
599
600 let mut stream = response.bytes_stream();
601 let mut full_content = String::new();
602 let mut role = None;
603 let mut tool_calls = Vec::new();
604 let mut buffer = String::new();
605
606 while let Some(chunk_result) = stream.next().await {
607 let chunk = chunk_result?;
608 let chunk_str = String::from_utf8_lossy(&chunk);
609 buffer.push_str(&chunk_str);
610
611 while let Some(line_end) = buffer.find('\n') {
613 let line = buffer[..line_end].trim().to_string();
614 buffer = buffer[line_end + 1..].to_string();
615
616 if line.is_empty() || line == "data: [DONE]" {
617 continue;
618 }
619
620 if let Some(data) = line.strip_prefix("data: ") {
621 match serde_json::from_str::<StreamChunk>(data) {
622 Ok(stream_chunk) => {
623 if let Some(choice) = stream_chunk.choices.first() {
624 if let Some(r) = &choice.delta.role {
625 role = Some(r.clone());
626 }
627 if let Some(content) = &choice.delta.content {
628 full_content.push_str(content);
629 on_chunk(content);
630 }
631 if let Some(delta_tool_calls) = &choice.delta.tool_calls {
632 for delta_tool_call in delta_tool_calls {
633 while tool_calls.len() <= delta_tool_call.index as usize {
635 tool_calls.push(None);
636 }
637 let tool_call_slot =
638 &mut tool_calls[delta_tool_call.index as usize];
639
640 if tool_call_slot.is_none() {
641 *tool_call_slot = Some(crate::chat::ToolCall {
642 id: String::new(),
643 call_type: "function".to_string(),
644 function: crate::chat::FunctionCall {
645 name: String::new(),
646 arguments: String::new(),
647 },
648 });
649 }
650
651 if let Some(tool_call) = tool_call_slot.as_mut() {
652 if let Some(id) = &delta_tool_call.id {
653 tool_call.id = id.clone();
654 }
655 if let Some(function) = &delta_tool_call.function {
656 if let Some(name) = &function.name {
657 tool_call.function.name = name.clone();
658 }
659 if let Some(args) = &function.arguments {
660 tool_call.function.arguments.push_str(args);
661 }
662 }
663 }
664 }
665 }
666 }
667 }
668 Err(e) => {
669 tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
670 }
671 }
672 }
673 }
674 }
675
676 let final_tool_calls = tool_calls.into_iter().flatten().collect::<Vec<_>>();
677 let tool_calls_option = if final_tool_calls.is_empty() {
678 None
679 } else {
680 Some(final_tool_calls)
681 };
682
683 Ok(ChatMessage {
684 role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
685 content: full_content,
686 name: None,
687 tool_calls: tool_calls_option,
688 tool_call_id: None,
689 })
690 }
691}
692
693#[cfg(feature = "local")]
694#[async_trait]
695impl LLMProvider for LocalLLMProvider {
696 fn as_any(&self) -> &dyn std::any::Any {
697 self
698 }
699
700 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
701 let prompt = self.format_messages(&request.messages);
702
703 let (stdout_backup, stderr_backup) = suppress_output();
705
706 let model = Arc::clone(&self.model);
708 let backend = Arc::clone(&self.backend);
709 let result = task::spawn_blocking(move || {
710 use std::num::NonZeroU32;
712 let ctx_params =
713 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
714
715 let mut context = model
716 .new_context(&backend, ctx_params)
717 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
718
719 let tokens = context
721 .model
722 .str_to_token(&prompt, AddBos::Always)
723 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
724
725 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
727 for (i, &token) in tokens.iter().enumerate() {
728 let compute_logits = true; prompt_batch
730 .add(token, i as i32, &[0], compute_logits)
731 .map_err(|e| {
732 HeliosError::LLMError(format!(
733 "Failed to add prompt token to batch: {:?}",
734 e
735 ))
736 })?;
737 }
738
739 context
741 .decode(&mut prompt_batch)
742 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
743
744 let mut generated_text = String::new();
746 let max_new_tokens = 512; let mut next_pos = tokens.len() as i32; for _ in 0..max_new_tokens {
750 let logits = context.get_logits();
752
753 let token_idx = logits
754 .iter()
755 .enumerate()
756 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
757 .map(|(idx, _)| idx)
758 .unwrap_or_else(|| {
759 let eos = context.model.token_eos();
760 eos.0 as usize
761 });
762 let token = LlamaToken(token_idx as i32);
763
764 if token == context.model.token_eos() {
766 break;
767 }
768
769 match context.model.token_to_str(token, Special::Plaintext) {
771 Ok(text) => {
772 generated_text.push_str(&text);
773 }
774 Err(_) => continue, }
776
777 let mut gen_batch = LlamaBatch::new(1, 1);
779 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
780 HeliosError::LLMError(format!(
781 "Failed to add generated token to batch: {:?}",
782 e
783 ))
784 })?;
785
786 context.decode(&mut gen_batch).map_err(|e| {
788 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
789 })?;
790
791 next_pos += 1;
792 }
793
794 Ok::<String, HeliosError>(generated_text)
795 })
796 .await
797 .map_err(|e| {
798 restore_output(stdout_backup, stderr_backup);
799 HeliosError::LLMError(format!("Task failed: {}", e))
800 })??;
801
802 restore_output(stdout_backup, stderr_backup);
804
805 let response = LLMResponse {
806 id: format!("local-{}", chrono::Utc::now().timestamp()),
807 object: "chat.completion".to_string(),
808 created: chrono::Utc::now().timestamp() as u64,
809 model: "local-model".to_string(),
810 choices: vec![Choice {
811 index: 0,
812 message: ChatMessage {
813 role: crate::chat::Role::Assistant,
814 content: result,
815 name: None,
816 tool_calls: None,
817 tool_call_id: None,
818 },
819 finish_reason: Some("stop".to_string()),
820 }],
821 usage: Usage {
822 prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, },
826 };
827
828 Ok(response)
829 }
830}
831
832#[cfg(feature = "local")]
833impl LocalLLMProvider {
834 async fn chat_stream_local<F>(
836 &self,
837 messages: Vec<ChatMessage>,
838 _temperature: Option<f32>,
839 _max_tokens: Option<u32>,
840 _stop: Option<Vec<String>>,
841 mut on_chunk: F,
842 ) -> Result<ChatMessage>
843 where
844 F: FnMut(&str) + Send,
845 {
846 let prompt = self.format_messages(&messages);
847
848 let stderr_backup = suppress_stderr();
850
851 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
853
854 let model = Arc::clone(&self.model);
856 let backend = Arc::clone(&self.backend);
857 let generation_task = task::spawn_blocking(move || {
858 use std::num::NonZeroU32;
860 let ctx_params =
861 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
862
863 let mut context = model
864 .new_context(&backend, ctx_params)
865 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
866
867 let tokens = context
869 .model
870 .str_to_token(&prompt, AddBos::Always)
871 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
872
873 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
875 for (i, &token) in tokens.iter().enumerate() {
876 let compute_logits = true;
877 prompt_batch
878 .add(token, i as i32, &[0], compute_logits)
879 .map_err(|e| {
880 HeliosError::LLMError(format!(
881 "Failed to add prompt token to batch: {:?}",
882 e
883 ))
884 })?;
885 }
886
887 context
889 .decode(&mut prompt_batch)
890 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
891
892 let mut generated_text = String::new();
894 let max_new_tokens = 512;
895 let mut next_pos = tokens.len() as i32;
896
897 for _ in 0..max_new_tokens {
898 let logits = context.get_logits();
899
900 let token_idx = logits
901 .iter()
902 .enumerate()
903 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
904 .map(|(idx, _)| idx)
905 .unwrap_or_else(|| {
906 let eos = context.model.token_eos();
907 eos.0 as usize
908 });
909 let token = LlamaToken(token_idx as i32);
910
911 if token == context.model.token_eos() {
913 break;
914 }
915
916 match context.model.token_to_str(token, Special::Plaintext) {
918 Ok(text) => {
919 generated_text.push_str(&text);
920 if tx.send(text).is_err() {
922 break;
923 }
924 }
925 Err(_) => continue,
926 }
927
928 let mut gen_batch = LlamaBatch::new(1, 1);
930 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
931 HeliosError::LLMError(format!(
932 "Failed to add generated token to batch: {:?}",
933 e
934 ))
935 })?;
936
937 context.decode(&mut gen_batch).map_err(|e| {
939 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
940 })?;
941
942 next_pos += 1;
943 }
944
945 Ok::<String, HeliosError>(generated_text)
946 });
947
948 while let Some(token) = rx.recv().await {
950 on_chunk(&token);
951 }
952
953 let result = match generation_task.await {
955 Ok(Ok(text)) => text,
956 Ok(Err(e)) => {
957 restore_stderr(stderr_backup);
958 return Err(e);
959 }
960 Err(e) => {
961 restore_stderr(stderr_backup);
962 return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
963 }
964 };
965
966 restore_stderr(stderr_backup);
968
969 Ok(ChatMessage {
970 role: crate::chat::Role::Assistant,
971 content: result,
972 name: None,
973 tool_calls: None,
974 tool_call_id: None,
975 })
976 }
977}
978
979#[async_trait]
980impl LLMProvider for LLMClient {
981 fn as_any(&self) -> &dyn std::any::Any {
982 self
983 }
984
985 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
986 self.provider.generate(request).await
987 }
988}
989
990impl LLMClient {
991 pub async fn chat(
993 &self,
994 messages: Vec<ChatMessage>,
995 tools: Option<Vec<ToolDefinition>>,
996 temperature: Option<f32>,
997 max_tokens: Option<u32>,
998 stop: Option<Vec<String>>,
999 ) -> Result<ChatMessage> {
1000 let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
1001 LLMProviderType::Remote(config) => (
1002 config.model_name.clone(),
1003 config.temperature,
1004 config.max_tokens,
1005 ),
1006 #[cfg(feature = "local")]
1007 LLMProviderType::Local(config) => (
1008 "local-model".to_string(),
1009 config.temperature,
1010 config.max_tokens,
1011 ),
1012 };
1013
1014 let request = LLMRequest {
1015 model: model_name,
1016 messages,
1017 temperature: temperature.or(Some(default_temperature)),
1018 max_tokens: max_tokens.or(Some(default_max_tokens)),
1019 tools: tools.clone(),
1020 tool_choice: if tools.is_some() {
1021 Some("auto".to_string())
1022 } else {
1023 None
1024 },
1025 stream: None,
1026 stop,
1027 };
1028
1029 let response = self.generate(request).await?;
1030
1031 response
1032 .choices
1033 .into_iter()
1034 .next()
1035 .map(|choice| choice.message)
1036 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1037 }
1038
1039 pub async fn chat_stream<F>(
1041 &self,
1042 messages: Vec<ChatMessage>,
1043 tools: Option<Vec<ToolDefinition>>,
1044 temperature: Option<f32>,
1045 max_tokens: Option<u32>,
1046 stop: Option<Vec<String>>,
1047 on_chunk: F,
1048 ) -> Result<ChatMessage>
1049 where
1050 F: FnMut(&str) + Send,
1051 {
1052 match &self.provider_type {
1053 LLMProviderType::Remote(_) => {
1054 if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
1055 provider
1056 .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
1057 .await
1058 } else {
1059 Err(HeliosError::AgentError("Provider type mismatch".into()))
1060 }
1061 }
1062 #[cfg(feature = "local")]
1063 LLMProviderType::Local(_) => {
1064 if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
1065 provider
1066 .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
1067 .await
1068 } else {
1069 Err(HeliosError::AgentError("Provider type mismatch".into()))
1070 }
1071 }
1072 }
1073 }
1074}
1075
1076