dynamo_runtime/pipeline/
context.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//! Context Module
16//!
17//! There are two context object defined in this module:
18//!
19//! - [`Context`] is an input context which is propagated through the processing pipeline,
20//!   up to the point where the input is pass to an [`crate::engine::AsyncEngine`] for processing.
21//! - [`StreamContext`] is the input context transformed into to a type erased context that maintains the inputs
22//!   registry and visitors. `StreamAdaptors` will amend themselves to the [`StreamContext`] to allow for the
23
24use std::ops::{Deref, DerefMut};
25use std::sync::Arc;
26
27use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
28use crate::engine::AsyncEngineController;
29use async_trait::async_trait;
30
31use super::registry::Registry;
32
33pub struct Context<T: Data> {
34    current: T,
35    controller: Arc<Controller>, //todo: hold this as an arc
36    registry: Registry,
37    stages: Vec<String>,
38}
39
40impl<T: Send + Sync + 'static> Context<T> {
41    // Create a new context with initial data
42    pub fn new(current: T) -> Self {
43        Context {
44            current,
45            controller: Arc::new(Controller::default()),
46            registry: Registry::new(),
47            stages: Vec::new(),
48        }
49    }
50
51    pub fn rejoin<U: Send + Sync + 'static>(current: T, context: Context<U>) -> Self {
52        Context {
53            current,
54            controller: context.controller,
55            registry: context.registry,
56            stages: context.stages,
57        }
58    }
59
60    pub fn with_controller(current: T, controller: Controller) -> Self {
61        Context {
62            current,
63            controller: Arc::new(controller),
64            registry: Registry::new(),
65            stages: Vec::new(),
66        }
67    }
68
69    pub fn with_id(current: T, id: String) -> Self {
70        Context {
71            current,
72            controller: Arc::new(Controller::new(id)),
73            registry: Registry::new(),
74            stages: Vec::new(),
75        }
76    }
77
78    /// Get the id of the context
79    pub fn id(&self) -> &str {
80        self.controller.id()
81    }
82
83    /// Get the content of the context
84    pub fn content(&self) -> &T {
85        &self.current
86    }
87
88    pub fn controller(&self) -> &Controller {
89        &self.controller
90    }
91
92    /// Insert an object into the registry with a specific key.
93    pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
94        self.registry.insert_shared(key, value);
95    }
96
97    /// Insert a unique and takable object into the registry with a specific key.
98    pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
99        self.registry.insert_unique(key, value);
100    }
101
102    /// Retrieve an object from the registry by key and type.
103    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
104        self.registry.get_shared(key)
105    }
106
107    /// Clone a unique object from the registry by key and type.
108    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
109        self.registry.clone_unique(key)
110    }
111
112    /// Take a unique object from the registry by key and type.
113    pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
114        self.registry.take_unique(key)
115    }
116
117    /// Transfer the Context to a new Object without updating the registry
118    /// This returns a tuple of the previous object and the new Context
119    pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
120        (
121            self.current,
122            Context {
123                current: new_current,
124                controller: self.controller,
125                registry: self.registry,
126                stages: self.stages,
127            },
128        )
129    }
130
131    /// Separate out the current object and context
132    pub fn into_parts(self) -> (T, Context<()>) {
133        self.transfer(())
134    }
135
136    pub fn stages(&self) -> &Vec<String> {
137        &self.stages
138    }
139
140    pub fn add_stage(&mut self, stage: &str) {
141        self.stages.push(stage.to_string());
142    }
143
144    /// Transforms the current context to another type using a provided function.
145    pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
146    where
147        F: FnOnce(T) -> U,
148    {
149        // Use the transfer method to move the current value out
150        let (current, temp_context) = self.transfer(());
151
152        // Apply the transformation function to the current value
153        let new_current = f(current);
154
155        // Use transfer again to create the new context with the transformed type
156        temp_context.transfer(new_current).1
157    }
158
159    pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
160    where
161        F: FnOnce(T) -> Result<U, E>,
162        U: Send + Sync + 'static,
163    {
164        // Use the transfer method to move the current value out
165        let (current, temp_context) = self.transfer(());
166
167        // Apply the transformation function to the current value
168        let new_current = f(current)?;
169
170        // Use transfer again to create the new context with the transformed type
171        Ok(temp_context.transfer(new_current).1)
172    }
173}
174
175impl<T: Data> std::fmt::Debug for Context<T> {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        f.debug_struct("Context")
178            .field("id", &self.controller.id())
179            .finish()
180    }
181}
182
183// Implement Deref to allow Context<T> to act like &T
184impl<T: Data> Deref for Context<T> {
185    type Target = T;
186
187    fn deref(&self) -> &Self::Target {
188        &self.current
189    }
190}
191
192// Implement DerefMut to allow Context<T> to act like &mut T
193impl<T: Data> DerefMut for Context<T> {
194    fn deref_mut(&mut self) -> &mut Self::Target {
195        &mut self.current
196    }
197}
198
199// Implement the custom trait for Context<T>
200impl<T> From<T> for Context<T>
201where
202    T: Send + Sync + 'static,
203{
204    fn from(current: T) -> Self {
205        Context::new(current)
206    }
207}
208
209// Define a custom trait for conversion from Context<T> to Context<U>
210pub trait IntoContext<U: Data> {
211    fn into_context(self) -> Context<U>;
212}
213
214// Implement the custom trait for converting Context<T> to Context<U>
215impl<T, U> IntoContext<U> for Context<T>
216where
217    T: Send + Sync + 'static + Into<U>,
218    U: Send + Sync + 'static,
219{
220    fn into_context(self) -> Context<U> {
221        self.map(|current| current.into())
222    }
223}
224
225impl<T: Data> AsyncEngineContextProvider for Context<T> {
226    fn context(&self) -> Arc<dyn AsyncEngineContext> {
227        self.controller.clone()
228    }
229}
230
231#[derive(Debug, Clone)]
232pub struct StreamContext {
233    controller: Arc<Controller>,
234    registry: Arc<Registry>,
235    stages: Vec<String>,
236}
237
238impl StreamContext {
239    fn new(controller: Arc<Controller>, registry: Registry) -> Self {
240        StreamContext {
241            controller,
242            registry: Arc::new(registry),
243            stages: Vec::new(),
244        }
245    }
246
247    /// Retrieve an object from the registry by key and type.
248    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
249        self.registry.get_shared(key)
250    }
251
252    /// Clone a unique object from the registry by key and type.
253    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
254        self.registry.clone_unique(key)
255    }
256
257    pub fn registry(&self) -> Arc<Registry> {
258        self.registry.clone()
259    }
260
261    pub fn stages(&self) -> &Vec<String> {
262        &self.stages
263    }
264
265    pub fn add_stage(&mut self, stage: &str) {
266        self.stages.push(stage.to_string());
267    }
268}
269
270#[async_trait]
271impl AsyncEngineContext for StreamContext {
272    fn id(&self) -> &str {
273        self.controller.id()
274    }
275
276    fn stop(&self) {
277        self.controller.stop();
278    }
279
280    fn kill(&self) {
281        self.controller.kill();
282    }
283
284    fn stop_generating(&self) {
285        self.controller.stop_generating();
286    }
287
288    fn is_stopped(&self) -> bool {
289        self.controller.is_stopped()
290    }
291
292    fn is_killed(&self) -> bool {
293        self.controller.is_killed()
294    }
295
296    async fn stopped(&self) {
297        self.controller.stopped().await
298    }
299
300    async fn killed(&self) {
301        self.controller.killed().await
302    }
303}
304
305impl AsyncEngineContextProvider for StreamContext {
306    fn context(&self) -> Arc<dyn AsyncEngineContext> {
307        self.controller.clone()
308    }
309}
310
311impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
312    fn from(value: Context<T>) -> Self {
313        StreamContext::new(value.controller, value.registry)
314    }
315}
316
317// TODO - refactor here - this came from the dynamo.llm-async-engine crate
318
319use tokio::sync::watch::{channel, Receiver, Sender};
320
321#[derive(Debug, Eq, PartialEq)]
322enum State {
323    Live,
324    Stopped,
325    Killed,
326}
327
328/// A context implementation with cancellation propagation.
329#[derive(Debug)]
330pub struct Controller {
331    id: String,
332    tx: Sender<State>,
333    rx: Receiver<State>,
334}
335
336impl Controller {
337    pub fn new(id: String) -> Self {
338        let (tx, rx) = channel(State::Live);
339        Self { id, tx, rx }
340    }
341
342    pub fn id(&self) -> &str {
343        &self.id
344    }
345}
346
347impl Default for Controller {
348    fn default() -> Self {
349        Self::new(uuid::Uuid::new_v4().to_string())
350    }
351}
352
353impl AsyncEngineController for Controller {}
354
355#[async_trait]
356impl AsyncEngineContext for Controller {
357    fn id(&self) -> &str {
358        &self.id
359    }
360
361    fn is_stopped(&self) -> bool {
362        *self.rx.borrow() != State::Live
363    }
364
365    fn is_killed(&self) -> bool {
366        *self.rx.borrow() == State::Killed
367    }
368
369    async fn stopped(&self) {
370        let mut rx = self.rx.clone();
371        if *rx.borrow_and_update() != State::Live {
372            return;
373        }
374        let _ = rx.changed().await;
375    }
376
377    async fn killed(&self) {
378        let mut rx = self.rx.clone();
379        if *rx.borrow_and_update() == State::Killed {
380            return;
381        }
382        let _ = rx.changed().await;
383    }
384
385    fn stop_generating(&self) {
386        self.stop();
387    }
388
389    fn stop(&self) {
390        let _ = self.tx.send(State::Stopped);
391    }
392
393    fn kill(&self) {
394        let _ = self.tx.send(State::Killed);
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[derive(Debug, Clone)]
403    struct Input {
404        value: String,
405    }
406
407    #[derive(Debug, Clone)]
408    struct Processed {
409        length: usize,
410    }
411
412    #[derive(Debug, Clone)]
413    struct Final {
414        message: String,
415    }
416
417    impl From<Input> for Processed {
418        fn from(input: Input) -> Self {
419            Processed {
420                length: input.value.len(),
421            }
422        }
423    }
424
425    impl From<Processed> for Final {
426        fn from(processed: Processed) -> Self {
427            Final {
428                message: format!("Processed length: {}", processed.length),
429            }
430        }
431    }
432
433    #[test]
434    fn test_insert_and_get() {
435        let mut ctx = Context::new(Input {
436            value: "Hello".to_string(),
437        });
438
439        ctx.insert("key1", 42);
440        ctx.insert("key2", "some data".to_string());
441
442        assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
443        assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
444        assert!(ctx.get::<f64>("key1").is_err()); // Testing a downcast failure
445    }
446
447    #[test]
448    fn test_transfer() {
449        let ctx = Context::new(Input {
450            value: "Hello".to_string(),
451        });
452
453        let (input, ctx) = ctx.transfer(Processed { length: 5 });
454
455        assert_eq!(input.value, "Hello");
456        assert_eq!(ctx.length, 5);
457    }
458
459    #[test]
460    fn test_map() {
461        let ctx = Context::new(Input {
462            value: "Hello".to_string(),
463        });
464
465        let ctx: Context<Processed> = ctx.map(|input| input.into());
466        let ctx: Context<Final> = ctx.map(|processed| processed.into());
467
468        assert_eq!(ctx.current.message, "Processed length: 5");
469    }
470
471    #[test]
472    fn test_into_context() {
473        let ctx = Context::new(Input {
474            value: "Hello".to_string(),
475        });
476
477        let ctx: Context<Processed> = ctx.into_context();
478        let ctx: Context<Final> = ctx.into_context();
479
480        assert_eq!(ctx.current.message, "Processed length: 5");
481    }
482}