Skip to main content

dynamo_runtime/
engine_routes.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::{Arc, RwLock};
8
9/// Callback type for engine routes (async)
10/// Takes JSON body, returns JSON response (or error) wrapped in a Future
11pub type EngineRouteCallback = Arc<
12    dyn Fn(
13            serde_json::Value,
14        ) -> Pin<Box<dyn Future<Output = anyhow::Result<serde_json::Value>> + Send>>
15        + Send
16        + Sync,
17>;
18
19/// Registry for engine route callbacks
20///
21/// This registry stores callbacks that handle requests to `/engine/*` routes.
22/// Routes are registered from Python via `runtime.register_engine_route()`.
23#[derive(Clone, Default)]
24pub struct EngineRouteRegistry {
25    routes: Arc<RwLock<HashMap<String, EngineRouteCallback>>>,
26}
27
28impl EngineRouteRegistry {
29    /// Create a new empty registry
30    pub fn new() -> Self {
31        Self {
32            routes: Arc::new(RwLock::new(HashMap::new())),
33        }
34    }
35
36    /// Register a callback for a route (e.g., "start_profile" for /engine/start_profile)
37    pub fn register(&self, route: &str, callback: EngineRouteCallback) {
38        let mut routes = self.routes.write().unwrap();
39        routes.insert(route.to_string(), callback);
40        tracing::debug!("Registered engine route: /engine/{}", route);
41    }
42
43    /// Get callback for a route
44    pub fn get(&self, route: &str) -> Option<EngineRouteCallback> {
45        let routes = self.routes.read().unwrap();
46        routes.get(route).cloned()
47    }
48
49    /// List all registered routes
50    pub fn routes(&self) -> Vec<String> {
51        let routes = self.routes.read().unwrap();
52        routes.keys().cloned().collect()
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[tokio::test]
61    async fn test_registry_basic() {
62        let registry = EngineRouteRegistry::new();
63
64        // Register a simple callback
65        let callback: EngineRouteCallback =
66            Arc::new(|body| Box::pin(async move { Ok(serde_json::json!({"echo": body})) }));
67
68        registry.register("test", callback);
69
70        // Verify it's registered
71        assert!(registry.get("test").is_some());
72        assert!(registry.get("nonexistent").is_none());
73
74        // Verify routes list
75        let routes = registry.routes();
76        assert_eq!(routes.len(), 1);
77        assert!(routes.contains(&"test".to_string()));
78    }
79
80    #[tokio::test]
81    async fn test_callback_execution() {
82        let registry = EngineRouteRegistry::new();
83
84        let callback: EngineRouteCallback = Arc::new(|body| {
85            Box::pin(async move {
86                let input = body.get("input").and_then(|v| v.as_str()).unwrap_or("");
87                Ok(serde_json::json!({
88                    "output": format!("processed: {}", input)
89                }))
90            })
91        });
92
93        registry.register("process", callback);
94
95        // Get and execute callback
96        let cb = registry.get("process").unwrap();
97        let result = cb(serde_json::json!({"input": "test"})).await.unwrap();
98
99        assert_eq!(result["output"], "processed: test");
100    }
101
102    #[tokio::test]
103    async fn test_clone_shares_routes() {
104        let registry = EngineRouteRegistry::new();
105
106        let callback: EngineRouteCallback =
107            Arc::new(|_| Box::pin(async { Ok(serde_json::json!({"ok": true})) }));
108        registry.register("test", callback);
109
110        // Clone the registry
111        let cloned = registry.clone();
112
113        // Both should see the same route
114        assert!(registry.get("test").is_some());
115        assert!(cloned.get("test").is_some());
116
117        // Register on clone
118        let callback2: EngineRouteCallback =
119            Arc::new(|_| Box::pin(async { Ok(serde_json::json!({"ok": false})) }));
120        cloned.register("test2", callback2);
121
122        // Original should also see it (they share the Arc)
123        assert!(registry.get("test2").is_some());
124    }
125}