use crate::integration::error::{
Error as IntegrationError, LockResultExt, Result as IntegrationResult,
};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::{Arc, RwLock};
use tracing::{trace, warn};
#[derive(Clone, Default, Debug)]
pub struct SharedContext {
data: Arc<RwLock<serde_json::Value>>,
}
impl SharedContext {
pub fn new() -> Self {
Self {
data: Arc::new(RwLock::new(serde_json::json!({}))),
}
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> IntegrationResult<Option<T>> {
trace!(key = key, "Attempting to get value from shared context");
let _data_guard = self.data.read().lock_err()?;
match &*_data_guard {
serde_json::Value::Object(map) => match map.get(key) {
Some(value) => {
match serde_json::from_value(value.clone()) {
Ok(deserialized) => Ok(Some(deserialized)),
Err(e) => {
warn!(key = key, error = %e, "Deserialization failed for shared context value");
Err(IntegrationError::from(e)) }
}
}
None => Ok(None), },
_ => Ok(None), }
}
pub fn set<T: Serialize>(&self, key: &str, value: T) -> IntegrationResult<()> {
trace!(key = key, "Attempting to set value in shared context");
let mut data_guard = self.data.write().lock_err()?;
let json_value = serde_json::to_value(value)?;
match &mut *data_guard {
serde_json::Value::Object(map) => {
map.insert(key.to_string(), json_value);
}
_ => {
warn!("Shared context was not an object, replacing with new object containing key: {}", key);
*data_guard = serde_json::json!({ key: json_value });
}
}
Ok(())
}
pub fn contains_key(&self, key: &str) -> IntegrationResult<bool> {
trace!(key = key, "Checking if key exists in shared context");
let _data_guard = self.data.read().lock_err()?;
match &*_data_guard {
serde_json::Value::Object(map) => Ok(map.contains_key(key)),
_ => Ok(false),
}
}
pub fn remove(&self, key: &str) -> IntegrationResult<Option<serde_json::Value>> {
trace!(key = key, "Attempting to remove key from shared context");
let mut _data_guard = self.data.write().lock_err()?;
match &mut *_data_guard {
serde_json::Value::Object(map) => Ok(map.remove(key)),
_ => Ok(None),
}
}
pub async fn dump(&self) -> IntegrationResult<serde_json::Value> {
let _data_guard = self.data.read().lock_err()?;
Ok(_data_guard.clone())
}
pub fn increment(&self, key: &str) -> IntegrationResult<()> {
trace!(key = key, "Attempting to increment value in shared context");
let current_value: Option<i32> = self.get(key)?;
let new_value = current_value.unwrap_or(0) + 1;
self.set(key, new_value)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::integration::error::{Error as IntegrationError, Result as IntegrationResult};
use crate::Context;
use crate::{Action, Event, Machine, MachineBuilder, State, Transition, TransitionType};
use futures::FutureExt;
use tokio::sync::RwLock;
async fn create_machines(
shared_context: SharedContext,
) -> (
Machine<(), Event, String, ()>,
Machine<Context, Event, String, ()>,
) {
let context_for_reader = shared_context.clone();
let idle_state_a = State::new("idle".to_string());
let done_state_a = State::new_final("done".to_string());
let event_a_transition = Transition::new(
"idle".to_string(),
Some("done".to_string()), Some(Event::from("EVENT_A")),
None, vec![], TransitionType::External, );
let machine_a = MachineBuilder::<(), Event, String, ()>::new(
"machine_a".to_string(),
"idle".to_string(),
)
.state(idle_state_a)
.state(done_state_a)
.transition(event_a_transition)
.build()
.await
.expect("Machine A async build failed");
let read_action = Action::from_fn(move |local_ctx: Arc<RwLock<Context>>, _evt: &Event| {
let ctx_reader_clone = context_for_reader.clone();
async move {
println!("Reader Action: Reading shared context");
let status = ctx_reader_clone.get::<String>("status")?;
let counter = ctx_reader_clone.get::<i32>("counter")?;
println!("Reader: Read status: {:?}, counter: {:?}", status, counter);
if let Some(s) = status {
local_ctx.write().await.set("local_status_copy", s)?;
}
Ok(())
}
.boxed()
});
let mut waiting_state_b = State::new("waiting".to_string());
let event_b_transition = Transition::new(
"waiting".to_string(),
None::<String>,
Some(Event::from("EVENT_B")),
None, vec![read_action.clone()], TransitionType::Internal,
);
waiting_state_b.add_transition("EVENT_B".to_string(), event_b_transition);
let finished_state_b = State::new_final("processed".to_string());
let machine_b = MachineBuilder::<Context, Event, String, ()>::new(
"waiting".to_string(),
"waiting".to_string(),
)
.state(waiting_state_b)
.state(finished_state_b)
.build()
.await
.expect("Machine B async build failed");
(machine_a, machine_b)
}
#[tokio::test]
async fn test_context_sharing_flow() -> IntegrationResult<()> {
let shared_context = SharedContext::new();
let (mut machine_a, mut machine_b) = create_machines(shared_context.clone()).await;
assert!(machine_a.is_in(&"idle".to_string()));
assert!(!machine_a.is_in(&"done".to_string()));
let result_a = machine_a.send(Event::from("EVENT_A")).await?;
assert!(result_a, "Machine A should handle EVENT_A");
assert!(!machine_a.is_in(&"idle".to_string()));
assert!(
machine_a.is_in(&"done".to_string()),
"Machine A should be in 'done' state"
);
assert_eq!(shared_context.get::<String>("status")?, None);
assert_eq!(shared_context.get::<i32>("counter")?, None);
machine_b.send(Event::from("EVENT_B")).await?;
Ok(())
}
#[tokio::test]
async fn test_complex_assertions() -> IntegrationResult<()> {
let shared_context = SharedContext::new();
let (mut machine_a, _machine_b) = create_machines(shared_context.clone()).await;
machine_a.send(Event::from("EVENT_A")).await?;
shared_context.set("local_status", "active")?;
shared_context.set("local_counter", 1i64)?;
let ctx_b = shared_context;
assert!(ctx_b
.get::<String>("local_status")?
.is_some_and(|s| s == "active"));
assert!(ctx_b.get::<i64>("local_counter")?.is_some_and(|c| c == 1));
Ok::<(), IntegrationError>(())
}
#[tokio::test]
async fn test_contains_remove() -> IntegrationResult<()> {
let shared_context = SharedContext::new();
shared_context.set("key1", "value1")?;
shared_context.set("key2", 123)?;
assert!(shared_context.contains_key("key1")?);
assert!(shared_context.contains_key("key2")?);
assert!(!shared_context.contains_key("key3")?);
let removed = shared_context.remove("key1")?;
assert_eq!(removed, Some(serde_json::json!("value1")));
assert!(!shared_context.contains_key("key1")?);
let removed_none = shared_context.remove("key3")?;
assert!(removed_none.is_none());
Ok(())
}
}