Skip to main content

dynamo_runtime/pipeline/
context.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::BTreeMap;
5use std::ops::{Deref, DerefMut};
6use std::sync::{Arc, Mutex};
7
8use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
9use crate::engine::AsyncEngineController;
10use async_trait::async_trait;
11
12use super::registry::Registry;
13
14pub struct Context<T: Data> {
15    current: T,
16    controller: Arc<Controller>, //todo: hold this as an arc
17    registry: Registry,
18    stages: Vec<String>,
19    metadata: BTreeMap<String, String>,
20}
21
22impl<T: Send + Sync + 'static> Context<T> {
23    // Create a new context with initial data
24    pub fn new(current: T) -> Self {
25        Context {
26            current,
27            controller: Arc::new(Controller::default()),
28            registry: Registry::new(),
29            stages: Vec::new(),
30            metadata: BTreeMap::new(),
31        }
32    }
33
34    pub fn rejoin<U: Send + Sync + 'static>(current: T, context: Context<U>) -> Self {
35        Context {
36            current,
37            controller: context.controller,
38            registry: context.registry,
39            stages: context.stages,
40            metadata: context.metadata,
41        }
42    }
43
44    pub fn with_controller(current: T, controller: Controller) -> Self {
45        Context {
46            current,
47            controller: Arc::new(controller),
48            registry: Registry::new(),
49            stages: Vec::new(),
50            metadata: BTreeMap::new(),
51        }
52    }
53
54    #[deprecated(
55        since = "1.1.2",
56        note = "Use `Context::with_id_and_metadata` instead; pass `Default::default()` \
57                when you have no metadata to propagate. `with_id` will be removed once \
58                all call sites have been migrated."
59    )]
60    pub fn with_id(current: T, id: String) -> Self {
61        Context {
62            current,
63            controller: Arc::new(Controller::new(id)),
64            registry: Registry::new(),
65            stages: Vec::new(),
66            metadata: BTreeMap::new(),
67        }
68    }
69
70    pub fn with_id_and_metadata(
71        current: T,
72        id: String,
73        metadata: BTreeMap<String, String>,
74    ) -> Self {
75        Context {
76            current,
77            controller: Arc::new(Controller::new(id)),
78            registry: Registry::new(),
79            stages: Vec::new(),
80            metadata,
81        }
82    }
83
84    /// Get the id of the context
85    pub fn id(&self) -> &str {
86        self.controller.id()
87    }
88
89    /// Get the content of the context
90    pub fn content(&self) -> &T {
91        &self.current
92    }
93
94    pub fn controller(&self) -> &Controller {
95        &self.controller
96    }
97
98    pub fn metadata(&self) -> &BTreeMap<String, String> {
99        &self.metadata
100    }
101
102    pub fn metadata_mut(&mut self) -> &mut BTreeMap<String, String> {
103        &mut self.metadata
104    }
105
106    pub fn set_metadata(&mut self, metadata: BTreeMap<String, String>) {
107        self.metadata = metadata;
108    }
109
110    pub fn insert_metadata<K: Into<String>, V: Into<String>>(&mut self, key: K, value: V) {
111        self.metadata.insert(key.into(), value.into());
112    }
113
114    /// Insert an object into the registry with a specific key.
115    pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
116        self.registry.insert_shared(key, value);
117    }
118
119    /// Insert a unique and takable object into the registry with a specific key.
120    pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
121        self.registry.insert_unique(key, value);
122    }
123
124    /// Retrieve an object from the registry by key and type.
125    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
126        self.registry.get_shared(key)
127    }
128
129    /// Clone a unique object from the registry by key and type.
130    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
131        self.registry.clone_unique(key)
132    }
133
134    /// Take a unique object from the registry by key and type.
135    pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
136        self.registry.take_unique(key)
137    }
138
139    /// Transfer the Context to a new Object without updating the registry
140    /// This returns a tuple of the previous object and the new Context
141    pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
142        (
143            self.current,
144            Context {
145                current: new_current,
146                controller: self.controller,
147                registry: self.registry,
148                stages: self.stages,
149                metadata: self.metadata,
150            },
151        )
152    }
153
154    /// Separate out the current object and context
155    pub fn into_parts(self) -> (T, Context<()>) {
156        self.transfer(())
157    }
158
159    pub fn stages(&self) -> &Vec<String> {
160        &self.stages
161    }
162
163    pub fn add_stage(&mut self, stage: &str) {
164        self.stages.push(stage.to_string());
165    }
166
167    /// Transforms the current context to another type using a provided function.
168    pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
169    where
170        F: FnOnce(T) -> U,
171    {
172        // Use the transfer method to move the current value out
173        let (current, temp_context) = self.transfer(());
174
175        // Apply the transformation function to the current value
176        let new_current = f(current);
177
178        // Use transfer again to create the new context with the transformed type
179        temp_context.transfer(new_current).1
180    }
181
182    pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
183    where
184        F: FnOnce(T) -> Result<U, E>,
185        U: Send + Sync + 'static,
186    {
187        // Use the transfer method to move the current value out
188        let (current, temp_context) = self.transfer(());
189
190        // Apply the transformation function to the current value
191        let new_current = f(current)?;
192
193        // Use transfer again to create the new context with the transformed type
194        Ok(temp_context.transfer(new_current).1)
195    }
196}
197
198impl<T: Data> std::fmt::Debug for Context<T> {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("Context")
201            .field("id", &self.controller.id())
202            .finish()
203    }
204}
205
206// Implement Deref to allow Context<T> to act like &T
207impl<T: Data> Deref for Context<T> {
208    type Target = T;
209
210    fn deref(&self) -> &Self::Target {
211        &self.current
212    }
213}
214
215// Implement DerefMut to allow Context<T> to act like &mut T
216impl<T: Data> DerefMut for Context<T> {
217    fn deref_mut(&mut self) -> &mut Self::Target {
218        &mut self.current
219    }
220}
221
222// Implement the custom trait for Context<T>
223impl<T> From<T> for Context<T>
224where
225    T: Send + Sync + 'static,
226{
227    fn from(current: T) -> Self {
228        Context::new(current)
229    }
230}
231
232// Define a custom trait for conversion from Context<T> to Context<U>
233pub trait IntoContext<U: Data> {
234    fn into_context(self) -> Context<U>;
235}
236
237// Implement the custom trait for converting Context<T> to Context<U>
238impl<T, U> IntoContext<U> for Context<T>
239where
240    T: Send + Sync + 'static + Into<U>,
241    U: Send + Sync + 'static,
242{
243    fn into_context(self) -> Context<U> {
244        self.map(|current| current.into())
245    }
246}
247
248impl<T: Data> AsyncEngineContextProvider for Context<T> {
249    fn context(&self) -> Arc<dyn AsyncEngineContext> {
250        self.controller.clone()
251    }
252}
253
254#[derive(Debug, Clone)]
255pub struct StreamContext {
256    controller: Arc<Controller>,
257    registry: Arc<Registry>,
258    stages: Vec<String>,
259    metadata: BTreeMap<String, String>,
260}
261
262impl StreamContext {
263    fn new(
264        controller: Arc<Controller>,
265        registry: Registry,
266        metadata: BTreeMap<String, String>,
267    ) -> Self {
268        StreamContext {
269            controller,
270            registry: Arc::new(registry),
271            stages: Vec::new(),
272            metadata,
273        }
274    }
275
276    /// Retrieve an object from the registry by key and type.
277    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
278        self.registry.get_shared(key)
279    }
280
281    /// Clone a unique object from the registry by key and type.
282    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
283        self.registry.clone_unique(key)
284    }
285
286    pub fn registry(&self) -> Arc<Registry> {
287        self.registry.clone()
288    }
289
290    pub fn stages(&self) -> &Vec<String> {
291        &self.stages
292    }
293
294    pub fn add_stage(&mut self, stage: &str) {
295        self.stages.push(stage.to_string());
296    }
297
298    pub fn metadata(&self) -> &BTreeMap<String, String> {
299        &self.metadata
300    }
301}
302
303#[async_trait]
304impl AsyncEngineContext for StreamContext {
305    fn id(&self) -> &str {
306        self.controller.id()
307    }
308
309    fn stop(&self) {
310        self.controller.stop();
311    }
312
313    fn kill(&self) {
314        self.controller.kill();
315    }
316
317    fn stop_generating(&self) {
318        self.controller.stop_generating();
319    }
320
321    fn is_stopped(&self) -> bool {
322        self.controller.is_stopped()
323    }
324
325    fn is_killed(&self) -> bool {
326        self.controller.is_killed()
327    }
328
329    async fn stopped(&self) {
330        self.controller.stopped().await
331    }
332
333    async fn killed(&self) {
334        self.controller.killed().await
335    }
336
337    fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
338        self.controller.link_child(child);
339    }
340}
341
342impl AsyncEngineContextProvider for StreamContext {
343    fn context(&self) -> Arc<dyn AsyncEngineContext> {
344        self.controller.clone()
345    }
346}
347
348impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
349    fn from(value: Context<T>) -> Self {
350        StreamContext::new(value.controller, value.registry, value.metadata)
351    }
352}
353
354// TODO - refactor here - this came from the dynamo.llm-async-engine crate
355
356use tokio::sync::watch::{Receiver, Sender, channel};
357
358#[derive(Debug, Eq, PartialEq)]
359enum State {
360    Live,
361    Stopped,
362    Killed,
363}
364
365/// A context implementation with cancellation propagation.
366#[derive(Debug)]
367pub struct Controller {
368    id: String,
369    tx: Sender<State>,
370    rx: Receiver<State>,
371    child_context: Mutex<Vec<Arc<dyn AsyncEngineContext>>>,
372}
373
374impl Controller {
375    pub fn new(id: String) -> Self {
376        let (tx, rx) = channel(State::Live);
377        Self {
378            id,
379            tx,
380            rx,
381            child_context: Mutex::new(Vec::new()),
382        }
383    }
384
385    pub fn id(&self) -> &str {
386        &self.id
387    }
388}
389
390impl Default for Controller {
391    fn default() -> Self {
392        Self::new(uuid::Uuid::new_v4().to_string())
393    }
394}
395
396impl AsyncEngineController for Controller {}
397
398#[async_trait]
399impl AsyncEngineContext for Controller {
400    fn id(&self) -> &str {
401        &self.id
402    }
403
404    fn is_stopped(&self) -> bool {
405        *self.rx.borrow() != State::Live
406    }
407
408    fn is_killed(&self) -> bool {
409        *self.rx.borrow() == State::Killed
410    }
411
412    async fn stopped(&self) {
413        let mut rx = self.rx.clone();
414        loop {
415            if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
416                return;
417            }
418        }
419    }
420
421    async fn killed(&self) {
422        let mut rx = self.rx.clone();
423        loop {
424            if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
425                return;
426            }
427        }
428    }
429
430    fn stop_generating(&self) {
431        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
432        let children = self
433            .child_context
434            .lock()
435            .expect("Failed to lock child context")
436            .iter()
437            .cloned()
438            .collect::<Vec<_>>();
439        for child in children {
440            child.stop_generating();
441        }
442
443        let _ = self.tx.send(State::Stopped);
444    }
445
446    fn stop(&self) {
447        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
448        let children = self
449            .child_context
450            .lock()
451            .expect("Failed to lock child context")
452            .iter()
453            .cloned()
454            .collect::<Vec<_>>();
455        for child in children {
456            child.stop();
457        }
458
459        let _ = self.tx.send(State::Stopped);
460    }
461
462    fn kill(&self) {
463        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
464        let children = self
465            .child_context
466            .lock()
467            .expect("Failed to lock child context")
468            .iter()
469            .cloned()
470            .collect::<Vec<_>>();
471        for child in children {
472            child.kill();
473        }
474
475        let _ = self.tx.send(State::Killed);
476    }
477
478    fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
479        self.child_context
480            .lock()
481            .expect("Failed to lock child context")
482            .push(child);
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    #[derive(Debug, Clone)]
491    struct Input {
492        value: String,
493    }
494
495    #[derive(Debug, Clone)]
496    struct Processed {
497        length: usize,
498    }
499
500    #[derive(Debug, Clone)]
501    struct Final {
502        message: String,
503    }
504
505    impl From<Input> for Processed {
506        fn from(input: Input) -> Self {
507            Processed {
508                length: input.value.len(),
509            }
510        }
511    }
512
513    impl From<Processed> for Final {
514        fn from(processed: Processed) -> Self {
515            Final {
516                message: format!("Processed length: {}", processed.length),
517            }
518        }
519    }
520
521    #[test]
522    fn test_insert_and_get() {
523        let mut ctx = Context::new(Input {
524            value: "Hello".to_string(),
525        });
526
527        ctx.insert("key1", 42);
528        ctx.insert("key2", "some data".to_string());
529
530        assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
531        assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
532        assert!(ctx.get::<f64>("key1").is_err()); // Testing a downcast failure
533    }
534
535    #[test]
536    fn test_metadata_preserved_across_transfers() {
537        let mut ctx = Context::new(Input {
538            value: "Hello".to_string(),
539        });
540        ctx.insert_metadata("tenant", "alpha");
541
542        let (_, transferred) = ctx.transfer(Processed { length: 5 });
543        assert_eq!(
544            transferred.metadata().get("tenant").map(String::as_str),
545            Some("alpha")
546        );
547    }
548
549    #[test]
550    fn test_with_id_and_metadata_constructor() {
551        let metadata = BTreeMap::from([("tenant".to_string(), "alpha".to_string())]);
552        let ctx = Context::with_id_and_metadata(
553            Input {
554                value: "Hello".to_string(),
555            },
556            "request-123".to_string(),
557            metadata,
558        );
559
560        assert_eq!(ctx.id(), "request-123");
561        assert_eq!(
562            ctx.metadata().get("tenant").map(String::as_str),
563            Some("alpha")
564        );
565    }
566
567    #[test]
568    fn test_metadata_preserved_across_rejoin() {
569        let mut ctx = Context::new(Input {
570            value: "Hello".to_string(),
571        });
572        ctx.insert_metadata("tenant", "alpha");
573
574        let (input, empty_ctx) = ctx.into_parts();
575        let rejoined = Context::rejoin(input, empty_ctx);
576        assert_eq!(
577            rejoined.metadata().get("tenant").map(String::as_str),
578            Some("alpha")
579        );
580    }
581
582    #[test]
583    fn test_metadata_preserved_in_stream_context() {
584        let mut ctx = Context::new(Input {
585            value: "Hello".to_string(),
586        });
587        ctx.insert_metadata("tenant", "alpha");
588
589        let stream_ctx = StreamContext::from(ctx);
590        assert_eq!(
591            stream_ctx.metadata().get("tenant").map(String::as_str),
592            Some("alpha")
593        );
594    }
595
596    #[test]
597    fn test_transfer() {
598        let ctx = Context::new(Input {
599            value: "Hello".to_string(),
600        });
601
602        let (input, ctx) = ctx.transfer(Processed { length: 5 });
603
604        assert_eq!(input.value, "Hello");
605        assert_eq!(ctx.length, 5);
606    }
607
608    #[test]
609    fn test_map() {
610        let ctx = Context::new(Input {
611            value: "Hello".to_string(),
612        });
613
614        let ctx: Context<Processed> = ctx.map(|input| input.into());
615        let ctx: Context<Final> = ctx.map(|processed| processed.into());
616
617        assert_eq!(ctx.current.message, "Processed length: 5");
618    }
619
620    #[test]
621    fn test_into_context() {
622        let ctx = Context::new(Input {
623            value: "Hello".to_string(),
624        });
625
626        let ctx: Context<Processed> = ctx.into_context();
627        let ctx: Context<Final> = ctx.into_context();
628
629        assert_eq!(ctx.current.message, "Processed length: 5");
630    }
631}