1use std::collections::VecDeque;
2use std::sync::Mutex;
3
4#[derive(Debug, Clone)]
5pub struct LlmRequest {
6 pub prompt: String,
7 pub max_tokens: Option<u32>,
8 pub model: Option<String>,
9}
10
11#[derive(Debug, Clone)]
12pub struct LlmResponse {
13 pub text: String,
14 pub provider: String,
15 pub prompt_tokens: u32,
16 pub completion_tokens: u32,
17}
18
19#[derive(Debug, Clone)]
20pub struct LlmProvider {
21 pub name: String,
22 pub endpoint: String,
23 pub priority: u32,
24 pub is_local: bool,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub enum ProviderStatus {
29 Available,
30 Unavailable,
31 Unknown,
32}
33
34#[derive(Debug, Clone)]
35pub struct LlmProviderStatus {
36 pub provider: LlmProvider,
37 pub status: ProviderStatus,
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub enum RouterError {
42 NoProviderAvailable,
43 ProviderFailed { provider: String, reason: String },
44}
45
46impl std::fmt::Display for RouterError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 RouterError::NoProviderAvailable => write!(f, "no LLM provider available"),
50 RouterError::ProviderFailed { provider, reason } => {
51 write!(f, "provider '{provider}' failed: {reason}")
52 }
53 }
54 }
55}
56
57impl std::error::Error for RouterError {}
58
59pub struct LlmRouter {
60 providers: Vec<LlmProvider>,
61 queued: Mutex<VecDeque<(u32, LlmRequest)>>,
62}
63
64impl LlmRouter {
65 pub fn new(mut providers: Vec<LlmProvider>) -> Self {
66 providers.sort_by_key(|p| p.priority);
67 Self { providers, queued: Mutex::new(VecDeque::new()) }
68 }
69
70 pub fn route_request(&self, agent_pid: u32, request: LlmRequest, local_only: bool) -> Result<LlmResponse, RouterError> {
71 let candidates: Vec<&LlmProvider> = self.providers.iter()
72 .filter(|p| !local_only || p.is_local)
73 .collect();
74
75 for provider in &candidates {
76 match self.check_availability(provider) {
77 ProviderStatus::Available => {
78 match self.call_provider(provider, &request) {
79 Ok(resp) => return Ok(resp),
80 Err(reason) => {
81 eprintln!("LLM fallback: provider '{}' failed for agent {}: {}; trying next", provider.name, agent_pid, reason);
82 }
83 }
84 }
85 status => {
86 eprintln!("LLM fallback: provider '{}' status {:?} for agent {}; trying next", provider.name, status, agent_pid);
87 }
88 }
89 }
90
91 eprintln!("No LLM provider available for agent {}; request queued", agent_pid);
92 self.queued.lock().unwrap().push_back((agent_pid, request));
93 Err(RouterError::NoProviderAvailable)
94 }
95
96 pub fn check_availability(&self, _provider: &LlmProvider) -> ProviderStatus {
97 ProviderStatus::Available
99 }
100
101 pub fn get_providers(&self) -> Vec<LlmProviderStatus> {
102 self.providers.iter().map(|p| LlmProviderStatus {
103 status: self.check_availability(p),
104 provider: p.clone(),
105 }).collect()
106 }
107
108 pub fn queued_count(&self) -> usize {
109 self.queued.lock().unwrap().len()
110 }
111
112 pub fn filter_for_air_gap(providers: &[LlmProvider], air_gapped: bool) -> Vec<&LlmProvider> {
113 if !air_gapped {
114 return providers.iter().collect();
115 }
116 providers.iter().filter(|p| p.is_local).collect()
117 }
118
119 fn call_provider(&self, provider: &LlmProvider, request: &LlmRequest) -> Result<LlmResponse, String> {
120 let prompt_tokens = (request.prompt.split_whitespace().count() as u32).max(1);
121 let completion_tokens = request.max_tokens.unwrap_or(64);
122 Ok(LlmResponse {
123 text: format!("[{}] response to: {}", provider.name, request.prompt),
124 provider: provider.name.clone(),
125 prompt_tokens,
126 completion_tokens,
127 })
128 }
129}
130
131#[cfg(test)]
132pub mod test_support {
133 use super::*;
134 use std::collections::HashSet;
135
136 pub struct MockRouter {
137 providers: Vec<LlmProvider>,
138 unavailable: HashSet<String>,
139 queued: Mutex<VecDeque<(u32, LlmRequest)>>,
140 }
141
142 impl MockRouter {
143 pub fn new(mut providers: Vec<LlmProvider>) -> Self {
144 providers.sort_by_key(|p| p.priority);
145 Self { providers, unavailable: HashSet::new(), queued: Mutex::new(VecDeque::new()) }
146 }
147
148 pub fn mark_unavailable(&mut self, name: &str) {
149 self.unavailable.insert(name.to_string());
150 }
151
152 pub fn route_request(&self, agent_pid: u32, request: LlmRequest, local_only: bool) -> Result<LlmResponse, RouterError> {
153 let candidates: Vec<&LlmProvider> = self.providers.iter()
154 .filter(|p| !local_only || p.is_local)
155 .collect();
156
157 for provider in &candidates {
158 if self.unavailable.contains(&provider.name) {
159 eprintln!("LLM fallback: provider '{}' unavailable for agent {}; trying next", provider.name, agent_pid);
160 continue;
161 }
162 let prompt_tokens = (request.prompt.split_whitespace().count() as u32).max(1);
163 let completion_tokens = request.max_tokens.unwrap_or(64);
164 return Ok(LlmResponse {
165 text: format!("[{}] response to: {}", provider.name, request.prompt),
166 provider: provider.name.clone(),
167 prompt_tokens,
168 completion_tokens,
169 });
170 }
171
172 eprintln!("No LLM provider available for agent {}; request queued", agent_pid);
173 self.queued.lock().unwrap().push_back((agent_pid, request));
174 Err(RouterError::NoProviderAvailable)
175 }
176
177 pub fn queued_count(&self) -> usize {
178 self.queued.lock().unwrap().len()
179 }
180
181 pub fn get_providers(&self) -> Vec<LlmProviderStatus> {
182 self.providers.iter().map(|p| LlmProviderStatus {
183 status: if self.unavailable.contains(&p.name) {
184 ProviderStatus::Unavailable
185 } else {
186 ProviderStatus::Available
187 },
188 provider: p.clone(),
189 }).collect()
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use test_support::MockRouter;
198
199 fn make_providers() -> Vec<LlmProvider> {
200 vec![
201 LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
202 LlmProvider { name: "ollama".into(), endpoint: "http://localhost:11434".into(), priority: 2, is_local: true },
203 ]
204 }
205
206 fn req(prompt: &str) -> LlmRequest {
207 LlmRequest { prompt: prompt.to_string(), max_tokens: None, model: None }
208 }
209
210 #[test]
211 fn routes_to_highest_priority_provider() {
212 let router = MockRouter::new(make_providers());
213 let resp = router.route_request(1, req("hello"), false).unwrap();
214 assert_eq!(resp.provider, "openai");
215 }
216
217 #[test]
218 fn provider_list_sorted_by_priority() {
219 let providers = vec![
220 LlmProvider { name: "ollama".into(), endpoint: "http://localhost:11434".into(), priority: 2, is_local: true },
221 LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
222 ];
223 let router = MockRouter::new(providers);
224 let resp = router.route_request(1, req("test"), false).unwrap();
225 assert_eq!(resp.provider, "openai");
226 }
227
228 #[test]
229 fn falls_back_to_next_provider_on_failure() {
230 let mut router = MockRouter::new(make_providers());
231 router.mark_unavailable("openai");
232 let resp = router.route_request(1, req("hello"), false).unwrap();
233 assert_eq!(resp.provider, "ollama");
234 }
235
236 #[test]
237 fn response_contains_prompt_text() {
238 let router = MockRouter::new(make_providers());
239 let resp = router.route_request(42, req("what is rust"), false).unwrap();
240 assert!(resp.text.contains("what is rust"));
241 }
242
243 #[test]
244 fn local_only_skips_cloud_providers() {
245 let router = MockRouter::new(make_providers());
246 let resp = router.route_request(1, req("private"), true).unwrap();
247 assert_eq!(resp.provider, "ollama");
248 }
249
250 #[test]
251 fn local_only_fails_when_no_local_provider() {
252 let providers = vec![
253 LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
254 ];
255 let router = MockRouter::new(providers);
256 let err = router.route_request(1, req("private"), true).unwrap_err();
257 assert_eq!(err, RouterError::NoProviderAvailable);
258 }
259
260 #[test]
261 fn queues_request_when_no_provider_available() {
262 let mut router = MockRouter::new(make_providers());
263 router.mark_unavailable("openai");
264 router.mark_unavailable("ollama");
265 let err = router.route_request(7, req("queue me"), false).unwrap_err();
266 assert_eq!(err, RouterError::NoProviderAvailable);
267 assert_eq!(router.queued_count(), 1);
268 }
269
270 #[test]
271 fn multiple_failed_requests_all_queued() {
272 let mut router = MockRouter::new(make_providers());
273 router.mark_unavailable("openai");
274 router.mark_unavailable("ollama");
275 for _ in 0..3 {
276 let _ = router.route_request(1, req("q"), false);
277 }
278 assert_eq!(router.queued_count(), 3);
279 }
280
281 #[test]
282 fn get_providers_reflects_availability() {
283 let mut router = MockRouter::new(make_providers());
284 router.mark_unavailable("openai");
285 let statuses = router.get_providers();
286 let openai = statuses.iter().find(|s| s.provider.name == "openai").unwrap();
287 let ollama = statuses.iter().find(|s| s.provider.name == "ollama").unwrap();
288 assert_eq!(openai.status, ProviderStatus::Unavailable);
289 assert_eq!(ollama.status, ProviderStatus::Available);
290 }
291
292 #[test]
293 fn llm_router_routes_successfully() {
294 let router = LlmRouter::new(make_providers());
295 let resp = router.route_request(1, req("hello"), false).unwrap();
296 assert!(!resp.provider.is_empty());
297 assert!(!resp.text.is_empty());
298 }
299
300 #[test]
301 fn llm_router_local_only_returns_local_provider() {
302 let router = LlmRouter::new(make_providers());
303 let resp = router.route_request(1, req("private"), true).unwrap();
304 assert_eq!(resp.provider, "ollama");
305 }
306
307 #[test]
308 fn filter_for_air_gap_returns_only_local_when_enabled() {
309 let providers = make_providers();
310 let filtered = LlmRouter::filter_for_air_gap(&providers, true);
311 assert_eq!(filtered.len(), 1);
312 assert_eq!(filtered[0].name, "ollama");
313 }
314
315 #[test]
316 fn filter_for_air_gap_returns_all_when_disabled() {
317 let providers = make_providers();
318 let filtered = LlmRouter::filter_for_air_gap(&providers, false);
319 assert_eq!(filtered.len(), 2);
320 }
321
322 #[test]
323 fn filter_for_air_gap_empty_when_no_local_providers() {
324 let providers = vec![
325 LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
326 LlmProvider { name: "anthropic".into(), endpoint: "https://api.anthropic.com".into(), priority: 2, is_local: false },
327 ];
328 let filtered = LlmRouter::filter_for_air_gap(&providers, true);
329 assert!(filtered.is_empty());
330 }
331}