Skip to main content

blazen_core/
context.rs

1//! Shared workflow state accessible by all steps.
2//!
3//! [`Context`] wraps an `Arc<RwLock<ContextInner>>` so it can be cheaply
4//! cloned and shared across concurrent step executions. It provides:
5//!
6//! - Typed key/value state storage (backed by JSON for serializability)
7//! - Event emission to the internal routing queue
8//! - Fan-in event collection
9//! - Publishing events to the external streaming channel
10//! - Workflow metadata (e.g. run ID)
11//! - State snapshotting and restoration for pause/resume/checkpoint
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use blazen_events::{AnyEvent, Event, EventEnvelope};
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19use tokio::sync::{RwLock, broadcast, mpsc};
20use uuid::Uuid;
21
22use crate::value::{BytesWrapper, StateValue};
23
24/// Type alias for the state map (supports both JSON and binary values).
25type StateMap = HashMap<String, StateValue>;
26
27/// Internal state behind the `Arc<RwLock<_>>`.
28struct ContextInner {
29    /// JSON-serialized key/value store shared across all steps.
30    state: StateMap,
31    /// Sender side of the internal event routing channel.
32    event_tx: mpsc::UnboundedSender<EventEnvelope>,
33    /// Sender side of the external broadcast channel for streaming.
34    stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
35    /// Fan-in accumulator keyed by event type string.
36    collected: HashMap<String, Vec<serde_json::Value>>,
37    /// Arbitrary JSON metadata (e.g. `run_id`, workflow name).
38    metadata: HashMap<String, serde_json::Value>,
39}
40
41/// Shared workflow context.
42///
43/// Cheaply clonable handle to the shared state. Every step receives a
44/// `Context` and can read/write state, emit events, and publish to the
45/// external stream.
46///
47/// State values are stored as JSON internally, enabling serialization for
48/// pause/resume/checkpoint functionality. Users can still use ergonomic
49/// typed accessors (`set`/`get`) as long as their types implement
50/// `Serialize`/`DeserializeOwned`.
51#[derive(Clone)]
52pub struct Context {
53    inner: Arc<RwLock<ContextInner>>,
54}
55
56impl Context {
57    // -----------------------------------------------------------------
58    // Construction (crate-internal)
59    // -----------------------------------------------------------------
60
61    /// Create a new context wired to the given channels.
62    pub(crate) fn new(
63        event_tx: mpsc::UnboundedSender<EventEnvelope>,
64        stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
65    ) -> Self {
66        Self {
67            inner: Arc::new(RwLock::new(ContextInner {
68                state: HashMap::new(),
69                event_tx,
70                stream_tx,
71                collected: HashMap::new(),
72                metadata: HashMap::new(),
73            })),
74        }
75    }
76
77    // -----------------------------------------------------------------
78    // Public state accessors
79    // -----------------------------------------------------------------
80
81    /// Store a typed value under `key`.
82    ///
83    /// The value is serialized to JSON before storage. Overwrites any
84    /// previous value stored under the same key regardless of its type.
85    ///
86    /// # Panics
87    ///
88    /// Panics if the value cannot be serialized to JSON. In practice this
89    /// should never happen for well-formed serde types.
90    pub async fn set<T: Serialize + Send + Sync + 'static>(&self, key: &str, value: T) {
91        let json_value =
92            serde_json::to_value(&value).expect("Context::set: value must be JSON-serializable");
93        let mut inner = self.inner.write().await;
94        inner
95            .state
96            .insert(key.to_owned(), StateValue::Json(json_value));
97    }
98
99    /// Retrieve a typed value previously stored under `key`.
100    ///
101    /// The stored JSON is deserialized back into type `T`. Returns `None`
102    /// if the key does not exist or the stored JSON cannot be deserialized
103    /// into `T`.
104    pub async fn get<T: DeserializeOwned + Send + Sync + Clone + 'static>(
105        &self,
106        key: &str,
107    ) -> Option<T> {
108        let inner = self.inner.read().await;
109        inner.state.get(key).and_then(|sv| match sv {
110            StateValue::Json(v) => serde_json::from_value::<T>(v.clone()).ok(),
111            StateValue::Bytes(_) | StateValue::Native(_) => None,
112        })
113    }
114
115    /// Store a raw [`StateValue`] directly.
116    ///
117    /// Used by language bindings for polymorphic dispatch (e.g. storing
118    /// platform-serialized opaque objects via [`StateValue::Native`]).
119    pub async fn set_value(&self, key: &str, value: StateValue) {
120        let mut inner = self.inner.write().await;
121        inner.state.insert(key.to_owned(), value);
122    }
123
124    /// Retrieve the raw [`StateValue`] stored under `key`.
125    ///
126    /// Returns `None` if the key does not exist. Unlike [`get`](Self::get),
127    /// this returns the value regardless of its variant.
128    pub async fn get_value(&self, key: &str) -> Option<StateValue> {
129        let inner = self.inner.read().await;
130        inner.state.get(key).cloned()
131    }
132
133    /// Store raw binary data under `key`.
134    ///
135    /// Useful for files, images, audio, and other binary artifacts that
136    /// should not be JSON-serialized.
137    pub async fn set_bytes(&self, key: &str, data: Vec<u8>) {
138        let mut inner = self.inner.write().await;
139        inner
140            .state
141            .insert(key.to_owned(), StateValue::Bytes(BytesWrapper(data)));
142    }
143
144    /// Retrieve raw binary data previously stored under `key`.
145    ///
146    /// Returns `None` if the key does not exist or the stored value is
147    /// a JSON variant rather than bytes.
148    pub async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
149        let inner = self.inner.read().await;
150        inner.state.get(key).and_then(|sv| match sv {
151            StateValue::Bytes(b) => Some(b.0.clone()),
152            StateValue::Json(_) | StateValue::Native(_) => None,
153        })
154    }
155
156    // -----------------------------------------------------------------
157    // Event emission
158    // -----------------------------------------------------------------
159
160    /// Emit an event into the internal routing queue.
161    ///
162    /// The event will be picked up by the event loop and routed to any
163    /// step whose `accepts` list includes its event type.
164    pub async fn send_event<E: Event + Serialize>(&self, event: E) {
165        let inner = self.inner.read().await;
166        let envelope = EventEnvelope::new(Box::new(event), None);
167        // Ignore send errors -- the receiver may have been dropped if the
168        // workflow already terminated.
169        let _ = inner.event_tx.send(envelope);
170    }
171
172    /// Publish an event to the external broadcast stream.
173    ///
174    /// Consumers that called [`crate::WorkflowHandler::stream_events`] will
175    /// receive this event. Unlike [`send_event`](Self::send_event), this does
176    /// **not** route the event through the internal step registry.
177    pub async fn write_event_to_stream<E: Event + Serialize>(&self, event: E) {
178        let inner = self.inner.read().await;
179        // Ignore send errors -- there may be no active subscribers.
180        let _ = inner.stream_tx.send(Box::new(event));
181    }
182
183    // -----------------------------------------------------------------
184    // Fan-in collection
185    // -----------------------------------------------------------------
186
187    /// Accumulate events of type `E` until `expected_count` are available.
188    ///
189    /// Returns `Some(Vec<E>)` when exactly `expected_count` events have been
190    /// collected, or `None` if not enough have arrived yet.
191    ///
192    /// Once the threshold is reached the internal buffer for this type is
193    /// cleared automatically so a subsequent call starts fresh.
194    pub async fn collect_events<E: Event + DeserializeOwned>(
195        &self,
196        expected_count: usize,
197    ) -> Option<Vec<E>> {
198        let mut inner = self.inner.write().await;
199        let type_key = E::event_type().to_owned();
200
201        let collected = inner.collected.entry(type_key).or_default();
202        if collected.len() >= expected_count {
203            let drained: Vec<serde_json::Value> = collected.drain(..expected_count).collect();
204            let mut results = Vec::with_capacity(drained.len());
205            for json_val in drained {
206                if let Ok(concrete) = serde_json::from_value::<E>(json_val) {
207                    results.push(concrete);
208                }
209            }
210            Some(results)
211        } else {
212            None
213        }
214    }
215
216    /// Push a type-erased event into the fan-in accumulator.
217    ///
218    /// The event is serialized to JSON and stored under its event type
219    /// string (obtained via `AnyEvent::event_type_id`).
220    pub(crate) async fn push_collected(&self, event: &dyn AnyEvent) {
221        let mut inner = self.inner.write().await;
222        let type_key = event.event_type_id().to_owned();
223        let json_val = event.to_json();
224        inner.collected.entry(type_key).or_default().push(json_val);
225    }
226
227    /// Clear the collection buffer for a specific event type.
228    #[allow(dead_code)]
229    pub(crate) async fn clear_collected<E: Event>(&self) {
230        let mut inner = self.inner.write().await;
231        let type_key = E::event_type().to_owned();
232        inner.collected.remove(&type_key);
233    }
234
235    // -----------------------------------------------------------------
236    // Snapshotting & restoration
237    // -----------------------------------------------------------------
238
239    /// Returns a clone of the entire state map.
240    ///
241    /// Useful for checkpointing or pausing a workflow so it can be
242    /// resumed later.
243    pub async fn snapshot_state(&self) -> HashMap<String, StateValue> {
244        let inner = self.inner.read().await;
245        inner.state.clone()
246    }
247
248    /// Replace the state map wholesale.
249    ///
250    /// Used to restore state from a previous checkpoint. Any existing
251    /// state is discarded.
252    pub async fn restore_state(&self, state: HashMap<String, StateValue>) {
253        let mut inner = self.inner.write().await;
254        inner.state = state;
255    }
256
257    /// Returns a clone of the collected events map (serialized as JSON).
258    ///
259    /// Useful for checkpointing fan-in state alongside the main state map.
260    pub async fn snapshot_collected(&self) -> HashMap<String, Vec<serde_json::Value>> {
261        let inner = self.inner.read().await;
262        inner.collected.clone()
263    }
264
265    /// Replace the collected events map wholesale.
266    ///
267    /// Used to restore fan-in state from a previous checkpoint. Any existing
268    /// collected events are discarded.
269    pub async fn restore_collected(&self, collected: HashMap<String, Vec<serde_json::Value>>) {
270        let mut inner = self.inner.write().await;
271        inner.collected = collected;
272    }
273
274    /// Returns a clone of the metadata map.
275    ///
276    /// Useful for checkpointing metadata alongside the main state map.
277    pub async fn snapshot_metadata(&self) -> HashMap<String, serde_json::Value> {
278        let inner = self.inner.read().await;
279        inner.metadata.clone()
280    }
281
282    /// Replace the metadata map wholesale.
283    ///
284    /// Used to restore metadata from a previous checkpoint. Any existing
285    /// metadata is discarded.
286    pub(crate) async fn restore_metadata(&self, metadata: HashMap<String, serde_json::Value>) {
287        let mut inner = self.inner.write().await;
288        inner.metadata = metadata;
289    }
290
291    // -----------------------------------------------------------------
292    // Metadata
293    // -----------------------------------------------------------------
294
295    /// Get the workflow run ID from metadata.
296    ///
297    /// # Panics
298    ///
299    /// Panics if the `run_id` metadata key was never set (this is always
300    /// set by the workflow engine before any step executes).
301    pub async fn run_id(&self) -> Uuid {
302        let inner = self.inner.read().await;
303        inner
304            .metadata
305            .get("run_id")
306            .and_then(|v| v.as_str())
307            .and_then(|s| Uuid::parse_str(s).ok())
308            .expect("run_id must be set in workflow metadata")
309    }
310
311    /// Store a metadata key/value pair.
312    pub(crate) async fn set_metadata(&self, key: &str, value: serde_json::Value) {
313        let mut inner = self.inner.write().await;
314        inner.metadata.insert(key.to_owned(), value);
315    }
316
317    /// Send a sentinel event through the broadcast stream to signal that
318    /// no more events will be published.
319    ///
320    /// Consumers that check for `"blazen::StreamEnd"` can use this to
321    /// terminate their iteration.
322    pub(crate) async fn signal_stream_end(&self) {
323        self.write_event_to_stream(blazen_events::DynamicEvent {
324            event_type: "blazen::StreamEnd".to_owned(),
325            data: serde_json::Value::Null,
326        })
327        .await;
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    /// Helper to build a context with throw-away channels.
336    fn test_context() -> Context {
337        let (event_tx, _event_rx) = mpsc::unbounded_channel();
338        let (stream_tx, _stream_rx) = broadcast::channel(16);
339        Context::new(event_tx, stream_tx)
340    }
341
342    #[tokio::test]
343    async fn set_and_get_typed_value() {
344        let ctx = test_context();
345        ctx.set("counter", 42_u64).await;
346        assert_eq!(ctx.get::<u64>("counter").await, Some(42));
347    }
348
349    #[tokio::test]
350    async fn get_wrong_type_returns_none() {
351        let ctx = test_context();
352        ctx.set("counter", 42_u64).await;
353        // JSON number 42 can be deserialized as a String? No -- serde_json
354        // will fail to parse a number as a String, so this returns None.
355        assert_eq!(ctx.get::<String>("counter").await, None);
356    }
357
358    #[tokio::test]
359    async fn get_missing_key_returns_none() {
360        let ctx = test_context();
361        assert_eq!(ctx.get::<u64>("nope").await, None);
362    }
363
364    #[tokio::test]
365    async fn run_id_roundtrip() {
366        let ctx = test_context();
367        let id = Uuid::new_v4();
368        ctx.set_metadata("run_id", serde_json::Value::String(id.to_string()))
369            .await;
370        assert_eq!(ctx.run_id().await, id);
371    }
372
373    #[tokio::test]
374    async fn collect_events_accumulation() {
375        use blazen_events::StartEvent;
376
377        let ctx = test_context();
378        let e1 = StartEvent {
379            data: serde_json::json!(1),
380        };
381        let e2 = StartEvent {
382            data: serde_json::json!(2),
383        };
384
385        ctx.push_collected(&e1).await;
386        // Not enough yet.
387        assert!(ctx.collect_events::<StartEvent>(2).await.is_none());
388
389        ctx.push_collected(&e2).await;
390        // Now we have 2.
391        let events = ctx.collect_events::<StartEvent>(2).await.unwrap();
392        assert_eq!(events.len(), 2);
393        assert_eq!(events[0].data, serde_json::json!(1));
394        assert_eq!(events[1].data, serde_json::json!(2));
395    }
396
397    #[tokio::test]
398    async fn snapshot_and_restore_state() {
399        let ctx = test_context();
400        ctx.set("name", "alice".to_string()).await;
401        ctx.set("count", 10_u32).await;
402
403        // Snapshot
404        let snap = ctx.snapshot_state().await;
405        assert_eq!(snap.len(), 2);
406        assert_eq!(
407            snap.get("name").unwrap(),
408            &StateValue::Json(serde_json::json!("alice"))
409        );
410        assert_eq!(
411            snap.get("count").unwrap(),
412            &StateValue::Json(serde_json::json!(10))
413        );
414
415        // Modify state
416        ctx.set("name", "bob".to_string()).await;
417        assert_eq!(ctx.get::<String>("name").await, Some("bob".to_string()));
418
419        // Restore
420        ctx.restore_state(snap).await;
421        assert_eq!(ctx.get::<String>("name").await, Some("alice".to_string()));
422        assert_eq!(ctx.get::<u32>("count").await, Some(10));
423    }
424
425    #[tokio::test]
426    async fn set_and_get_bytes() {
427        let ctx = test_context();
428        let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
429        ctx.set_bytes("binary", data.clone()).await;
430
431        assert_eq!(ctx.get_bytes("binary").await, Some(data));
432        // get<T> should return None for bytes values.
433        assert_eq!(ctx.get::<String>("binary").await, None);
434    }
435
436    #[tokio::test]
437    async fn get_bytes_returns_none_for_json() {
438        let ctx = test_context();
439        ctx.set("key", "value".to_string()).await;
440        assert_eq!(ctx.get_bytes("key").await, None);
441    }
442
443    #[tokio::test]
444    async fn get_bytes_returns_none_for_missing_key() {
445        let ctx = test_context();
446        assert_eq!(ctx.get_bytes("nope").await, None);
447    }
448
449    #[tokio::test]
450    async fn snapshot_collected() {
451        use blazen_events::StartEvent;
452
453        let ctx = test_context();
454        let e1 = StartEvent {
455            data: serde_json::json!("a"),
456        };
457        ctx.push_collected(&e1).await;
458
459        let snap = ctx.snapshot_collected().await;
460        assert_eq!(snap.len(), 1);
461        let start_events = snap.get("blazen::StartEvent").unwrap();
462        assert_eq!(start_events.len(), 1);
463    }
464
465    #[tokio::test]
466    async fn set_value_and_get_value() {
467        let ctx = test_context();
468        let native = StateValue::native(vec![0x80, 0x04, 0x95]);
469        ctx.set_value("pickled", native.clone()).await;
470
471        let retrieved = ctx.get_value("pickled").await;
472        assert_eq!(retrieved, Some(native));
473    }
474
475    #[tokio::test]
476    async fn get_value_returns_all_variants() {
477        let ctx = test_context();
478        ctx.set("json_key", "hello".to_string()).await;
479        ctx.set_bytes("bytes_key", vec![1, 2, 3]).await;
480        ctx.set_value("native_key", StateValue::native(vec![4, 5, 6]))
481            .await;
482
483        assert!(ctx.get_value("json_key").await.unwrap().is_json());
484        assert!(ctx.get_value("bytes_key").await.unwrap().is_bytes());
485        assert!(ctx.get_value("native_key").await.unwrap().is_native());
486        assert!(ctx.get_value("missing").await.is_none());
487    }
488
489    #[tokio::test]
490    async fn get_returns_none_for_native() {
491        let ctx = test_context();
492        ctx.set_value("key", StateValue::native(vec![0x80, 0x04]))
493            .await;
494        assert_eq!(ctx.get::<String>("key").await, None);
495    }
496
497    #[tokio::test]
498    async fn get_bytes_returns_none_for_native() {
499        let ctx = test_context();
500        ctx.set_value("key", StateValue::native(vec![0x80, 0x04]))
501            .await;
502        assert_eq!(ctx.get_bytes("key").await, None);
503    }
504}