1use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use serde_json::Value as JsonValue;
11
12use crate::matrixrpc::{
13 ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
14 RegistryService, ServiceId, ServiceStatus,
15};
16
17#[derive(Debug, thiserror::Error)]
19pub enum NodeRouterError {
20 #[error("Node '{0}' not found in any registered service")]
22 NodeNotFound(String),
23
24 #[error("Service '{service_id}' for node '{node_id}' is not running (status: {status:?})")]
26 ServiceNotRunning {
27 node_id: String,
28 service_id: ServiceId,
29 status: ServiceStatus,
30 },
31
32 #[error("No services registered in the registry")]
34 NoServicesRegistered,
35
36 #[error("Routing failed: {0}")]
38 RoutingFailed(String),
39
40 #[error("Node '{node_id}' does not support capability '{capability}'")]
42 CapabilityNotSupported { node_id: String, capability: String },
43
44 #[error("Invalid context for node '{node_id}': {reason}")]
46 InvalidContext { node_id: String, reason: String },
47
48 #[error("Internal error: {0}")]
50 Internal(String),
51}
52
53#[derive(Debug, Clone)]
55pub struct NodeRouteResult {
56 pub service_id: ServiceId,
58 pub node_id: String,
60 pub context: NodeContext,
62 pub request_id: JsonRpcId,
64 pub callback_endpoint: String,
66 pub timeout_ms: u64,
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct NodeContext {
73 pub input_data: JsonValue,
75 pub variables: HashMap<String, JsonValue>,
77 pub previous_results: HashMap<String, JsonValue>,
79 pub metadata: HashMap<String, JsonValue>,
81}
82
83impl NodeContext {
84 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn with_input(data: JsonValue) -> Self {
91 Self {
92 input_data: data,
93 ..Default::default()
94 }
95 }
96
97 pub fn variable(mut self, key: impl Into<String>, value: JsonValue) -> Self {
99 self.variables.insert(key.into(), value);
100 self
101 }
102
103 pub fn previous_result(mut self, node_id: impl Into<String>, result: JsonValue) -> Self {
105 self.previous_results.insert(node_id.into(), result);
106 self
107 }
108
109 pub fn metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
111 self.metadata.insert(key.into(), value);
112 self
113 }
114
115 pub fn to_json(&self) -> JsonValue {
117 serde_json::json!({
118 "input_data": self.input_data,
119 "variables": self.variables,
120 "previous_results": self.previous_results,
121 "metadata": self.metadata
122 })
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum NodeType {
129 Task,
131 Condition,
133 Validate,
135 Ai,
137 Composite,
139}
140
141impl NodeType {
142 pub fn as_str(&self) -> &'static str {
144 match self {
145 NodeType::Task => "task",
146 NodeType::Condition => "condition",
147 NodeType::Validate => "validate",
148 NodeType::Ai => "ai",
149 NodeType::Composite => "composite",
150 }
151 }
152
153 pub fn from_str(s: &str) -> Option<Self> {
155 match s {
156 "task" => Some(NodeType::Task),
157 "condition" => Some(NodeType::Condition),
158 "validate" => Some(NodeType::Validate),
159 "ai" => Some(NodeType::Ai),
160 "composite" => Some(NodeType::Composite),
161 _ => None,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub enum NodeCapability {
169 AiExecution,
171 ToolExecution,
173 ContextAccess,
175}
176
177impl NodeCapability {
178 pub fn as_str(&self) -> &'static str {
180 match self {
181 NodeCapability::AiExecution => "ai_execution",
182 NodeCapability::ToolExecution => "tool_execution",
183 NodeCapability::ContextAccess => "context_access",
184 }
185 }
186
187 pub fn from_str(s: &str) -> Option<Self> {
189 match s {
190 "ai_execution" => Some(NodeCapability::AiExecution),
191 "tool_execution" => Some(NodeCapability::ToolExecution),
192 "context_access" => Some(NodeCapability::ContextAccess),
193 _ => None,
194 }
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct NodeDefinition {
201 pub id: String,
203 pub name: String,
205 pub service_id: ServiceId,
207 pub node_type: NodeType,
209 pub description: Option<String>,
211 pub capabilities: Vec<NodeCapability>,
213 pub timeout_ms: Option<u64>,
215 pub params_schema: Option<JsonValue>,
217}
218
219#[derive(Debug)]
224pub struct NodeRouter {
225 registry: Arc<RegistryService>,
227 node_index: Arc<RwLock<HashMap<String, NodeDefinition>>>,
229 default_timeout_ms: u64,
231 default_callback_endpoint: String,
233}
234
235impl NodeRouter {
236 pub fn new(registry: Arc<RegistryService>) -> Self {
238 Self {
239 registry,
240 node_index: Arc::new(RwLock::new(HashMap::new())),
241 default_timeout_ms: 60_000, default_callback_endpoint: "matrixcode://callback".to_string(),
243 }
244 }
245
246 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
248 self.default_timeout_ms = timeout_ms;
249 self
250 }
251
252 pub fn with_callback_endpoint(mut self, endpoint: impl Into<String>) -> Self {
254 self.default_callback_endpoint = endpoint.into();
255 self
256 }
257
258 pub async fn register_node(&self, _service_id: ServiceId, node_def: NodeDefinition) {
260 let mut index = self.node_index.write().await;
261 index.insert(node_def.id.clone(), node_def);
262 }
263
264 pub async fn unregister_service_nodes(&self, service_id: &ServiceId) {
266 let mut index = self.node_index.write().await;
267 index.retain(|_, def| def.service_id != *service_id);
268 }
269
270 pub async fn rebuild_index(&self) -> Result<(), NodeRouterError> {
272 let services = self.registry.list_all().await;
273 let mut index = self.node_index.write().await;
274 index.clear();
275
276 for service in services {
277 if service.status != ServiceStatus::Running {
278 continue;
279 }
280
281 for cap in &service.capabilities {
283 if cap.name == "nodes" {
284 if let Some(nodes_json) = cap.config.get("nodes") {
286 if let Ok(nodes) = serde_json::from_value::<Vec<JsonValue>>(nodes_json.clone()) {
287 for node in nodes {
288 if let Some(id) = node.get("id").and_then(|n| n.as_str()) {
289 let node_type = node
290 .get("type")
291 .and_then(|t| t.as_str())
292 .and_then(NodeType::from_str)
293 .unwrap_or(NodeType::Task);
294
295 let capabilities: Vec<NodeCapability> = node
296 .get("capabilities")
297 .and_then(|c| c.as_array())
298 .map(|arr| {
299 arr.iter()
300 .filter_map(|c| c.as_str().and_then(NodeCapability::from_str))
301 .collect()
302 })
303 .unwrap_or_default();
304
305 let def = NodeDefinition {
306 id: id.to_string(),
307 name: node.get("name").and_then(|n| n.as_str()).map(|s| s.to_string()).unwrap_or_else(|| id.to_string()),
308 service_id: service.id.clone(),
309 node_type,
310 description: node.get("description").and_then(|d| d.as_str()).map(|s| s.to_string()),
311 capabilities,
312 timeout_ms: node.get("timeout_ms").and_then(|t| t.as_u64()),
313 params_schema: node.get("params_schema").cloned(),
314 };
315 index.insert(id.to_string(), def);
316 }
317 }
318 }
319 }
320 }
321 }
322 }
323
324 Ok(())
325 }
326
327 pub async fn route(
329 &self,
330 node_id: &str,
331 context: NodeContext,
332 request_id: JsonRpcId,
333 required_capabilities: Vec<NodeCapability>,
334 ) -> Result<NodeRouteResult, NodeRouterError> {
335 let index = self.node_index.read().await;
337 let node_def = index
338 .get(node_id)
339 .cloned()
340 .ok_or_else(|| NodeRouterError::NodeNotFound(node_id.to_string()))?;
341
342 for cap in required_capabilities {
344 if !node_def.capabilities.contains(&cap) {
345 return Err(NodeRouterError::CapabilityNotSupported {
346 node_id: node_id.to_string(),
347 capability: cap.as_str().to_string(),
348 });
349 }
350 }
351
352 let service = self.registry.get(&node_def.service_id).await;
354 match service {
355 Some(s) if s.status == ServiceStatus::Running => {
356 let timeout = node_def.timeout_ms.unwrap_or(self.default_timeout_ms);
358 Ok(NodeRouteResult {
359 service_id: node_def.service_id,
360 node_id: node_def.id,
361 context,
362 request_id,
363 callback_endpoint: self.default_callback_endpoint.clone(),
364 timeout_ms: timeout,
365 })
366 }
367 Some(s) => {
368 Err(NodeRouterError::ServiceNotRunning {
370 node_id: node_id.to_string(),
371 service_id: node_def.service_id,
372 status: s.status,
373 })
374 }
375 None => {
376 Err(NodeRouterError::NodeNotFound(node_id.to_string()))
378 }
379 }
380 }
381
382 pub async fn has_node(&self, node_id: &str) -> bool {
384 let index = self.node_index.read().await;
385 index.contains_key(node_id)
386 }
387
388 pub async fn list_nodes(&self) -> Vec<NodeDefinition> {
390 let index = self.node_index.read().await;
391 index.values().cloned().collect()
392 }
393
394 pub async fn get_node(&self, node_id: &str) -> Option<NodeDefinition> {
396 let index = self.node_index.read().await;
397 index.get(node_id).cloned()
398 }
399
400 pub async fn get_nodes_by_type(&self, node_type: NodeType) -> Vec<NodeDefinition> {
402 let index = self.node_index.read().await;
403 index
404 .values()
405 .filter(|def| def.node_type == node_type)
406 .cloned()
407 .collect()
408 }
409
410 pub async fn get_nodes_by_capability(&self, capability: NodeCapability) -> Vec<NodeDefinition> {
412 let index = self.node_index.read().await;
413 index
414 .values()
415 .filter(|def| def.capabilities.contains(&capability))
416 .cloned()
417 .collect()
418 }
419
420 pub fn create_node_request(&self, route_result: NodeRouteResult) -> JsonRpcRequest {
422 JsonRpcRequest::with_id("node.execute", route_result.request_id)
423 .params(serde_json::json!({
424 "node_id": route_result.node_id,
425 "context": route_result.context.to_json(),
426 "callback_endpoint": route_result.callback_endpoint
427 }))
428 }
429
430 pub fn create_error_response(
432 &self,
433 error: NodeRouterError,
434 request_id: JsonRpcId,
435 ) -> JsonRpcResponse {
436 let (code, message, data) = match error {
437 NodeRouterError::NodeNotFound(node) => {
438 let available: Vec<String> = {
439 Vec::new()
441 };
442 (
443 ErrorCode::RESOURCE_NOT_FOUND,
444 format!("Node '{}' not found", node),
445 Some(serde_json::json!({ "available_nodes": available })),
446 )
447 }
448 NodeRouterError::ServiceNotRunning { node_id, service_id, status } => (
449 ErrorCode::INVALID_STATE,
450 format!("Service '{}' is not running", service_id),
451 Some(serde_json::json!({
452 "node_id": node_id,
453 "service_id": service_id.to_string(),
454 "status": serde_json::to_string(&status).unwrap_or_default()
455 })),
456 ),
457 NodeRouterError::CapabilityNotSupported { node_id, capability } => (
458 ErrorCode::CAPABILITY_NOT_SUPPORTED,
459 format!("Node '{}' does not support capability '{}'", node_id, capability),
460 None,
461 ),
462 NodeRouterError::NoServicesRegistered => (
463 ErrorCode::RESOURCE_NOT_FOUND,
464 "No services registered".to_string(),
465 None,
466 ),
467 NodeRouterError::InvalidContext { node_id, reason } => (
468 ErrorCode::INVALID_PARAMS,
469 format!("Invalid context for node '{}'", node_id),
470 Some(serde_json::json!({ "reason": reason })),
471 ),
472 NodeRouterError::RoutingFailed(msg) | NodeRouterError::Internal(msg) => (
473 ErrorCode::INTERNAL_ERROR,
474 msg,
475 None,
476 ),
477 };
478
479 JsonRpcResponse::error(
480 request_id,
481 JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
482 )
483 }
484
485 pub async fn node_count(&self) -> usize {
487 let index = self.node_index.read().await;
488 index.len()
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[tokio::test]
497 async fn test_node_router_creation() {
498 let registry = Arc::new(RegistryService::new());
499 let router = NodeRouter::new(registry);
500 assert_eq!(router.default_timeout_ms, 60_000);
501 }
502
503 #[tokio::test]
504 async fn test_register_node() {
505 let registry = Arc::new(RegistryService::new());
506 let router = NodeRouter::new(registry);
507
508 let service_id = ServiceId::new("test-service");
509 let node_def = NodeDefinition {
510 id: "validate-node".to_string(),
511 name: "Validate Node".to_string(),
512 service_id: service_id.clone(),
513 node_type: NodeType::Validate,
514 description: Some("Validates input data".to_string()),
515 capabilities: vec![NodeCapability::ContextAccess],
516 timeout_ms: Some(10_000),
517 params_schema: None,
518 };
519
520 router.register_node(service_id, node_def).await;
521 assert!(router.has_node("validate-node").await);
522 }
523
524 #[tokio::test]
525 async fn test_node_types() {
526 assert_eq!(NodeType::Task.as_str(), "task");
527 assert_eq!(NodeType::from_str("condition"), Some(NodeType::Condition));
528 assert_eq!(NodeType::from_str("unknown"), None);
529 }
530
531 #[tokio::test]
532 async fn test_node_capabilities() {
533 assert_eq!(NodeCapability::AiExecution.as_str(), "ai_execution");
534 assert_eq!(
535 NodeCapability::from_str("tool_execution"),
536 Some(NodeCapability::ToolExecution)
537 );
538 }
539
540 #[tokio::test]
541 async fn test_node_context() {
542 let context = NodeContext::new()
543 .variable("key", serde_json::json!("value"))
544 .previous_result("prev-node", serde_json::json!({"result": "ok"}));
545
546 let json = context.to_json();
547 assert!(json.get("variables").is_some());
548 assert!(json.get("previous_results").is_some());
549 }
550
551 #[tokio::test]
552 async fn test_create_node_request() {
553 let registry = Arc::new(RegistryService::new());
554 let router = NodeRouter::new(registry);
555
556 let route_result = NodeRouteResult {
557 service_id: ServiceId::new("test-service"),
558 node_id: "test-node".to_string(),
559 context: NodeContext::new(),
560 request_id: JsonRpcId::Number(1),
561 callback_endpoint: "matrixcode://callback".to_string(),
562 timeout_ms: 30_000,
563 };
564
565 let request = router.create_node_request(route_result);
566 assert_eq!(request.method, "node.execute");
567 assert!(request.params.is_some());
568 }
569
570 #[tokio::test]
571 async fn test_route_node_not_found() {
572 let registry = Arc::new(RegistryService::new());
573 let router = NodeRouter::new(registry);
574
575 let result = router.route(
576 "unknown-node",
577 NodeContext::new(),
578 JsonRpcId::Number(1),
579 vec![],
580 ).await;
581
582 assert!(matches!(result, Err(NodeRouterError::NodeNotFound(_))));
583 }
584
585 #[tokio::test]
586 async fn test_list_nodes() {
587 let registry = Arc::new(RegistryService::new());
588 let router = NodeRouter::new(registry);
589
590 let service_id = ServiceId::new("test-service");
591 router.register_node(service_id.clone(), NodeDefinition {
592 id: "node1".to_string(),
593 name: "Node 1".to_string(),
594 service_id: service_id.clone(),
595 node_type: NodeType::Task,
596 description: None,
597 capabilities: vec![],
598 timeout_ms: None,
599 params_schema: None,
600 }).await;
601
602 router.register_node(service_id.clone(), NodeDefinition {
603 id: "node2".to_string(),
604 name: "Node 2".to_string(),
605 service_id: service_id.clone(),
606 node_type: NodeType::Condition,
607 description: None,
608 capabilities: vec![NodeCapability::AiExecution],
609 timeout_ms: None,
610 params_schema: None,
611 }).await;
612
613 let nodes = router.list_nodes().await;
614 assert_eq!(nodes.len(), 2);
615
616 let task_nodes = router.get_nodes_by_type(NodeType::Task).await;
617 assert_eq!(task_nodes.len(), 1);
618
619 let ai_nodes = router.get_nodes_by_capability(NodeCapability::AiExecution).await;
620 assert_eq!(ai_nodes.len(), 1);
621 }
622}