limit_llm/
local_provider.rs1use crate::error::LlmError;
2use crate::openai_provider::OpenAiProvider;
3use crate::providers::{LlmProvider, ProviderResponseChunk};
4use crate::types::{Message, Tool};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8
9#[derive(Clone)]
16pub struct LocalProvider {
17 openai: OpenAiProvider,
18}
19
20impl LocalProvider {
21 pub const DEFAULT_OLLAMA_URL: &'static str = "http://localhost:11434/v1/chat/completions";
23
24 pub const DEFAULT_LMSTUDIO_URL: &'static str = "http://localhost:1234/v1/chat/completions";
26
27 pub const DEFAULT_VLLM_URL: &'static str = "http://localhost:8000/v1/chat/completions";
29
30 pub fn new(base_url: Option<&str>, model: &str, max_tokens: u32, timeout: u64) -> Self {
38 let url = base_url.unwrap_or(Self::DEFAULT_OLLAMA_URL);
39 let api_key = "local".to_string();
41
42 Self {
43 openai: OpenAiProvider::new(api_key, Some(url), model, max_tokens, timeout),
44 }
45 }
46
47 pub fn ollama(model: &str, max_tokens: u32, timeout: u64) -> Self {
49 Self::new(Some(Self::DEFAULT_OLLAMA_URL), model, max_tokens, timeout)
50 }
51
52 pub fn lmstudio(model: &str, max_tokens: u32, timeout: u64) -> Self {
54 Self::new(Some(Self::DEFAULT_LMSTUDIO_URL), model, max_tokens, timeout)
55 }
56
57 pub fn vllm(model: &str, max_tokens: u32, timeout: u64) -> Self {
59 Self::new(Some(Self::DEFAULT_VLLM_URL), model, max_tokens, timeout)
60 }
61}
62
63#[async_trait]
64impl LlmProvider for LocalProvider {
65 #[allow(clippy::type_complexity)]
66 async fn send(
67 &self,
68 messages: Vec<Message>,
69 tools: Vec<Tool>,
70 ) -> Result<
71 Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
72 LlmError,
73 > {
74 self.openai.send(messages, tools).await
75 }
76
77 fn provider_name(&self) -> &str {
78 "local"
79 }
80
81 fn model_name(&self) -> &str {
82 self.openai.model_name()
83 }
84
85 fn clone_box(&self) -> Box<dyn LlmProvider> {
86 Box::new(self.clone())
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_local_provider_creation() {
96 let provider = LocalProvider::new(None, "llama3.2", 4096, 120);
97 assert_eq!(provider.provider_name(), "local");
98 assert_eq!(provider.model_name(), "llama3.2");
99 }
100
101 #[test]
102 fn test_local_provider_custom_url() {
103 let provider = LocalProvider::new(
104 Some("http://custom:8080/v1/chat/completions"),
105 "custom-model",
106 8192,
107 60,
108 );
109 assert_eq!(provider.provider_name(), "local");
110 assert_eq!(provider.model_name(), "custom-model");
111 }
112
113 #[test]
114 fn test_ollama_preset() {
115 let provider = LocalProvider::ollama("llama3.2", 4096, 120);
116 assert_eq!(provider.provider_name(), "local");
117 assert_eq!(provider.model_name(), "llama3.2");
118 }
119
120 #[test]
121 fn test_lmstudio_preset() {
122 let provider = LocalProvider::lmstudio("local-model", 4096, 120);
123 assert_eq!(provider.provider_name(), "local");
124 assert_eq!(provider.model_name(), "local-model");
125 }
126
127 #[test]
128 fn test_vllm_preset() {
129 let provider = LocalProvider::vllm("meta-llama/Llama-3.2-3B", 4096, 120);
130 assert_eq!(provider.provider_name(), "local");
131 assert_eq!(provider.model_name(), "meta-llama/Llama-3.2-3B");
132 }
133
134 #[test]
135 fn test_local_provider_clone() {
136 let provider = LocalProvider::new(None, "test-model", 4096, 120);
137 let cloned = provider.clone_box();
138 assert_eq!(cloned.provider_name(), "local");
139 assert_eq!(cloned.model_name(), "test-model");
140 }
141}