use anyhow::{anyhow, Result};
use async_trait::async_trait;
use std::{
collections::HashMap,
path::PathBuf,
time::{Duration, Instant},
};
use super::{
config::ProviderConfig,
providers::LLMProvider,
types::{LLMRequest, LLMResponse, LLMResponseChunk, LLMResponseStream, Usage},
};
pub struct LocalModelProvider {
config: ProviderConfig,
model_path: PathBuf,
}
impl LocalModelProvider {
pub fn new(config: ProviderConfig) -> Result<Self> {
let model_path = PathBuf::from(
config
.base_url
.as_ref()
.ok_or_else(|| anyhow!("Model path not specified for local provider"))?,
);
if !model_path.exists() {
return Err(anyhow!("Model file does not exist: {:?}", model_path));
}
Ok(Self { config, model_path })
}
}
#[async_trait]
impl LLMProvider for LocalModelProvider {
async fn generate(&self, model: &str, request: &LLMRequest) -> Result<LLMResponse> {
let prompt = format!(
"{}\n\n{}",
request.system_prompt.as_deref().unwrap_or(""),
request
.messages
.iter()
.map(|m| format!("{:?}: {}", m.role, m.content))
.collect::<Vec<_>>()
.join("\n")
);
let content = format!(
"Local model response to: {}... (Model: {})",
&prompt.chars().take(50).collect::<String>(),
model
);
let prompt_tokens = prompt.split_whitespace().count();
let completion_tokens = content.split_whitespace().count();
Ok(LLMResponse {
content,
model_used: model.to_string(),
provider_used: "local".to_string(),
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
cost: 0.0, },
latency: Duration::from_millis(100),
quality_score: Some(0.7),
metadata: HashMap::new(),
})
}
fn get_available_models(&self) -> Vec<String> {
vec![
"llama-2-7b".to_string(),
"llama-2-13b".to_string(),
"mistral-7b".to_string(),
"codellama-7b".to_string(),
]
}
fn supports_streaming(&self) -> bool {
true
}
fn get_provider_name(&self) -> &str {
"local"
}
fn estimate_cost(&self, _model: &str, _input_tokens: usize, _output_tokens: usize) -> f64 {
0.0 }
async fn generate_stream(
&self,
model: &str,
request: &LLMRequest,
) -> Result<LLMResponseStream> {
let response = self.generate(model, request).await?;
let words: Vec<String> = response
.content
.split_whitespace()
.map(String::from)
.collect();
let chunk_size = 3;
let model_name = model.to_string();
let provider_name = "local".to_string();
let started_at = Instant::now();
let chunks: Vec<Result<LLMResponseChunk>> = words
.chunks(chunk_size)
.enumerate()
.map(|(index, chunk)| {
let total_words = words.len();
let is_final = (index + 1) * chunk_size >= total_words;
Ok(LLMResponseChunk {
content: chunk.join(" ") + if !is_final { " " } else { "" },
is_final,
chunk_index: index,
model_used: model_name.clone(),
provider_used: provider_name.clone(),
latency: started_at.elapsed(),
metadata: HashMap::new(),
})
})
.collect();
let stream = futures_util::stream::iter(chunks);
Ok(LLMResponseStream {
stream: Box::pin(stream),
model_used: model.to_string(),
provider_used: "local".to_string(),
started_at,
})
}
}