Skip to main content

matrixcode_core/matrixrpc/router/
node_router.rs

1//! Node Router Implementation
2//!
3//! Routes workflow node execution requests to the appropriate extension service.
4//! Handles node discovery, capability checking, and callback endpoint setup.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use serde_json::Value as JsonValue;
11
12use crate::matrixrpc::{
13    ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
14    RegistryService, ServiceId, ServiceStatus,
15};
16
17/// Error type for node routing operations
18#[derive(Debug, thiserror::Error)]
19pub enum NodeRouterError {
20    /// Node not found in any registered service
21    #[error("Node '{0}' not found in any registered service")]
22    NodeNotFound(String),
23
24    /// Service that provides the node is not running
25    #[error("Service '{service_id}' for node '{node_id}' is not running (status: {status:?})")]
26    ServiceNotRunning {
27        node_id: String,
28        service_id: ServiceId,
29        status: ServiceStatus,
30    },
31
32    /// No services registered
33    #[error("No services registered in the registry")]
34    NoServicesRegistered,
35
36    /// Routing failed
37    #[error("Routing failed: {0}")]
38    RoutingFailed(String),
39
40    /// Node does not support required capability
41    #[error("Node '{node_id}' does not support capability '{capability}'")]
42    CapabilityNotSupported { node_id: String, capability: String },
43
44    /// Invalid node context
45    #[error("Invalid context for node '{node_id}': {reason}")]
46    InvalidContext { node_id: String, reason: String },
47
48    /// Internal error
49    #[error("Internal error: {0}")]
50    Internal(String),
51}
52
53/// Result of a node routing operation
54#[derive(Debug, Clone)]
55pub struct NodeRouteResult {
56    /// The service ID that should handle the node execution
57    pub service_id: ServiceId,
58    /// The node ID
59    pub node_id: String,
60    /// The execution context
61    pub context: NodeContext,
62    /// The original request ID for correlation
63    pub request_id: JsonRpcId,
64    /// Callback endpoint URL
65    pub callback_endpoint: String,
66    /// Timeout for node execution (milliseconds)
67    pub timeout_ms: u64,
68}
69
70/// Node execution context
71#[derive(Debug, Clone, Default)]
72pub struct NodeContext {
73    /// Input data for the node
74    pub input_data: JsonValue,
75    /// Workflow variables
76    pub variables: HashMap<String, JsonValue>,
77    /// Previous node results
78    pub previous_results: HashMap<String, JsonValue>,
79    /// Custom metadata
80    pub metadata: HashMap<String, JsonValue>,
81}
82
83impl NodeContext {
84    /// Create a new empty context
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Create a context with input data
90    pub fn with_input(data: JsonValue) -> Self {
91        Self {
92            input_data: data,
93            ..Default::default()
94        }
95    }
96
97    /// Add a variable to the context
98    pub fn variable(mut self, key: impl Into<String>, value: JsonValue) -> Self {
99        self.variables.insert(key.into(), value);
100        self
101    }
102
103    /// Add a previous result
104    pub fn previous_result(mut self, node_id: impl Into<String>, result: JsonValue) -> Self {
105        self.previous_results.insert(node_id.into(), result);
106        self
107    }
108
109    /// Add metadata
110    pub fn metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
111        self.metadata.insert(key.into(), value);
112        self
113    }
114
115    /// Convert to JSON value
116    pub fn to_json(&self) -> JsonValue {
117        serde_json::json!({
118            "input_data": self.input_data,
119            "variables": self.variables,
120            "previous_results": self.previous_results,
121            "metadata": self.metadata
122        })
123    }
124}
125
126/// Node type enumeration
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum NodeType {
129    /// Task node - executes a specific task
130    Task,
131    /// Condition node - evaluates conditions
132    Condition,
133    /// Validation node - validates data
134    Validate,
135    /// AI node - calls AI for processing
136    Ai,
137    /// Composite node - combines multiple operations
138    Composite,
139}
140
141impl NodeType {
142    /// Get the string representation
143    pub fn as_str(&self) -> &'static str {
144        match self {
145            NodeType::Task => "task",
146            NodeType::Condition => "condition",
147            NodeType::Validate => "validate",
148            NodeType::Ai => "ai",
149            NodeType::Composite => "composite",
150        }
151    }
152
153    /// Parse from string
154    pub fn from_str(s: &str) -> Option<Self> {
155        match s {
156            "task" => Some(NodeType::Task),
157            "condition" => Some(NodeType::Condition),
158            "validate" => Some(NodeType::Validate),
159            "ai" => Some(NodeType::Ai),
160            "composite" => Some(NodeType::Composite),
161            _ => None,
162        }
163    }
164}
165
166/// Node capability flags
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub enum NodeCapability {
169    /// Can call AI
170    AiExecution,
171    /// Can execute tools
172    ToolExecution,
173    /// Can access workflow context
174    ContextAccess,
175}
176
177impl NodeCapability {
178    /// Get the string representation
179    pub fn as_str(&self) -> &'static str {
180        match self {
181            NodeCapability::AiExecution => "ai_execution",
182            NodeCapability::ToolExecution => "tool_execution",
183            NodeCapability::ContextAccess => "context_access",
184        }
185    }
186
187    /// Parse from string
188    pub fn from_str(s: &str) -> Option<Self> {
189        match s {
190            "ai_execution" => Some(NodeCapability::AiExecution),
191            "tool_execution" => Some(NodeCapability::ToolExecution),
192            "context_access" => Some(NodeCapability::ContextAccess),
193            _ => None,
194        }
195    }
196}
197
198/// Node definition from an extension service
199#[derive(Debug, Clone)]
200pub struct NodeDefinition {
201    /// Node ID
202    pub id: String,
203    /// Node name
204    pub name: String,
205    /// Service ID that provides this node
206    pub service_id: ServiceId,
207    /// Node type
208    pub node_type: NodeType,
209    /// Node description
210    pub description: Option<String>,
211    /// Supported capabilities
212    pub capabilities: Vec<NodeCapability>,
213    /// Timeout in milliseconds
214    pub timeout_ms: Option<u64>,
215    /// Parameter schema (JSON Schema)
216    pub params_schema: Option<JsonValue>,
217}
218
219/// Node Router
220///
221/// Routes workflow node execution requests to the appropriate extension service.
222/// Uses the registry service to discover which services provide each node.
223#[derive(Debug)]
224pub struct NodeRouter {
225    /// Reference to the registry service
226    registry: Arc<RegistryService>,
227    /// Node ID to definition mapping (cached)
228    node_index: Arc<RwLock<HashMap<String, NodeDefinition>>>,
229    /// Default timeout for node execution (milliseconds)
230    default_timeout_ms: u64,
231    /// Default callback endpoint
232    default_callback_endpoint: String,
233}
234
235impl NodeRouter {
236    /// Create a new node router with a registry service
237    pub fn new(registry: Arc<RegistryService>) -> Self {
238        Self {
239            registry,
240            node_index: Arc::new(RwLock::new(HashMap::new())),
241            default_timeout_ms: 60_000, // 60 seconds default for nodes
242            default_callback_endpoint: "matrixcode://callback".to_string(),
243        }
244    }
245
246    /// Set default timeout for node execution
247    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
248        self.default_timeout_ms = timeout_ms;
249        self
250    }
251
252    /// Set default callback endpoint
253    pub fn with_callback_endpoint(mut self, endpoint: impl Into<String>) -> Self {
254        self.default_callback_endpoint = endpoint.into();
255        self
256    }
257
258    /// Register a node from a service
259    pub async fn register_node(&self, _service_id: ServiceId, node_def: NodeDefinition) {
260        let mut index = self.node_index.write().await;
261        index.insert(node_def.id.clone(), node_def);
262    }
263
264    /// Unregister all nodes from a service
265    pub async fn unregister_service_nodes(&self, service_id: &ServiceId) {
266        let mut index = self.node_index.write().await;
267        index.retain(|_, def| def.service_id != *service_id);
268    }
269
270    /// Rebuild the node index from the registry
271    pub async fn rebuild_index(&self) -> Result<(), NodeRouterError> {
272        let services = self.registry.list_all().await;
273        let mut index = self.node_index.write().await;
274        index.clear();
275
276        for service in services {
277            if service.status != ServiceStatus::Running {
278                continue;
279            }
280
281            // Extract nodes from service capabilities
282            for cap in &service.capabilities {
283                if cap.name == "nodes" {
284                    // Parse nodes from capability config
285                    if let Some(nodes_json) = cap.config.get("nodes") {
286                        if let Ok(nodes) = serde_json::from_value::<Vec<JsonValue>>(nodes_json.clone()) {
287                            for node in nodes {
288                                if let Some(id) = node.get("id").and_then(|n| n.as_str()) {
289                                    let node_type = node
290                                        .get("type")
291                                        .and_then(|t| t.as_str())
292                                        .and_then(NodeType::from_str)
293                                        .unwrap_or(NodeType::Task);
294
295                                    let capabilities: Vec<NodeCapability> = node
296                                        .get("capabilities")
297                                        .and_then(|c| c.as_array())
298                                        .map(|arr| {
299                                            arr.iter()
300                                                .filter_map(|c| c.as_str().and_then(NodeCapability::from_str))
301                                                .collect()
302                                        })
303                                        .unwrap_or_default();
304
305                                    let def = NodeDefinition {
306                                        id: id.to_string(),
307                                        name: node.get("name").and_then(|n| n.as_str()).map(|s| s.to_string()).unwrap_or_else(|| id.to_string()),
308                                        service_id: service.id.clone(),
309                                        node_type,
310                                        description: node.get("description").and_then(|d| d.as_str()).map(|s| s.to_string()),
311                                        capabilities,
312                                        timeout_ms: node.get("timeout_ms").and_then(|t| t.as_u64()),
313                                        params_schema: node.get("params_schema").cloned(),
314                                    };
315                                    index.insert(id.to_string(), def);
316                                }
317                            }
318                        }
319                    }
320                }
321            }
322        }
323
324        Ok(())
325    }
326
327    /// Route a node execution request
328    pub async fn route(
329        &self,
330        node_id: &str,
331        context: NodeContext,
332        request_id: JsonRpcId,
333        required_capabilities: Vec<NodeCapability>,
334    ) -> Result<NodeRouteResult, NodeRouterError> {
335        // Look up the node in the index
336        let index = self.node_index.read().await;
337        let node_def = index
338            .get(node_id)
339            .cloned()
340            .ok_or_else(|| NodeRouterError::NodeNotFound(node_id.to_string()))?;
341
342        // Check required capabilities
343        for cap in required_capabilities {
344            if !node_def.capabilities.contains(&cap) {
345                return Err(NodeRouterError::CapabilityNotSupported {
346                    node_id: node_id.to_string(),
347                    capability: cap.as_str().to_string(),
348                });
349            }
350        }
351
352        // Check if the service is running
353        let service = self.registry.get(&node_def.service_id).await;
354        match service {
355            Some(s) if s.status == ServiceStatus::Running => {
356                // Service is healthy, proceed with routing
357                let timeout = node_def.timeout_ms.unwrap_or(self.default_timeout_ms);
358                Ok(NodeRouteResult {
359                    service_id: node_def.service_id,
360                    node_id: node_def.id,
361                    context,
362                    request_id,
363                    callback_endpoint: self.default_callback_endpoint.clone(),
364                    timeout_ms: timeout,
365                })
366            }
367            Some(s) => {
368                // Service exists but not running
369                Err(NodeRouterError::ServiceNotRunning {
370                    node_id: node_id.to_string(),
371                    service_id: node_def.service_id,
372                    status: s.status,
373                })
374            }
375            None => {
376                // Service not found in registry
377                Err(NodeRouterError::NodeNotFound(node_id.to_string()))
378            }
379        }
380    }
381
382    /// Check if a node is available
383    pub async fn has_node(&self, node_id: &str) -> bool {
384        let index = self.node_index.read().await;
385        index.contains_key(node_id)
386    }
387
388    /// List all available nodes
389    pub async fn list_nodes(&self) -> Vec<NodeDefinition> {
390        let index = self.node_index.read().await;
391        index.values().cloned().collect()
392    }
393
394    /// Get node definition
395    pub async fn get_node(&self, node_id: &str) -> Option<NodeDefinition> {
396        let index = self.node_index.read().await;
397        index.get(node_id).cloned()
398    }
399
400    /// Get nodes by type
401    pub async fn get_nodes_by_type(&self, node_type: NodeType) -> Vec<NodeDefinition> {
402        let index = self.node_index.read().await;
403        index
404            .values()
405            .filter(|def| def.node_type == node_type)
406            .cloned()
407            .collect()
408    }
409
410    /// Get nodes by capability
411    pub async fn get_nodes_by_capability(&self, capability: NodeCapability) -> Vec<NodeDefinition> {
412        let index = self.node_index.read().await;
413        index
414            .values()
415            .filter(|def| def.capabilities.contains(&capability))
416            .cloned()
417            .collect()
418    }
419
420    /// Create a JSON-RPC request for node execution
421    pub fn create_node_request(&self, route_result: NodeRouteResult) -> JsonRpcRequest {
422        JsonRpcRequest::with_id("node.execute", route_result.request_id)
423            .params(serde_json::json!({
424                "node_id": route_result.node_id,
425                "context": route_result.context.to_json(),
426                "callback_endpoint": route_result.callback_endpoint
427            }))
428    }
429
430    /// Create an error response for routing failures
431    pub fn create_error_response(
432        &self,
433        error: NodeRouterError,
434        request_id: JsonRpcId,
435    ) -> JsonRpcResponse {
436        let (code, message, data) = match error {
437            NodeRouterError::NodeNotFound(node) => {
438                let available: Vec<String> = {
439                    // We can't read async here, so use empty list for simplicity
440                    Vec::new()
441                };
442                (
443                    ErrorCode::RESOURCE_NOT_FOUND,
444                    format!("Node '{}' not found", node),
445                    Some(serde_json::json!({ "available_nodes": available })),
446                )
447            }
448            NodeRouterError::ServiceNotRunning { node_id, service_id, status } => (
449                ErrorCode::INVALID_STATE,
450                format!("Service '{}' is not running", service_id),
451                Some(serde_json::json!({
452                    "node_id": node_id,
453                    "service_id": service_id.to_string(),
454                    "status": serde_json::to_string(&status).unwrap_or_default()
455                })),
456            ),
457            NodeRouterError::CapabilityNotSupported { node_id, capability } => (
458                ErrorCode::CAPABILITY_NOT_SUPPORTED,
459                format!("Node '{}' does not support capability '{}'", node_id, capability),
460                None,
461            ),
462            NodeRouterError::NoServicesRegistered => (
463                ErrorCode::RESOURCE_NOT_FOUND,
464                "No services registered".to_string(),
465                None,
466            ),
467            NodeRouterError::InvalidContext { node_id, reason } => (
468                ErrorCode::INVALID_PARAMS,
469                format!("Invalid context for node '{}'", node_id),
470                Some(serde_json::json!({ "reason": reason })),
471            ),
472            NodeRouterError::RoutingFailed(msg) | NodeRouterError::Internal(msg) => (
473                ErrorCode::INTERNAL_ERROR,
474                msg,
475                None,
476            ),
477        };
478
479        JsonRpcResponse::error(
480            request_id,
481            JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
482        )
483    }
484
485    /// Get the count of registered nodes
486    pub async fn node_count(&self) -> usize {
487        let index = self.node_index.read().await;
488        index.len()
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[tokio::test]
497    async fn test_node_router_creation() {
498        let registry = Arc::new(RegistryService::new());
499        let router = NodeRouter::new(registry);
500        assert_eq!(router.default_timeout_ms, 60_000);
501    }
502
503    #[tokio::test]
504    async fn test_register_node() {
505        let registry = Arc::new(RegistryService::new());
506        let router = NodeRouter::new(registry);
507
508        let service_id = ServiceId::new("test-service");
509        let node_def = NodeDefinition {
510            id: "validate-node".to_string(),
511            name: "Validate Node".to_string(),
512            service_id: service_id.clone(),
513            node_type: NodeType::Validate,
514            description: Some("Validates input data".to_string()),
515            capabilities: vec![NodeCapability::ContextAccess],
516            timeout_ms: Some(10_000),
517            params_schema: None,
518        };
519
520        router.register_node(service_id, node_def).await;
521        assert!(router.has_node("validate-node").await);
522    }
523
524    #[tokio::test]
525    async fn test_node_types() {
526        assert_eq!(NodeType::Task.as_str(), "task");
527        assert_eq!(NodeType::from_str("condition"), Some(NodeType::Condition));
528        assert_eq!(NodeType::from_str("unknown"), None);
529    }
530
531    #[tokio::test]
532    async fn test_node_capabilities() {
533        assert_eq!(NodeCapability::AiExecution.as_str(), "ai_execution");
534        assert_eq!(
535            NodeCapability::from_str("tool_execution"),
536            Some(NodeCapability::ToolExecution)
537        );
538    }
539
540    #[tokio::test]
541    async fn test_node_context() {
542        let context = NodeContext::new()
543            .variable("key", serde_json::json!("value"))
544            .previous_result("prev-node", serde_json::json!({"result": "ok"}));
545
546        let json = context.to_json();
547        assert!(json.get("variables").is_some());
548        assert!(json.get("previous_results").is_some());
549    }
550
551    #[tokio::test]
552    async fn test_create_node_request() {
553        let registry = Arc::new(RegistryService::new());
554        let router = NodeRouter::new(registry);
555
556        let route_result = NodeRouteResult {
557            service_id: ServiceId::new("test-service"),
558            node_id: "test-node".to_string(),
559            context: NodeContext::new(),
560            request_id: JsonRpcId::Number(1),
561            callback_endpoint: "matrixcode://callback".to_string(),
562            timeout_ms: 30_000,
563        };
564
565        let request = router.create_node_request(route_result);
566        assert_eq!(request.method, "node.execute");
567        assert!(request.params.is_some());
568    }
569
570    #[tokio::test]
571    async fn test_route_node_not_found() {
572        let registry = Arc::new(RegistryService::new());
573        let router = NodeRouter::new(registry);
574
575        let result = router.route(
576            "unknown-node",
577            NodeContext::new(),
578            JsonRpcId::Number(1),
579            vec![],
580        ).await;
581
582        assert!(matches!(result, Err(NodeRouterError::NodeNotFound(_))));
583    }
584
585    #[tokio::test]
586    async fn test_list_nodes() {
587        let registry = Arc::new(RegistryService::new());
588        let router = NodeRouter::new(registry);
589
590        let service_id = ServiceId::new("test-service");
591        router.register_node(service_id.clone(), NodeDefinition {
592            id: "node1".to_string(),
593            name: "Node 1".to_string(),
594            service_id: service_id.clone(),
595            node_type: NodeType::Task,
596            description: None,
597            capabilities: vec![],
598            timeout_ms: None,
599            params_schema: None,
600        }).await;
601
602        router.register_node(service_id.clone(), NodeDefinition {
603            id: "node2".to_string(),
604            name: "Node 2".to_string(),
605            service_id: service_id.clone(),
606            node_type: NodeType::Condition,
607            description: None,
608            capabilities: vec![NodeCapability::AiExecution],
609            timeout_ms: None,
610            params_schema: None,
611        }).await;
612
613        let nodes = router.list_nodes().await;
614        assert_eq!(nodes.len(), 2);
615
616        let task_nodes = router.get_nodes_by_type(NodeType::Task).await;
617        assert_eq!(task_nodes.len(), 1);
618
619        let ai_nodes = router.get_nodes_by_capability(NodeCapability::AiExecution).await;
620        assert_eq!(ai_nodes.len(), 1);
621    }
622}