dynamo_runtime/pipeline/
context.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//! Context Module
16//!
17//! There are two context object defined in this module:
18//!
19//! - [`Context`] is an input context which is propagated through the processing pipeline,
20//!    up to the point where the input is pass to an [`crate::engine::AsyncEngine`] for processing.
21//! - [`StreamContext`] is the input context transformed into to a type erased context that maintains the inputs
22//!   registry and visitors. `StreamAdaptors` will amend themselves to the [`StreamContext`] to allow for the
23
24use std::ops::{Deref, DerefMut};
25use std::sync::Arc;
26
27use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
28use crate::engine::AsyncEngineController;
29use async_trait::async_trait;
30
31use super::registry::Registry;
32
33pub struct Context<T: Data> {
34    current: T,
35    controller: Arc<Controller>, //todo: hold this as an arc
36    registry: Registry,
37    stages: Vec<String>,
38}
39
40impl<T: Send + Sync + 'static> Context<T> {
41    // Create a new context with initial data
42    pub fn new(current: T) -> Self {
43        Context {
44            current,
45            controller: Arc::new(Controller::default()),
46            registry: Registry::new(),
47            stages: Vec::new(),
48        }
49    }
50
51    pub fn with_controller(current: T, controller: Controller) -> Self {
52        Context {
53            current,
54            controller: Arc::new(controller),
55            registry: Registry::new(),
56            stages: Vec::new(),
57        }
58    }
59
60    pub fn with_id(current: T, id: String) -> Self {
61        Context {
62            current,
63            controller: Arc::new(Controller::new(id)),
64            registry: Registry::new(),
65            stages: Vec::new(),
66        }
67    }
68
69    pub fn id(&self) -> &str {
70        self.controller.id()
71    }
72
73    pub fn controller(&self) -> &Controller {
74        &self.controller
75    }
76
77    /// Insert an object into the registry with a specific key.
78    pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
79        self.registry.insert_shared(key, value);
80    }
81
82    /// Insert a unique and takable object into the registry with a specific key.
83    pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
84        self.registry.insert_unique(key, value);
85    }
86
87    /// Retrieve an object from the registry by key and type.
88    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
89        self.registry.get_shared(key)
90    }
91
92    /// Clone a unique object from the registry by key and type.
93    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
94        self.registry.clone_unique(key)
95    }
96
97    /// Take a unique object from the registry by key and type.
98    pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
99        self.registry.take_unique(key)
100    }
101
102    /// Transfer the Context to a new Object without updating the registry
103    /// This returns a tuple of the previous object and the new Context
104    pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
105        (
106            self.current,
107            Context {
108                current: new_current,
109                controller: self.controller,
110                registry: self.registry,
111                stages: self.stages,
112            },
113        )
114    }
115
116    /// Separate out the current object and context
117    pub fn into_parts(self) -> (T, Context<()>) {
118        self.transfer(())
119    }
120
121    pub fn stages(&self) -> &Vec<String> {
122        &self.stages
123    }
124
125    pub fn add_stage(&mut self, stage: &str) {
126        self.stages.push(stage.to_string());
127    }
128
129    /// Transforms the current context to another type using a provided function.
130    pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
131    where
132        F: FnOnce(T) -> U,
133    {
134        // Use the transfer method to move the current value out
135        let (current, temp_context) = self.transfer(());
136
137        // Apply the transformation function to the current value
138        let new_current = f(current);
139
140        // Use transfer again to create the new context with the transformed type
141        temp_context.transfer(new_current).1
142    }
143
144    pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
145    where
146        F: FnOnce(T) -> Result<U, E>,
147        U: Send + Sync + 'static,
148    {
149        // Use the transfer method to move the current value out
150        let (current, temp_context) = self.transfer(());
151
152        // Apply the transformation function to the current value
153        let new_current = f(current)?;
154
155        // Use transfer again to create the new context with the transformed type
156        Ok(temp_context.transfer(new_current).1)
157    }
158}
159
160impl<T: Data> std::fmt::Debug for Context<T> {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("Context")
163            .field("id", &self.controller.id())
164            .finish()
165    }
166}
167
168// Implement Deref to allow Context<T> to act like &T
169impl<T: Data> Deref for Context<T> {
170    type Target = T;
171
172    fn deref(&self) -> &Self::Target {
173        &self.current
174    }
175}
176
177// Implement DerefMut to allow Context<T> to act like &mut T
178impl<T: Data> DerefMut for Context<T> {
179    fn deref_mut(&mut self) -> &mut Self::Target {
180        &mut self.current
181    }
182}
183
184// Implement the custom trait for Context<T>
185impl<T> From<T> for Context<T>
186where
187    T: Send + Sync + 'static,
188{
189    fn from(current: T) -> Self {
190        Context::new(current)
191    }
192}
193
194// Define a custom trait for conversion from Context<T> to Context<U>
195pub trait IntoContext<U: Data> {
196    fn into_context(self) -> Context<U>;
197}
198
199// Implement the custom trait for converting Context<T> to Context<U>
200impl<T, U> IntoContext<U> for Context<T>
201where
202    T: Send + Sync + 'static + Into<U>,
203    U: Send + Sync + 'static,
204{
205    fn into_context(self) -> Context<U> {
206        self.map(|current| current.into())
207    }
208}
209
210impl<T: Data> AsyncEngineContextProvider for Context<T> {
211    fn context(&self) -> Arc<dyn AsyncEngineContext> {
212        self.controller.clone()
213    }
214}
215
216#[derive(Debug, Clone)]
217pub struct StreamContext {
218    controller: Arc<Controller>,
219    registry: Arc<Registry>,
220    stages: Vec<String>,
221}
222
223impl StreamContext {
224    fn new(controller: Arc<Controller>, registry: Registry) -> Self {
225        StreamContext {
226            controller,
227            registry: Arc::new(registry),
228            stages: Vec::new(),
229        }
230    }
231
232    /// Retrieve an object from the registry by key and type.
233    pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
234        self.registry.get_shared(key)
235    }
236
237    /// Clone a unique object from the registry by key and type.
238    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
239        self.registry.clone_unique(key)
240    }
241
242    pub fn registry(&self) -> Arc<Registry> {
243        self.registry.clone()
244    }
245
246    pub fn stages(&self) -> &Vec<String> {
247        &self.stages
248    }
249
250    pub fn add_stage(&mut self, stage: &str) {
251        self.stages.push(stage.to_string());
252    }
253}
254
255#[async_trait]
256impl AsyncEngineContext for StreamContext {
257    fn id(&self) -> &str {
258        self.controller.id()
259    }
260
261    fn stop(&self) {
262        self.controller.stop();
263    }
264
265    fn kill(&self) {
266        self.controller.kill();
267    }
268
269    fn stop_generating(&self) {
270        self.controller.stop_generating();
271    }
272
273    fn is_stopped(&self) -> bool {
274        self.controller.is_stopped()
275    }
276
277    fn is_killed(&self) -> bool {
278        self.controller.is_killed()
279    }
280
281    async fn stopped(&self) {
282        self.controller.stopped().await
283    }
284
285    async fn killed(&self) {
286        self.controller.killed().await
287    }
288}
289
290impl AsyncEngineContextProvider for StreamContext {
291    fn context(&self) -> Arc<dyn AsyncEngineContext> {
292        self.controller.clone()
293    }
294}
295
296impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
297    fn from(value: Context<T>) -> Self {
298        StreamContext::new(value.controller, value.registry)
299    }
300}
301
302// TODO - refactor here - this came from the dynamo.llm-async-engine crate
303
304use tokio::sync::watch::{channel, Receiver, Sender};
305
306#[derive(Debug, Eq, PartialEq)]
307enum State {
308    Live,
309    Stopped,
310    Killed,
311}
312
313/// A context implementation with cancellation propagation.
314#[derive(Debug)]
315pub struct Controller {
316    id: String,
317    tx: Sender<State>,
318    rx: Receiver<State>,
319}
320
321impl Controller {
322    pub fn new(id: String) -> Self {
323        let (tx, rx) = channel(State::Live);
324        Self { id, tx, rx }
325    }
326
327    pub fn id(&self) -> &str {
328        &self.id
329    }
330}
331
332impl Default for Controller {
333    fn default() -> Self {
334        Self::new(uuid::Uuid::new_v4().to_string())
335    }
336}
337
338impl AsyncEngineController for Controller {}
339
340#[async_trait]
341impl AsyncEngineContext for Controller {
342    fn id(&self) -> &str {
343        &self.id
344    }
345
346    fn is_stopped(&self) -> bool {
347        *self.rx.borrow() != State::Live
348    }
349
350    fn is_killed(&self) -> bool {
351        *self.rx.borrow() == State::Killed
352    }
353
354    async fn stopped(&self) {
355        let mut rx = self.rx.clone();
356        if *rx.borrow_and_update() != State::Live {
357            return;
358        }
359        let _ = rx.changed().await;
360    }
361
362    async fn killed(&self) {
363        let mut rx = self.rx.clone();
364        if *rx.borrow_and_update() == State::Killed {
365            return;
366        }
367        let _ = rx.changed().await;
368    }
369
370    fn stop_generating(&self) {
371        self.stop();
372    }
373
374    fn stop(&self) {
375        let _ = self.tx.send(State::Stopped);
376    }
377
378    fn kill(&self) {
379        let _ = self.tx.send(State::Killed);
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[derive(Debug, Clone)]
388    struct Input {
389        value: String,
390    }
391
392    #[derive(Debug, Clone)]
393    struct Processed {
394        length: usize,
395    }
396
397    #[derive(Debug, Clone)]
398    struct Final {
399        message: String,
400    }
401
402    impl From<Input> for Processed {
403        fn from(input: Input) -> Self {
404            Processed {
405                length: input.value.len(),
406            }
407        }
408    }
409
410    impl From<Processed> for Final {
411        fn from(processed: Processed) -> Self {
412            Final {
413                message: format!("Processed length: {}", processed.length),
414            }
415        }
416    }
417
418    #[test]
419    fn test_insert_and_get() {
420        let mut ctx = Context::new(Input {
421            value: "Hello".to_string(),
422        });
423
424        ctx.insert("key1", 42);
425        ctx.insert("key2", "some data".to_string());
426
427        assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
428        assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
429        assert!(ctx.get::<f64>("key1").is_err()); // Testing a downcast failure
430    }
431
432    #[test]
433    fn test_transfer() {
434        let ctx = Context::new(Input {
435            value: "Hello".to_string(),
436        });
437
438        let (input, ctx) = ctx.transfer(Processed { length: 5 });
439
440        assert_eq!(input.value, "Hello");
441        assert_eq!(ctx.length, 5);
442    }
443
444    #[test]
445    fn test_map() {
446        let ctx = Context::new(Input {
447            value: "Hello".to_string(),
448        });
449
450        let ctx: Context<Processed> = ctx.map(|input| input.into());
451        let ctx: Context<Final> = ctx.map(|processed| processed.into());
452
453        assert_eq!(ctx.current.message, "Processed length: 5");
454    }
455
456    #[test]
457    fn test_into_context() {
458        let ctx = Context::new(Input {
459            value: "Hello".to_string(),
460        });
461
462        let ctx: Context<Processed> = ctx.into_context();
463        let ctx: Context<Final> = ctx.into_context();
464
465        assert_eq!(ctx.current.message, "Processed length: 5");
466    }
467}