1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use crate::compiler::CompiledAction;
8use crate::executor::{CompiledExecutor, ExecutionResult};
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum RoutingDecision {
14 Compiled {
16 compiled_id: String,
17 signature: String,
18 },
19 Llm { reason: String },
21}
22
23pub struct ExecutionRouter {
25 compiled: parking_lot::Mutex<HashMap<String, CompiledAction>>,
27 compiled_hits: parking_lot::Mutex<u64>,
29 llm_fallbacks: parking_lot::Mutex<u64>,
30}
31
32impl ExecutionRouter {
33 pub fn new() -> Self {
34 Self {
35 compiled: parking_lot::Mutex::new(HashMap::new()),
36 compiled_hits: parking_lot::Mutex::new(0),
37 llm_fallbacks: parking_lot::Mutex::new(0),
38 }
39 }
40
41 pub fn register(&self, compiled: CompiledAction) {
43 self.compiled
44 .lock()
45 .insert(compiled.signature.clone(), compiled);
46 }
47
48 pub fn route(&self, signature: &str) -> RoutingDecision {
50 let compiled = self.compiled.lock();
51 if let Some(action) = compiled.get(signature) {
52 *self.compiled_hits.lock() += 1;
53 RoutingDecision::Compiled {
54 compiled_id: action.id.clone(),
55 signature: action.signature.clone(),
56 }
57 } else {
58 *self.llm_fallbacks.lock() += 1;
59 RoutingDecision::Llm {
60 reason: format!("No compiled action for signature: {}", signature),
61 }
62 }
63 }
64
65 pub fn execute_compiled(
67 &self,
68 signature: &str,
69 variables: HashMap<String, serde_json::Value>,
70 ) -> Option<ExecutionResult> {
71 let compiled = self.compiled.lock();
72 let action = compiled.get(signature)?;
73 let mut executor = CompiledExecutor::with_variables(variables);
74 Some(executor.execute(action))
75 }
76
77 pub fn compiled_count(&self) -> usize {
79 self.compiled.lock().len()
80 }
81
82 pub fn stats(&self) -> RouterStats {
84 RouterStats {
85 compiled_actions: self.compiled.lock().len(),
86 compiled_hits: *self.compiled_hits.lock(),
87 llm_fallbacks: *self.llm_fallbacks.lock(),
88 }
89 }
90
91 pub fn deregister(&self, signature: &str) -> bool {
93 self.compiled.lock().remove(signature).is_some()
94 }
95}
96
97impl Default for ExecutionRouter {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RouterStats {
105 pub compiled_actions: usize,
106 pub compiled_hits: u64,
107 pub llm_fallbacks: u64,
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::ast::ActionNode;
114
115 fn make_compiled(sig: &str) -> CompiledAction {
116 CompiledAction {
117 id: uuid::Uuid::new_v4().to_string(),
118 signature: sig.into(),
119 ast: ActionNode::Action {
120 tool: "test".into(),
121 params: HashMap::new(),
122 },
123 required_variables: vec![],
124 compiled_at: chrono::Utc::now().to_rfc3339(),
125 source_occurrences: 5,
126 source_success_rate: 1.0,
127 }
128 }
129
130 #[test]
131 fn test_router_prefers_compiled() {
132 let router = ExecutionRouter::new();
133 router.register(make_compiled("git_push_flow"));
134
135 let decision = router.route("git_push_flow");
136 assert!(matches!(decision, RoutingDecision::Compiled { .. }));
137 }
138
139 #[test]
140 fn test_router_fallback_to_llm() {
141 let router = ExecutionRouter::new();
142 let decision = router.route("unknown_action");
143 assert!(matches!(decision, RoutingDecision::Llm { .. }));
144 }
145
146 #[test]
147 fn test_router_execute_compiled() {
148 let router = ExecutionRouter::new();
149 router.register(make_compiled("deploy"));
150
151 let result = router.execute_compiled("deploy", HashMap::new());
152 assert!(result.is_some());
153 let result = result.unwrap();
154 assert!(result.success);
155 assert_eq!(result.tokens_used, 0);
156 }
157
158 #[test]
159 fn test_router_stats() {
160 let router = ExecutionRouter::new();
161 router.register(make_compiled("flow_a"));
162
163 router.route("flow_a"); router.route("flow_a"); router.route("flow_b"); let stats = router.stats();
168 assert_eq!(stats.compiled_hits, 2);
169 assert_eq!(stats.llm_fallbacks, 1);
170 assert_eq!(stats.compiled_actions, 1);
171 }
172
173 #[test]
174 fn test_router_deregister() {
175 let router = ExecutionRouter::new();
176 router.register(make_compiled("temp"));
177 assert_eq!(router.compiled_count(), 1);
178 assert!(router.deregister("temp"));
179 assert_eq!(router.compiled_count(), 0);
180 }
181
182 #[test]
183 fn test_router_default() {
184 let router = ExecutionRouter::default();
185 assert_eq!(router.compiled_count(), 0);
186 }
187
188 #[test]
189 fn test_deregister_nonexistent() {
190 let router = ExecutionRouter::new();
191 assert!(!router.deregister("nope"));
192 }
193
194 #[test]
195 fn test_execute_compiled_nonexistent() {
196 let router = ExecutionRouter::new();
197 let result = router.execute_compiled("nope", HashMap::new());
198 assert!(result.is_none());
199 }
200
201 #[test]
202 fn test_routing_decision_serde() {
203 let decision = RoutingDecision::Compiled { compiled_id: "id".into(), signature: "sig".into() };
204 let json = serde_json::to_string(&decision).unwrap();
205 let restored: RoutingDecision = serde_json::from_str(&json).unwrap();
206 assert_eq!(restored, decision);
207 }
208
209 #[test]
210 fn test_routing_decision_llm_serde() {
211 let decision = RoutingDecision::Llm { reason: "no match".into() };
212 let json = serde_json::to_string(&decision).unwrap();
213 let restored: RoutingDecision = serde_json::from_str(&json).unwrap();
214 assert_eq!(restored, decision);
215 }
216
217 #[test]
218 fn test_router_stats_initial() {
219 let router = ExecutionRouter::new();
220 let stats = router.stats();
221 assert_eq!(stats.compiled_actions, 0);
222 assert_eq!(stats.compiled_hits, 0);
223 assert_eq!(stats.llm_fallbacks, 0);
224 }
225
226 #[test]
227 fn test_router_stats_serde() {
228 let stats = RouterStats { compiled_actions: 5, compiled_hits: 10, llm_fallbacks: 3 };
229 let json = serde_json::to_string(&stats).unwrap();
230 let restored: RouterStats = serde_json::from_str(&json).unwrap();
231 assert_eq!(restored.compiled_hits, 10);
232 }
233}