1use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18 crate::config::LocalConfig,
19 llama_cpp_2::{
20 context::params::LlamaContextParams,
21 llama_backend::LlamaBackend,
22 llama_batch::LlamaBatch,
23 model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24 token::LlamaToken,
25 },
26 std::{fs::File, os::fd::AsRawFd, sync::Arc},
27 tokio::task,
28};
29
30#[cfg(feature = "candle")]
31use crate::candle_provider::CandleLLMProvider;
32
33#[cfg(feature = "local")]
35impl From<llama_cpp_2::LLamaCppError> for HeliosError {
36 fn from(err: llama_cpp_2::LLamaCppError) -> Self {
37 HeliosError::LlamaCppError(format!("{:?}", err))
38 }
39}
40
41#[derive(Clone)]
43pub enum LLMProviderType {
44 Remote(LLMConfig),
46 #[cfg(feature = "local")]
48 Local(LocalConfig),
49 #[cfg(feature = "candle")]
51 Candle(crate::config::CandleConfig),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct LLMRequest {
57 pub model: String,
59 pub messages: Vec<ChatMessage>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub temperature: Option<f32>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub max_tokens: Option<u32>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub tools: Option<Vec<ToolDefinition>>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub tool_choice: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub stream: Option<bool>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub stop: Option<Vec<String>>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct StreamChunk {
84 pub id: String,
86 pub object: String,
88 pub created: u64,
90 pub model: String,
92 pub choices: Vec<StreamChoice>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct StreamChoice {
99 pub index: u32,
101 pub delta: Delta,
103 pub finish_reason: Option<String>,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct DeltaToolCall {
110 pub index: u32,
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub id: Option<String>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub function: Option<DeltaFunctionCall>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct DeltaFunctionCall {
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub name: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub arguments: Option<String>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Delta {
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub role: Option<String>,
137 #[serde(skip_serializing_if = "Option::is_none")]
139 pub content: Option<String>,
140 #[serde(skip_serializing_if = "Option::is_none")]
142 pub tool_calls: Option<Vec<DeltaToolCall>>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct LLMResponse {
148 pub id: String,
150 pub object: String,
152 pub created: u64,
154 pub model: String,
156 pub choices: Vec<Choice>,
158 pub usage: Usage,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct Choice {
165 pub index: u32,
167 pub message: ChatMessage,
169 pub finish_reason: Option<String>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct Usage {
176 pub prompt_tokens: u32,
178 pub completion_tokens: u32,
180 pub total_tokens: u32,
182}
183
184#[async_trait]
186pub trait LLMProvider: Send + Sync {
187 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
189 fn as_any(&self) -> &dyn std::any::Any;
191}
192
193pub struct LLMClient {
195 provider: Box<dyn LLMProvider + Send + Sync>,
196 provider_type: LLMProviderType,
197}
198
199impl LLMClient {
200 pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
202 let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
203 LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
204 #[cfg(feature = "local")]
205 LLMProviderType::Local(config) => {
206 Box::new(LocalLLMProvider::new(config.clone()).await?)
207 }
208 #[cfg(feature = "candle")]
209 LLMProviderType::Candle(config) => {
210 Box::new(CandleLLMProvider::new(config.clone()).await?)
211 }
212 };
213
214 Ok(Self {
215 provider,
216 provider_type,
217 })
218 }
219
220 pub fn provider_type(&self) -> &LLMProviderType {
222 &self.provider_type
223 }
224}
225
226pub struct RemoteLLMClient {
228 config: LLMConfig,
229 client: Client,
230}
231
232impl RemoteLLMClient {
233 pub fn new(config: LLMConfig) -> Self {
235 Self {
236 config,
237 client: Client::new(),
238 }
239 }
240
241 pub fn config(&self) -> &LLMConfig {
243 &self.config
244 }
245}
246
247#[cfg(feature = "local")]
249fn suppress_output() -> (i32, i32) {
250 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
252
253 let stdout_backup = unsafe { libc::dup(1) };
255 let stderr_backup = unsafe { libc::dup(2) };
256
257 unsafe {
259 libc::dup2(dev_null.as_raw_fd(), 1); libc::dup2(dev_null.as_raw_fd(), 2); }
262
263 (stdout_backup, stderr_backup)
264}
265
266#[cfg(feature = "local")]
268fn restore_output(stdout_backup: i32, stderr_backup: i32) {
269 unsafe {
270 libc::dup2(stdout_backup, 1); libc::dup2(stderr_backup, 2); libc::close(stdout_backup);
273 libc::close(stderr_backup);
274 }
275}
276
277#[cfg(feature = "local")]
279fn suppress_stderr() -> i32 {
280 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
281 let stderr_backup = unsafe { libc::dup(2) };
282 unsafe {
283 libc::dup2(dev_null.as_raw_fd(), 2);
284 }
285 stderr_backup
286}
287
288#[cfg(feature = "local")]
290fn restore_stderr(stderr_backup: i32) {
291 unsafe {
292 libc::dup2(stderr_backup, 2);
293 libc::close(stderr_backup);
294 }
295}
296
297#[cfg(feature = "local")]
299pub struct LocalLLMProvider {
300 model: Arc<LlamaModel>,
301 backend: Arc<LlamaBackend>,
302}
303
304#[cfg(feature = "local")]
305impl LocalLLMProvider {
306 pub async fn new(config: LocalConfig) -> Result<Self> {
308 let (stdout_backup, stderr_backup) = suppress_output();
310
311 let backend = LlamaBackend::init().map_err(|e| {
313 restore_output(stdout_backup, stderr_backup);
314 HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
315 })?;
316
317 let model_path = Self::download_model(&config).await.map_err(|e| {
319 restore_output(stdout_backup, stderr_backup);
320 e
321 })?;
322
323 let model_params = LlamaModelParams::default().with_n_gpu_layers(99); let model =
327 LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
328 restore_output(stdout_backup, stderr_backup);
329 HeliosError::LLMError(format!("Failed to load model: {:?}", e))
330 })?;
331
332 restore_output(stdout_backup, stderr_backup);
334
335 Ok(Self {
336 model: Arc::new(model),
337 backend: Arc::new(backend),
338 })
339 }
340
341 async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
343 use std::process::Command;
344
345 if let Some(cached_path) =
347 Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
348 {
349 return Ok(cached_path);
351 }
352
353 let output = Command::new("huggingface-cli")
357 .args([
358 "download",
359 &config.huggingface_repo,
360 &config.model_file,
361 "--local-dir",
362 ".cache/models",
363 "--local-dir-use-symlinks",
364 "False",
365 ])
366 .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .output()
369 .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
370
371 if !output.status.success() {
372 return Err(HeliosError::LLMError(format!(
373 "Failed to download model: {}",
374 String::from_utf8_lossy(&output.stderr)
375 )));
376 }
377
378 let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
379 if !model_path.exists() {
380 return Err(HeliosError::LLMError(format!(
381 "Model file not found after download: {}",
382 model_path.display()
383 )));
384 }
385
386 Ok(model_path)
387 }
388
389 fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
391 let cache_dir = std::env::var("HF_HOME")
393 .map(std::path::PathBuf::from)
394 .unwrap_or_else(|_| {
395 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
396 std::path::PathBuf::from(home)
397 .join(".cache")
398 .join("huggingface")
399 });
400
401 let hub_dir = cache_dir.join("hub");
402
403 let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
406 let repo_dir = hub_dir.join(&cache_repo_name);
407
408 if !repo_dir.exists() {
409 return None;
410 }
411
412 let snapshots_dir = repo_dir.join("snapshots");
414 if snapshots_dir.exists() {
415 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
416 for entry in entries.flatten() {
417 if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
418 if snapshot_path.exists() {
419 return Some(snapshot_path);
420 }
421 }
422 }
423 }
424 }
425
426 let blobs_dir = repo_dir.join("blobs");
428 if blobs_dir.exists() {
429 }
433
434 None
435 }
436
437 fn format_messages(&self, messages: &[ChatMessage]) -> String {
439 let mut formatted = String::new();
440
441 for message in messages {
443 match message.role {
444 crate::chat::Role::System => {
445 formatted.push_str("<|im_start|>system\n");
446 formatted.push_str(&message.content);
447 formatted.push_str("\n<|im_end|>\n");
448 }
449 crate::chat::Role::User => {
450 formatted.push_str("<|im_start|>user\n");
451 formatted.push_str(&message.content);
452 formatted.push_str("\n<|im_end|>\n");
453 }
454 crate::chat::Role::Assistant => {
455 formatted.push_str("<|im_start|>assistant\n");
456 formatted.push_str(&message.content);
457 formatted.push_str("\n<|im_end|>\n");
458 }
459 crate::chat::Role::Tool => {
460 formatted.push_str("<|im_start|>assistant\n");
462 formatted.push_str(&message.content);
463 formatted.push_str("\n<|im_end|>\n");
464 }
465 }
466 }
467
468 formatted.push_str("<|im_start|>assistant\n");
470
471 formatted
472 }
473}
474
475#[async_trait]
476impl LLMProvider for RemoteLLMClient {
477 fn as_any(&self) -> &dyn std::any::Any {
478 self
479 }
480
481 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
482 let url = format!("{}/chat/completions", self.config.base_url);
483
484 let mut request_builder = self
485 .client
486 .post(&url)
487 .header("Content-Type", "application/json");
488
489 if !self.config.base_url.contains("10.")
491 && !self.config.base_url.contains("localhost")
492 && !self.config.base_url.contains("127.0.0.1")
493 {
494 request_builder =
495 request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
496 }
497
498 let response = request_builder.json(&request).send().await?;
499
500 if !response.status().is_success() {
501 let status = response.status();
502 let error_text = response
503 .text()
504 .await
505 .unwrap_or_else(|_| "Unknown error".to_string());
506 return Err(HeliosError::LLMError(format!(
507 "LLM API request failed with status {}: {}",
508 status, error_text
509 )));
510 }
511
512 let llm_response: LLMResponse = response.json().await?;
513 Ok(llm_response)
514 }
515}
516
517impl RemoteLLMClient {
518 pub async fn chat(
520 &self,
521 messages: Vec<ChatMessage>,
522 tools: Option<Vec<ToolDefinition>>,
523 temperature: Option<f32>,
524 max_tokens: Option<u32>,
525 stop: Option<Vec<String>>,
526 ) -> Result<ChatMessage> {
527 let request = LLMRequest {
528 model: self.config.model_name.clone(),
529 messages,
530 temperature: temperature.or(Some(self.config.temperature)),
531 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
532 tools: tools.clone(),
533 tool_choice: if tools.is_some() {
534 Some("auto".to_string())
535 } else {
536 None
537 },
538 stream: None,
539 stop,
540 };
541
542 let response = self.generate(request).await?;
543
544 response
545 .choices
546 .into_iter()
547 .next()
548 .map(|choice| choice.message)
549 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
550 }
551
552 pub async fn chat_stream<F>(
554 &self,
555 messages: Vec<ChatMessage>,
556 tools: Option<Vec<ToolDefinition>>,
557 temperature: Option<f32>,
558 max_tokens: Option<u32>,
559 stop: Option<Vec<String>>,
560 mut on_chunk: F,
561 ) -> Result<ChatMessage>
562 where
563 F: FnMut(&str) + Send,
564 {
565 let request = LLMRequest {
566 model: self.config.model_name.clone(),
567 messages,
568 temperature: temperature.or(Some(self.config.temperature)),
569 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
570 tools: tools.clone(),
571 tool_choice: if tools.is_some() {
572 Some("auto".to_string())
573 } else {
574 None
575 },
576 stream: Some(true),
577 stop,
578 };
579
580 let url = format!("{}/chat/completions", self.config.base_url);
581
582 let mut request_builder = self
583 .client
584 .post(&url)
585 .header("Content-Type", "application/json");
586
587 if !self.config.base_url.contains("10.")
589 && !self.config.base_url.contains("localhost")
590 && !self.config.base_url.contains("127.0.0.1")
591 {
592 request_builder =
593 request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
594 }
595
596 let response = request_builder.json(&request).send().await?;
597
598 if !response.status().is_success() {
599 let status = response.status();
600 let error_text = response
601 .text()
602 .await
603 .unwrap_or_else(|_| "Unknown error".to_string());
604 return Err(HeliosError::LLMError(format!(
605 "LLM API request failed with status {}: {}",
606 status, error_text
607 )));
608 }
609
610 let mut stream = response.bytes_stream();
611 let mut full_content = String::new();
612 let mut role = None;
613 let mut tool_calls = Vec::new();
614 let mut buffer = String::new();
615
616 while let Some(chunk_result) = stream.next().await {
617 let chunk = chunk_result?;
618 let chunk_str = String::from_utf8_lossy(&chunk);
619 buffer.push_str(&chunk_str);
620
621 while let Some(line_end) = buffer.find('\n') {
623 let line = buffer[..line_end].trim().to_string();
624 buffer = buffer[line_end + 1..].to_string();
625
626 if line.is_empty() || line == "data: [DONE]" {
627 continue;
628 }
629
630 if let Some(data) = line.strip_prefix("data: ") {
631 match serde_json::from_str::<StreamChunk>(data) {
632 Ok(stream_chunk) => {
633 if let Some(choice) = stream_chunk.choices.first() {
634 if let Some(r) = &choice.delta.role {
635 role = Some(r.clone());
636 }
637 if let Some(content) = &choice.delta.content {
638 full_content.push_str(content);
639 on_chunk(content);
640 }
641 if let Some(delta_tool_calls) = &choice.delta.tool_calls {
642 for delta_tool_call in delta_tool_calls {
643 while tool_calls.len() <= delta_tool_call.index as usize {
645 tool_calls.push(None);
646 }
647 let tool_call_slot =
648 &mut tool_calls[delta_tool_call.index as usize];
649
650 if tool_call_slot.is_none() {
651 *tool_call_slot = Some(crate::chat::ToolCall {
652 id: String::new(),
653 call_type: "function".to_string(),
654 function: crate::chat::FunctionCall {
655 name: String::new(),
656 arguments: String::new(),
657 },
658 });
659 }
660
661 if let Some(tool_call) = tool_call_slot.as_mut() {
662 if let Some(id) = &delta_tool_call.id {
663 tool_call.id = id.clone();
664 }
665 if let Some(function) = &delta_tool_call.function {
666 if let Some(name) = &function.name {
667 tool_call.function.name = name.clone();
668 }
669 if let Some(args) = &function.arguments {
670 tool_call.function.arguments.push_str(args);
671 }
672 }
673 }
674 }
675 }
676 }
677 }
678 Err(e) => {
679 tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
680 }
681 }
682 }
683 }
684 }
685
686 let final_tool_calls = tool_calls.into_iter().flatten().collect::<Vec<_>>();
687 let tool_calls_option = if final_tool_calls.is_empty() {
688 None
689 } else {
690 Some(final_tool_calls)
691 };
692
693 Ok(ChatMessage {
694 role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
695 content: full_content,
696 name: None,
697 tool_calls: tool_calls_option,
698 tool_call_id: None,
699 })
700 }
701}
702
703#[cfg(feature = "local")]
704#[async_trait]
705impl LLMProvider for LocalLLMProvider {
706 fn as_any(&self) -> &dyn std::any::Any {
707 self
708 }
709
710 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
711 let prompt = self.format_messages(&request.messages);
712
713 let (stdout_backup, stderr_backup) = suppress_output();
715
716 let model = Arc::clone(&self.model);
718 let backend = Arc::clone(&self.backend);
719 let result = task::spawn_blocking(move || {
720 use std::num::NonZeroU32;
722 let ctx_params =
723 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
724
725 let mut context = model
726 .new_context(&backend, ctx_params)
727 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
728
729 let tokens = context
731 .model
732 .str_to_token(&prompt, AddBos::Always)
733 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
734
735 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
737 for (i, &token) in tokens.iter().enumerate() {
738 let compute_logits = true; prompt_batch
740 .add(token, i as i32, &[0], compute_logits)
741 .map_err(|e| {
742 HeliosError::LLMError(format!(
743 "Failed to add prompt token to batch: {:?}",
744 e
745 ))
746 })?;
747 }
748
749 context
751 .decode(&mut prompt_batch)
752 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
753
754 let mut generated_text = String::new();
756 let max_new_tokens = 512; let mut next_pos = tokens.len() as i32; for _ in 0..max_new_tokens {
760 let logits = context.get_logits();
762
763 let token_idx = logits
764 .iter()
765 .enumerate()
766 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
767 .map(|(idx, _)| idx)
768 .unwrap_or_else(|| {
769 let eos = context.model.token_eos();
770 eos.0 as usize
771 });
772 let token = LlamaToken(token_idx as i32);
773
774 if token == context.model.token_eos() {
776 break;
777 }
778
779 match context.model.token_to_str(token, Special::Plaintext) {
781 Ok(text) => {
782 generated_text.push_str(&text);
783 }
784 Err(_) => continue, }
786
787 let mut gen_batch = LlamaBatch::new(1, 1);
789 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
790 HeliosError::LLMError(format!(
791 "Failed to add generated token to batch: {:?}",
792 e
793 ))
794 })?;
795
796 context.decode(&mut gen_batch).map_err(|e| {
798 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
799 })?;
800
801 next_pos += 1;
802 }
803
804 Ok::<String, HeliosError>(generated_text)
805 })
806 .await
807 .map_err(|e| {
808 restore_output(stdout_backup, stderr_backup);
809 HeliosError::LLMError(format!("Task failed: {}", e))
810 })??;
811
812 restore_output(stdout_backup, stderr_backup);
814
815 let response = LLMResponse {
816 id: format!("local-{}", chrono::Utc::now().timestamp()),
817 object: "chat.completion".to_string(),
818 created: chrono::Utc::now().timestamp() as u64,
819 model: "local-model".to_string(),
820 choices: vec![Choice {
821 index: 0,
822 message: ChatMessage {
823 role: crate::chat::Role::Assistant,
824 content: result,
825 name: None,
826 tool_calls: None,
827 tool_call_id: None,
828 },
829 finish_reason: Some("stop".to_string()),
830 }],
831 usage: Usage {
832 prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, },
836 };
837
838 Ok(response)
839 }
840}
841
842#[cfg(feature = "local")]
843impl LocalLLMProvider {
844 async fn chat_stream_local<F>(
846 &self,
847 messages: Vec<ChatMessage>,
848 _temperature: Option<f32>,
849 _max_tokens: Option<u32>,
850 _stop: Option<Vec<String>>,
851 mut on_chunk: F,
852 ) -> Result<ChatMessage>
853 where
854 F: FnMut(&str) + Send,
855 {
856 let prompt = self.format_messages(&messages);
857
858 let stderr_backup = suppress_stderr();
860
861 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
863
864 let model = Arc::clone(&self.model);
866 let backend = Arc::clone(&self.backend);
867 let generation_task = task::spawn_blocking(move || {
868 use std::num::NonZeroU32;
870 let ctx_params =
871 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
872
873 let mut context = model
874 .new_context(&backend, ctx_params)
875 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
876
877 let tokens = context
879 .model
880 .str_to_token(&prompt, AddBos::Always)
881 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
882
883 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
885 for (i, &token) in tokens.iter().enumerate() {
886 let compute_logits = true;
887 prompt_batch
888 .add(token, i as i32, &[0], compute_logits)
889 .map_err(|e| {
890 HeliosError::LLMError(format!(
891 "Failed to add prompt token to batch: {:?}",
892 e
893 ))
894 })?;
895 }
896
897 context
899 .decode(&mut prompt_batch)
900 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
901
902 let mut generated_text = String::new();
904 let max_new_tokens = 512;
905 let mut next_pos = tokens.len() as i32;
906
907 for _ in 0..max_new_tokens {
908 let logits = context.get_logits();
909
910 let token_idx = logits
911 .iter()
912 .enumerate()
913 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
914 .map(|(idx, _)| idx)
915 .unwrap_or_else(|| {
916 let eos = context.model.token_eos();
917 eos.0 as usize
918 });
919 let token = LlamaToken(token_idx as i32);
920
921 if token == context.model.token_eos() {
923 break;
924 }
925
926 match context.model.token_to_str(token, Special::Plaintext) {
928 Ok(text) => {
929 generated_text.push_str(&text);
930 if tx.send(text).is_err() {
932 break;
933 }
934 }
935 Err(_) => continue,
936 }
937
938 let mut gen_batch = LlamaBatch::new(1, 1);
940 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
941 HeliosError::LLMError(format!(
942 "Failed to add generated token to batch: {:?}",
943 e
944 ))
945 })?;
946
947 context.decode(&mut gen_batch).map_err(|e| {
949 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
950 })?;
951
952 next_pos += 1;
953 }
954
955 Ok::<String, HeliosError>(generated_text)
956 });
957
958 while let Some(token) = rx.recv().await {
960 on_chunk(&token);
961 }
962
963 let result = match generation_task.await {
965 Ok(Ok(text)) => text,
966 Ok(Err(e)) => {
967 restore_stderr(stderr_backup);
968 return Err(e);
969 }
970 Err(e) => {
971 restore_stderr(stderr_backup);
972 return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
973 }
974 };
975
976 restore_stderr(stderr_backup);
978
979 Ok(ChatMessage {
980 role: crate::chat::Role::Assistant,
981 content: result,
982 name: None,
983 tool_calls: None,
984 tool_call_id: None,
985 })
986 }
987}
988
989#[async_trait]
990impl LLMProvider for LLMClient {
991 fn as_any(&self) -> &dyn std::any::Any {
992 self
993 }
994
995 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
996 self.provider.generate(request).await
997 }
998}
999
1000impl LLMClient {
1001 pub async fn chat(
1003 &self,
1004 messages: Vec<ChatMessage>,
1005 tools: Option<Vec<ToolDefinition>>,
1006 temperature: Option<f32>,
1007 max_tokens: Option<u32>,
1008 stop: Option<Vec<String>>,
1009 ) -> Result<ChatMessage> {
1010 let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
1011 LLMProviderType::Remote(config) => (
1012 config.model_name.clone(),
1013 config.temperature,
1014 config.max_tokens,
1015 ),
1016 #[cfg(feature = "local")]
1017 LLMProviderType::Local(config) => (
1018 "local-model".to_string(),
1019 config.temperature,
1020 config.max_tokens,
1021 ),
1022 #[cfg(feature = "candle")]
1023 LLMProviderType::Candle(config) => (
1024 config.huggingface_repo.clone(),
1025 config.temperature,
1026 config.max_tokens,
1027 ),
1028 };
1029
1030 let request = LLMRequest {
1031 model: model_name,
1032 messages,
1033 temperature: temperature.or(Some(default_temperature)),
1034 max_tokens: max_tokens.or(Some(default_max_tokens)),
1035 tools: tools.clone(),
1036 tool_choice: if tools.is_some() {
1037 Some("auto".to_string())
1038 } else {
1039 None
1040 },
1041 stream: None,
1042 stop,
1043 };
1044
1045 let response = self.generate(request).await?;
1046
1047 response
1048 .choices
1049 .into_iter()
1050 .next()
1051 .map(|choice| choice.message)
1052 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1053 }
1054
1055 pub async fn chat_stream<F>(
1057 &self,
1058 messages: Vec<ChatMessage>,
1059 tools: Option<Vec<ToolDefinition>>,
1060 temperature: Option<f32>,
1061 max_tokens: Option<u32>,
1062 stop: Option<Vec<String>>,
1063 on_chunk: F,
1064 ) -> Result<ChatMessage>
1065 where
1066 F: FnMut(&str) + Send,
1067 {
1068 match &self.provider_type {
1069 LLMProviderType::Remote(_) => {
1070 if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
1071 provider
1072 .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
1073 .await
1074 } else {
1075 Err(HeliosError::AgentError("Provider type mismatch".into()))
1076 }
1077 }
1078 #[cfg(feature = "local")]
1079 LLMProviderType::Local(_) => {
1080 if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
1081 provider
1082 .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
1083 .await
1084 } else {
1085 Err(HeliosError::AgentError("Provider type mismatch".into()))
1086 }
1087 }
1088 #[cfg(feature = "candle")]
1089 LLMProviderType::Candle(config) => {
1090 let (model_name, default_temperature, default_max_tokens) = (
1092 config.huggingface_repo.clone(),
1093 config.temperature,
1094 config.max_tokens,
1095 );
1096
1097 let request = LLMRequest {
1098 model: model_name,
1099 messages,
1100 temperature: temperature.or(Some(default_temperature)),
1101 max_tokens: max_tokens.or(Some(default_max_tokens)),
1102 tools: tools.clone(),
1103 tool_choice: if tools.is_some() {
1104 Some("auto".to_string())
1105 } else {
1106 None
1107 },
1108 stream: None,
1109 stop,
1110 };
1111
1112 let response = self.provider.generate(request).await?;
1113 if let Some(choice) = response.choices.first() {
1114 on_chunk(&choice.message.content);
1115 }
1116 response
1117 .choices
1118 .into_iter()
1119 .next()
1120 .map(|choice| choice.message)
1121 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
1122 }
1123 }
1124 }
1125}
1126
1127