1use std::sync::Arc;
7
8use serde_json::Value as JsonValue;
9use tokio::sync::{broadcast, RwLock};
10
11use crate::matrixrpc::{
12 callback::{CallbackConfig, CallbackHandler, CallbackResult, CallbackType, SecurityValidator},
13 lifecycle::{LifecycleConfig, LifecycleManager},
14 protocol::{ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse},
15 registry::{RegistryBuilder, RegistryService, RegistryStats},
16 router::{NodeCapability, NodeContext, NodeDefinition, NodeRouter, NodeRouteResult},
17 service::{ExtensionService, ServiceId, ServiceStatus},
18 ToolDefinition, ToolRouter, ToolRouteResult,
19};
20
21#[derive(Debug, Clone)]
23pub struct GatewayConfig {
24 pub registry_heartbeat_timeout_secs: u64,
26
27 pub lifecycle_config: LifecycleConfig,
29
30 pub callback_config: CallbackConfig,
32
33 pub default_tool_timeout_ms: u64,
35
36 pub default_node_timeout_ms: u64,
38
39 pub auto_discovery: bool,
41
42 pub max_services: u32,
44}
45
46impl Default for GatewayConfig {
47 fn default() -> Self {
48 Self {
49 registry_heartbeat_timeout_secs: 60,
50 lifecycle_config: LifecycleConfig::default(),
51 callback_config: CallbackConfig::default(),
52 default_tool_timeout_ms: 30_000,
53 default_node_timeout_ms: 60_000,
54 auto_discovery: true,
55 max_services: 100,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Default)]
62pub struct GatewayStats {
63 pub total_services: usize,
65
66 pub running_services: usize,
68
69 pub total_tools: usize,
71
72 pub total_nodes: usize,
74
75 pub total_callbacks: usize,
77
78 pub total_errors: usize,
80}
81
82#[derive(Debug, thiserror::Error)]
84pub enum GatewayError {
85 #[error("Service '{0}' not found")]
87 ServiceNotFound(String),
88
89 #[error("Maximum service limit ({0}) exceeded")]
91 ServiceLimitExceeded(u32),
92
93 #[error("Registration failed: {0}")]
95 RegistrationFailed(String),
96
97 #[error("Routing failed: {0}")]
99 RoutingFailed(String),
100
101 #[error("Execution failed: {0}")]
103 ExecutionFailed(String),
104
105 #[error("Callback failed: {0}")]
107 CallbackFailed(String),
108
109 #[error("Internal error: {0}")]
111 Internal(String),
112}
113
114#[derive(Debug, Clone)]
116pub struct ServiceRegistrationRequest {
117 pub name: String,
119
120 pub version: String,
122
123 pub description: Option<String>,
125
126 pub tools: Vec<ToolDefinition>,
128
129 pub nodes: Vec<NodeDefinition>,
131
132 pub transport_type: String,
134
135 pub command: Option<String>,
137
138 pub args: Vec<String>,
140}
141
142impl ServiceRegistrationRequest {
143 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
145 Self {
146 name: name.into(),
147 version: version.into(),
148 description: None,
149 tools: Vec::new(),
150 nodes: Vec::new(),
151 transport_type: "stdio".to_string(),
152 command: None,
153 args: Vec::new(),
154 }
155 }
156
157 pub fn description(mut self, desc: impl Into<String>) -> Self {
159 self.description = Some(desc.into());
160 self
161 }
162
163 pub fn tool(mut self, tool: ToolDefinition) -> Self {
165 self.tools.push(tool);
166 self
167 }
168
169 pub fn node(mut self, node: NodeDefinition) -> Self {
171 self.nodes.push(node);
172 self
173 }
174
175 pub fn transport(mut self, transport_type: impl Into<String>) -> Self {
177 self.transport_type = transport_type.into();
178 self
179 }
180
181 pub fn command(mut self, cmd: impl Into<String>, args: Vec<String>) -> Self {
183 self.command = Some(cmd.into());
184 self.args = args;
185 self
186 }
187}
188
189pub struct ExtensionGateway {
194 config: GatewayConfig,
196
197 registry: Arc<RegistryService>,
199
200 lifecycle: Arc<LifecycleManager>,
202
203 tool_router: Arc<ToolRouter>,
205
206 node_router: Arc<NodeRouter>,
208
209 callback: Arc<CallbackHandler>,
211
212 security: Arc<SecurityValidator>,
214
215 stats: Arc<RwLock<GatewayStats>>,
217
218 event_tx: broadcast::Sender<GatewayEvent>,
220}
221
222#[derive(Debug, Clone)]
224pub enum GatewayEvent {
225 ServiceRegistered(ServiceId),
227
228 ServiceUnregistered(ServiceId),
230
231 ServiceStatusChanged {
233 service_id: ServiceId,
234 old_status: ServiceStatus,
235 new_status: ServiceStatus,
236 },
237
238 ToolRegistered {
240 tool_name: String,
241 service_id: ServiceId,
242 },
243
244 NodeRegistered {
246 node_id: String,
247 service_id: ServiceId,
248 },
249
250 CallbackReceived {
252 callback_type: CallbackType,
253 service_id: ServiceId,
254 },
255
256 Error(String),
258}
259
260impl ExtensionGateway {
261 pub fn new() -> Self {
263 Self::with_config(GatewayConfig::default())
264 }
265
266 pub fn with_config(config: GatewayConfig) -> Self {
268 let (event_tx, _) = broadcast::channel(256);
269
270 let registry = Arc::new(
272 RegistryBuilder::new()
273 .heartbeat_timeout(config.registry_heartbeat_timeout_secs)
274 .build(),
275 );
276
277 let lifecycle = Arc::new(
279 LifecycleManager::with_config(registry.clone(), config.lifecycle_config.clone()),
280 );
281
282 let tool_router = Arc::new(ToolRouter::new(registry.clone()));
284 let node_router = Arc::new(NodeRouter::new(registry.clone()));
285
286 let security = Arc::new(SecurityValidator::new());
288 let callback = Arc::new(
289 CallbackHandler::with_config(
290 security.clone(),
291 tool_router.clone(),
292 node_router.clone(),
293 config.callback_config.clone(),
294 ),
295 );
296
297 Self {
298 config,
299 registry,
300 lifecycle,
301 tool_router,
302 node_router,
303 callback,
304 security,
305 stats: Arc::new(RwLock::new(GatewayStats::default())),
306 event_tx,
307 }
308 }
309
310 pub fn subscribe(&self) -> broadcast::Receiver<GatewayEvent> {
312 self.event_tx.subscribe()
313 }
314
315 pub async fn register_service(&self, request: ServiceRegistrationRequest) -> Result<ServiceId, GatewayError> {
317 {
319 let stats = self.stats.read().await;
320 if stats.total_services >= self.config.max_services as usize {
321 return Err(GatewayError::ServiceLimitExceeded(self.config.max_services));
322 }
323 }
324
325 let mut service = ExtensionService::new(&request.name, &request.version);
327
328 if let Some(desc) = &request.description {
329 service = service.description(desc);
330 }
331
332 let transport_config = crate::matrixrpc::service::TransportConfig {
334 transport_type: if request.transport_type == "tcp" {
335 crate::matrixrpc::service::TransportType::Tcp
336 } else {
337 crate::matrixrpc::service::TransportType::Stdio
338 },
339 command: request.command.clone(),
340 args: request.args.clone(),
341 ..Default::default()
342 };
343 service = service.transport(transport_config);
344
345 let tools_count = request.tools.len();
347 let nodes_count = request.nodes.len();
348
349 let service_id = self
351 .lifecycle
352 .start_service(service)
353 .await
354 .map_err(|e| GatewayError::RegistrationFailed(e.to_string()))?;
355
356 for tool in request.tools {
358 self.tool_router
359 .register_tool(service_id.clone(), tool.clone())
360 .await;
361 let _ = self.event_tx.send(GatewayEvent::ToolRegistered {
362 tool_name: tool.name,
363 service_id: service_id.clone(),
364 });
365 }
366
367 for node in request.nodes {
369 self.node_router
370 .register_node(service_id.clone(), node.clone())
371 .await;
372 let _ = self.event_tx.send(GatewayEvent::NodeRegistered {
373 node_id: node.id,
374 service_id: service_id.clone(),
375 });
376 }
377
378 {
380 let mut stats = self.stats.write().await;
381 stats.total_services += 1;
382 stats.total_tools += tools_count;
383 stats.total_nodes += nodes_count;
384 }
385
386 let _ = self.event_tx.send(GatewayEvent::ServiceRegistered(service_id.clone()));
388
389 Ok(service_id)
390 }
391
392 pub async fn unregister_service(&self, service_id: &ServiceId) -> Result<(), GatewayError> {
394 self.tool_router.unregister_service_tools(service_id).await;
396 self.node_router.unregister_service_nodes(service_id).await;
397
398 self.lifecycle
400 .stop_service(service_id)
401 .await
402 .map_err(|e| GatewayError::Internal(e.to_string()))?;
403
404 self.security.invalidate_service_tokens(service_id).await;
406
407 {
409 let mut stats = self.stats.write().await;
410 stats.total_services -= 1;
411 }
412
413 let _ = self.event_tx.send(GatewayEvent::ServiceUnregistered(service_id.clone()));
415
416 Ok(())
417 }
418
419 pub async fn execute_tool(
421 &self,
422 tool_name: &str,
423 params: JsonValue,
424 request_id: JsonRpcId,
425 ) -> Result<ToolRouteResult, GatewayError> {
426 let route_result = self
428 .tool_router
429 .route(tool_name, params, request_id)
430 .await
431 .map_err(|e| GatewayError::RoutingFailed(e.to_string()))?;
432
433 Ok(route_result)
437 }
438
439 pub async fn execute_node(
441 &self,
442 node_id: &str,
443 context: NodeContext,
444 request_id: JsonRpcId,
445 required_capabilities: Vec<NodeCapability>,
446 ) -> Result<NodeRouteResult, GatewayError> {
447 let callback_types = self.get_callback_types_for_capabilities(&required_capabilities);
449
450 let route_result = self
452 .node_router
453 .route(node_id, context, request_id, required_capabilities)
454 .await
455 .map_err(|e| GatewayError::RoutingFailed(e.to_string()))?;
456
457 let _token = self
459 .callback
460 .generate_token(route_result.service_id.clone(), route_result.request_id.to_string(), callback_types)
461 .await
462 .map_err(|e| GatewayError::CallbackFailed(e.to_string()))?;
463
464 Ok(route_result)
467 }
468
469 pub async fn handle_callback(&self, request: JsonRpcRequest) -> Result<CallbackResult, GatewayError> {
471 {
473 let mut stats = self.stats.write().await;
474 stats.total_callbacks += 1;
475 }
476
477 let result = self
479 .callback
480 .handle_request(request.clone())
481 .await;
482
483 match result {
484 Ok(res) => {
485 if let Some(params) = &request.params {
487 if let Some(service_id) = params.get("service_id").and_then(|v| v.as_str()) {
488 let _ = self.event_tx.send(GatewayEvent::CallbackReceived {
489 callback_type: res.callback_type(),
490 service_id: ServiceId::new(service_id),
491 });
492 }
493 }
494 Ok(res)
495 }
496 Err(e) => {
497 {
499 let mut stats = self.stats.write().await;
500 stats.total_errors += 1;
501 }
502 Err(GatewayError::CallbackFailed(e.to_string()))
503 }
504 }
505 }
506
507 fn get_callback_types_for_capabilities(&self, capabilities: &[NodeCapability]) -> Vec<String> {
509 let mut types = Vec::new();
510
511 for cap in capabilities {
512 match cap {
513 NodeCapability::AiExecution => types.push("ai".to_string()),
514 NodeCapability::ToolExecution => types.push("tool".to_string()),
515 NodeCapability::ContextAccess => types.push("context".to_string()),
516 }
517 }
518
519 types
520 }
521
522 pub async fn list_services(&self) -> Vec<ExtensionService> {
524 self.registry.list_all().await
525 }
526
527 pub async fn get_service(&self, service_id: &ServiceId) -> Option<ExtensionService> {
529 self.registry.get(service_id).await
530 }
531
532 pub async fn get_service_by_name(&self, name: &str) -> Option<ExtensionService> {
534 self.registry.get_by_name(name).await
535 }
536
537 pub async fn list_tools(&self) -> Vec<ToolDefinition> {
539 self.tool_router.list_tools().await
540 }
541
542 pub async fn list_nodes(&self) -> Vec<NodeDefinition> {
544 self.node_router.list_nodes().await
545 }
546
547 pub async fn has_tool(&self, tool_name: &str) -> bool {
549 self.tool_router.has_tool(tool_name).await
550 }
551
552 pub async fn has_node(&self, node_id: &str) -> bool {
554 self.node_router.has_node(node_id).await
555 }
556
557 pub async fn registry_stats(&self) -> RegistryStats {
559 self.registry.stats().await
560 }
561
562 pub async fn gateway_stats(&self) -> GatewayStats {
564 self.stats.read().await.clone()
565 }
566
567 pub async fn health_check(&self) -> Vec<ServiceId> {
569 self.lifecycle.health_check().await
570 }
571
572 pub async fn heartbeat(&self, service_id: &ServiceId) -> Result<(), GatewayError> {
574 self.lifecycle
575 .handle_heartbeat(service_id)
576 .await
577 .map_err(|e| GatewayError::Internal(e.to_string()))
578 }
579
580 pub async fn get_service_status(&self, service_id: &ServiceId) -> Option<ServiceStatus> {
582 self.lifecycle.get_status(service_id).await
583 }
584
585 pub async fn stop_all(&self) {
587 self.lifecycle.stop_all().await;
588 self.registry.clear().await;
589 self.security.cleanup_expired().await;
590
591 let mut stats = self.stats.write().await;
592 stats.total_services = 0;
593 stats.total_tools = 0;
594 stats.total_nodes = 0;
595 }
596
597 pub fn create_error_response(&self, error: GatewayError, id: JsonRpcId) -> JsonRpcResponse {
599 let (code, message, data) = match error {
600 GatewayError::ServiceNotFound(id) => (
601 ErrorCode::RESOURCE_NOT_FOUND,
602 format!("Service '{}' not found", id),
603 None,
604 ),
605 GatewayError::ServiceLimitExceeded(limit) => (
606 ErrorCode::PERMISSION_DENIED,
607 format!("Maximum service limit ({}) exceeded", limit),
608 None,
609 ),
610 GatewayError::RegistrationFailed(msg) => (
611 ErrorCode::INTERNAL_ERROR,
612 "Registration failed".to_string(),
613 Some(serde_json::json!({ "reason": msg })),
614 ),
615 GatewayError::RoutingFailed(msg) => (
616 ErrorCode::INTERNAL_ERROR,
617 "Routing failed".to_string(),
618 Some(serde_json::json!({ "reason": msg })),
619 ),
620 GatewayError::ExecutionFailed(msg) => (
621 ErrorCode::INTERNAL_ERROR,
622 "Execution failed".to_string(),
623 Some(serde_json::json!({ "reason": msg })),
624 ),
625 GatewayError::CallbackFailed(msg) => (
626 ErrorCode::CALLBACK_ERROR,
627 "Callback failed".to_string(),
628 Some(serde_json::json!({ "reason": msg })),
629 ),
630 GatewayError::Internal(msg) => (
631 ErrorCode::INTERNAL_ERROR,
632 msg,
633 None,
634 ),
635 };
636
637 JsonRpcResponse::error(
638 id,
639 JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
640 )
641 }
642
643 pub fn registry(&self) -> Arc<RegistryService> {
645 self.registry.clone()
646 }
647
648 pub fn lifecycle(&self) -> Arc<LifecycleManager> {
650 self.lifecycle.clone()
651 }
652
653 pub fn tool_router(&self) -> Arc<ToolRouter> {
655 self.tool_router.clone()
656 }
657
658 pub fn node_router(&self) -> Arc<NodeRouter> {
660 self.node_router.clone()
661 }
662
663 pub fn callback(&self) -> Arc<CallbackHandler> {
665 self.callback.clone()
666 }
667
668 pub fn security(&self) -> Arc<SecurityValidator> {
670 self.security.clone()
671 }
672}
673
674impl Default for ExtensionGateway {
675 fn default() -> Self {
676 Self::new()
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use crate::matrixrpc::NodeType;
684
685 #[tokio::test]
686 async fn test_gateway_creation() {
687 let gateway = ExtensionGateway::new();
688 let stats = gateway.gateway_stats().await;
689
690 assert_eq!(stats.total_services, 0);
691 assert_eq!(stats.total_tools, 0);
692 assert_eq!(stats.total_nodes, 0);
693 }
694
695 #[tokio::test]
696 async fn test_gateway_with_config() {
697 let config = GatewayConfig {
698 max_services: 10,
699 default_tool_timeout_ms: 10_000,
700 ..Default::default()
701 };
702
703 let gateway = ExtensionGateway::with_config(config);
704 assert_eq!(gateway.config.max_services, 10);
705 }
706
707 #[tokio::test]
708 async fn test_register_service() {
709 let gateway = ExtensionGateway::new();
710
711 let request = ServiceRegistrationRequest::new("test-service", "1.0.0")
712 .description("A test service");
713
714 let service_id = gateway.register_service(request).await.unwrap();
715 assert!(gateway.get_service(&service_id).await.is_some());
716 }
717
718 #[tokio::test]
719 async fn test_unregister_service() {
720 let gateway = ExtensionGateway::new();
721
722 let request = ServiceRegistrationRequest::new("test-service", "1.0.0");
723 let service_id = gateway.register_service(request).await.unwrap();
724
725 gateway.unregister_service(&service_id).await.unwrap();
726 assert!(gateway.get_service(&service_id).await.is_none());
727 }
728
729 #[tokio::test]
730 async fn test_register_service_with_tools() {
731 let gateway = ExtensionGateway::new();
732
733 let request = ServiceRegistrationRequest::new("test-service", "1.0.0")
734 .tool(ToolDefinition {
735 name: "test_tool".to_string(),
736 service_id: ServiceId::generate(),
737 description: Some("Test tool".to_string()),
738 risk_level: Some("safe".to_string()),
739 timeout_ms: Some(5000),
740 });
741
742 let service_id = gateway.register_service(request).await.unwrap();
743
744 assert!(gateway.has_tool("test_tool").await);
746 }
747
748 #[tokio::test]
749 async fn test_register_service_with_nodes() {
750 let gateway = ExtensionGateway::new();
751
752 let request = ServiceRegistrationRequest::new("test-service", "1.0.0")
753 .node(NodeDefinition {
754 id: "test_node".to_string(),
755 name: "Test Node".to_string(),
756 service_id: ServiceId::generate(),
757 node_type: NodeType::Task,
758 description: Some("Test node".to_string()),
759 capabilities: vec![NodeCapability::AiExecution],
760 timeout_ms: Some(10_000),
761 params_schema: None,
762 });
763
764 let service_id = gateway.register_service(request).await.unwrap();
765
766 assert!(gateway.has_node("test_node").await);
768 }
769
770 #[tokio::test]
771 async fn test_list_services() {
772 let gateway = ExtensionGateway::new();
773
774 gateway
775 .register_service(ServiceRegistrationRequest::new("service1", "1.0.0"))
776 .await
777 .unwrap();
778 gateway
779 .register_service(ServiceRegistrationRequest::new("service2", "1.0.0"))
780 .await
781 .unwrap();
782
783 let services = gateway.list_services().await;
784 assert_eq!(services.len(), 2);
785 }
786
787 #[tokio::test]
788 async fn test_list_tools() {
789 let gateway = ExtensionGateway::new();
790
791 let request = ServiceRegistrationRequest::new("test", "1.0")
792 .tool(ToolDefinition {
793 name: "tool1".to_string(),
794 service_id: ServiceId::generate(),
795 description: None,
796 risk_level: None,
797 timeout_ms: None,
798 })
799 .tool(ToolDefinition {
800 name: "tool2".to_string(),
801 service_id: ServiceId::generate(),
802 description: None,
803 risk_level: None,
804 timeout_ms: None,
805 });
806
807 gateway.register_service(request).await.unwrap();
808
809 let tools = gateway.list_tools().await;
810 assert_eq!(tools.len(), 2);
811 }
812
813 #[tokio::test]
814 async fn test_service_limit() {
815 let config = GatewayConfig {
816 max_services: 2,
817 ..Default::default()
818 };
819
820 let gateway = ExtensionGateway::with_config(config);
821
822 gateway
824 .register_service(ServiceRegistrationRequest::new("s1", "1.0"))
825 .await
826 .unwrap();
827 gateway
828 .register_service(ServiceRegistrationRequest::new("s2", "1.0"))
829 .await
830 .unwrap();
831
832 let result = gateway
834 .register_service(ServiceRegistrationRequest::new("s3", "1.0"))
835 .await;
836
837 assert!(matches!(result, Err(GatewayError::ServiceLimitExceeded(2))));
838 }
839
840 #[tokio::test]
841 async fn test_gateway_events() {
842 let gateway = ExtensionGateway::new();
843 let mut event_rx = gateway.subscribe();
844
845 let request = ServiceRegistrationRequest::new("test-service", "1.0.0");
846 gateway.register_service(request).await.unwrap();
847
848 let event = event_rx.try_recv();
850 assert!(event.is_ok());
851 }
852
853 #[tokio::test]
854 async fn test_health_check() {
855 let gateway = ExtensionGateway::new();
856
857 let request = ServiceRegistrationRequest::new("test-service", "1.0.0");
858 gateway.register_service(request).await.unwrap();
859
860 let unhealthy = gateway.health_check().await;
861 }
864
865 #[tokio::test]
866 async fn test_stop_all() {
867 let gateway = ExtensionGateway::new();
868
869 gateway
870 .register_service(ServiceRegistrationRequest::new("s1", "1.0"))
871 .await
872 .unwrap();
873 gateway
874 .register_service(ServiceRegistrationRequest::new("s2", "1.0"))
875 .await
876 .unwrap();
877
878 gateway.stop_all().await;
879
880 let services = gateway.list_services().await;
881 assert!(services.is_empty());
882 }
883}