1use std::collections::HashMap;
14use std::sync::Arc;
15
16use blazen_events::{AnyEvent, Event, EventEnvelope};
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19use tokio::sync::{RwLock, broadcast, mpsc};
20use uuid::Uuid;
21
22use crate::value::{BytesWrapper, StateValue};
23
24type StateMap = HashMap<String, StateValue>;
26
27struct ContextInner {
29 state: StateMap,
31 event_tx: mpsc::UnboundedSender<EventEnvelope>,
33 stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
35 collected: HashMap<String, Vec<serde_json::Value>>,
37 metadata: HashMap<String, serde_json::Value>,
39}
40
41#[derive(Clone)]
52pub struct Context {
53 inner: Arc<RwLock<ContextInner>>,
54}
55
56impl Context {
57 pub(crate) fn new(
63 event_tx: mpsc::UnboundedSender<EventEnvelope>,
64 stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
65 ) -> Self {
66 Self {
67 inner: Arc::new(RwLock::new(ContextInner {
68 state: HashMap::new(),
69 event_tx,
70 stream_tx,
71 collected: HashMap::new(),
72 metadata: HashMap::new(),
73 })),
74 }
75 }
76
77 pub async fn set<T: Serialize + Send + Sync + 'static>(&self, key: &str, value: T) {
91 let json_value =
92 serde_json::to_value(&value).expect("Context::set: value must be JSON-serializable");
93 let mut inner = self.inner.write().await;
94 inner
95 .state
96 .insert(key.to_owned(), StateValue::Json(json_value));
97 }
98
99 pub async fn get<T: DeserializeOwned + Send + Sync + Clone + 'static>(
105 &self,
106 key: &str,
107 ) -> Option<T> {
108 let inner = self.inner.read().await;
109 inner.state.get(key).and_then(|sv| match sv {
110 StateValue::Json(v) => serde_json::from_value::<T>(v.clone()).ok(),
111 StateValue::Bytes(_) | StateValue::Native(_) => None,
112 })
113 }
114
115 pub async fn set_value(&self, key: &str, value: StateValue) {
120 let mut inner = self.inner.write().await;
121 inner.state.insert(key.to_owned(), value);
122 }
123
124 pub async fn get_value(&self, key: &str) -> Option<StateValue> {
129 let inner = self.inner.read().await;
130 inner.state.get(key).cloned()
131 }
132
133 pub async fn set_bytes(&self, key: &str, data: Vec<u8>) {
138 let mut inner = self.inner.write().await;
139 inner
140 .state
141 .insert(key.to_owned(), StateValue::Bytes(BytesWrapper(data)));
142 }
143
144 pub async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
149 let inner = self.inner.read().await;
150 inner.state.get(key).and_then(|sv| match sv {
151 StateValue::Bytes(b) => Some(b.0.clone()),
152 StateValue::Json(_) | StateValue::Native(_) => None,
153 })
154 }
155
156 pub async fn send_event<E: Event + Serialize>(&self, event: E) {
165 let inner = self.inner.read().await;
166 let envelope = EventEnvelope::new(Box::new(event), None);
167 let _ = inner.event_tx.send(envelope);
170 }
171
172 pub async fn write_event_to_stream<E: Event + Serialize>(&self, event: E) {
178 let inner = self.inner.read().await;
179 let _ = inner.stream_tx.send(Box::new(event));
181 }
182
183 pub async fn collect_events<E: Event + DeserializeOwned>(
195 &self,
196 expected_count: usize,
197 ) -> Option<Vec<E>> {
198 let mut inner = self.inner.write().await;
199 let type_key = E::event_type().to_owned();
200
201 let collected = inner.collected.entry(type_key).or_default();
202 if collected.len() >= expected_count {
203 let drained: Vec<serde_json::Value> = collected.drain(..expected_count).collect();
204 let mut results = Vec::with_capacity(drained.len());
205 for json_val in drained {
206 if let Ok(concrete) = serde_json::from_value::<E>(json_val) {
207 results.push(concrete);
208 }
209 }
210 Some(results)
211 } else {
212 None
213 }
214 }
215
216 pub(crate) async fn push_collected(&self, event: &dyn AnyEvent) {
221 let mut inner = self.inner.write().await;
222 let type_key = event.event_type_id().to_owned();
223 let json_val = event.to_json();
224 inner.collected.entry(type_key).or_default().push(json_val);
225 }
226
227 #[allow(dead_code)]
229 pub(crate) async fn clear_collected<E: Event>(&self) {
230 let mut inner = self.inner.write().await;
231 let type_key = E::event_type().to_owned();
232 inner.collected.remove(&type_key);
233 }
234
235 pub async fn snapshot_state(&self) -> HashMap<String, StateValue> {
244 let inner = self.inner.read().await;
245 inner.state.clone()
246 }
247
248 pub async fn restore_state(&self, state: HashMap<String, StateValue>) {
253 let mut inner = self.inner.write().await;
254 inner.state = state;
255 }
256
257 pub async fn snapshot_collected(&self) -> HashMap<String, Vec<serde_json::Value>> {
261 let inner = self.inner.read().await;
262 inner.collected.clone()
263 }
264
265 pub async fn restore_collected(&self, collected: HashMap<String, Vec<serde_json::Value>>) {
270 let mut inner = self.inner.write().await;
271 inner.collected = collected;
272 }
273
274 pub async fn snapshot_metadata(&self) -> HashMap<String, serde_json::Value> {
278 let inner = self.inner.read().await;
279 inner.metadata.clone()
280 }
281
282 pub(crate) async fn restore_metadata(&self, metadata: HashMap<String, serde_json::Value>) {
287 let mut inner = self.inner.write().await;
288 inner.metadata = metadata;
289 }
290
291 pub async fn run_id(&self) -> Uuid {
302 let inner = self.inner.read().await;
303 inner
304 .metadata
305 .get("run_id")
306 .and_then(|v| v.as_str())
307 .and_then(|s| Uuid::parse_str(s).ok())
308 .expect("run_id must be set in workflow metadata")
309 }
310
311 pub(crate) async fn set_metadata(&self, key: &str, value: serde_json::Value) {
313 let mut inner = self.inner.write().await;
314 inner.metadata.insert(key.to_owned(), value);
315 }
316
317 pub(crate) async fn signal_stream_end(&self) {
323 self.write_event_to_stream(blazen_events::DynamicEvent {
324 event_type: "blazen::StreamEnd".to_owned(),
325 data: serde_json::Value::Null,
326 })
327 .await;
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn test_context() -> Context {
337 let (event_tx, _event_rx) = mpsc::unbounded_channel();
338 let (stream_tx, _stream_rx) = broadcast::channel(16);
339 Context::new(event_tx, stream_tx)
340 }
341
342 #[tokio::test]
343 async fn set_and_get_typed_value() {
344 let ctx = test_context();
345 ctx.set("counter", 42_u64).await;
346 assert_eq!(ctx.get::<u64>("counter").await, Some(42));
347 }
348
349 #[tokio::test]
350 async fn get_wrong_type_returns_none() {
351 let ctx = test_context();
352 ctx.set("counter", 42_u64).await;
353 assert_eq!(ctx.get::<String>("counter").await, None);
356 }
357
358 #[tokio::test]
359 async fn get_missing_key_returns_none() {
360 let ctx = test_context();
361 assert_eq!(ctx.get::<u64>("nope").await, None);
362 }
363
364 #[tokio::test]
365 async fn run_id_roundtrip() {
366 let ctx = test_context();
367 let id = Uuid::new_v4();
368 ctx.set_metadata("run_id", serde_json::Value::String(id.to_string()))
369 .await;
370 assert_eq!(ctx.run_id().await, id);
371 }
372
373 #[tokio::test]
374 async fn collect_events_accumulation() {
375 use blazen_events::StartEvent;
376
377 let ctx = test_context();
378 let e1 = StartEvent {
379 data: serde_json::json!(1),
380 };
381 let e2 = StartEvent {
382 data: serde_json::json!(2),
383 };
384
385 ctx.push_collected(&e1).await;
386 assert!(ctx.collect_events::<StartEvent>(2).await.is_none());
388
389 ctx.push_collected(&e2).await;
390 let events = ctx.collect_events::<StartEvent>(2).await.unwrap();
392 assert_eq!(events.len(), 2);
393 assert_eq!(events[0].data, serde_json::json!(1));
394 assert_eq!(events[1].data, serde_json::json!(2));
395 }
396
397 #[tokio::test]
398 async fn snapshot_and_restore_state() {
399 let ctx = test_context();
400 ctx.set("name", "alice".to_string()).await;
401 ctx.set("count", 10_u32).await;
402
403 let snap = ctx.snapshot_state().await;
405 assert_eq!(snap.len(), 2);
406 assert_eq!(
407 snap.get("name").unwrap(),
408 &StateValue::Json(serde_json::json!("alice"))
409 );
410 assert_eq!(
411 snap.get("count").unwrap(),
412 &StateValue::Json(serde_json::json!(10))
413 );
414
415 ctx.set("name", "bob".to_string()).await;
417 assert_eq!(ctx.get::<String>("name").await, Some("bob".to_string()));
418
419 ctx.restore_state(snap).await;
421 assert_eq!(ctx.get::<String>("name").await, Some("alice".to_string()));
422 assert_eq!(ctx.get::<u32>("count").await, Some(10));
423 }
424
425 #[tokio::test]
426 async fn set_and_get_bytes() {
427 let ctx = test_context();
428 let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
429 ctx.set_bytes("binary", data.clone()).await;
430
431 assert_eq!(ctx.get_bytes("binary").await, Some(data));
432 assert_eq!(ctx.get::<String>("binary").await, None);
434 }
435
436 #[tokio::test]
437 async fn get_bytes_returns_none_for_json() {
438 let ctx = test_context();
439 ctx.set("key", "value".to_string()).await;
440 assert_eq!(ctx.get_bytes("key").await, None);
441 }
442
443 #[tokio::test]
444 async fn get_bytes_returns_none_for_missing_key() {
445 let ctx = test_context();
446 assert_eq!(ctx.get_bytes("nope").await, None);
447 }
448
449 #[tokio::test]
450 async fn snapshot_collected() {
451 use blazen_events::StartEvent;
452
453 let ctx = test_context();
454 let e1 = StartEvent {
455 data: serde_json::json!("a"),
456 };
457 ctx.push_collected(&e1).await;
458
459 let snap = ctx.snapshot_collected().await;
460 assert_eq!(snap.len(), 1);
461 let start_events = snap.get("blazen::StartEvent").unwrap();
462 assert_eq!(start_events.len(), 1);
463 }
464
465 #[tokio::test]
466 async fn set_value_and_get_value() {
467 let ctx = test_context();
468 let native = StateValue::native(vec![0x80, 0x04, 0x95]);
469 ctx.set_value("pickled", native.clone()).await;
470
471 let retrieved = ctx.get_value("pickled").await;
472 assert_eq!(retrieved, Some(native));
473 }
474
475 #[tokio::test]
476 async fn get_value_returns_all_variants() {
477 let ctx = test_context();
478 ctx.set("json_key", "hello".to_string()).await;
479 ctx.set_bytes("bytes_key", vec![1, 2, 3]).await;
480 ctx.set_value("native_key", StateValue::native(vec![4, 5, 6]))
481 .await;
482
483 assert!(ctx.get_value("json_key").await.unwrap().is_json());
484 assert!(ctx.get_value("bytes_key").await.unwrap().is_bytes());
485 assert!(ctx.get_value("native_key").await.unwrap().is_native());
486 assert!(ctx.get_value("missing").await.is_none());
487 }
488
489 #[tokio::test]
490 async fn get_returns_none_for_native() {
491 let ctx = test_context();
492 ctx.set_value("key", StateValue::native(vec![0x80, 0x04]))
493 .await;
494 assert_eq!(ctx.get::<String>("key").await, None);
495 }
496
497 #[tokio::test]
498 async fn get_bytes_returns_none_for_native() {
499 let ctx = test_context();
500 ctx.set_value("key", StateValue::native(vec![0x80, 0x04]))
501 .await;
502 assert_eq!(ctx.get_bytes("key").await, None);
503 }
504}