use crate::chat::ChatMessage;
use crate::config::LLMConfig;
use crate::error::{HeliosError, Result};
use crate::tools::ToolDefinition;
use async_trait::async_trait;
use futures::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[cfg(feature = "local")]
use {
crate::config::LocalConfig,
llama_cpp_2::{
context::params::LlamaContextParams,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
token::LlamaToken,
},
std::{fs::File, os::fd::AsRawFd, sync::Arc},
tokio::task,
};
#[cfg(feature = "candle")]
use crate::candle_provider::CandleLLMProvider;
#[cfg(feature = "local")]
impl From<llama_cpp_2::LLamaCppError> for HeliosError {
fn from(err: llama_cpp_2::LLamaCppError) -> Self {
HeliosError::LlamaCppError(format!("{:?}", err))
}
}
#[derive(Clone)]
pub enum LLMProviderType {
Remote(LLMConfig),
#[cfg(feature = "local")]
Local(LocalConfig),
#[cfg(feature = "candle")]
Candle(crate::config::CandleConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaToolCall {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<DeltaFunctionCall>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaFunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<DeltaToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
fn as_any(&self) -> &dyn std::any::Any;
}
pub struct LLMClient {
provider: Box<dyn LLMProvider + Send + Sync>,
provider_type: LLMProviderType,
}
impl LLMClient {
pub async fn new(provider_type: LLMProviderType) -> Result<Self> {
let provider: Box<dyn LLMProvider + Send + Sync> = match &provider_type {
LLMProviderType::Remote(config) => Box::new(RemoteLLMClient::new(config.clone())),
#[cfg(feature = "local")]
LLMProviderType::Local(config) => {
Box::new(LocalLLMProvider::new(config.clone()).await?)
}
#[cfg(feature = "candle")]
LLMProviderType::Candle(config) => {
Box::new(CandleLLMProvider::new(config.clone()).await?)
}
};
Ok(Self {
provider,
provider_type,
})
}
pub fn provider_type(&self) -> &LLMProviderType {
&self.provider_type
}
}
pub struct RemoteLLMClient {
config: LLMConfig,
client: Client,
}
impl RemoteLLMClient {
pub fn new(config: LLMConfig) -> Self {
Self {
config,
client: Client::new(),
}
}
pub fn config(&self) -> &LLMConfig {
&self.config
}
}
#[cfg(feature = "local")]
fn suppress_output() -> (i32, i32) {
let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
let stdout_backup = unsafe { libc::dup(1) };
let stderr_backup = unsafe { libc::dup(2) };
unsafe {
libc::dup2(dev_null.as_raw_fd(), 1); libc::dup2(dev_null.as_raw_fd(), 2); }
(stdout_backup, stderr_backup)
}
#[cfg(feature = "local")]
fn restore_output(stdout_backup: i32, stderr_backup: i32) {
unsafe {
libc::dup2(stdout_backup, 1); libc::dup2(stderr_backup, 2); libc::close(stdout_backup);
libc::close(stderr_backup);
}
}
#[cfg(feature = "local")]
fn suppress_stderr() -> i32 {
let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
let stderr_backup = unsafe { libc::dup(2) };
unsafe {
libc::dup2(dev_null.as_raw_fd(), 2);
}
stderr_backup
}
#[cfg(feature = "local")]
fn restore_stderr(stderr_backup: i32) {
unsafe {
libc::dup2(stderr_backup, 2);
libc::close(stderr_backup);
}
}
#[cfg(feature = "local")]
pub struct LocalLLMProvider {
model: Arc<LlamaModel>,
backend: Arc<LlamaBackend>,
}
#[cfg(feature = "local")]
impl LocalLLMProvider {
pub async fn new(config: LocalConfig) -> Result<Self> {
let (stdout_backup, stderr_backup) = suppress_output();
let backend = LlamaBackend::init().map_err(|e| {
restore_output(stdout_backup, stderr_backup);
HeliosError::LLMError(format!("Failed to initialize llama backend: {:?}", e))
})?;
let model_path = Self::download_model(&config).await.map_err(|e| {
restore_output(stdout_backup, stderr_backup);
e
})?;
let model_params = LlamaModelParams::default().with_n_gpu_layers(99);
let model =
LlamaModel::load_from_file(&backend, &model_path, &model_params).map_err(|e| {
restore_output(stdout_backup, stderr_backup);
HeliosError::LLMError(format!("Failed to load model: {:?}", e))
})?;
restore_output(stdout_backup, stderr_backup);
Ok(Self {
model: Arc::new(model),
backend: Arc::new(backend),
})
}
async fn download_model(config: &LocalConfig) -> Result<std::path::PathBuf> {
use std::process::Command;
if let Some(cached_path) =
Self::find_model_in_cache(&config.huggingface_repo, &config.model_file)
{
return Ok(cached_path);
}
let output = Command::new("huggingface-cli")
.args([
"download",
&config.huggingface_repo,
&config.model_file,
"--local-dir",
".cache/models",
"--local-dir-use-symlinks",
"False",
])
.stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .output()
.map_err(|e| HeliosError::LLMError(format!("Failed to run huggingface-cli: {}", e)))?;
if !output.status.success() {
return Err(HeliosError::LLMError(format!(
"Failed to download model: {}",
String::from_utf8_lossy(&output.stderr)
)));
}
let model_path = std::path::PathBuf::from(".cache/models").join(&config.model_file);
if !model_path.exists() {
return Err(HeliosError::LLMError(format!(
"Model file not found after download: {}",
model_path.display()
)));
}
Ok(model_path)
}
fn find_model_in_cache(repo: &str, model_file: &str) -> Option<std::path::PathBuf> {
let cache_dir = std::env::var("HF_HOME")
.map(std::path::PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
std::path::PathBuf::from(home)
.join(".cache")
.join("huggingface")
});
let hub_dir = cache_dir.join("hub");
let cache_repo_name = format!("models--{}", repo.replace("/", "--"));
let repo_dir = hub_dir.join(&cache_repo_name);
if !repo_dir.exists() {
return None;
}
let snapshots_dir = repo_dir.join("snapshots");
if snapshots_dir.exists() {
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
for entry in entries.flatten() {
if let Ok(snapshot_path) = entry.path().join(model_file).canonicalize() {
if snapshot_path.exists() {
return Some(snapshot_path);
}
}
}
}
}
let blobs_dir = repo_dir.join("blobs");
if blobs_dir.exists() {
}
None
}
fn format_messages(&self, messages: &[ChatMessage]) -> String {
let mut formatted = String::new();
for message in messages {
match message.role {
crate::chat::Role::System => {
formatted.push_str("<|im_start|>system\n");
formatted.push_str(&message.content);
formatted.push_str("\n<|im_end|>\n");
}
crate::chat::Role::User => {
formatted.push_str("<|im_start|>user\n");
formatted.push_str(&message.content);
formatted.push_str("\n<|im_end|>\n");
}
crate::chat::Role::Assistant => {
formatted.push_str("<|im_start|>assistant\n");
formatted.push_str(&message.content);
formatted.push_str("\n<|im_end|>\n");
}
crate::chat::Role::Tool => {
formatted.push_str("<|im_start|>assistant\n");
formatted.push_str(&message.content);
formatted.push_str("\n<|im_end|>\n");
}
}
}
formatted.push_str("<|im_start|>assistant\n");
formatted
}
}
#[async_trait]
impl LLMProvider for RemoteLLMClient {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
let url = format!("{}/chat/completions", self.config.base_url);
let mut request_builder = self
.client
.post(&url)
.header("Content-Type", "application/json");
if !self.config.base_url.contains("10.")
&& !self.config.base_url.contains("localhost")
&& !self.config.base_url.contains("127.0.0.1")
{
request_builder =
request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
}
let response = request_builder.json(&request).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(HeliosError::LLMError(format!(
"LLM API request failed with status {}: {}",
status, error_text
)));
}
let llm_response: LLMResponse = response.json().await?;
Ok(llm_response)
}
}
impl RemoteLLMClient {
pub async fn chat(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<ToolDefinition>>,
temperature: Option<f32>,
max_tokens: Option<u32>,
stop: Option<Vec<String>>,
) -> Result<ChatMessage> {
let request = LLMRequest {
model: self.config.model_name.clone(),
messages,
temperature: temperature.or(Some(self.config.temperature)),
max_tokens: max_tokens.or(Some(self.config.max_tokens)),
tools: tools.clone(),
tool_choice: if tools.is_some() {
Some("auto".to_string())
} else {
None
},
stream: None,
stop,
};
let response = self.generate(request).await?;
response
.choices
.into_iter()
.next()
.map(|choice| choice.message)
.ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
}
pub async fn chat_stream<F>(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<ToolDefinition>>,
temperature: Option<f32>,
max_tokens: Option<u32>,
stop: Option<Vec<String>>,
mut on_chunk: F,
) -> Result<ChatMessage>
where
F: FnMut(&str) + Send,
{
let request = LLMRequest {
model: self.config.model_name.clone(),
messages,
temperature: temperature.or(Some(self.config.temperature)),
max_tokens: max_tokens.or(Some(self.config.max_tokens)),
tools: tools.clone(),
tool_choice: if tools.is_some() {
Some("auto".to_string())
} else {
None
},
stream: Some(true),
stop,
};
let url = format!("{}/chat/completions", self.config.base_url);
let mut request_builder = self
.client
.post(&url)
.header("Content-Type", "application/json");
if !self.config.base_url.contains("10.")
&& !self.config.base_url.contains("localhost")
&& !self.config.base_url.contains("127.0.0.1")
{
request_builder =
request_builder.header("Authorization", format!("Bearer {}", self.config.api_key));
}
let response = request_builder.json(&request).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(HeliosError::LLMError(format!(
"LLM API request failed with status {}: {}",
status, error_text
)));
}
let mut stream = response.bytes_stream();
let mut full_content = String::new();
let mut role = None;
let mut tool_calls = Vec::new();
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer = buffer[line_end + 1..].to_string();
if line.is_empty() || line == "data: [DONE]" {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(data) {
Ok(stream_chunk) => {
if let Some(choice) = stream_chunk.choices.first() {
if let Some(r) = &choice.delta.role {
role = Some(r.clone());
}
if let Some(content) = &choice.delta.content {
full_content.push_str(content);
on_chunk(content);
}
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
for delta_tool_call in delta_tool_calls {
while tool_calls.len() <= delta_tool_call.index as usize {
tool_calls.push(None);
}
let tool_call_slot =
&mut tool_calls[delta_tool_call.index as usize];
if tool_call_slot.is_none() {
*tool_call_slot = Some(crate::chat::ToolCall {
id: String::new(),
call_type: "function".to_string(),
function: crate::chat::FunctionCall {
name: String::new(),
arguments: String::new(),
},
});
}
if let Some(tool_call) = tool_call_slot.as_mut() {
if let Some(id) = &delta_tool_call.id {
tool_call.id = id.clone();
}
if let Some(function) = &delta_tool_call.function {
if let Some(name) = &function.name {
tool_call.function.name = name.clone();
}
if let Some(args) = &function.arguments {
tool_call.function.arguments.push_str(args);
}
}
}
}
}
}
}
Err(e) => {
tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
}
}
}
}
}
let final_tool_calls = tool_calls.into_iter().flatten().collect::<Vec<_>>();
let tool_calls_option = if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
};
Ok(ChatMessage {
role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
content: full_content,
name: None,
tool_calls: tool_calls_option,
tool_call_id: None,
})
}
}
#[cfg(feature = "local")]
#[async_trait]
impl LLMProvider for LocalLLMProvider {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
let prompt = self.format_messages(&request.messages);
let (stdout_backup, stderr_backup) = suppress_output();
let model = Arc::clone(&self.model);
let backend = Arc::clone(&self.backend);
let result = task::spawn_blocking(move || {
use std::num::NonZeroU32;
let ctx_params =
LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
let mut context = model
.new_context(&backend, ctx_params)
.map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
let tokens = context
.model
.str_to_token(&prompt, AddBos::Always)
.map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
for (i, &token) in tokens.iter().enumerate() {
let compute_logits = true; prompt_batch
.add(token, i as i32, &[0], compute_logits)
.map_err(|e| {
HeliosError::LLMError(format!(
"Failed to add prompt token to batch: {:?}",
e
))
})?;
}
context
.decode(&mut prompt_batch)
.map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
let mut generated_text = String::new();
let max_new_tokens = 512; let mut next_pos = tokens.len() as i32;
for _ in 0..max_new_tokens {
let logits = context.get_logits();
let token_idx = logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(idx, _)| idx)
.unwrap_or_else(|| {
let eos = context.model.token_eos();
eos.0 as usize
});
let token = LlamaToken(token_idx as i32);
if token == context.model.token_eos() {
break;
}
match context.model.token_to_str(token, Special::Plaintext) {
Ok(text) => {
generated_text.push_str(&text);
}
Err(_) => continue, }
let mut gen_batch = LlamaBatch::new(1, 1);
gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
HeliosError::LLMError(format!(
"Failed to add generated token to batch: {:?}",
e
))
})?;
context.decode(&mut gen_batch).map_err(|e| {
HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
})?;
next_pos += 1;
}
Ok::<String, HeliosError>(generated_text)
})
.await
.map_err(|e| {
restore_output(stdout_backup, stderr_backup);
HeliosError::LLMError(format!("Task failed: {}", e))
})??;
restore_output(stdout_backup, stderr_backup);
let response = LLMResponse {
id: format!("local-{}", chrono::Utc::now().timestamp()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: "local-model".to_string(),
choices: vec![Choice {
index: 0,
message: ChatMessage {
role: crate::chat::Role::Assistant,
content: result,
name: None,
tool_calls: None,
tool_call_id: None,
},
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, },
};
Ok(response)
}
}
#[cfg(feature = "local")]
impl LocalLLMProvider {
async fn chat_stream_local<F>(
&self,
messages: Vec<ChatMessage>,
_temperature: Option<f32>,
_max_tokens: Option<u32>,
_stop: Option<Vec<String>>,
mut on_chunk: F,
) -> Result<ChatMessage>
where
F: FnMut(&str) + Send,
{
let prompt = self.format_messages(&messages);
let stderr_backup = suppress_stderr();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let model = Arc::clone(&self.model);
let backend = Arc::clone(&self.backend);
let generation_task = task::spawn_blocking(move || {
use std::num::NonZeroU32;
let ctx_params =
LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
let mut context = model
.new_context(&backend, ctx_params)
.map_err(|e| HeliosError::LLMError(format!("Failed to create context: {:?}", e)))?;
let tokens = context
.model
.str_to_token(&prompt, AddBos::Always)
.map_err(|e| HeliosError::LLMError(format!("Tokenization failed: {:?}", e)))?;
let mut prompt_batch = LlamaBatch::new(tokens.len(), 1);
for (i, &token) in tokens.iter().enumerate() {
let compute_logits = true;
prompt_batch
.add(token, i as i32, &[0], compute_logits)
.map_err(|e| {
HeliosError::LLMError(format!(
"Failed to add prompt token to batch: {:?}",
e
))
})?;
}
context
.decode(&mut prompt_batch)
.map_err(|e| HeliosError::LLMError(format!("Failed to decode prompt: {:?}", e)))?;
let mut generated_text = String::new();
let max_new_tokens = 512;
let mut next_pos = tokens.len() as i32;
for _ in 0..max_new_tokens {
let logits = context.get_logits();
let token_idx = logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(idx, _)| idx)
.unwrap_or_else(|| {
let eos = context.model.token_eos();
eos.0 as usize
});
let token = LlamaToken(token_idx as i32);
if token == context.model.token_eos() {
break;
}
match context.model.token_to_str(token, Special::Plaintext) {
Ok(text) => {
generated_text.push_str(&text);
if tx.send(text).is_err() {
break;
}
}
Err(_) => continue,
}
let mut gen_batch = LlamaBatch::new(1, 1);
gen_batch.add(token, next_pos, &[0], true).map_err(|e| {
HeliosError::LLMError(format!(
"Failed to add generated token to batch: {:?}",
e
))
})?;
context.decode(&mut gen_batch).map_err(|e| {
HeliosError::LLMError(format!("Failed to decode token: {:?}", e))
})?;
next_pos += 1;
}
Ok::<String, HeliosError>(generated_text)
});
while let Some(token) = rx.recv().await {
on_chunk(&token);
}
let result = match generation_task.await {
Ok(Ok(text)) => text,
Ok(Err(e)) => {
restore_stderr(stderr_backup);
return Err(e);
}
Err(e) => {
restore_stderr(stderr_backup);
return Err(HeliosError::LLMError(format!("Task failed: {}", e)));
}
};
restore_stderr(stderr_backup);
Ok(ChatMessage {
role: crate::chat::Role::Assistant,
content: result,
name: None,
tool_calls: None,
tool_call_id: None,
})
}
}
#[async_trait]
impl LLMProvider for LLMClient {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
self.provider.generate(request).await
}
}
impl LLMClient {
pub async fn chat(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<ToolDefinition>>,
temperature: Option<f32>,
max_tokens: Option<u32>,
stop: Option<Vec<String>>,
) -> Result<ChatMessage> {
let (model_name, default_temperature, default_max_tokens) = match &self.provider_type {
LLMProviderType::Remote(config) => (
config.model_name.clone(),
config.temperature,
config.max_tokens,
),
#[cfg(feature = "local")]
LLMProviderType::Local(config) => (
"local-model".to_string(),
config.temperature,
config.max_tokens,
),
#[cfg(feature = "candle")]
LLMProviderType::Candle(config) => (
config.huggingface_repo.clone(),
config.temperature,
config.max_tokens,
),
};
let request = LLMRequest {
model: model_name,
messages,
temperature: temperature.or(Some(default_temperature)),
max_tokens: max_tokens.or(Some(default_max_tokens)),
tools: tools.clone(),
tool_choice: if tools.is_some() {
Some("auto".to_string())
} else {
None
},
stream: None,
stop,
};
let response = self.generate(request).await?;
response
.choices
.into_iter()
.next()
.map(|choice| choice.message)
.ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
}
pub async fn chat_stream<F>(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<ToolDefinition>>,
temperature: Option<f32>,
max_tokens: Option<u32>,
stop: Option<Vec<String>>,
on_chunk: F,
) -> Result<ChatMessage>
where
F: FnMut(&str) + Send,
{
match &self.provider_type {
LLMProviderType::Remote(_) => {
if let Some(provider) = self.provider.as_any().downcast_ref::<RemoteLLMClient>() {
provider
.chat_stream(messages, tools, temperature, max_tokens, stop, on_chunk)
.await
} else {
Err(HeliosError::AgentError("Provider type mismatch".into()))
}
}
#[cfg(feature = "local")]
LLMProviderType::Local(_) => {
if let Some(provider) = self.provider.as_any().downcast_ref::<LocalLLMProvider>() {
provider
.chat_stream_local(messages, temperature, max_tokens, stop, on_chunk)
.await
} else {
Err(HeliosError::AgentError("Provider type mismatch".into()))
}
}
#[cfg(feature = "candle")]
LLMProviderType::Candle(config) => {
let (model_name, default_temperature, default_max_tokens) = (
config.huggingface_repo.clone(),
config.temperature,
config.max_tokens,
);
let request = LLMRequest {
model: model_name,
messages,
temperature: temperature.or(Some(default_temperature)),
max_tokens: max_tokens.or(Some(default_max_tokens)),
tools: tools.clone(),
tool_choice: if tools.is_some() {
Some("auto".to_string())
} else {
None
},
stream: None,
stop,
};
let response = self.provider.generate(request).await?;
if let Some(choice) = response.choices.first() {
on_chunk(&choice.message.content);
}
response
.choices
.into_iter()
.next()
.map(|choice| choice.message)
.ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
}
}
}
}