Skip to main content

openhawk_core/
llm_router.rs

1use std::collections::VecDeque;
2use std::sync::Mutex;
3
4#[derive(Debug, Clone)]
5pub struct LlmRequest {
6    pub prompt: String,
7    pub max_tokens: Option<u32>,
8    pub model: Option<String>,
9}
10
11#[derive(Debug, Clone)]
12pub struct LlmResponse {
13    pub text: String,
14    pub provider: String,
15    pub prompt_tokens: u32,
16    pub completion_tokens: u32,
17}
18
19#[derive(Debug, Clone)]
20pub struct LlmProvider {
21    pub name: String,
22    pub endpoint: String,
23    pub priority: u32,
24    pub is_local: bool,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub enum ProviderStatus {
29    Available,
30    Unavailable,
31    Unknown,
32}
33
34#[derive(Debug, Clone)]
35pub struct LlmProviderStatus {
36    pub provider: LlmProvider,
37    pub status: ProviderStatus,
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub enum RouterError {
42    NoProviderAvailable,
43    ProviderFailed { provider: String, reason: String },
44}
45
46impl std::fmt::Display for RouterError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            RouterError::NoProviderAvailable => write!(f, "no LLM provider available"),
50            RouterError::ProviderFailed { provider, reason } => {
51                write!(f, "provider '{provider}' failed: {reason}")
52            }
53        }
54    }
55}
56
57impl std::error::Error for RouterError {}
58
59pub struct LlmRouter {
60    providers: Vec<LlmProvider>,
61    queued: Mutex<VecDeque<(u32, LlmRequest)>>,
62}
63
64impl LlmRouter {
65    pub fn new(mut providers: Vec<LlmProvider>) -> Self {
66        providers.sort_by_key(|p| p.priority);
67        Self { providers, queued: Mutex::new(VecDeque::new()) }
68    }
69
70    pub fn route_request(&self, agent_pid: u32, request: LlmRequest, local_only: bool) -> Result<LlmResponse, RouterError> {
71        let candidates: Vec<&LlmProvider> = self.providers.iter()
72            .filter(|p| !local_only || p.is_local)
73            .collect();
74
75        for provider in &candidates {
76            match self.check_availability(provider) {
77                ProviderStatus::Available => {
78                    match self.call_provider(provider, &request) {
79                        Ok(resp) => return Ok(resp),
80                        Err(reason) => {
81                            eprintln!("LLM fallback: provider '{}' failed for agent {}: {}; trying next", provider.name, agent_pid, reason);
82                        }
83                    }
84                }
85                status => {
86                    eprintln!("LLM fallback: provider '{}' status {:?} for agent {}; trying next", provider.name, status, agent_pid);
87                }
88            }
89        }
90
91        eprintln!("No LLM provider available for agent {}; request queued", agent_pid);
92        self.queued.lock().unwrap().push_back((agent_pid, request));
93        Err(RouterError::NoProviderAvailable)
94    }
95
96    pub fn check_availability(&self, _provider: &LlmProvider) -> ProviderStatus {
97        // Real impl would HTTP-ping the endpoint; always available in this stub
98        ProviderStatus::Available
99    }
100
101    pub fn get_providers(&self) -> Vec<LlmProviderStatus> {
102        self.providers.iter().map(|p| LlmProviderStatus {
103            status: self.check_availability(p),
104            provider: p.clone(),
105        }).collect()
106    }
107
108    pub fn queued_count(&self) -> usize {
109        self.queued.lock().unwrap().len()
110    }
111
112    pub fn filter_for_air_gap(providers: &[LlmProvider], air_gapped: bool) -> Vec<&LlmProvider> {
113        if !air_gapped {
114            return providers.iter().collect();
115        }
116        providers.iter().filter(|p| p.is_local).collect()
117    }
118
119    fn call_provider(&self, provider: &LlmProvider, request: &LlmRequest) -> Result<LlmResponse, String> {
120        let prompt_tokens = (request.prompt.split_whitespace().count() as u32).max(1);
121        let completion_tokens = request.max_tokens.unwrap_or(64);
122        Ok(LlmResponse {
123            text: format!("[{}] response to: {}", provider.name, request.prompt),
124            provider: provider.name.clone(),
125            prompt_tokens,
126            completion_tokens,
127        })
128    }
129}
130
131#[cfg(test)]
132pub mod test_support {
133    use super::*;
134    use std::collections::HashSet;
135
136    pub struct MockRouter {
137        providers: Vec<LlmProvider>,
138        unavailable: HashSet<String>,
139        queued: Mutex<VecDeque<(u32, LlmRequest)>>,
140    }
141
142    impl MockRouter {
143        pub fn new(mut providers: Vec<LlmProvider>) -> Self {
144            providers.sort_by_key(|p| p.priority);
145            Self { providers, unavailable: HashSet::new(), queued: Mutex::new(VecDeque::new()) }
146        }
147
148        pub fn mark_unavailable(&mut self, name: &str) {
149            self.unavailable.insert(name.to_string());
150        }
151
152        pub fn route_request(&self, agent_pid: u32, request: LlmRequest, local_only: bool) -> Result<LlmResponse, RouterError> {
153            let candidates: Vec<&LlmProvider> = self.providers.iter()
154                .filter(|p| !local_only || p.is_local)
155                .collect();
156
157            for provider in &candidates {
158                if self.unavailable.contains(&provider.name) {
159                    eprintln!("LLM fallback: provider '{}' unavailable for agent {}; trying next", provider.name, agent_pid);
160                    continue;
161                }
162                let prompt_tokens = (request.prompt.split_whitespace().count() as u32).max(1);
163                let completion_tokens = request.max_tokens.unwrap_or(64);
164                return Ok(LlmResponse {
165                    text: format!("[{}] response to: {}", provider.name, request.prompt),
166                    provider: provider.name.clone(),
167                    prompt_tokens,
168                    completion_tokens,
169                });
170            }
171
172            eprintln!("No LLM provider available for agent {}; request queued", agent_pid);
173            self.queued.lock().unwrap().push_back((agent_pid, request));
174            Err(RouterError::NoProviderAvailable)
175        }
176
177        pub fn queued_count(&self) -> usize {
178            self.queued.lock().unwrap().len()
179        }
180
181        pub fn get_providers(&self) -> Vec<LlmProviderStatus> {
182            self.providers.iter().map(|p| LlmProviderStatus {
183                status: if self.unavailable.contains(&p.name) {
184                    ProviderStatus::Unavailable
185                } else {
186                    ProviderStatus::Available
187                },
188                provider: p.clone(),
189            }).collect()
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use test_support::MockRouter;
198
199    fn make_providers() -> Vec<LlmProvider> {
200        vec![
201            LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
202            LlmProvider { name: "ollama".into(), endpoint: "http://localhost:11434".into(), priority: 2, is_local: true },
203        ]
204    }
205
206    fn req(prompt: &str) -> LlmRequest {
207        LlmRequest { prompt: prompt.to_string(), max_tokens: None, model: None }
208    }
209
210    #[test]
211    fn routes_to_highest_priority_provider() {
212        let router = MockRouter::new(make_providers());
213        let resp = router.route_request(1, req("hello"), false).unwrap();
214        assert_eq!(resp.provider, "openai");
215    }
216
217    #[test]
218    fn provider_list_sorted_by_priority() {
219        let providers = vec![
220            LlmProvider { name: "ollama".into(), endpoint: "http://localhost:11434".into(), priority: 2, is_local: true },
221            LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
222        ];
223        let router = MockRouter::new(providers);
224        let resp = router.route_request(1, req("test"), false).unwrap();
225        assert_eq!(resp.provider, "openai");
226    }
227
228    #[test]
229    fn falls_back_to_next_provider_on_failure() {
230        let mut router = MockRouter::new(make_providers());
231        router.mark_unavailable("openai");
232        let resp = router.route_request(1, req("hello"), false).unwrap();
233        assert_eq!(resp.provider, "ollama");
234    }
235
236    #[test]
237    fn response_contains_prompt_text() {
238        let router = MockRouter::new(make_providers());
239        let resp = router.route_request(42, req("what is rust"), false).unwrap();
240        assert!(resp.text.contains("what is rust"));
241    }
242
243    #[test]
244    fn local_only_skips_cloud_providers() {
245        let router = MockRouter::new(make_providers());
246        let resp = router.route_request(1, req("private"), true).unwrap();
247        assert_eq!(resp.provider, "ollama");
248    }
249
250    #[test]
251    fn local_only_fails_when_no_local_provider() {
252        let providers = vec![
253            LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
254        ];
255        let router = MockRouter::new(providers);
256        let err = router.route_request(1, req("private"), true).unwrap_err();
257        assert_eq!(err, RouterError::NoProviderAvailable);
258    }
259
260    #[test]
261    fn queues_request_when_no_provider_available() {
262        let mut router = MockRouter::new(make_providers());
263        router.mark_unavailable("openai");
264        router.mark_unavailable("ollama");
265        let err = router.route_request(7, req("queue me"), false).unwrap_err();
266        assert_eq!(err, RouterError::NoProviderAvailable);
267        assert_eq!(router.queued_count(), 1);
268    }
269
270    #[test]
271    fn multiple_failed_requests_all_queued() {
272        let mut router = MockRouter::new(make_providers());
273        router.mark_unavailable("openai");
274        router.mark_unavailable("ollama");
275        for _ in 0..3 {
276            let _ = router.route_request(1, req("q"), false);
277        }
278        assert_eq!(router.queued_count(), 3);
279    }
280
281    #[test]
282    fn get_providers_reflects_availability() {
283        let mut router = MockRouter::new(make_providers());
284        router.mark_unavailable("openai");
285        let statuses = router.get_providers();
286        let openai = statuses.iter().find(|s| s.provider.name == "openai").unwrap();
287        let ollama = statuses.iter().find(|s| s.provider.name == "ollama").unwrap();
288        assert_eq!(openai.status, ProviderStatus::Unavailable);
289        assert_eq!(ollama.status, ProviderStatus::Available);
290    }
291
292    #[test]
293    fn llm_router_routes_successfully() {
294        let router = LlmRouter::new(make_providers());
295        let resp = router.route_request(1, req("hello"), false).unwrap();
296        assert!(!resp.provider.is_empty());
297        assert!(!resp.text.is_empty());
298    }
299
300    #[test]
301    fn llm_router_local_only_returns_local_provider() {
302        let router = LlmRouter::new(make_providers());
303        let resp = router.route_request(1, req("private"), true).unwrap();
304        assert_eq!(resp.provider, "ollama");
305    }
306
307    #[test]
308    fn filter_for_air_gap_returns_only_local_when_enabled() {
309        let providers = make_providers();
310        let filtered = LlmRouter::filter_for_air_gap(&providers, true);
311        assert_eq!(filtered.len(), 1);
312        assert_eq!(filtered[0].name, "ollama");
313    }
314
315    #[test]
316    fn filter_for_air_gap_returns_all_when_disabled() {
317        let providers = make_providers();
318        let filtered = LlmRouter::filter_for_air_gap(&providers, false);
319        assert_eq!(filtered.len(), 2);
320    }
321
322    #[test]
323    fn filter_for_air_gap_empty_when_no_local_providers() {
324        let providers = vec![
325            LlmProvider { name: "openai".into(), endpoint: "https://api.openai.com/v1".into(), priority: 1, is_local: false },
326            LlmProvider { name: "anthropic".into(), endpoint: "https://api.anthropic.com".into(), priority: 2, is_local: false },
327        ];
328        let filtered = LlmRouter::filter_for_air_gap(&providers, true);
329        assert!(filtered.is_empty());
330    }
331}