dynamo_runtime/pipeline/
context.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::ops::{Deref, DerefMut};
5use std::sync::{Arc, Mutex};
6
7use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
8use crate::engine::AsyncEngineController;
9use async_trait::async_trait;
10
11use super::registry::Registry;
12
13pub struct Context<T: Data> {
14    current: T,
15    controller: Arc<Controller>, //todo: hold this as an arc
16    registry: Registry,
17    stages: Vec<String>,
18}
19
20impl<T: Send + Sync + 'static> Context<T> {
21    // Create a new context with initial data
22    pub fn new(current: T) -> Self {
23        Context {
24            current,
25            controller: Arc::new(Controller::default()),
26            registry: Registry::new(),
27            stages: Vec::new(),
28        }
29    }
30
31    pub fn rejoin<U: Send + Sync + 'static>(current: T, context: Context<U>) -> Self {
32        Context {
33            current,
34            controller: context.controller,
35            registry: context.registry,
36            stages: context.stages,
37        }
38    }
39
40    pub fn with_controller(current: T, controller: Controller) -> Self {
41        Context {
42            current,
43            controller: Arc::new(controller),
44            registry: Registry::new(),
45            stages: Vec::new(),
46        }
47    }
48
49    pub fn with_id(current: T, id: String) -> Self {
50        Context {
51            current,
52            controller: Arc::new(Controller::new(id)),
53            registry: Registry::new(),
54            stages: Vec::new(),
55        }
56    }
57
58    /// Get the id of the context
59    pub fn id(&self) -> &str {
60        self.controller.id()
61    }
62
63    /// Get the content of the context
64    pub fn content(&self) -> &T {
65        &self.current
66    }
67
68    pub fn controller(&self) -> &Controller {
69        &self.controller
70    }
71
72    /// Insert an object into the registry with a specific key.
73    pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
74        self.registry.insert_shared(key, value);
75    }
76
77    /// Insert a unique and takable object into the registry with a specific key.
78    pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
79        self.registry.insert_unique(key, value);
80    }
81
82    /// Retrieve an object from the registry by key and type.
83    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
84        self.registry.get_shared(key)
85    }
86
87    /// Clone a unique object from the registry by key and type.
88    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
89        self.registry.clone_unique(key)
90    }
91
92    /// Take a unique object from the registry by key and type.
93    pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
94        self.registry.take_unique(key)
95    }
96
97    /// Transfer the Context to a new Object without updating the registry
98    /// This returns a tuple of the previous object and the new Context
99    pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
100        (
101            self.current,
102            Context {
103                current: new_current,
104                controller: self.controller,
105                registry: self.registry,
106                stages: self.stages,
107            },
108        )
109    }
110
111    /// Separate out the current object and context
112    pub fn into_parts(self) -> (T, Context<()>) {
113        self.transfer(())
114    }
115
116    pub fn stages(&self) -> &Vec<String> {
117        &self.stages
118    }
119
120    pub fn add_stage(&mut self, stage: &str) {
121        self.stages.push(stage.to_string());
122    }
123
124    /// Transforms the current context to another type using a provided function.
125    pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
126    where
127        F: FnOnce(T) -> U,
128    {
129        // Use the transfer method to move the current value out
130        let (current, temp_context) = self.transfer(());
131
132        // Apply the transformation function to the current value
133        let new_current = f(current);
134
135        // Use transfer again to create the new context with the transformed type
136        temp_context.transfer(new_current).1
137    }
138
139    pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
140    where
141        F: FnOnce(T) -> Result<U, E>,
142        U: Send + Sync + 'static,
143    {
144        // Use the transfer method to move the current value out
145        let (current, temp_context) = self.transfer(());
146
147        // Apply the transformation function to the current value
148        let new_current = f(current)?;
149
150        // Use transfer again to create the new context with the transformed type
151        Ok(temp_context.transfer(new_current).1)
152    }
153}
154
155impl<T: Data> std::fmt::Debug for Context<T> {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("Context")
158            .field("id", &self.controller.id())
159            .finish()
160    }
161}
162
163// Implement Deref to allow Context<T> to act like &T
164impl<T: Data> Deref for Context<T> {
165    type Target = T;
166
167    fn deref(&self) -> &Self::Target {
168        &self.current
169    }
170}
171
172// Implement DerefMut to allow Context<T> to act like &mut T
173impl<T: Data> DerefMut for Context<T> {
174    fn deref_mut(&mut self) -> &mut Self::Target {
175        &mut self.current
176    }
177}
178
179// Implement the custom trait for Context<T>
180impl<T> From<T> for Context<T>
181where
182    T: Send + Sync + 'static,
183{
184    fn from(current: T) -> Self {
185        Context::new(current)
186    }
187}
188
189// Define a custom trait for conversion from Context<T> to Context<U>
190pub trait IntoContext<U: Data> {
191    fn into_context(self) -> Context<U>;
192}
193
194// Implement the custom trait for converting Context<T> to Context<U>
195impl<T, U> IntoContext<U> for Context<T>
196where
197    T: Send + Sync + 'static + Into<U>,
198    U: Send + Sync + 'static,
199{
200    fn into_context(self) -> Context<U> {
201        self.map(|current| current.into())
202    }
203}
204
205impl<T: Data> AsyncEngineContextProvider for Context<T> {
206    fn context(&self) -> Arc<dyn AsyncEngineContext> {
207        self.controller.clone()
208    }
209}
210
211#[derive(Debug, Clone)]
212pub struct StreamContext {
213    controller: Arc<Controller>,
214    registry: Arc<Registry>,
215    stages: Vec<String>,
216}
217
218impl StreamContext {
219    fn new(controller: Arc<Controller>, registry: Registry) -> Self {
220        StreamContext {
221            controller,
222            registry: Arc::new(registry),
223            stages: Vec::new(),
224        }
225    }
226
227    /// Retrieve an object from the registry by key and type.
228    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
229        self.registry.get_shared(key)
230    }
231
232    /// Clone a unique object from the registry by key and type.
233    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
234        self.registry.clone_unique(key)
235    }
236
237    pub fn registry(&self) -> Arc<Registry> {
238        self.registry.clone()
239    }
240
241    pub fn stages(&self) -> &Vec<String> {
242        &self.stages
243    }
244
245    pub fn add_stage(&mut self, stage: &str) {
246        self.stages.push(stage.to_string());
247    }
248}
249
250#[async_trait]
251impl AsyncEngineContext for StreamContext {
252    fn id(&self) -> &str {
253        self.controller.id()
254    }
255
256    fn stop(&self) {
257        self.controller.stop();
258    }
259
260    fn kill(&self) {
261        self.controller.kill();
262    }
263
264    fn stop_generating(&self) {
265        self.controller.stop_generating();
266    }
267
268    fn is_stopped(&self) -> bool {
269        self.controller.is_stopped()
270    }
271
272    fn is_killed(&self) -> bool {
273        self.controller.is_killed()
274    }
275
276    async fn stopped(&self) {
277        self.controller.stopped().await
278    }
279
280    async fn killed(&self) {
281        self.controller.killed().await
282    }
283
284    fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
285        self.controller.link_child(child);
286    }
287}
288
289impl AsyncEngineContextProvider for StreamContext {
290    fn context(&self) -> Arc<dyn AsyncEngineContext> {
291        self.controller.clone()
292    }
293}
294
295impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
296    fn from(value: Context<T>) -> Self {
297        StreamContext::new(value.controller, value.registry)
298    }
299}
300
301// TODO - refactor here - this came from the dynamo.llm-async-engine crate
302
303use tokio::sync::watch::{Receiver, Sender, channel};
304
305#[derive(Debug, Eq, PartialEq)]
306enum State {
307    Live,
308    Stopped,
309    Killed,
310}
311
312/// A context implementation with cancellation propagation.
313#[derive(Debug)]
314pub struct Controller {
315    id: String,
316    tx: Sender<State>,
317    rx: Receiver<State>,
318    child_context: Mutex<Vec<Arc<dyn AsyncEngineContext>>>,
319}
320
321impl Controller {
322    pub fn new(id: String) -> Self {
323        let (tx, rx) = channel(State::Live);
324        Self {
325            id,
326            tx,
327            rx,
328            child_context: Mutex::new(Vec::new()),
329        }
330    }
331
332    pub fn id(&self) -> &str {
333        &self.id
334    }
335}
336
337impl Default for Controller {
338    fn default() -> Self {
339        Self::new(uuid::Uuid::new_v4().to_string())
340    }
341}
342
343impl AsyncEngineController for Controller {}
344
345#[async_trait]
346impl AsyncEngineContext for Controller {
347    fn id(&self) -> &str {
348        &self.id
349    }
350
351    fn is_stopped(&self) -> bool {
352        *self.rx.borrow() != State::Live
353    }
354
355    fn is_killed(&self) -> bool {
356        *self.rx.borrow() == State::Killed
357    }
358
359    async fn stopped(&self) {
360        let mut rx = self.rx.clone();
361        loop {
362            if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
363                return;
364            }
365        }
366    }
367
368    async fn killed(&self) {
369        let mut rx = self.rx.clone();
370        loop {
371            if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
372                return;
373            }
374        }
375    }
376
377    fn stop_generating(&self) {
378        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
379        let children = self
380            .child_context
381            .lock()
382            .expect("Failed to lock child context")
383            .iter()
384            .cloned()
385            .collect::<Vec<_>>();
386        for child in children {
387            child.stop_generating();
388        }
389
390        let _ = self.tx.send(State::Stopped);
391    }
392
393    fn stop(&self) {
394        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
395        let children = self
396            .child_context
397            .lock()
398            .expect("Failed to lock child context")
399            .iter()
400            .cloned()
401            .collect::<Vec<_>>();
402        for child in children {
403            child.stop();
404        }
405
406        let _ = self.tx.send(State::Stopped);
407    }
408
409    fn kill(&self) {
410        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
411        let children = self
412            .child_context
413            .lock()
414            .expect("Failed to lock child context")
415            .iter()
416            .cloned()
417            .collect::<Vec<_>>();
418        for child in children {
419            child.kill();
420        }
421
422        let _ = self.tx.send(State::Killed);
423    }
424
425    fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
426        self.child_context
427            .lock()
428            .expect("Failed to lock child context")
429            .push(child);
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[derive(Debug, Clone)]
438    struct Input {
439        value: String,
440    }
441
442    #[derive(Debug, Clone)]
443    struct Processed {
444        length: usize,
445    }
446
447    #[derive(Debug, Clone)]
448    struct Final {
449        message: String,
450    }
451
452    impl From<Input> for Processed {
453        fn from(input: Input) -> Self {
454            Processed {
455                length: input.value.len(),
456            }
457        }
458    }
459
460    impl From<Processed> for Final {
461        fn from(processed: Processed) -> Self {
462            Final {
463                message: format!("Processed length: {}", processed.length),
464            }
465        }
466    }
467
468    #[test]
469    fn test_insert_and_get() {
470        let mut ctx = Context::new(Input {
471            value: "Hello".to_string(),
472        });
473
474        ctx.insert("key1", 42);
475        ctx.insert("key2", "some data".to_string());
476
477        assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
478        assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
479        assert!(ctx.get::<f64>("key1").is_err()); // Testing a downcast failure
480    }
481
482    #[test]
483    fn test_transfer() {
484        let ctx = Context::new(Input {
485            value: "Hello".to_string(),
486        });
487
488        let (input, ctx) = ctx.transfer(Processed { length: 5 });
489
490        assert_eq!(input.value, "Hello");
491        assert_eq!(ctx.length, 5);
492    }
493
494    #[test]
495    fn test_map() {
496        let ctx = Context::new(Input {
497            value: "Hello".to_string(),
498        });
499
500        let ctx: Context<Processed> = ctx.map(|input| input.into());
501        let ctx: Context<Final> = ctx.map(|processed| processed.into());
502
503        assert_eq!(ctx.current.message, "Processed length: 5");
504    }
505
506    #[test]
507    fn test_into_context() {
508        let ctx = Context::new(Input {
509            value: "Hello".to_string(),
510        });
511
512        let ctx: Context<Processed> = ctx.into_context();
513        let ctx: Context<Final> = ctx.into_context();
514
515        assert_eq!(ctx.current.message, "Processed length: 5");
516    }
517}