1use crate::chat::ChatMessage;
8use crate::config::{LLMConfig, LocalConfig};
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use llama_cpp_2::context::params::LlamaContextParams;
14use llama_cpp_2::llama_backend::LlamaBackend;
15use llama_cpp_2::llama_batch::LlamaBatch;
16use llama_cpp_2::model::params::LlamaModelParams;
17use llama_cpp_2::model::{AddBos, LlamaModel, Special};
18use llama_cpp_2::token::LlamaToken;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::os::fd::AsRawFd;
23use std::sync::Arc;
24use tokio::task;
25
26impl From<llama_cpp_2::LLamaCppError> for HeliosError {
28 fn from(err: llama_cpp_2::LLamaCppError) -> Self {
29 HeliosError::LlamaCppError(format!("{:?}", err))
30 }
31}
32
33#[derive(Clone)]
35pub enum LLMProviderType {
36 Remote(LLMConfig),
38 Local(LocalConfig),
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct LLMRequest {
45 pub model: String,
47 pub messages: Vec<ChatMessage>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub temperature: Option<f32>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub max_tokens: Option<u32>,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub tools: Option<Vec<ToolDefinition>>,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub tool_choice: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub stream: Option<bool>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub stop: Option<Vec<String>>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct StreamChunk {
72 pub id: String,
74 pub object: String,
76 pub created: u64,
78 pub model: String,
80 pub choices: Vec<StreamChoice>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct StreamChoice {
87 pub index: u32,
89 pub delta: Delta,
91 pub finish_reason: Option<String>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct Delta {
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub role: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub content: Option<String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct LLMResponse {
109 pub id: String,
111 pub object: String,
113 pub created: u64,
115 pub model: String,
117 pub choices: Vec<Choice>,
119 pub usage: Usage,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct Choice {
126 pub index: u32,
128 pub message: ChatMessage,
130 pub finish_reason: Option<String>,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct Usage {
137 pub prompt_tokens: u32,
139 pub completion_tokens: u32,
141 pub total_tokens: u32,
143}
144
145#[async_trait]
147pub trait LLMProvider: Send + Sync {
148 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
150 fn as_any(&self) -> &dyn std::any::Any;
152}
153
154pub struct LLMClient {
156 provider: Box<dyn LLMProvider + Send + Sync>,
157 provider_type: LLMProviderType,
158}
159
160impl LLMClient {
161 pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
163 let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
164 LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
165 LLMProviderType::Local(config) => {
166 Box::new(LocalLLMProvider::new(config.clone()).await?)
167 }
168 };
169
170 Ok(Self {
171 provider,
172 provider_type,
173 })
174 }
175
176 pub fn provider_type(&self) -> &LLMProviderType {
178 &self.provider_type
179 }
180}
181
182pub struct RemoteLLMClient {
184 config: LLMConfig,
185 client: Client,
186}
187
188impl RemoteLLMClient {
189 pub fn new(config: LLMConfig) -> Self {
191 Self {
192 config,
193 client: Client::new(),
194 }
195 }
196
197 pub fn config(&self) -> &LLMConfig {
199 &self.config
200 }
201}
202
203fn suppress_output() -> (i32, i32) {
205 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
207
208 let stdout_backup = unsafe { libc::dup(1) };
210 let stderr_backup = unsafe { libc::dup(2) };
211
212 unsafe {
214 libc::dup2(dev_null.as_raw_fd(), 1); libc::dup2(dev_null.as_raw_fd(), 2); }
217
218 (stdout_backup, stderr_backup)
219}
220
221fn restore_output(stdout_backup: i32, stderr_backup: i32) {
223 unsafe {
224 libc::dup2(stdout_backup, 1); libc::dup2(stderr_backup, 2); libc::close(stdout_backup);
227 libc::close(stderr_backup);
228 }
229}
230
231fn suppress_stderr() -> i32 {
233 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
234 let stderr_backup = unsafe { libc::dup(2) };
235 unsafe {
236 libc::dup2(dev_null.as_raw_fd(), 2);
237 }
238 stderr_backup
239}
240
241fn restore_stderr(stderr_backup: i32) {
243 unsafe {
244 libc::dup2(stderr_backup, 2);
245 libc::close(stderr_backup);
246 }
247}
248
249pub struct LocalLLMProvider {
251 model: Arc<LlamaModel>,
252 backend: Arc<LlamaBackend>,
253}
254
255impl LocalLLMProvider {
256 pub async fn new(config: LocalConfig) -> Result<Self> {
258 let (stdout_backup, stderr_backup) = suppress_output();
260
261 let backend = LlamaBackend::init().map_err(|e| {
263 restore_output(stdout_backup, stderr_backup);
264 HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
265 })?;
266
267 let model_path = Self::download_model(&config).await.map_err(|e| {
269 restore_output(stdout_backup, stderr_backup);
270 e
271 })?;
272
273 let model_params = LlamaModelParams::default().with_n_gpu_layers(99); let model =
277 LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
278 restore_output(stdout_backup, stderr_backup);
279 HeliosError::LLMError(format!("Failed to load model: {:?}", e))
280 })?;
281
282 restore_output(stdout_backup, stderr_backup);
284
285 Ok(Self {
286 model: Arc::new(model),
287 backend: Arc::new(backend),
288 })
289 }
290
291 async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
293 use std::process::Command;
294
295 if let Some(cached_path) =
297 Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
298 {
299 return Ok(cached_path);
301 }
302
303 let output = Command::new("huggingface-cli")
307 .args([
308 "download",
309 &config.huggingface_repo,
310 &config.model_file,
311 "--local-dir",
312 ".cache/models",
313 "--local-dir-use-symlinks",
314 "False",
315 ])
316 .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .output()
319 .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
320
321 if !output.status.success() {
322 return Err(HeliosError::LLMError(format!(
323 "Failed to download model: {}",
324 String::from_utf8_lossy(&output.stderr)
325 )));
326 }
327
328 let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
329 if !model_path.exists() {
330 return Err(HeliosError::LLMError(format!(
331 "Model file not found after download: {}",
332 model_path.display()
333 )));
334 }
335
336 Ok(model_path)
337 }
338
339 fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
341 let cache_dir = std::env::var("HF_HOME")
343 .map(std::path::PathBuf::from)
344 .unwrap_or_else(|_| {
345 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
346 std::path::PathBuf::from(home)
347 .join(".cache")
348 .join("huggingface")
349 });
350
351 let hub_dir = cache_dir.join("hub");
352
353 let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
356 let repo_dir = hub_dir.join(&cache_repo_name);
357
358 if !repo_dir.exists() {
359 return None;
360 }
361
362 let snapshots_dir = repo_dir.join("snapshots");
364 if snapshots_dir.exists() {
365 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
366 for entry in entries.flatten() {
367 if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
368 if snapshot_path.exists() {
369 return Some(snapshot_path);
370 }
371 }
372 }
373 }
374 }
375
376 let blobs_dir = repo_dir.join("blobs");
378 if blobs_dir.exists() {
379 }
383
384 None
385 }
386
387 fn format_messages(&self, messages: &[ChatMessage]) -> String {
389 let mut formatted = String::new();
390
391 for message in messages {
393 match message.role {
394 crate::chat::Role::System => {
395 formatted.push_str("<|im_start|>system\n");
396 formatted.push_str(&message.content);
397 formatted.push_str("\n<|im_end|>\n");
398 }
399 crate::chat::Role::User => {
400 formatted.push_str("<|im_start|>user\n");
401 formatted.push_str(&message.content);
402 formatted.push_str("\n<|im_end|>\n");
403 }
404 crate::chat::Role::Assistant => {
405 formatted.push_str("<|im_start|>assistant\n");
406 formatted.push_str(&message.content);
407 formatted.push_str("\n<|im_end|>\n");
408 }
409 crate::chat::Role::Tool => {
410 formatted.push_str("<|im_start|>assistant\n");
412 formatted.push_str(&message.content);
413 formatted.push_str("\n<|im_end|>\n");
414 }
415 }
416 }
417
418 formatted.push_str("<|im_start|>assistant\n");
420
421 formatted
422 }
423}
424
425#[async_trait]
426impl LLMProvider for RemoteLLMClient {
427 fn as_any(&self) -> &dyn std::any::Any {
428 self
429 }
430
431 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
432 let url = format!("{}/chat/completions", self.config.base_url);
433
434 let response = self
435 .client
436 .post(&url)
437 .header("Authorization", format!("Bearer {}", self.config.api_key))
438 .header("Content-Type", "application/json")
439 .json(&request)
440 .send()
441 .await?;
442
443 if !response.status().is_success() {
444 let status = response.status();
445 let error_text = response
446 .text()
447 .await
448 .unwrap_or_else(|_| "Unknown error".to_string());
449 return Err(HeliosError::LLMError(format!(
450 "LLM API request failed with status {}: {}",
451 status, error_text
452 )));
453 }
454
455 let llm_response: LLMResponse = response.json().await?;
456 Ok(llm_response)
457 }
458}
459
460impl RemoteLLMClient {
461 pub async fn chat(
463 &self,
464 messages: Vec<ChatMessage>,
465 tools: Option<Vec<ToolDefinition>>,
466 temperature: Option<f32>,
467 max_tokens: Option<u32>,
468 stop: Option<Vec<String>>,
469 ) -> Result<ChatMessage> {
470 let request = LLMRequest {
471 model: self.config.model_name.clone(),
472 messages,
473 temperature: temperature.or(Some(self.config.temperature)),
474 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
475 tools,
476 tool_choice: None,
477 stream: None,
478 stop,
479 };
480
481 let response = self.generate(request).await?;
482
483 response
484 .choices
485 .into_iter()
486 .next()
487 .map(|choice| choice.message)
488 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
489 }
490
491 pub async fn chat_stream<F>(
493 &self,
494 messages: Vec<ChatMessage>,
495 tools: Option<Vec<ToolDefinition>>,
496 temperature: Option<f32>,
497 max_tokens: Option<u32>,
498 stop: Option<Vec<String>>,
499 mut on_chunk: F,
500 ) -> Result<ChatMessage>
501 where
502 F: FnMut(&str) + Send,
503 {
504 let request = LLMRequest {
505 model: self.config.model_name.clone(),
506 messages,
507 temperature: temperature.or(Some(self.config.temperature)),
508 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
509 tools,
510 tool_choice: None,
511 stream: Some(true),
512 stop,
513 };
514
515 let url = format!("{}/chat/completions", self.config.base_url);
516
517 let response = self
518 .client
519 .post(&url)
520 .header("Authorization", format!("Bearer {}", self.config.api_key))
521 .header("Content-Type", "application/json")
522 .json(&request)
523 .send()
524 .await?;
525
526 if !response.status().is_success() {
527 let status = response.status();
528 let error_text = response
529 .text()
530 .await
531 .unwrap_or_else(|_| "Unknown error".to_string());
532 return Err(HeliosError::LLMError(format!(
533 "LLM API request failed with status {}: {}",
534 status, error_text
535 )));
536 }
537
538 let mut stream = response.bytes_stream();
539 let mut full_content = String::new();
540 let mut role = None;
541 let mut buffer = String::new();
542
543 while let Some(chunk_result) = stream.next().await {
544 let chunk = chunk_result?;
545 let chunk_str = String::from_utf8_lossy(&chunk);
546 buffer.push_str(&chunk_str);
547
548 while let Some(line_end) = buffer.find('\n') {
550 let line = buffer[..line_end].trim().to_string();
551 buffer = buffer[line_end + 1..].to_string();
552
553 if line.is_empty() || line == "data: [DONE]" {
554 continue;
555 }
556
557 if let Some(data) = line.strip_prefix("data: ") {
558 match serde_json::from_str::<StreamChunk>(data) {
559 Ok(stream_chunk) => {
560 if let Some(choice) = stream_chunk.choices.first() {
561 if let Some(r) = &choice.delta.role {
562 role = Some(r.clone());
563 }
564 if let Some(content) = &choice.delta.content {
565 full_content.push_str(content);
566 on_chunk(content);
567 }
568 }
569 }
570 Err(e) => {
571 tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
572 }
573 }
574 }
575 }
576 }
577
578 Ok(ChatMessage {
579 role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
580 content: full_content,
581 name: None,
582 tool_calls: None,
583 tool_call_id: None,
584 })
585 }
586}
587
588#[async_trait]
589impl LLMProvider for LocalLLMProvider {
590 fn as_any(&self) -> &dyn std::any::Any {
591 self
592 }
593
594 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
595 let prompt = self.format_messages(&request.messages);
596
597 let (stdout_backup, stderr_backup) = suppress_output();
599
600 let model = Arc::clone(&self.model);
602 let backend = Arc::clone(&self.backend);
603 let result = task::spawn_blocking(move || {
604 use std::num::NonZeroU32;
606 let ctx_params =
607 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
608
609 let mut context = model
610 .new_context(&backend, ctx_params)
611 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
612
613 let tokens = context
615 .model
616 .str_to_token(&prompt, AddBos::Always)
617 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
618
619 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
621 for (i, &token) in tokens.iter().enumerate() {
622 let compute_logits = true; prompt_batch
624 .add(token, i as i32, &[0], compute_logits)
625 .map_err(|e| {
626 HeliosError::LLMError(format!(
627 "Failed to add prompt token to batch: {:?}",
628 e
629 ))
630 })?;
631 }
632
633 context
635 .decode(&mut prompt_batch)
636 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
637
638 let mut generated_text = String::new();
640 let max_new_tokens = 512; let mut next_pos = tokens.len() as i32; for _ in 0..max_new_tokens {
644 let logits = context.get_logits();
646
647 let token_idx = logits
648 .iter()
649 .enumerate()
650 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
651 .map(|(idx, _)| idx)
652 .unwrap_or_else(|| {
653 let eos = context.model.token_eos();
654 eos.0 as usize
655 });
656 let token = LlamaToken(token_idx as i32);
657
658 if token == context.model.token_eos() {
660 break;
661 }
662
663 match context.model.token_to_str(token, Special::Plaintext) {
665 Ok(text) => {
666 generated_text.push_str(&text);
667 }
668 Err(_) => continue, }
670
671 let mut gen_batch = LlamaBatch::new(1, 1);
673 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
674 HeliosError::LLMError(format!(
675 "Failed to add generated token to batch: {:?}",
676 e
677 ))
678 })?;
679
680 context.decode(&mut gen_batch).map_err(|e| {
682 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
683 })?;
684
685 next_pos += 1;
686 }
687
688 Ok::<String, HeliosError>(generated_text)
689 })
690 .await
691 .map_err(|e| {
692 restore_output(stdout_backup, stderr_backup);
693 HeliosError::LLMError(format!("Task failed: {}", e))
694 })??;
695
696 restore_output(stdout_backup, stderr_backup);
698
699 let response = LLMResponse {
700 id: format!("local-{}", chrono::Utc::now().timestamp()),
701 object: "chat.completion".to_string(),
702 created: chrono::Utc::now().timestamp() as u64,
703 model: "local-model".to_string(),
704 choices: vec![Choice {
705 index: 0,
706 message: ChatMessage {
707 role: crate::chat::Role::Assistant,
708 content: result,
709 name: None,
710 tool_calls: None,
711 tool_call_id: None,
712 },
713 finish_reason: Some("stop".to_string()),
714 }],
715 usage: Usage {
716 prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, },
720 };
721
722 Ok(response)
723 }
724}
725
726impl LocalLLMProvider {
727 async fn chat_stream_local<F>(
729 &self,
730 messages: Vec<ChatMessage>,
731 _temperature: Option<f32>,
732 _max_tokens: Option<u32>,
733 _stop: Option<Vec<String>>,
734 mut on_chunk: F,
735 ) -> Result<ChatMessage>
736 where
737 F: FnMut(&str) + Send,
738 {
739 let prompt = self.format_messages(&messages);
740
741 let stderr_backup = suppress_stderr();
743
744 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
746
747 let model = Arc::clone(&self.model);
749 let backend = Arc::clone(&self.backend);
750 let generation_task = task::spawn_blocking(move || {
751 use std::num::NonZeroU32;
753 let ctx_params =
754 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
755
756 let mut context = model
757 .new_context(&backend, ctx_params)
758 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
759
760 let tokens = context
762 .model
763 .str_to_token(&prompt, AddBos::Always)
764 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
765
766 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
768 for (i, &token) in tokens.iter().enumerate() {
769 let compute_logits = true;
770 prompt_batch
771 .add(token, i as i32, &[0], compute_logits)
772 .map_err(|e| {
773 HeliosError::LLMError(format!(
774 "Failed to add prompt token to batch: {:?}",
775 e
776 ))
777 })?;
778 }
779
780 context
782 .decode(&mut prompt_batch)
783 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
784
785 let mut generated_text = String::new();
787 let max_new_tokens = 512;
788 let mut next_pos = tokens.len() as i32;
789
790 for _ in 0..max_new_tokens {
791 let logits = context.get_logits();
792
793 let token_idx = logits
794 .iter()
795 .enumerate()
796 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
797 .map(|(idx, _)| idx)
798 .unwrap_or_else(|| {
799 let eos = context.model.token_eos();
800 eos.0 as usize
801 });
802 let token = LlamaToken(token_idx as i32);
803
804 if token == context.model.token_eos() {
806 break;
807 }
808
809 match context.model.token_to_str(token, Special::Plaintext) {
811 Ok(text) => {
812 generated_text.push_str(&text);
813 if tx.send(text).is_err() {
815 break;
816 }
817 }
818 Err(_) => continue,
819 }
820
821 let mut gen_batch = LlamaBatch::new(1, 1);
823 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
824 HeliosError::LLMError(format!(
825 "Failed to add generated token to batch: {:?}",
826 e
827 ))
828 })?;
829
830 context.decode(&mut gen_batch).map_err(|e| {
832 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
833 })?;
834
835 next_pos += 1;
836 }
837
838 Ok::<String, HeliosError>(generated_text)
839 });
840
841 while let Some(token) = rx.recv().await {
843 on_chunk(&token);
844 }
845
846 let result = match generation_task.await {
848 Ok(Ok(text)) => text,
849 Ok(Err(e)) => {
850 restore_stderr(stderr_backup);
851 return Err(e);
852 }
853 Err(e) => {
854 restore_stderr(stderr_backup);
855 return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
856 }
857 };
858
859 restore_stderr(stderr_backup);
861
862 Ok(ChatMessage {
863 role: crate::chat::Role::Assistant,
864 content: result,
865 name: None,
866 tool_calls: None,
867 tool_call_id: None,
868 })
869 }
870}
871
872#[async_trait]
873impl LLMProvider for LLMClient {
874 fn as_any(&self) -> &dyn std::any::Any {
875 self
876 }
877
878 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
879 self.provider.generate(request).await
880 }
881}
882
883impl LLMClient {
884 pub async fn chat(
886 &self,
887 messages: Vec<ChatMessage>,
888 tools: Option<Vec<ToolDefinition>>,
889 temperature: Option<f32>,
890 max_tokens: Option<u32>,
891 stop: Option<Vec<String>>,
892 ) -> Result<ChatMessage> {
893 let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
894 LLMProviderType::Remote(config) => (
895 config.model_name.clone(),
896 config.temperature,
897 config.max_tokens,
898 ),
899 LLMProviderType::Local(config) => (
900 "local-model".to_string(),
901 config.temperature,
902 config.max_tokens,
903 ),
904 };
905
906 let request = LLMRequest {
907 model: model_name,
908 messages,
909 temperature: temperature.or(Some(default_temperature)),
910 max_tokens: max_tokens.or(Some(default_max_tokens)),
911 tools,
912 tool_choice: None,
913 stream: None,
914 stop,
915 };
916
917 let response = self.generate(request).await?;
918
919 response
920 .choices
921 .into_iter()
922 .next()
923 .map(|choice| choice.message)
924 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
925 }
926
927 pub async fn chat_stream<F>(
929 &self,
930 messages: Vec<ChatMessage>,
931 tools: Option<Vec<ToolDefinition>>,
932 temperature: Option<f32>,
933 max_tokens: Option<u32>,
934 stop: Option<Vec<String>>,
935 on_chunk: F,
936 ) -> Result<ChatMessage>
937 where
938 F: FnMut(&str) + Send,
939 {
940 match &self.provider_type {
941 LLMProviderType::Remote(_) => {
942 if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
943 provider
944 .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
945 .await
946 } else {
947 Err(HeliosError::AgentError("Provider type mismatch".into()))
948 }
949 }
950 LLMProviderType::Local(_) => {
951 if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
952 provider
953 .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
954 .await
955 } else {
956 Err(HeliosError::AgentError("Provider type mismatch".into()))
957 }
958 }
959 }
960 }
961}
962
963