use std::collections::HashMap;
use std::sync::Arc;
use blazen_events::{AnyEvent, Event, EventEnvelope};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::{RwLock, broadcast, mpsc};
use uuid::Uuid;
use crate::value::{BytesWrapper, StateValue};
type StateMap = HashMap<String, StateValue>;
struct ContextInner {
state: StateMap,
event_tx: mpsc::UnboundedSender<EventEnvelope>,
stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
collected: HashMap<String, Vec<serde_json::Value>>,
metadata: HashMap<String, serde_json::Value>,
}
#[derive(Clone)]
pub struct Context {
inner: Arc<RwLock<ContextInner>>,
}
impl Context {
pub(crate) fn new(
event_tx: mpsc::UnboundedSender<EventEnvelope>,
stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
) -> Self {
Self {
inner: Arc::new(RwLock::new(ContextInner {
state: HashMap::new(),
event_tx,
stream_tx,
collected: HashMap::new(),
metadata: HashMap::new(),
})),
}
}
pub async fn set<T: Serialize + Send + Sync + 'static>(&self, key: &str, value: T) {
let json_value =
serde_json::to_value(&value).expect("Context::set: value must be JSON-serializable");
let mut inner = self.inner.write().await;
inner
.state
.insert(key.to_owned(), StateValue::Json(json_value));
}
pub async fn get<T: DeserializeOwned + Send + Sync + Clone + 'static>(
&self,
key: &str,
) -> Option<T> {
let inner = self.inner.read().await;
inner.state.get(key).and_then(|sv| match sv {
StateValue::Json(v) => serde_json::from_value::<T>(v.clone()).ok(),
StateValue::Bytes(_) => None,
})
}
pub async fn set_bytes(&self, key: &str, data: Vec<u8>) {
let mut inner = self.inner.write().await;
inner
.state
.insert(key.to_owned(), StateValue::Bytes(BytesWrapper(data)));
}
pub async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
let inner = self.inner.read().await;
inner.state.get(key).and_then(|sv| match sv {
StateValue::Bytes(b) => Some(b.0.clone()),
StateValue::Json(_) => None,
})
}
pub async fn send_event<E: Event + Serialize>(&self, event: E) {
let inner = self.inner.read().await;
let envelope = EventEnvelope::new(Box::new(event), None);
let _ = inner.event_tx.send(envelope);
}
pub async fn write_event_to_stream<E: Event + Serialize>(&self, event: E) {
let inner = self.inner.read().await;
let _ = inner.stream_tx.send(Box::new(event));
}
pub async fn collect_events<E: Event + DeserializeOwned>(
&self,
expected_count: usize,
) -> Option<Vec<E>> {
let mut inner = self.inner.write().await;
let type_key = E::event_type().to_owned();
let collected = inner.collected.entry(type_key).or_default();
if collected.len() >= expected_count {
let drained: Vec<serde_json::Value> = collected.drain(..expected_count).collect();
let mut results = Vec::with_capacity(drained.len());
for json_val in drained {
if let Ok(concrete) = serde_json::from_value::<E>(json_val) {
results.push(concrete);
}
}
Some(results)
} else {
None
}
}
pub(crate) async fn push_collected(&self, event: &dyn AnyEvent) {
let mut inner = self.inner.write().await;
let type_key = event.event_type_id().to_owned();
let json_val = event.to_json();
inner.collected.entry(type_key).or_default().push(json_val);
}
#[allow(dead_code)]
pub(crate) async fn clear_collected<E: Event>(&self) {
let mut inner = self.inner.write().await;
let type_key = E::event_type().to_owned();
inner.collected.remove(&type_key);
}
pub async fn snapshot_state(&self) -> HashMap<String, StateValue> {
let inner = self.inner.read().await;
inner.state.clone()
}
pub async fn restore_state(&self, state: HashMap<String, StateValue>) {
let mut inner = self.inner.write().await;
inner.state = state;
}
pub async fn snapshot_collected(&self) -> HashMap<String, Vec<serde_json::Value>> {
let inner = self.inner.read().await;
inner.collected.clone()
}
pub async fn restore_collected(&self, collected: HashMap<String, Vec<serde_json::Value>>) {
let mut inner = self.inner.write().await;
inner.collected = collected;
}
pub async fn snapshot_metadata(&self) -> HashMap<String, serde_json::Value> {
let inner = self.inner.read().await;
inner.metadata.clone()
}
pub(crate) async fn restore_metadata(&self, metadata: HashMap<String, serde_json::Value>) {
let mut inner = self.inner.write().await;
inner.metadata = metadata;
}
pub async fn run_id(&self) -> Uuid {
let inner = self.inner.read().await;
inner
.metadata
.get("run_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok())
.expect("run_id must be set in workflow metadata")
}
pub(crate) async fn set_metadata(&self, key: &str, value: serde_json::Value) {
let mut inner = self.inner.write().await;
inner.metadata.insert(key.to_owned(), value);
}
pub(crate) async fn signal_stream_end(&self) {
self.write_event_to_stream(blazen_events::DynamicEvent {
event_type: "blazen::StreamEnd".to_owned(),
data: serde_json::Value::Null,
})
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> Context {
let (event_tx, _event_rx) = mpsc::unbounded_channel();
let (stream_tx, _stream_rx) = broadcast::channel(16);
Context::new(event_tx, stream_tx)
}
#[tokio::test]
async fn set_and_get_typed_value() {
let ctx = test_context();
ctx.set("counter", 42_u64).await;
assert_eq!(ctx.get::<u64>("counter").await, Some(42));
}
#[tokio::test]
async fn get_wrong_type_returns_none() {
let ctx = test_context();
ctx.set("counter", 42_u64).await;
assert_eq!(ctx.get::<String>("counter").await, None);
}
#[tokio::test]
async fn get_missing_key_returns_none() {
let ctx = test_context();
assert_eq!(ctx.get::<u64>("nope").await, None);
}
#[tokio::test]
async fn run_id_roundtrip() {
let ctx = test_context();
let id = Uuid::new_v4();
ctx.set_metadata("run_id", serde_json::Value::String(id.to_string()))
.await;
assert_eq!(ctx.run_id().await, id);
}
#[tokio::test]
async fn collect_events_accumulation() {
use blazen_events::StartEvent;
let ctx = test_context();
let e1 = StartEvent {
data: serde_json::json!(1),
};
let e2 = StartEvent {
data: serde_json::json!(2),
};
ctx.push_collected(&e1).await;
assert!(ctx.collect_events::<StartEvent>(2).await.is_none());
ctx.push_collected(&e2).await;
let events = ctx.collect_events::<StartEvent>(2).await.unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].data, serde_json::json!(1));
assert_eq!(events[1].data, serde_json::json!(2));
}
#[tokio::test]
async fn snapshot_and_restore_state() {
let ctx = test_context();
ctx.set("name", "alice".to_string()).await;
ctx.set("count", 10_u32).await;
let snap = ctx.snapshot_state().await;
assert_eq!(snap.len(), 2);
assert_eq!(
snap.get("name").unwrap(),
&StateValue::Json(serde_json::json!("alice"))
);
assert_eq!(
snap.get("count").unwrap(),
&StateValue::Json(serde_json::json!(10))
);
ctx.set("name", "bob".to_string()).await;
assert_eq!(ctx.get::<String>("name").await, Some("bob".to_string()));
ctx.restore_state(snap).await;
assert_eq!(ctx.get::<String>("name").await, Some("alice".to_string()));
assert_eq!(ctx.get::<u32>("count").await, Some(10));
}
#[tokio::test]
async fn set_and_get_bytes() {
let ctx = test_context();
let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
ctx.set_bytes("binary", data.clone()).await;
assert_eq!(ctx.get_bytes("binary").await, Some(data));
assert_eq!(ctx.get::<String>("binary").await, None);
}
#[tokio::test]
async fn get_bytes_returns_none_for_json() {
let ctx = test_context();
ctx.set("key", "value".to_string()).await;
assert_eq!(ctx.get_bytes("key").await, None);
}
#[tokio::test]
async fn get_bytes_returns_none_for_missing_key() {
let ctx = test_context();
assert_eq!(ctx.get_bytes("nope").await, None);
}
#[tokio::test]
async fn snapshot_collected() {
use blazen_events::StartEvent;
let ctx = test_context();
let e1 = StartEvent {
data: serde_json::json!("a"),
};
ctx.push_collected(&e1).await;
let snap = ctx.snapshot_collected().await;
assert_eq!(snap.len(), 1);
let start_events = snap.get("blazen::StartEvent").unwrap();
assert_eq!(start_events.len(), 1);
}
}