use crate::error::LlmError;
use crate::openai_provider::OpenAiProvider;
use crate::providers::{LlmProvider, ProviderResponseChunk};
use crate::types::{Message, Tool};
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
#[derive(Clone)]
pub struct LocalProvider {
openai: OpenAiProvider,
}
impl LocalProvider {
pub const DEFAULT_OLLAMA_URL: &'static str = "http://localhost:11434/v1/chat/completions";
pub const DEFAULT_LMSTUDIO_URL: &'static str = "http://localhost:1234/v1/chat/completions";
pub const DEFAULT_VLLM_URL: &'static str = "http://localhost:8000/v1/chat/completions";
pub fn new(base_url: Option<&str>, model: &str, max_tokens: u32, timeout: u64) -> Self {
let url = base_url.unwrap_or(Self::DEFAULT_OLLAMA_URL);
let api_key = "local".to_string();
Self {
openai: OpenAiProvider::new(api_key, Some(url), model, max_tokens, timeout),
}
}
pub fn ollama(model: &str, max_tokens: u32, timeout: u64) -> Self {
Self::new(Some(Self::DEFAULT_OLLAMA_URL), model, max_tokens, timeout)
}
pub fn lmstudio(model: &str, max_tokens: u32, timeout: u64) -> Self {
Self::new(Some(Self::DEFAULT_LMSTUDIO_URL), model, max_tokens, timeout)
}
pub fn vllm(model: &str, max_tokens: u32, timeout: u64) -> Self {
Self::new(Some(Self::DEFAULT_VLLM_URL), model, max_tokens, timeout)
}
}
#[async_trait]
impl LlmProvider for LocalProvider {
#[allow(clippy::type_complexity)]
async fn send(
&self,
messages: Vec<Message>,
tools: Vec<Tool>,
) -> Result<
Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
LlmError,
> {
self.openai.send(messages, tools).await
}
fn provider_name(&self) -> &str {
"local"
}
fn model_name(&self) -> &str {
self.openai.model_name()
}
fn clone_box(&self) -> Box<dyn LlmProvider> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_local_provider_creation() {
let provider = LocalProvider::new(None, "llama3.2", 4096, 120);
assert_eq!(provider.provider_name(), "local");
assert_eq!(provider.model_name(), "llama3.2");
}
#[test]
fn test_local_provider_custom_url() {
let provider = LocalProvider::new(
Some("http://custom:8080/v1/chat/completions"),
"custom-model",
8192,
60,
);
assert_eq!(provider.provider_name(), "local");
assert_eq!(provider.model_name(), "custom-model");
}
#[test]
fn test_ollama_preset() {
let provider = LocalProvider::ollama("llama3.2", 4096, 120);
assert_eq!(provider.provider_name(), "local");
assert_eq!(provider.model_name(), "llama3.2");
}
#[test]
fn test_lmstudio_preset() {
let provider = LocalProvider::lmstudio("local-model", 4096, 120);
assert_eq!(provider.provider_name(), "local");
assert_eq!(provider.model_name(), "local-model");
}
#[test]
fn test_vllm_preset() {
let provider = LocalProvider::vllm("meta-llama/Llama-3.2-3B", 4096, 120);
assert_eq!(provider.provider_name(), "local");
assert_eq!(provider.model_name(), "meta-llama/Llama-3.2-3B");
}
#[test]
fn test_local_provider_clone() {
let provider = LocalProvider::new(None, "test-model", 4096, 120);
let cloned = provider.clone_box();
assert_eq!(cloned.provider_name(), "local");
assert_eq!(cloned.model_name(), "test-model");
}
}