Skip to main content

dynamo_runtime/
engine_routes.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::{Arc, RwLock};
8
9/// Callback type for engine routes (async)
10/// Takes JSON body, returns JSON response (or error) wrapped in a Future
11pub type EngineRouteCallback = Arc<
12    dyn Fn(
13            serde_json::Value,
14        ) -> Pin<Box<dyn Future<Output = anyhow::Result<serde_json::Value>> + Send>>
15        + Send
16        + Sync,
17>;
18
19/// Registry for engine route callbacks
20///
21/// This registry stores callbacks that handle requests to `/engine/*` routes.
22/// Routes are registered from Python via `runtime.register_engine_route()`.
23#[derive(Clone, Default)]
24pub struct EngineRouteRegistry {
25    routes: Arc<RwLock<HashMap<String, EngineRouteCallback>>>,
26}
27
28impl EngineRouteRegistry {
29    /// Create a new empty registry
30    pub fn new() -> Self {
31        Self {
32            routes: Arc::new(RwLock::new(HashMap::new())),
33        }
34    }
35
36    /// Register a callback for a route (e.g., "start_profile" for /engine/start_profile)
37    ///
38    /// A route name is expected to be registered exactly once. Re-registering an
39    /// existing name overwrites the previous callback and emits a warning, since
40    /// it usually signals two registration mechanisms colliding rather than an
41    /// intentional replacement.
42    pub fn register(&self, route: &str, callback: EngineRouteCallback) {
43        let mut routes = self.routes.write().unwrap();
44        if routes.insert(route.to_string(), callback).is_some() {
45            tracing::warn!("Overwriting already-registered engine route: /engine/{route}");
46        } else {
47            tracing::debug!("Registered engine route: /engine/{route}");
48        }
49    }
50
51    /// Get callback for a route
52    pub fn get(&self, route: &str) -> Option<EngineRouteCallback> {
53        let routes = self.routes.read().unwrap();
54        routes.get(route).cloned()
55    }
56
57    /// List all registered routes
58    pub fn routes(&self) -> Vec<String> {
59        let routes = self.routes.read().unwrap();
60        routes.keys().cloned().collect()
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[tokio::test]
69    async fn test_registry_basic() {
70        let registry = EngineRouteRegistry::new();
71
72        // Register a simple callback
73        let callback: EngineRouteCallback =
74            Arc::new(|body| Box::pin(async move { Ok(serde_json::json!({"echo": body})) }));
75
76        registry.register("test", callback);
77
78        // Verify it's registered
79        assert!(registry.get("test").is_some());
80        assert!(registry.get("nonexistent").is_none());
81
82        // Verify routes list
83        let routes = registry.routes();
84        assert_eq!(routes.len(), 1);
85        assert!(routes.contains(&"test".to_string()));
86    }
87
88    #[tokio::test]
89    async fn test_callback_execution() {
90        let registry = EngineRouteRegistry::new();
91
92        let callback: EngineRouteCallback = Arc::new(|body| {
93            Box::pin(async move {
94                let input = body.get("input").and_then(|v| v.as_str()).unwrap_or("");
95                Ok(serde_json::json!({
96                    "output": format!("processed: {}", input)
97                }))
98            })
99        });
100
101        registry.register("process", callback);
102
103        // Get and execute callback
104        let cb = registry.get("process").unwrap();
105        let result = cb(serde_json::json!({"input": "test"})).await.unwrap();
106
107        assert_eq!(result["output"], "processed: test");
108    }
109
110    #[tokio::test]
111    async fn test_clone_shares_routes() {
112        let registry = EngineRouteRegistry::new();
113
114        let callback: EngineRouteCallback =
115            Arc::new(|_| Box::pin(async { Ok(serde_json::json!({"ok": true})) }));
116        registry.register("test", callback);
117
118        // Clone the registry
119        let cloned = registry.clone();
120
121        // Both should see the same route
122        assert!(registry.get("test").is_some());
123        assert!(cloned.get("test").is_some());
124
125        // Register on clone
126        let callback2: EngineRouteCallback =
127            Arc::new(|_| Box::pin(async { Ok(serde_json::json!({"ok": false})) }));
128        cloned.register("test2", callback2);
129
130        // Original should also see it (they share the Arc)
131        assert!(registry.get("test2").is_some());
132    }
133}