dynamo_runtime/
engine.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Asynchronous Engine System with Type Erasure Support
17//!
18//! This module provides the core asynchronous engine abstraction for Dynamo's runtime system.
19//! It defines the `AsyncEngine` trait for streaming engines and provides sophisticated
20//! type-erasure capabilities for managing heterogeneous engine collections.
21//!
22//! ## Type Erasure Overview
23//!
24//! Type erasure is a critical feature that allows storing different `AsyncEngine` implementations
25//! with varying generic type parameters in a single collection (e.g., `HashMap<String, Arc<dyn AnyAsyncEngine>>`).
26//! This is essential for:
27//!
28//! - **Dynamic Engine Management**: Registering and retrieving engines at runtime based on configuration
29//! - **Plugin Systems**: Loading different engine implementations without compile-time knowledge
30//! - **Service Discovery**: Managing multiple engine types in a unified registry
31//!
32//! ## Implementation Details
33//!
34//! The type-erasure system uses several advanced Rust features:
35//!
36//! - **Trait Objects (`dyn Trait`)**: For runtime polymorphism without compile-time type information
37//! - **`std::any::TypeId`**: For runtime type checking during downcasting
38//! - **`std::any::Any`**: For type-erased storage and safe downcasting
39//! - **`PhantomData`**: For maintaining type relationships in generic wrappers
40//! - **Extension Traits**: For ergonomic API design without modifying existing types
41//!
42//! ## Safety Considerations
43//!
44//! ⚠️ **IMPORTANT**: The type-erasure system relies on precise type matching at runtime.
45//! When modifying these traits or their implementations:
46//!
47//! - **Never change the type ID logic** in `AnyAsyncEngine` implementations
48//! - **Maintain the blanket `Data` implementation** for all `Send + Sync + 'static` types
49//! - **Test downcasting thoroughly** when adding new engine types
50//! - **Document any changes** that affect the type-erasure behavior
51//!
52//! ## Usage Example
53//!
54//! ```rust,ignore
55//! use std::collections::HashMap;
56//! use std::sync::Arc;
57//! use crate::engine::{AsyncEngine, AsAnyAsyncEngine, DowncastAnyAsyncEngine};
58//!
59//! // Create typed engines
60//! let string_engine: Arc<dyn AsyncEngine<String, String, ()>> = Arc::new(MyStringEngine::new());
61//! let int_engine: Arc<dyn AsyncEngine<i32, i32, ()>> = Arc::new(MyIntEngine::new());
62//!
63//! // Store in heterogeneous collection
64//! let mut engines: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
65//! engines.insert("string".to_string(), string_engine.into_any_engine());
66//! engines.insert("int".to_string(), int_engine.into_any_engine());
67//!
68//! // Retrieve and downcast safely
69//! if let Some(typed_engine) = engines.get("string").unwrap().downcast::<String, String, ()>() {
70//!     let result = typed_engine.generate("hello".to_string()).await;
71//! }
72//! ```
73
74use std::{
75    any::{Any, TypeId},
76    fmt::Debug,
77    future::Future,
78    marker::PhantomData,
79    pin::Pin,
80    sync::Arc,
81};
82
83pub use async_trait::async_trait;
84use futures::stream::Stream;
85
86/// All [`Send`] + [`Sync`] + `'static` types can be used as [`AsyncEngine`] request and response types.
87///
88/// This is implemented as a blanket implementation for all types that meet the bounds.
89/// **Do not manually implement this trait** - the blanket implementation covers all valid types.
90pub trait Data: Send + Sync + 'static {}
91impl<T: Send + Sync + 'static> Data for T {}
92
93/// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`]
94/// by associating it with a [`AsyncEngineContext`].
95pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send>>;
96pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
97
98pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>;
99pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>;
100pub type EngineStream<Resp> = Pin<Box<dyn AsyncEngineStream<Resp>>>;
101pub type Context = Arc<dyn AsyncEngineContext>;
102
103impl<T: Data> From<EngineStream<T>> for DataStream<T> {
104    fn from(stream: EngineStream<T>) -> Self {
105        Box::pin(stream)
106    }
107}
108
109// The Controller and the Context when https://github.com/rust-lang/rust/issues/65991 becomes stable
110pub trait AsyncEngineController: Send + Sync {}
111
112/// The [`AsyncEngineContext`] trait defines the interface to control the resulting stream
113/// produced by the engine.
114///
115/// This trait provides lifecycle management for async operations, including:
116/// - Stream identification via unique IDs
117/// - Graceful shutdown capabilities (`stop_generating`)
118/// - Immediate termination capabilities (`kill`)
119/// - Status checking for stopped/killed states
120///
121/// Implementations should ensure thread-safety and proper state management
122/// across concurrent access patterns.
123#[async_trait]
124pub trait AsyncEngineContext: Send + Sync + Debug {
125    /// Unique ID for the Stream
126    fn id(&self) -> &str;
127
128    /// Returns true if `stop_generating()` has been called; otherwise, false.
129    fn is_stopped(&self) -> bool;
130
131    /// Returns true if `kill()` has been called; otherwise, false.
132    /// This can be used with a `.take_while()` stream combinator to immediately terminate
133    /// the stream.
134    ///
135    /// An ideal location for a `[.take_while(!ctx.is_killed())]` stream combinator is on
136    /// the most downstream  return stream.
137    fn is_killed(&self) -> bool;
138
139    /// Calling this method when [`AsyncEngineContext::is_stopped`] is `true` will return
140    /// immediately; otherwise, it will [`AsyncEngineContext::is_stopped`] will return true.
141    async fn stopped(&self);
142
143    /// Calling this method when [`AsyncEngineContext::is_killed`] is `true` will return
144    /// immediately; otherwise, it will [`AsyncEngineContext::is_killed`] will return true.
145    async fn killed(&self);
146
147    // Controller
148
149    /// Informs the [`AsyncEngine`] to stop producing results for this particular stream.
150    /// This method is idempotent. This method does not invalidate results current in the
151    /// stream. It might take some time for the engine to stop producing results. The caller
152    /// can decided to drain the stream or drop the stream.
153    fn stop_generating(&self);
154
155    /// See [`AsyncEngineContext::stop_generating`].
156    fn stop(&self);
157
158    /// Extends the [`AsyncEngineContext::stop_generating`] also indicates a preference to
159    /// terminate without draining the remaining items in the stream. This is implementation
160    /// specific and may not be supported by all engines.
161    fn kill(&self);
162
163    /// Links child AsyncEngineContext to this AsyncEngineContext. If the `stop_generating`, `stop`
164    /// or `kill` on this AsyncEngineContext is called, the same method is called on all linked
165    /// child AsyncEngineContext, in the order they are linked, and then the method on this
166    /// AsyncEngineContext continues.
167    fn link_child(&self, child: Arc<dyn AsyncEngineContext>);
168}
169
170/// Provides access to the [`AsyncEngineContext`] associated with an engine operation.
171///
172/// This trait is implemented by both unary and streaming engine results, allowing
173/// uniform access to context information regardless of the operation type.
174pub trait AsyncEngineContextProvider: Send + Debug {
175    fn context(&self) -> Arc<dyn AsyncEngineContext>;
176}
177
178/// A unary (single-response) asynchronous engine operation.
179///
180/// This trait combines `Future` semantics with context provider capabilities,
181/// representing a single async operation that produces one result.
182pub trait AsyncEngineUnary<Resp: Data>:
183    Future<Output = Resp> + AsyncEngineContextProvider + Send
184{
185}
186
187/// A streaming asynchronous engine operation.
188///
189/// This trait combines `Stream` semantics with context provider capabilities,
190/// representing a continuous async operation that produces multiple results over time.
191pub trait AsyncEngineStream<Resp: Data>:
192    Stream<Item = Resp> + AsyncEngineContextProvider + Send
193{
194}
195
196/// Engine is a trait that defines the interface for a streaming engine.
197/// The synchronous Engine version is does not need to be awaited.
198///
199/// This is the core trait for all async engine implementations. It provides:
200/// - Generic type parameters for request, response, and error types
201/// - Async generation capabilities with proper error handling
202/// - Thread-safe design with `Send + Sync` bounds
203///
204/// ## Type Parameters
205/// - `Req`: The request type that implements `Data`
206/// - `Resp`: The response type that implements both `Data` and `AsyncEngineContextProvider`
207/// - `E`: The error type that implements `Data`
208///
209/// ## Implementation Notes
210/// Implementations should ensure proper error handling and resource management.
211/// The `generate` method should be cancellable via the response's context provider.
212#[async_trait]
213pub trait AsyncEngine<Req: Send + Sync + 'static, Resp: AsyncEngineContextProvider, E: Data>:
214    Send + Sync
215{
216    /// Generate a stream of completion responses.
217    async fn generate(&self, request: Req) -> Result<Resp, E>;
218}
219
220/// Adapter for a [`DataStream`] to a [`ResponseStream`].
221///
222/// A common pattern is to consume the [`ResponseStream`] with standard stream combinators
223/// which produces a [`DataStream`] stream, then form a [`ResponseStream`] by propagating the
224/// original [`AsyncEngineContext`].
225pub struct ResponseStream<R: Data> {
226    stream: DataStream<R>,
227    ctx: Arc<dyn AsyncEngineContext>,
228}
229
230impl<R: Data> ResponseStream<R> {
231    pub fn new(stream: DataStream<R>, ctx: Arc<dyn AsyncEngineContext>) -> Pin<Box<Self>> {
232        Box::pin(Self { stream, ctx })
233    }
234}
235
236impl<R: Data> Stream for ResponseStream<R> {
237    type Item = R;
238
239    #[inline]
240    fn poll_next(
241        mut self: Pin<&mut Self>,
242        cx: &mut std::task::Context<'_>,
243    ) -> std::task::Poll<Option<Self::Item>> {
244        Pin::new(&mut self.stream).poll_next(cx)
245    }
246}
247
248impl<R: Data> AsyncEngineStream<R> for ResponseStream<R> {}
249
250impl<R: Data> AsyncEngineContextProvider for ResponseStream<R> {
251    fn context(&self) -> Arc<dyn AsyncEngineContext> {
252        self.ctx.clone()
253    }
254}
255
256impl<R: Data> Debug for ResponseStream<R> {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.debug_struct("ResponseStream")
259            // todo: add debug for stream - possibly propagate some information about what
260            // engine created the stream
261            // .field("stream", &self.stream)
262            .field("ctx", &self.ctx)
263            .finish()
264    }
265}
266
267impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineUnary<T>>> {
268    fn context(&self) -> Arc<dyn AsyncEngineContext> {
269        AsyncEngineContextProvider::context(&**self)
270    }
271}
272
273impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineStream<T>>> {
274    fn context(&self) -> Arc<dyn AsyncEngineContext> {
275        AsyncEngineContextProvider::context(&**self)
276    }
277}
278
279/// A type-erased `AsyncEngine`.
280///
281/// This trait enables storing heterogeneous `AsyncEngine` implementations in collections
282/// by erasing their specific generic type parameters. It provides runtime type information
283/// and safe downcasting capabilities.
284///
285/// ## Type Erasure Mechanism
286/// The trait uses `std::any::TypeId` to preserve type information at runtime, allowing
287/// safe downcasting back to the original `AsyncEngine<Req, Resp, E>` types.
288///
289/// ## Safety Guarantees
290/// - Type IDs are preserved exactly as they were during type erasure
291/// - Downcasting is only possible to the original type combination
292/// - Incorrect downcasts return `None` rather than panicking
293///
294/// ## Implementation Notes
295/// This trait is implemented by the internal `AnyEngineWrapper` struct. Users should
296/// not implement this trait directly - use the `AsAnyAsyncEngine` extension trait instead.
297pub trait AnyAsyncEngine: Send + Sync {
298    /// Returns the `TypeId` of the request type used by this engine.
299    fn request_type_id(&self) -> TypeId;
300
301    /// Returns the `TypeId` of the response type used by this engine.
302    fn response_type_id(&self) -> TypeId;
303
304    /// Returns the `TypeId` of the error type used by this engine.
305    fn error_type_id(&self) -> TypeId;
306
307    /// Provides access to the underlying engine as a `dyn Any` for downcasting.
308    fn as_any(&self) -> &dyn Any;
309}
310
311/// An internal wrapper to hold a typed `AsyncEngine` behind the `AnyAsyncEngine` trait object.
312///
313/// This struct uses `PhantomData<fn(Req, Resp, E)>` to maintain the type relationship
314/// without storing the types directly, enabling the type-erasure mechanism.
315///
316/// ## PhantomData Usage
317/// The `PhantomData<fn(Req, Resp, E)>` ensures that the compiler knows about the
318/// generic type parameters without requiring them to be `'static`, which would
319/// prevent storing non-static types in the engine.
320struct AnyEngineWrapper<Req, Resp, E>
321where
322    Req: Data,
323    Resp: Data + AsyncEngineContextProvider,
324    E: Data,
325{
326    engine: Arc<dyn AsyncEngine<Req, Resp, E>>,
327    _phantom: PhantomData<fn(Req, Resp, E)>,
328}
329
330impl<Req, Resp, E> AnyAsyncEngine for AnyEngineWrapper<Req, Resp, E>
331where
332    Req: Data,
333    Resp: Data + AsyncEngineContextProvider,
334    E: Data,
335{
336    fn request_type_id(&self) -> TypeId {
337        TypeId::of::<Req>()
338    }
339
340    fn response_type_id(&self) -> TypeId {
341        TypeId::of::<Resp>()
342    }
343
344    fn error_type_id(&self) -> TypeId {
345        TypeId::of::<E>()
346    }
347
348    fn as_any(&self) -> &dyn Any {
349        &self.engine
350    }
351}
352
353/// An extension trait that provides a convenient way to type-erase an `AsyncEngine`.
354///
355/// This trait provides the `.into_any_engine()` method on any `Arc<dyn AsyncEngine<...>>`,
356/// enabling ergonomic type erasure without explicit wrapper construction.
357///
358/// ## Usage
359/// ```rust,ignore
360/// use crate::engine::AsAnyAsyncEngine;
361///
362/// let typed_engine: Arc<dyn AsyncEngine<String, String, ()>> = Arc::new(MyEngine::new());
363/// let any_engine = typed_engine.into_any_engine();
364/// ```
365pub trait AsAnyAsyncEngine {
366    /// Converts a typed `AsyncEngine` into a type-erased `AnyAsyncEngine`.
367    fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine>;
368}
369
370impl<Req, Resp, E> AsAnyAsyncEngine for Arc<dyn AsyncEngine<Req, Resp, E>>
371where
372    Req: Data,
373    Resp: Data + AsyncEngineContextProvider,
374    E: Data,
375{
376    fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine> {
377        Arc::new(AnyEngineWrapper {
378            engine: self,
379            _phantom: PhantomData,
380        })
381    }
382}
383
384/// An extension trait that provides a convenient method to downcast an `AnyAsyncEngine`.
385///
386/// This trait provides the `.downcast<Req, Resp, E>()` method on `Arc<dyn AnyAsyncEngine>`,
387/// enabling safe downcasting back to the original typed engine.
388///
389/// ## Safety
390/// The downcast method performs runtime type checking using `TypeId` comparison.
391/// It will only succeed if the type parameters exactly match the original engine's types.
392///
393/// ## Usage
394/// ```rust,ignore
395/// use crate::engine::DowncastAnyAsyncEngine;
396///
397/// let any_engine: Arc<dyn AnyAsyncEngine> = // ... from collection
398/// if let Some(typed_engine) = any_engine.downcast::<String, String, ()>() {
399///     // Use the typed engine
400///     let result = typed_engine.generate("hello".to_string()).await;
401/// }
402/// ```
403pub trait DowncastAnyAsyncEngine {
404    /// Attempts to downcast an `AnyAsyncEngine` to a specific `AsyncEngine` type.
405    ///
406    /// Returns `Some(engine)` if the type parameters match the original engine,
407    /// or `None` if the types don't match.
408    fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
409    where
410        Req: Data,
411        Resp: Data + AsyncEngineContextProvider,
412        E: Data;
413}
414
415impl DowncastAnyAsyncEngine for Arc<dyn AnyAsyncEngine> {
416    fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
417    where
418        Req: Data,
419        Resp: Data + AsyncEngineContextProvider,
420        E: Data,
421    {
422        if self.request_type_id() == TypeId::of::<Req>()
423            && self.response_type_id() == TypeId::of::<Resp>()
424            && self.error_type_id() == TypeId::of::<E>()
425        {
426            self.as_any()
427                .downcast_ref::<Arc<dyn AsyncEngine<Req, Resp, E>>>()
428                .cloned()
429        } else {
430            None
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use std::collections::HashMap;
439
440    // 1. Define mock data structures
441    #[derive(Debug, PartialEq)]
442    struct Req1(String);
443
444    #[derive(Debug, PartialEq)]
445    struct Resp1(String);
446
447    // Dummy context provider implementation for the response
448    impl AsyncEngineContextProvider for Resp1 {
449        fn context(&self) -> Arc<dyn AsyncEngineContext> {
450            // For this test, we don't need a real context.
451            unimplemented!()
452        }
453    }
454
455    #[derive(Debug)]
456    struct Err1;
457
458    // A different set of types for testing failure cases
459    #[derive(Debug)]
460    struct Req2;
461    #[derive(Debug)]
462    struct Resp2;
463    impl AsyncEngineContextProvider for Resp2 {
464        fn context(&self) -> Arc<dyn AsyncEngineContext> {
465            unimplemented!()
466        }
467    }
468
469    // 2. Define a mock engine
470    struct MockEngine;
471
472    #[async_trait]
473    impl AsyncEngine<Req1, Resp1, Err1> for MockEngine {
474        async fn generate(&self, request: Req1) -> Result<Resp1, Err1> {
475            Ok(Resp1(format!("response to {}", request.0)))
476        }
477    }
478
479    #[tokio::test]
480    async fn test_engine_type_erasure_and_downcast() {
481        // 3. Create a typed engine
482        let typed_engine: Arc<dyn AsyncEngine<Req1, Resp1, Err1>> = Arc::new(MockEngine);
483
484        // 4. Use the extension trait to erase the type
485        let any_engine = typed_engine.into_any_engine();
486
487        // Check type IDs are preserved
488        assert_eq!(any_engine.request_type_id(), TypeId::of::<Req1>());
489        assert_eq!(any_engine.response_type_id(), TypeId::of::<Resp1>());
490        assert_eq!(any_engine.error_type_id(), TypeId::of::<Err1>());
491
492        // 5. Use the new downcast method on the Arc
493        let downcasted_engine = any_engine.downcast::<Req1, Resp1, Err1>();
494
495        // 6. Assert success
496        assert!(downcasted_engine.is_some());
497
498        // We can even use the downcasted engine
499        let response = downcasted_engine
500            .unwrap()
501            .generate(Req1("hello".to_string()))
502            .await;
503        assert_eq!(response.unwrap(), Resp1("response to hello".to_string()));
504
505        // 7. Assert failure for wrong types
506        let failed_downcast = any_engine.downcast::<Req2, Resp2, Err1>();
507        assert!(failed_downcast.is_none());
508
509        // 8. HashMap usage test
510        let mut engine_map: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
511        engine_map.insert("mock".to_string(), any_engine);
512
513        let retrieved_engine = engine_map.get("mock").unwrap();
514        let final_engine = retrieved_engine.downcast::<Req1, Resp1, Err1>().unwrap();
515        let final_response = final_engine.generate(Req1("world".to_string())).await;
516        assert_eq!(
517            final_response.unwrap(),
518            Resp1("response to world".to_string())
519        );
520    }
521}