Skip to main content

dynamo_runtime/
engine.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Asynchronous Engine System with Type Erasure Support
5//!
6//! This module provides the core asynchronous engine abstraction for Dynamo's runtime system.
7//! It defines the `AsyncEngine` trait for streaming engines and provides sophisticated
8//! type-erasure capabilities for managing heterogeneous engine collections.
9//!
10//! ## Type Erasure Overview
11//!
12//! Type erasure is a critical feature that allows storing different `AsyncEngine` implementations
13//! with varying generic type parameters in a single collection (e.g., `HashMap<String, Arc<dyn AnyAsyncEngine>>`).
14//! This is essential for:
15//!
16//! - **Dynamic Engine Management**: Registering and retrieving engines at runtime based on configuration
17//! - **Plugin Systems**: Loading different engine implementations without compile-time knowledge
18//! - **Service Discovery**: Managing multiple engine types in a unified registry
19//!
20//! ## Implementation Details
21//!
22//! The type-erasure system uses several advanced Rust features:
23//!
24//! - **Trait Objects (`dyn Trait`)**: For runtime polymorphism without compile-time type information
25//! - **`std::any::TypeId`**: For runtime type checking during downcasting
26//! - **`std::any::Any`**: For type-erased storage and safe downcasting
27//! - **`PhantomData`**: For maintaining type relationships in generic wrappers
28//! - **Extension Traits**: For ergonomic API design without modifying existing types
29//!
30//! ## Safety Considerations
31//!
32//! ⚠️ **IMPORTANT**: The type-erasure system relies on precise type matching at runtime.
33//! When modifying these traits or their implementations:
34//!
35//! - **Never change the type ID logic** in `AnyAsyncEngine` implementations
36//! - **Maintain the blanket `Data` implementation** for all `Send + Sync + 'static` types
37//! - **Test downcasting thoroughly** when adding new engine types
38//! - **Document any changes** that affect the type-erasure behavior
39//!
40//! ## Usage Example
41//!
42//! ```rust,ignore
43//! use std::collections::HashMap;
44//! use std::sync::Arc;
45//! use crate::engine::{AsyncEngine, AsAnyAsyncEngine, DowncastAnyAsyncEngine};
46//!
47//! // Create typed engines
48//! let string_engine: Arc<dyn AsyncEngine<String, String, ()>> = Arc::new(MyStringEngine::new());
49//! let int_engine: Arc<dyn AsyncEngine<i32, i32, ()>> = Arc::new(MyIntEngine::new());
50//!
51//! // Store in heterogeneous collection
52//! let mut engines: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
53//! engines.insert("string".to_string(), string_engine.into_any_engine());
54//! engines.insert("int".to_string(), int_engine.into_any_engine());
55//!
56//! // Retrieve and downcast safely
57//! if let Some(typed_engine) = engines.get("string").unwrap().downcast::<String, String, ()>() {
58//!     let result = typed_engine.generate("hello".to_string()).await;
59//! }
60//! ```
61
62use std::{
63    any::{Any, TypeId},
64    fmt::Debug,
65    future::Future,
66    marker::PhantomData,
67    pin::Pin,
68    sync::Arc,
69};
70
71pub use async_trait::async_trait;
72use futures::stream::Stream;
73
74/// All [`Send`] + [`Sync`] + `'static` types can be used as [`AsyncEngine`] request and response types.
75///
76/// This is implemented as a blanket implementation for all types that meet the bounds.
77/// **Do not manually implement this trait** - the blanket implementation covers all valid types.
78pub trait Data: Send + Sync + 'static {}
79impl<T: Send + Sync + 'static> Data for T {}
80
81/// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`]
82/// by associating it with a [`AsyncEngineContext`].
83pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send>>;
84pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
85
86pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>;
87pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>;
88/// Trait-object alias for an [`AsyncEngineStream`] — used on both sides of an
89/// engine: the input side via [`crate::pipeline::ManyIn`] and the output side via
90/// [`crate::pipeline::ManyOut`]. The directional names exist
91/// at the [`crate::pipeline`] alias layer for documentary clarity at use sites.
92pub type EngineStream<T> = Pin<Box<dyn AsyncEngineStream<T>>>;
93pub type Context = Arc<dyn AsyncEngineContext>;
94
95impl<T: Data> From<EngineStream<T>> for DataStream<T> {
96    fn from(stream: EngineStream<T>) -> Self {
97        Box::pin(stream)
98    }
99}
100
101// The Controller and the Context when https://github.com/rust-lang/rust/issues/65991 becomes stable
102pub trait AsyncEngineController: Send + Sync {}
103
104/// The [`AsyncEngineContext`] trait defines the interface to control the resulting stream
105/// produced by the engine.
106///
107/// This trait provides lifecycle management for async operations, including:
108/// - Stream identification via unique IDs
109/// - Graceful shutdown capabilities (`stop_generating`)
110/// - Immediate termination capabilities (`kill`)
111/// - Status checking for stopped/killed states
112///
113/// Implementations should ensure thread-safety and proper state management
114/// across concurrent access patterns.
115#[async_trait]
116pub trait AsyncEngineContext: Send + Sync + Debug {
117    /// Unique ID for the Stream
118    fn id(&self) -> &str;
119
120    /// Returns true if `stop_generating()` has been called; otherwise, false.
121    fn is_stopped(&self) -> bool;
122
123    /// Returns true if `kill()` has been called; otherwise, false.
124    /// This can be used with a `.take_while()` stream combinator to immediately terminate
125    /// the stream.
126    ///
127    /// An ideal location for a `[.take_while(!ctx.is_killed())]` stream combinator is on
128    /// the most downstream  return stream.
129    fn is_killed(&self) -> bool;
130
131    /// Calling this method when [`AsyncEngineContext::is_stopped`] is `true` will return
132    /// immediately; otherwise, it will [`AsyncEngineContext::is_stopped`] will return true.
133    async fn stopped(&self);
134
135    /// Calling this method when [`AsyncEngineContext::is_killed`] is `true` will return
136    /// immediately; otherwise, it will [`AsyncEngineContext::is_killed`] will return true.
137    async fn killed(&self);
138
139    // Controller
140
141    /// Informs the [`AsyncEngine`] to stop producing results for this particular stream.
142    /// This method is idempotent. This method does not invalidate results current in the
143    /// stream. It might take some time for the engine to stop producing results. The caller
144    /// can decided to drain the stream or drop the stream.
145    fn stop_generating(&self);
146
147    /// See [`AsyncEngineContext::stop_generating`].
148    fn stop(&self);
149
150    /// Extends the [`AsyncEngineContext::stop_generating`] also indicates a preference to
151    /// terminate without draining the remaining items in the stream. This is implementation
152    /// specific and may not be supported by all engines.
153    fn kill(&self);
154
155    /// Links child AsyncEngineContext to this AsyncEngineContext. If the `stop_generating`, `stop`
156    /// or `kill` on this AsyncEngineContext is called, the same method is called on all linked
157    /// child AsyncEngineContext, in the order they are linked, and then the method on this
158    /// AsyncEngineContext continues.
159    fn link_child(&self, child: Arc<dyn AsyncEngineContext>);
160}
161
162/// Provides access to the [`AsyncEngineContext`] associated with an engine operation.
163///
164/// This trait is implemented by both unary and streaming engine results, allowing
165/// uniform access to context information regardless of the operation type.
166pub trait AsyncEngineContextProvider: Send + Debug {
167    fn context(&self) -> Arc<dyn AsyncEngineContext>;
168}
169
170/// A unary (single-response) asynchronous engine operation.
171///
172/// This trait combines `Future` semantics with context provider capabilities,
173/// representing a single async operation that produces one result.
174pub trait AsyncEngineUnary<Resp: Data>:
175    Future<Output = Resp> + AsyncEngineContextProvider + Send
176{
177}
178
179/// A streaming asynchronous engine operation.
180///
181/// This trait combines `Stream` semantics with context provider capabilities,
182/// representing a continuous async operation that produces multiple messages over time.
183///
184/// - **Output side:** wrapped as [`EngineStream<T>`] = `crate::pipeline::ManyOut<T>`
185///   — the stream of response chunks an engine emits.
186/// - **Input side:** same `EngineStream<T>` shape, exposed as
187///   `crate::pipeline::ManyIn<T>` for documentary clarity at the call site.
188///
189/// [`ResponseStream`] is the canonical concrete implementor; [`RequestStream`]
190/// is a type alias of it for the input side.
191pub trait AsyncEngineStream<T: Data>: Stream<Item = T> + AsyncEngineContextProvider + Send {}
192
193/// Engine is a trait that defines the interface for a streaming engine.
194/// The synchronous Engine version is does not need to be awaited.
195///
196/// This is the core trait for all async engine implementations. It provides:
197/// - Generic type parameters for request, response, and error types
198/// - Async generation capabilities with proper error handling
199/// - Thread-safe design with `Send + Sync` bounds
200///
201/// ## Type Parameters
202/// - `Req`: The request type — required to be `Send + 'static`. The `Sync`
203///   bound was removed from `Req` for convenience: forcing `Sync` on `Req`
204///   propagates a `+ Sync` constraint onto every type that flows in (in
205///   particular, every input-side trait-object alias), and no
206///   existing implementation of `AsyncEngine` relies on the `Sync` nature of
207///   the request. Revisit if a future implementation genuinely needs
208///   shared-reference access to a request value across threads.
209/// - `Resp`: The response type that implements `AsyncEngineContextProvider`
210/// - `E`: The error type that implements `Data`
211///
212/// ## Implementation Notes
213/// Implementations should ensure proper error handling and resource management.
214/// The `generate` method should be cancellable via the response's context provider.
215#[async_trait]
216pub trait AsyncEngine<Req: Send + 'static, Resp: AsyncEngineContextProvider, E: Data>:
217    Send + Sync
218{
219    /// Generate a stream of completion responses.
220    async fn generate(&self, request: Req) -> Result<Resp, E>;
221}
222
223/// Adapter for a [`DataStream`] to a [`ResponseStream`].
224///
225/// A common pattern is to consume the [`ResponseStream`] with standard stream combinators
226/// which produces a [`DataStream`] stream, then form a [`ResponseStream`] by propagating the
227/// original [`AsyncEngineContext`].
228pub struct ResponseStream<R: Data> {
229    stream: DataStream<R>,
230    ctx: Arc<dyn AsyncEngineContext>,
231}
232
233impl<R: Data> ResponseStream<R> {
234    pub fn new(stream: DataStream<R>, ctx: Arc<dyn AsyncEngineContext>) -> Pin<Box<Self>> {
235        Box::pin(Self { stream, ctx })
236    }
237}
238
239impl<R: Data> Stream for ResponseStream<R> {
240    type Item = R;
241
242    #[inline]
243    fn poll_next(
244        mut self: Pin<&mut Self>,
245        cx: &mut std::task::Context<'_>,
246    ) -> std::task::Poll<Option<Self::Item>> {
247        Pin::new(&mut self.stream).poll_next(cx)
248    }
249}
250
251impl<R: Data> AsyncEngineStream<R> for ResponseStream<R> {}
252
253impl<R: Data> AsyncEngineContextProvider for ResponseStream<R> {
254    fn context(&self) -> Arc<dyn AsyncEngineContext> {
255        self.ctx.clone()
256    }
257}
258
259impl<R: Data> Debug for ResponseStream<R> {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        f.debug_struct("ResponseStream")
262            // todo: add debug for stream - possibly propagate some information about what
263            // engine created the stream
264            // .field("stream", &self.stream)
265            .field("ctx", &self.ctx)
266            .finish()
267    }
268}
269
270/// Input-side type alias of [`ResponseStream`] — same struct, different name to
271/// signal role at the call site.
272///
273/// The shape is identical: a `(stream, ctx)` pair that implements [`Stream`],
274/// [`AsyncEngineContextProvider`], and [`AsyncEngineStream`]. Use `RequestStream`
275/// when you're constructing a value to feed into the `Req` slot of an engine,
276/// and [`ResponseStream`] when constructing a value to emit from the `Resp` slot.
277/// Functionally interchangeable.
278pub type RequestStream<R> = ResponseStream<R>;
279
280impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineUnary<T>>> {
281    fn context(&self) -> Arc<dyn AsyncEngineContext> {
282        AsyncEngineContextProvider::context(&**self)
283    }
284}
285
286impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineStream<T>>> {
287    fn context(&self) -> Arc<dyn AsyncEngineContext> {
288        AsyncEngineContextProvider::context(&**self)
289    }
290}
291
292/// A type-erased `AsyncEngine`.
293///
294/// This trait enables storing heterogeneous `AsyncEngine` implementations in collections
295/// by erasing their specific generic type parameters. It provides runtime type information
296/// and safe downcasting capabilities.
297///
298/// ## Type Erasure Mechanism
299/// The trait uses `std::any::TypeId` to preserve type information at runtime, allowing
300/// safe downcasting back to the original `AsyncEngine<Req, Resp, E>` types.
301///
302/// ## Safety Guarantees
303/// - Type IDs are preserved exactly as they were during type erasure
304/// - Downcasting is only possible to the original type combination
305/// - Incorrect downcasts return `None` rather than panicking
306///
307/// ## Implementation Notes
308/// This trait is implemented by the internal `AnyEngineWrapper` struct. Users should
309/// not implement this trait directly - use the `AsAnyAsyncEngine` extension trait instead.
310pub trait AnyAsyncEngine: Send + Sync {
311    /// Returns the `TypeId` of the request type used by this engine.
312    fn request_type_id(&self) -> TypeId;
313
314    /// Returns the `TypeId` of the response type used by this engine.
315    fn response_type_id(&self) -> TypeId;
316
317    /// Returns the `TypeId` of the error type used by this engine.
318    fn error_type_id(&self) -> TypeId;
319
320    /// Provides access to the underlying engine as a `dyn Any` for downcasting.
321    fn as_any(&self) -> &dyn Any;
322}
323
324/// An internal wrapper to hold a typed `AsyncEngine` behind the `AnyAsyncEngine` trait object.
325///
326/// This struct uses `PhantomData<fn(Req, Resp, E)>` to maintain the type relationship
327/// without storing the types directly, enabling the type-erasure mechanism.
328///
329/// ## PhantomData Usage
330/// The `PhantomData<fn(Req, Resp, E)>` ensures that the compiler knows about the
331/// generic type parameters without requiring them to be `'static`, which would
332/// prevent storing non-static types in the engine.
333struct AnyEngineWrapper<Req, Resp, E>
334where
335    Req: Data,
336    Resp: Data + AsyncEngineContextProvider,
337    E: Data,
338{
339    engine: Arc<dyn AsyncEngine<Req, Resp, E>>,
340    _phantom: PhantomData<fn(Req, Resp, E)>,
341}
342
343impl<Req, Resp, E> AnyAsyncEngine for AnyEngineWrapper<Req, Resp, E>
344where
345    Req: Data,
346    Resp: Data + AsyncEngineContextProvider,
347    E: Data,
348{
349    fn request_type_id(&self) -> TypeId {
350        TypeId::of::<Req>()
351    }
352
353    fn response_type_id(&self) -> TypeId {
354        TypeId::of::<Resp>()
355    }
356
357    fn error_type_id(&self) -> TypeId {
358        TypeId::of::<E>()
359    }
360
361    fn as_any(&self) -> &dyn Any {
362        &self.engine
363    }
364}
365
366/// An extension trait that provides a convenient way to type-erase an `AsyncEngine`.
367///
368/// This trait provides the `.into_any_engine()` method on any `Arc<dyn AsyncEngine<...>>`,
369/// enabling ergonomic type erasure without explicit wrapper construction.
370///
371/// ## Usage
372/// ```rust,ignore
373/// use crate::engine::AsAnyAsyncEngine;
374///
375/// let typed_engine: Arc<dyn AsyncEngine<String, String, ()>> = Arc::new(MyEngine::new());
376/// let any_engine = typed_engine.into_any_engine();
377/// ```
378pub trait AsAnyAsyncEngine {
379    /// Converts a typed `AsyncEngine` into a type-erased `AnyAsyncEngine`.
380    fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine>;
381}
382
383impl<Req, Resp, E> AsAnyAsyncEngine for Arc<dyn AsyncEngine<Req, Resp, E>>
384where
385    Req: Data,
386    Resp: Data + AsyncEngineContextProvider,
387    E: Data,
388{
389    fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine> {
390        Arc::new(AnyEngineWrapper {
391            engine: self,
392            _phantom: PhantomData,
393        })
394    }
395}
396
397/// An extension trait that provides a convenient method to downcast an `AnyAsyncEngine`.
398///
399/// This trait provides the `.downcast<Req, Resp, E>()` method on `Arc<dyn AnyAsyncEngine>`,
400/// enabling safe downcasting back to the original typed engine.
401///
402/// ## Safety
403/// The downcast method performs runtime type checking using `TypeId` comparison.
404/// It will only succeed if the type parameters exactly match the original engine's types.
405///
406/// ## Usage
407/// ```rust,ignore
408/// use crate::engine::DowncastAnyAsyncEngine;
409///
410/// let any_engine: Arc<dyn AnyAsyncEngine> = // ... from collection
411/// if let Some(typed_engine) = any_engine.downcast::<String, String, ()>() {
412///     // Use the typed engine
413///     let result = typed_engine.generate("hello".to_string()).await;
414/// }
415/// ```
416pub trait DowncastAnyAsyncEngine {
417    /// Attempts to downcast an `AnyAsyncEngine` to a specific `AsyncEngine` type.
418    ///
419    /// Returns `Some(engine)` if the type parameters match the original engine,
420    /// or `None` if the types don't match.
421    fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
422    where
423        Req: Data,
424        Resp: Data + AsyncEngineContextProvider,
425        E: Data;
426}
427
428impl DowncastAnyAsyncEngine for Arc<dyn AnyAsyncEngine> {
429    fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
430    where
431        Req: Data,
432        Resp: Data + AsyncEngineContextProvider,
433        E: Data,
434    {
435        if self.request_type_id() == TypeId::of::<Req>()
436            && self.response_type_id() == TypeId::of::<Resp>()
437            && self.error_type_id() == TypeId::of::<E>()
438        {
439            self.as_any()
440                .downcast_ref::<Arc<dyn AsyncEngine<Req, Resp, E>>>()
441                .cloned()
442        } else {
443            None
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use std::collections::HashMap;
452
453    // 1. Define mock data structures
454    #[derive(Debug, PartialEq)]
455    struct Req1(String);
456
457    #[derive(Debug, PartialEq)]
458    struct Resp1(String);
459
460    // Dummy context provider implementation for the response
461    impl AsyncEngineContextProvider for Resp1 {
462        fn context(&self) -> Arc<dyn AsyncEngineContext> {
463            // For this test, we don't need a real context.
464            unimplemented!()
465        }
466    }
467
468    #[derive(Debug)]
469    struct Err1;
470
471    // A different set of types for testing failure cases
472    #[derive(Debug)]
473    struct Req2;
474    #[derive(Debug)]
475    struct Resp2;
476    impl AsyncEngineContextProvider for Resp2 {
477        fn context(&self) -> Arc<dyn AsyncEngineContext> {
478            unimplemented!()
479        }
480    }
481
482    // 2. Define a mock engine
483    struct MockEngine;
484
485    #[async_trait]
486    impl AsyncEngine<Req1, Resp1, Err1> for MockEngine {
487        async fn generate(&self, request: Req1) -> Result<Resp1, Err1> {
488            Ok(Resp1(format!("response to {}", request.0)))
489        }
490    }
491
492    #[tokio::test]
493    async fn test_engine_type_erasure_and_downcast() {
494        // 3. Create a typed engine
495        let typed_engine: Arc<dyn AsyncEngine<Req1, Resp1, Err1>> = Arc::new(MockEngine);
496
497        // 4. Use the extension trait to erase the type
498        let any_engine = typed_engine.into_any_engine();
499
500        // Check type IDs are preserved
501        assert_eq!(any_engine.request_type_id(), TypeId::of::<Req1>());
502        assert_eq!(any_engine.response_type_id(), TypeId::of::<Resp1>());
503        assert_eq!(any_engine.error_type_id(), TypeId::of::<Err1>());
504
505        // 5. Use the new downcast method on the Arc
506        let downcasted_engine = any_engine.downcast::<Req1, Resp1, Err1>();
507
508        // 6. Assert success
509        assert!(downcasted_engine.is_some());
510
511        // We can even use the downcasted engine
512        let response = downcasted_engine
513            .unwrap()
514            .generate(Req1("hello".to_string()))
515            .await;
516        assert_eq!(response.unwrap(), Resp1("response to hello".to_string()));
517
518        // 7. Assert failure for wrong types
519        let failed_downcast = any_engine.downcast::<Req2, Resp2, Err1>();
520        assert!(failed_downcast.is_none());
521
522        // 8. HashMap usage test
523        let mut engine_map: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
524        engine_map.insert("mock".to_string(), any_engine);
525
526        let retrieved_engine = engine_map.get("mock").unwrap();
527        let final_engine = retrieved_engine.downcast::<Req1, Resp1, Err1>().unwrap();
528        let final_response = final_engine.generate(Req1("world".to_string())).await;
529        assert_eq!(
530            final_response.unwrap(),
531            Resp1("response to world".to_string())
532        );
533    }
534}