Skip to main content

limit_llm/
local_provider.rs

1use crate::error::LlmError;
2use crate::openai_provider::OpenAiProvider;
3use crate::providers::{LlmProvider, ProviderResponseChunk};
4use crate::types::{Message, Tool};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8
9/// Local LLM provider for Ollama, LM Studio, vLLM, and other OpenAI-compatible local servers.
10///
11/// This provider wraps OpenAiProvider with sensible defaults for local inference:
12/// - No API key required (empty string by default)
13/// - Default localhost URL (Ollama: 11434)
14/// - Extended timeout for slower local inference
15#[derive(Clone)]
16pub struct LocalProvider {
17    openai: OpenAiProvider,
18}
19
20impl LocalProvider {
21    /// Default Ollama endpoint
22    pub const DEFAULT_OLLAMA_URL: &'static str = "http://localhost:11434/v1/chat/completions";
23
24    /// Default LM Studio endpoint
25    pub const DEFAULT_LMSTUDIO_URL: &'static str = "http://localhost:1234/v1/chat/completions";
26
27    /// Default vLLM endpoint
28    pub const DEFAULT_VLLM_URL: &'static str = "http://localhost:8000/v1/chat/completions";
29
30    /// Create a new local provider
31    ///
32    /// # Arguments
33    /// * `base_url` - Optional custom endpoint. Defaults to Ollama.
34    /// * `model` - Model name to use (required)
35    /// * `max_tokens` - Maximum output tokens
36    /// * `timeout` - Request timeout in seconds (default: 120 for slower local inference)
37    pub fn new(base_url: Option<&str>, model: &str, max_tokens: u32, timeout: u64) -> Self {
38        let url = base_url.unwrap_or(Self::DEFAULT_OLLAMA_URL);
39        // Local servers typically don't require auth, use placeholder
40        let api_key = "local".to_string();
41
42        Self {
43            openai: OpenAiProvider::new(api_key, Some(url), model, max_tokens, timeout),
44        }
45    }
46
47    /// Create provider configured for Ollama
48    pub fn ollama(model: &str, max_tokens: u32, timeout: u64) -> Self {
49        Self::new(Some(Self::DEFAULT_OLLAMA_URL), model, max_tokens, timeout)
50    }
51
52    /// Create provider configured for LM Studio
53    pub fn lmstudio(model: &str, max_tokens: u32, timeout: u64) -> Self {
54        Self::new(Some(Self::DEFAULT_LMSTUDIO_URL), model, max_tokens, timeout)
55    }
56
57    /// Create provider configured for vLLM
58    pub fn vllm(model: &str, max_tokens: u32, timeout: u64) -> Self {
59        Self::new(Some(Self::DEFAULT_VLLM_URL), model, max_tokens, timeout)
60    }
61}
62
63#[async_trait]
64impl LlmProvider for LocalProvider {
65    #[allow(clippy::type_complexity)]
66    async fn send(
67        &self,
68        messages: Vec<Message>,
69        tools: Vec<Tool>,
70    ) -> Result<
71        Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
72        LlmError,
73    > {
74        self.openai.send(messages, tools).await
75    }
76
77    fn provider_name(&self) -> &str {
78        "local"
79    }
80
81    fn model_name(&self) -> &str {
82        self.openai.model_name()
83    }
84
85    fn clone_box(&self) -> Box<dyn LlmProvider> {
86        Box::new(self.clone())
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_local_provider_creation() {
96        let provider = LocalProvider::new(None, "llama3.2", 4096, 120);
97        assert_eq!(provider.provider_name(), "local");
98        assert_eq!(provider.model_name(), "llama3.2");
99    }
100
101    #[test]
102    fn test_local_provider_custom_url() {
103        let provider = LocalProvider::new(
104            Some("http://custom:8080/v1/chat/completions"),
105            "custom-model",
106            8192,
107            60,
108        );
109        assert_eq!(provider.provider_name(), "local");
110        assert_eq!(provider.model_name(), "custom-model");
111    }
112
113    #[test]
114    fn test_ollama_preset() {
115        let provider = LocalProvider::ollama("llama3.2", 4096, 120);
116        assert_eq!(provider.provider_name(), "local");
117        assert_eq!(provider.model_name(), "llama3.2");
118    }
119
120    #[test]
121    fn test_lmstudio_preset() {
122        let provider = LocalProvider::lmstudio("local-model", 4096, 120);
123        assert_eq!(provider.provider_name(), "local");
124        assert_eq!(provider.model_name(), "local-model");
125    }
126
127    #[test]
128    fn test_vllm_preset() {
129        let provider = LocalProvider::vllm("meta-llama/Llama-3.2-3B", 4096, 120);
130        assert_eq!(provider.provider_name(), "local");
131        assert_eq!(provider.model_name(), "meta-llama/Llama-3.2-3B");
132    }
133
134    #[test]
135    fn test_local_provider_clone() {
136        let provider = LocalProvider::new(None, "test-model", 4096, 120);
137        let cloned = provider.clone_box();
138        assert_eq!(cloned.provider_name(), "local");
139        assert_eq!(cloned.model_name(), "test-model");
140    }
141}