1use crate::chat::ChatMessage;
8use crate::config::LLMConfig;
9use crate::error::{HeliosError, Result};
10use crate::tools::ToolDefinition;
11use async_trait::async_trait;
12use futures::stream::StreamExt;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "local")]
17use {
18 crate::config::LocalConfig,
19 llama_cpp_2::{
20 context::params::LlamaContextParams,
21 llama_backend::LlamaBackend,
22 llama_batch::LlamaBatch,
23 model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
24 token::LlamaToken,
25 },
26 std::{fs::File, os::fd::AsRawFd, sync::Arc},
27 tokio::task,
28};
29
30#[cfg(feature = "local")]
32impl From<llama_cpp_2::LLamaCppError> for HeliosError {
33 fn from(err: llama_cpp_2::LLamaCppError) -> Self {
34 HeliosError::LlamaCppError(format!("{:?}", err))
35 }
36}
37
38#[derive(Clone)]
40pub enum LLMProviderType {
41 Remote(LLMConfig),
43 #[cfg(feature = "local")]
45 Local(LocalConfig),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LLMRequest {
51 pub model: String,
53 pub messages: Vec<ChatMessage>,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub temperature: Option<f32>,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub max_tokens: Option<u32>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub tools: Option<Vec<ToolDefinition>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub tool_choice: Option<String>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub stream: Option<bool>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub stop: Option<Vec<String>>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct StreamChunk {
78 pub id: String,
80 pub object: String,
82 pub created: u64,
84 pub model: String,
86 pub choices: Vec<StreamChoice>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct StreamChoice {
93 pub index: u32,
95 pub delta: Delta,
97 pub finish_reason: Option<String>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct Delta {
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub role: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub content: Option<String>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct LLMResponse {
115 pub id: String,
117 pub object: String,
119 pub created: u64,
121 pub model: String,
123 pub choices: Vec<Choice>,
125 pub usage: Usage,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct Choice {
132 pub index: u32,
134 pub message: ChatMessage,
136 pub finish_reason: Option<String>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct Usage {
143 pub prompt_tokens: u32,
145 pub completion_tokens: u32,
147 pub total_tokens: u32,
149}
150
151#[async_trait]
153pub trait LLMProvider: Send + Sync {
154 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
156 fn as_any(&self) -> &dyn std::any::Any;
158}
159
160pub struct LLMClient {
162 provider: Box<dyn LLMProvider + Send + Sync>,
163 provider_type: LLMProviderType,
164}
165
166impl LLMClient {
167 pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
169 let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
170 LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
171 #[cfg(feature = "local")]
172 LLMProviderType::Local(config) => {
173 Box::new(LocalLLMProvider::new(config.clone()).await?)
174 }
175 };
176
177 Ok(Self {
178 provider,
179 provider_type,
180 })
181 }
182
183 pub fn provider_type(&self) -> &LLMProviderType {
185 &self.provider_type
186 }
187}
188
189pub struct RemoteLLMClient {
191 config: LLMConfig,
192 client: Client,
193}
194
195impl RemoteLLMClient {
196 pub fn new(config: LLMConfig) -> Self {
198 Self {
199 config,
200 client: Client::new(),
201 }
202 }
203
204 pub fn config(&self) -> &LLMConfig {
206 &self.config
207 }
208}
209
210#[cfg(feature = "local")]
212fn suppress_output() -> (i32, i32) {
213 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
215
216 let stdout_backup = unsafe { libc::dup(1) };
218 let stderr_backup = unsafe { libc::dup(2) };
219
220 unsafe {
222 libc::dup2(dev_null.as_raw_fd(), 1); libc::dup2(dev_null.as_raw_fd(), 2); }
225
226 (stdout_backup, stderr_backup)
227}
228
229#[cfg(feature = "local")]
231fn restore_output(stdout_backup: i32, stderr_backup: i32) {
232 unsafe {
233 libc::dup2(stdout_backup, 1); libc::dup2(stderr_backup, 2); libc::close(stdout_backup);
236 libc::close(stderr_backup);
237 }
238}
239
240#[cfg(feature = "local")]
242fn suppress_stderr() -> i32 {
243 let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
244 let stderr_backup = unsafe { libc::dup(2) };
245 unsafe {
246 libc::dup2(dev_null.as_raw_fd(), 2);
247 }
248 stderr_backup
249}
250
251#[cfg(feature = "local")]
253fn restore_stderr(stderr_backup: i32) {
254 unsafe {
255 libc::dup2(stderr_backup, 2);
256 libc::close(stderr_backup);
257 }
258}
259
260#[cfg(feature = "local")]
262pub struct LocalLLMProvider {
263 model: Arc<LlamaModel>,
264 backend: Arc<LlamaBackend>,
265}
266
267#[cfg(feature = "local")]
268impl LocalLLMProvider {
269 pub async fn new(config: LocalConfig) -> Result<Self> {
271 let (stdout_backup, stderr_backup) = suppress_output();
273
274 let backend = LlamaBackend::init().map_err(|e| {
276 restore_output(stdout_backup, stderr_backup);
277 HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
278 })?;
279
280 let model_path = Self::download_model(&config).await.map_err(|e| {
282 restore_output(stdout_backup, stderr_backup);
283 e
284 })?;
285
286 let model_params = LlamaModelParams::default().with_n_gpu_layers(99); let model =
290 LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
291 restore_output(stdout_backup, stderr_backup);
292 HeliosError::LLMError(format!("Failed to load model: {:?}", e))
293 })?;
294
295 restore_output(stdout_backup, stderr_backup);
297
298 Ok(Self {
299 model: Arc::new(model),
300 backend: Arc::new(backend),
301 })
302 }
303
304 async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
306 use std::process::Command;
307
308 if let Some(cached_path) =
310 Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
311 {
312 return Ok(cached_path);
314 }
315
316 let output = Command::new("huggingface-cli")
320 .args([
321 "download",
322 &config.huggingface_repo,
323 &config.model_file,
324 "--local-dir",
325 ".cache/models",
326 "--local-dir-use-symlinks",
327 "False",
328 ])
329 .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .output()
332 .map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
333
334 if !output.status.success() {
335 return Err(HeliosError::LLMError(format!(
336 "Failed to download model: {}",
337 String::from_utf8_lossy(&output.stderr)
338 )));
339 }
340
341 let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
342 if !model_path.exists() {
343 return Err(HeliosError::LLMError(format!(
344 "Model file not found after download: {}",
345 model_path.display()
346 )));
347 }
348
349 Ok(model_path)
350 }
351
352 fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
354 let cache_dir = std::env::var("HF_HOME")
356 .map(std::path::PathBuf::from)
357 .unwrap_or_else(|_| {
358 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
359 std::path::PathBuf::from(home)
360 .join(".cache")
361 .join("huggingface")
362 });
363
364 let hub_dir = cache_dir.join("hub");
365
366 let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
369 let repo_dir = hub_dir.join(&cache_repo_name);
370
371 if !repo_dir.exists() {
372 return None;
373 }
374
375 let snapshots_dir = repo_dir.join("snapshots");
377 if snapshots_dir.exists() {
378 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
379 for entry in entries.flatten() {
380 if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
381 if snapshot_path.exists() {
382 return Some(snapshot_path);
383 }
384 }
385 }
386 }
387 }
388
389 let blobs_dir = repo_dir.join("blobs");
391 if blobs_dir.exists() {
392 }
396
397 None
398 }
399
400 fn format_messages(&self, messages: &[ChatMessage]) -> String {
402 let mut formatted = String::new();
403
404 for message in messages {
406 match message.role {
407 crate::chat::Role::System => {
408 formatted.push_str("<|im_start|>system\n");
409 formatted.push_str(&message.content);
410 formatted.push_str("\n<|im_end|>\n");
411 }
412 crate::chat::Role::User => {
413 formatted.push_str("<|im_start|>user\n");
414 formatted.push_str(&message.content);
415 formatted.push_str("\n<|im_end|>\n");
416 }
417 crate::chat::Role::Assistant => {
418 formatted.push_str("<|im_start|>assistant\n");
419 formatted.push_str(&message.content);
420 formatted.push_str("\n<|im_end|>\n");
421 }
422 crate::chat::Role::Tool => {
423 formatted.push_str("<|im_start|>assistant\n");
425 formatted.push_str(&message.content);
426 formatted.push_str("\n<|im_end|>\n");
427 }
428 }
429 }
430
431 formatted.push_str("<|im_start|>assistant\n");
433
434 formatted
435 }
436}
437
438#[async_trait]
439impl LLMProvider for RemoteLLMClient {
440 fn as_any(&self) -> &dyn std::any::Any {
441 self
442 }
443
444 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
445 let url = format!("{}/chat/completions", self.config.base_url);
446
447 let response = self
448 .client
449 .post(&url)
450 .header("Authorization", format!("Bearer {}", self.config.api_key))
451 .header("Content-Type", "application/json")
452 .json(&request)
453 .send()
454 .await?;
455
456 if !response.status().is_success() {
457 let status = response.status();
458 let error_text = response
459 .text()
460 .await
461 .unwrap_or_else(|_| "Unknown error".to_string());
462 return Err(HeliosError::LLMError(format!(
463 "LLM API request failed with status {}: {}",
464 status, error_text
465 )));
466 }
467
468 let llm_response: LLMResponse = response.json().await?;
469 Ok(llm_response)
470 }
471}
472
473impl RemoteLLMClient {
474 pub async fn chat(
476 &self,
477 messages: Vec<ChatMessage>,
478 tools: Option<Vec<ToolDefinition>>,
479 temperature: Option<f32>,
480 max_tokens: Option<u32>,
481 stop: Option<Vec<String>>,
482 ) -> Result<ChatMessage> {
483 let request = LLMRequest {
484 model: self.config.model_name.clone(),
485 messages,
486 temperature: temperature.or(Some(self.config.temperature)),
487 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
488 tools,
489 tool_choice: None,
490 stream: None,
491 stop,
492 };
493
494 let response = self.generate(request).await?;
495
496 response
497 .choices
498 .into_iter()
499 .next()
500 .map(|choice| choice.message)
501 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
502 }
503
504 pub async fn chat_stream<F>(
506 &self,
507 messages: Vec<ChatMessage>,
508 tools: Option<Vec<ToolDefinition>>,
509 temperature: Option<f32>,
510 max_tokens: Option<u32>,
511 stop: Option<Vec<String>>,
512 mut on_chunk: F,
513 ) -> Result<ChatMessage>
514 where
515 F: FnMut(&str) + Send,
516 {
517 let request = LLMRequest {
518 model: self.config.model_name.clone(),
519 messages,
520 temperature: temperature.or(Some(self.config.temperature)),
521 max_tokens: max_tokens.or(Some(self.config.max_tokens)),
522 tools,
523 tool_choice: None,
524 stream: Some(true),
525 stop,
526 };
527
528 let url = format!("{}/chat/completions", self.config.base_url);
529
530 let response = self
531 .client
532 .post(&url)
533 .header("Authorization", format!("Bearer {}", self.config.api_key))
534 .header("Content-Type", "application/json")
535 .json(&request)
536 .send()
537 .await?;
538
539 if !response.status().is_success() {
540 let status = response.status();
541 let error_text = response
542 .text()
543 .await
544 .unwrap_or_else(|_| "Unknown error".to_string());
545 return Err(HeliosError::LLMError(format!(
546 "LLM API request failed with status {}: {}",
547 status, error_text
548 )));
549 }
550
551 let mut stream = response.bytes_stream();
552 let mut full_content = String::new();
553 let mut role = None;
554 let mut buffer = String::new();
555
556 while let Some(chunk_result) = stream.next().await {
557 let chunk = chunk_result?;
558 let chunk_str = String::from_utf8_lossy(&chunk);
559 buffer.push_str(&chunk_str);
560
561 while let Some(line_end) = buffer.find('\n') {
563 let line = buffer[..line_end].trim().to_string();
564 buffer = buffer[line_end + 1..].to_string();
565
566 if line.is_empty() || line == "data: [DONE]" {
567 continue;
568 }
569
570 if let Some(data) = line.strip_prefix("data: ") {
571 match serde_json::from_str::<StreamChunk>(data) {
572 Ok(stream_chunk) => {
573 if let Some(choice) = stream_chunk.choices.first() {
574 if let Some(r) = &choice.delta.role {
575 role = Some(r.clone());
576 }
577 if let Some(content) = &choice.delta.content {
578 full_content.push_str(content);
579 on_chunk(content);
580 }
581 }
582 }
583 Err(e) => {
584 tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
585 }
586 }
587 }
588 }
589 }
590
591 Ok(ChatMessage {
592 role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
593 content: full_content,
594 name: None,
595 tool_calls: None,
596 tool_call_id: None,
597 })
598 }
599}
600
601#[cfg(feature = "local")]
602#[async_trait]
603impl LLMProvider for LocalLLMProvider {
604 fn as_any(&self) -> &dyn std::any::Any {
605 self
606 }
607
608 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
609 let prompt = self.format_messages(&request.messages);
610
611 let (stdout_backup, stderr_backup) = suppress_output();
613
614 let model = Arc::clone(&self.model);
616 let backend = Arc::clone(&self.backend);
617 let result = task::spawn_blocking(move || {
618 use std::num::NonZeroU32;
620 let ctx_params =
621 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
622
623 let mut context = model
624 .new_context(&backend, ctx_params)
625 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
626
627 let tokens = context
629 .model
630 .str_to_token(&prompt, AddBos::Always)
631 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
632
633 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
635 for (i, &token) in tokens.iter().enumerate() {
636 let compute_logits = true; prompt_batch
638 .add(token, i as i32, &[0], compute_logits)
639 .map_err(|e| {
640 HeliosError::LLMError(format!(
641 "Failed to add prompt token to batch: {:?}",
642 e
643 ))
644 })?;
645 }
646
647 context
649 .decode(&mut prompt_batch)
650 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
651
652 let mut generated_text = String::new();
654 let max_new_tokens = 512; let mut next_pos = tokens.len() as i32; for _ in 0..max_new_tokens {
658 let logits = context.get_logits();
660
661 let token_idx = logits
662 .iter()
663 .enumerate()
664 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
665 .map(|(idx, _)| idx)
666 .unwrap_or_else(|| {
667 let eos = context.model.token_eos();
668 eos.0 as usize
669 });
670 let token = LlamaToken(token_idx as i32);
671
672 if token == context.model.token_eos() {
674 break;
675 }
676
677 match context.model.token_to_str(token, Special::Plaintext) {
679 Ok(text) => {
680 generated_text.push_str(&text);
681 }
682 Err(_) => continue, }
684
685 let mut gen_batch = LlamaBatch::new(1, 1);
687 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
688 HeliosError::LLMError(format!(
689 "Failed to add generated token to batch: {:?}",
690 e
691 ))
692 })?;
693
694 context.decode(&mut gen_batch).map_err(|e| {
696 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
697 })?;
698
699 next_pos += 1;
700 }
701
702 Ok::<String, HeliosError>(generated_text)
703 })
704 .await
705 .map_err(|e| {
706 restore_output(stdout_backup, stderr_backup);
707 HeliosError::LLMError(format!("Task failed: {}", e))
708 })??;
709
710 restore_output(stdout_backup, stderr_backup);
712
713 let response = LLMResponse {
714 id: format!("local-{}", chrono::Utc::now().timestamp()),
715 object: "chat.completion".to_string(),
716 created: chrono::Utc::now().timestamp() as u64,
717 model: "local-model".to_string(),
718 choices: vec![Choice {
719 index: 0,
720 message: ChatMessage {
721 role: crate::chat::Role::Assistant,
722 content: result,
723 name: None,
724 tool_calls: None,
725 tool_call_id: None,
726 },
727 finish_reason: Some("stop".to_string()),
728 }],
729 usage: Usage {
730 prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, },
734 };
735
736 Ok(response)
737 }
738}
739
740#[cfg(feature = "local")]
741impl LocalLLMProvider {
742 async fn chat_stream_local<F>(
744 &self,
745 messages: Vec<ChatMessage>,
746 _temperature: Option<f32>,
747 _max_tokens: Option<u32>,
748 _stop: Option<Vec<String>>,
749 mut on_chunk: F,
750 ) -> Result<ChatMessage>
751 where
752 F: FnMut(&str) + Send,
753 {
754 let prompt = self.format_messages(&messages);
755
756 let stderr_backup = suppress_stderr();
758
759 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
761
762 let model = Arc::clone(&self.model);
764 let backend = Arc::clone(&self.backend);
765 let generation_task = task::spawn_blocking(move || {
766 use std::num::NonZeroU32;
768 let ctx_params =
769 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
770
771 let mut context = model
772 .new_context(&backend, ctx_params)
773 .map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
774
775 let tokens = context
777 .model
778 .str_to_token(&prompt, AddBos::Always)
779 .map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
780
781 let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
783 for (i, &token) in tokens.iter().enumerate() {
784 let compute_logits = true;
785 prompt_batch
786 .add(token, i as i32, &[0], compute_logits)
787 .map_err(|e| {
788 HeliosError::LLMError(format!(
789 "Failed to add prompt token to batch: {:?}",
790 e
791 ))
792 })?;
793 }
794
795 context
797 .decode(&mut prompt_batch)
798 .map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
799
800 let mut generated_text = String::new();
802 let max_new_tokens = 512;
803 let mut next_pos = tokens.len() as i32;
804
805 for _ in 0..max_new_tokens {
806 let logits = context.get_logits();
807
808 let token_idx = logits
809 .iter()
810 .enumerate()
811 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
812 .map(|(idx, _)| idx)
813 .unwrap_or_else(|| {
814 let eos = context.model.token_eos();
815 eos.0 as usize
816 });
817 let token = LlamaToken(token_idx as i32);
818
819 if token == context.model.token_eos() {
821 break;
822 }
823
824 match context.model.token_to_str(token, Special::Plaintext) {
826 Ok(text) => {
827 generated_text.push_str(&text);
828 if tx.send(text).is_err() {
830 break;
831 }
832 }
833 Err(_) => continue,
834 }
835
836 let mut gen_batch = LlamaBatch::new(1, 1);
838 gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
839 HeliosError::LLMError(format!(
840 "Failed to add generated token to batch: {:?}",
841 e
842 ))
843 })?;
844
845 context.decode(&mut gen_batch).map_err(|e| {
847 HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
848 })?;
849
850 next_pos += 1;
851 }
852
853 Ok::<String, HeliosError>(generated_text)
854 });
855
856 while let Some(token) = rx.recv().await {
858 on_chunk(&token);
859 }
860
861 let result = match generation_task.await {
863 Ok(Ok(text)) => text,
864 Ok(Err(e)) => {
865 restore_stderr(stderr_backup);
866 return Err(e);
867 }
868 Err(e) => {
869 restore_stderr(stderr_backup);
870 return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
871 }
872 };
873
874 restore_stderr(stderr_backup);
876
877 Ok(ChatMessage {
878 role: crate::chat::Role::Assistant,
879 content: result,
880 name: None,
881 tool_calls: None,
882 tool_call_id: None,
883 })
884 }
885}
886
887#[async_trait]
888impl LLMProvider for LLMClient {
889 fn as_any(&self) -> &dyn std::any::Any {
890 self
891 }
892
893 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
894 self.provider.generate(request).await
895 }
896}
897
898impl LLMClient {
899 pub async fn chat(
901 &self,
902 messages: Vec<ChatMessage>,
903 tools: Option<Vec<ToolDefinition>>,
904 temperature: Option<f32>,
905 max_tokens: Option<u32>,
906 stop: Option<Vec<String>>,
907 ) -> Result<ChatMessage> {
908 let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
909 LLMProviderType::Remote(config) => (
910 config.model_name.clone(),
911 config.temperature,
912 config.max_tokens,
913 ),
914 #[cfg(feature = "local")]
915 LLMProviderType::Local(config) => (
916 "local-model".to_string(),
917 config.temperature,
918 config.max_tokens,
919 ),
920 };
921
922 let request = LLMRequest {
923 model: model_name,
924 messages,
925 temperature: temperature.or(Some(default_temperature)),
926 max_tokens: max_tokens.or(Some(default_max_tokens)),
927 tools,
928 tool_choice: None,
929 stream: None,
930 stop,
931 };
932
933 let response = self.generate(request).await?;
934
935 response
936 .choices
937 .into_iter()
938 .next()
939 .map(|choice| choice.message)
940 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
941 }
942
943 pub async fn chat_stream<F>(
945 &self,
946 messages: Vec<ChatMessage>,
947 tools: Option<Vec<ToolDefinition>>,
948 temperature: Option<f32>,
949 max_tokens: Option<u32>,
950 stop: Option<Vec<String>>,
951 on_chunk: F,
952 ) -> Result<ChatMessage>
953 where
954 F: FnMut(&str) + Send,
955 {
956 match &self.provider_type {
957 LLMProviderType::Remote(_) => {
958 if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
959 provider
960 .chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
961 .await
962 } else {
963 Err(HeliosError::AgentError("Provider type mismatch".into()))
964 }
965 }
966 #[cfg(feature = "local")]
967 LLMProviderType::Local(_) => {
968 if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
969 provider
970 .chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
971 .await
972 } else {
973 Err(HeliosError::AgentError("Provider type mismatch".into()))
974 }
975 }
976 }
977 }
978}
979
980