use crate::{Actor, ActorBehavior, Message, Port};
use anyhow::{Error, Result};
use reflow_actor::{message::EncodableValue, ActorContext};
use reflow_actor_macro::actor;
use serde_json::{json, Value};
use std::collections::HashMap;
const MAX_TRANSIENT_DEPTH: usize = 10;
#[actor(
FsmActor,
inports::<100>(event, tick, control, data),
outports::<50>(state, transition, context, done, emit),
state(MemoryState)
)]
pub async fn fsm_actor(ctx: ActorContext) -> Result<HashMap<String, Message>, Error> {
let payload = ctx.get_payload();
let config = ctx.get_config_hashmap();
let states = config
.get("states")
.cloned()
.unwrap_or(Value::Object(Default::default()));
let guards_def = config
.get("guards")
.cloned()
.unwrap_or(Value::Object(Default::default()));
let initial = config
.get("initial")
.and_then(|v| v.as_str())
.unwrap_or("initial")
.to_string();
let stored_ctx: Vec<(String, Value)> = ctx.get_pool("_fsm_ctx").into_iter().collect();
if stored_ctx.is_empty() {
if let Some(Value::Object(init_ctx)) = config.get("context") {
for (k, v) in init_ctx {
ctx.pool_upsert("_fsm_ctx", k, v.clone());
}
}
}
let fsm_pool: Vec<(String, Value)> = ctx.get_pool("_fsm").into_iter().collect();
let first_run = !fsm_pool.iter().any(|(k, _)| k == "current");
let current_state = fsm_pool
.into_iter()
.find(|(k, _)| k == "current")
.and_then(|(_, v)| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| {
ctx.pool_upsert("_fsm", "current", json!(initial));
initial.clone()
});
let _ = first_run;
if let Some(msg) = payload.get("control") {
let cmd = match msg {
Message::String(s) => s.to_string(),
_ => String::new(),
};
if cmd == "reset" {
ctx.pool_upsert("_fsm", "current", json!(initial));
ctx.pool_upsert("_fsm", "timeout_elapsed", json!(0.0));
if let Some(Value::Object(init_ctx)) = config.get("context") {
for (k, v) in init_ctx {
ctx.pool_upsert("_fsm_ctx", k, v.clone());
}
}
let mut out = HashMap::new();
out.insert("state".to_string(), Message::String(initial.clone().into()));
return Ok(out);
} else if let Some(target) = cmd.strip_prefix("set:") {
ctx.pool_upsert("_fsm", "current", json!(target));
ctx.pool_upsert("_fsm", "timeout_elapsed", json!(0.0));
let mut out = HashMap::new();
out.insert(
"state".to_string(),
Message::String(target.to_string().into()),
);
return Ok(out);
}
}
if let Some(Message::Object(obj)) = payload.get("data") {
let v: Value = obj.as_ref().clone().into();
if let Some(map) = v.as_object() {
for (k, val) in map {
ctx.pool_upsert("_fsm_ctx", k, val.clone());
}
}
return Ok(HashMap::new());
}
let (event_name, event_payload) = resolve_event(&payload);
let mut is_timeout = false;
if payload.contains_key("tick") && event_name.is_none() {
static FSM_TICK_COUNT: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
let tc = FSM_TICK_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if tc % 100 == 0 {
eprintln!(
"[fsm:{}] tick={tc} state={current_state}",
ctx.get_config().get_node_id()
);
}
let state_def = states.get(¤t_state);
if let Some(timeout_trans) = state_def
.and_then(|s| s.get("on"))
.and_then(|on| on.get("_timeout"))
{
let delay = timeout_trans
.get("delay")
.and_then(|v| v.as_f64())
.unwrap_or(1.0);
let dt = config
.get("dt")
.and_then(|v| v.as_f64())
.unwrap_or(1.0 / 30.0);
let elapsed: f64 = ctx
.get_pool("_fsm")
.into_iter()
.find(|(k, _)| k == "timeout_elapsed")
.and_then(|(_, v)| v.as_f64())
.unwrap_or(0.0);
let new_elapsed = elapsed + dt;
if new_elapsed >= delay {
is_timeout = true;
ctx.pool_upsert("_fsm", "timeout_elapsed", json!(0.0));
} else {
ctx.pool_upsert("_fsm", "timeout_elapsed", json!(new_elapsed));
return Ok(HashMap::new());
}
} else {
return Ok(HashMap::new());
}
}
let event = if is_timeout {
Some("_timeout".to_string())
} else {
event_name
};
if event.is_none() {
return Ok(HashMap::new());
}
let event = event.unwrap();
let eval_scope = build_eval_scope(&event_payload, &ctx);
let state_def = match states.get(¤t_state) {
Some(s) => s,
None => return Ok(HashMap::new()),
};
let transition = find_transition(state_def, &event, &guards_def, &eval_scope);
if let Some(target) = transition {
execute_transition(
&ctx,
&states,
&guards_def,
¤t_state,
&target,
&event,
&eval_scope,
)
} else {
Ok(HashMap::new())
}
}
fn resolve_event(payload: &HashMap<String, Message>) -> (Option<String>, Value) {
if let Some(msg) = payload.get("event") {
match msg {
Message::String(s) => (Some(s.to_string()), Value::Null),
Message::Object(obj) => {
let v: Value = obj.as_ref().clone().into();
let name = v
.get("type")
.and_then(|t| t.as_str())
.map(|s| s.to_string());
let payload = v.get("payload").cloned().unwrap_or(Value::Null);
(name, payload)
}
_ => (None, Value::Null),
}
} else {
(None, Value::Null)
}
}
fn build_eval_scope(event_payload: &Value, ctx: &ActorContext) -> Value {
let mut scope = serde_json::Map::new();
for (k, v) in ctx.get_pool("_fsm_ctx") {
scope.insert(k, v);
}
if let Some(obj) = event_payload.as_object() {
for (k, v) in obj {
scope.insert(k.clone(), v.clone());
}
}
Value::Object(scope)
}
fn evaluate_guard(guard_name: &str, guards_def: &Value, scope: &Value) -> bool {
let guard = match guards_def.get(guard_name) {
Some(g) => g,
None => return true, };
let field = guard.get("field").and_then(|v| v.as_str());
let operator = guard
.get("operator")
.and_then(|v| v.as_str())
.unwrap_or("is");
let rule_value = guard.get("value");
let field_value = if let Some(field_name) = field {
scope.get(field_name).cloned()
} else {
Some(scope.clone())
};
let field_value = match field_value {
Some(v) => v,
None => return operator == "empty", };
match operator {
"is" => rule_value == Some(&field_value),
"is_not" => rule_value != Some(&field_value),
"contains" => match (&field_value, rule_value) {
(Value::String(s), Some(Value::String(needle))) => s.contains(needle.as_str()),
(Value::Array(arr), Some(val)) => arr.contains(val),
_ => false,
},
"not_contains" => match (&field_value, rule_value) {
(Value::String(s), Some(Value::String(needle))) => !s.contains(needle.as_str()),
(Value::Array(arr), Some(val)) => !arr.contains(val),
_ => true,
},
"greater_than" | "gt" => match (&field_value, rule_value) {
(Value::Number(a), Some(Value::Number(b))) => {
a.as_f64().unwrap_or(0.0) > b.as_f64().unwrap_or(0.0)
}
_ => false,
},
"less_than" | "lt" => match (&field_value, rule_value) {
(Value::Number(a), Some(Value::Number(b))) => {
a.as_f64().unwrap_or(0.0) < b.as_f64().unwrap_or(0.0)
}
_ => false,
},
"greater_equal" | "gte" => match (&field_value, rule_value) {
(Value::Number(a), Some(Value::Number(b))) => {
a.as_f64().unwrap_or(0.0) >= b.as_f64().unwrap_or(0.0)
}
_ => false,
},
"less_equal" | "lte" => match (&field_value, rule_value) {
(Value::Number(a), Some(Value::Number(b))) => {
a.as_f64().unwrap_or(0.0) <= b.as_f64().unwrap_or(0.0)
}
_ => false,
},
"empty" => match &field_value {
Value::Null => true,
Value::String(s) => s.is_empty(),
Value::Array(arr) => arr.is_empty(),
Value::Object(obj) => obj.is_empty(),
_ => false,
},
"not_empty" => match &field_value {
Value::Null => false,
Value::String(s) => !s.is_empty(),
Value::Array(arr) => !arr.is_empty(),
Value::Object(obj) => !obj.is_empty(),
_ => true,
},
"in" => match rule_value {
Some(Value::Array(arr)) => arr.contains(&field_value),
_ => false,
},
"not_in" => match rule_value {
Some(Value::Array(arr)) => !arr.contains(&field_value),
_ => true,
},
_ => false,
}
}
fn find_transition(
state_def: &Value,
event: &str,
guards_def: &Value,
scope: &Value,
) -> Option<String> {
let on = state_def.get("on")?;
let trans = on.get(event)?;
if let Some(target) = trans.as_str() {
return Some(target.to_string());
}
if let Some(arr) = trans.as_array() {
for t in arr {
let guard = t.get("guard").and_then(|g| g.as_str());
if let Some(guard_name) = guard {
if evaluate_guard(guard_name, guards_def, scope) {
return t
.get("target")
.and_then(|t| t.as_str())
.map(|s| s.to_string());
}
} else {
return t
.get("target")
.and_then(|t| t.as_str())
.map(|s| s.to_string());
}
}
return None;
}
if let Some(guard_name) = trans.get("guard").and_then(|g| g.as_str()) {
if !evaluate_guard(guard_name, guards_def, scope) {
return None;
}
}
trans
.get("target")
.and_then(|t| t.as_str())
.map(|s| s.to_string())
}
fn find_transient_transition(
state_def: &Value,
guards_def: &Value,
scope: &Value,
) -> Option<String> {
let always = state_def.get("always")?;
if let Some(arr) = always.as_array() {
for t in arr {
let guard = t.get("guard").and_then(|g| g.as_str());
if let Some(guard_name) = guard {
if evaluate_guard(guard_name, guards_def, scope) {
return t
.get("target")
.and_then(|t| t.as_str())
.map(|s| s.to_string());
}
} else {
return t
.get("target")
.and_then(|t| t.as_str())
.map(|s| s.to_string());
}
}
}
if always.is_object() {
let guard = always.get("guard").and_then(|g| g.as_str());
if let Some(guard_name) = guard {
if !evaluate_guard(guard_name, guards_def, scope) {
return None;
}
}
return always
.get("target")
.and_then(|t| t.as_str())
.map(|s| s.to_string());
}
None
}
fn execute_transition(
ctx: &ActorContext,
states: &Value,
guards_def: &Value,
from: &str,
to: &str,
event: &str,
scope: &Value,
) -> Result<HashMap<String, Message>, Error> {
let mut out = HashMap::new();
let mut emit_values = serde_json::Map::new();
let mut current = from.to_string();
let mut target = to.to_string();
let mut depth = 0;
loop {
if let Some(exit_action) = states.get(¤t).and_then(|s| s.get("exit")) {
apply_action(exit_action, ctx, &mut emit_values);
}
ctx.pool_upsert("_fsm", "current", json!(target));
ctx.pool_upsert("_fsm", "timeout_elapsed", json!(0.0));
if let Some(entry_action) = states.get(&target).and_then(|s| s.get("entry")) {
apply_action(entry_action, ctx, &mut emit_values);
}
let is_final = states
.get(&target)
.and_then(|s| s.get("type"))
.and_then(|t| t.as_str())
== Some("final");
if is_final {
out.insert("done".to_string(), Message::Flow);
}
let scope = build_eval_scope(&Value::Null, ctx);
let transient = states
.get(&target)
.and_then(|s| find_transient_transition(s, guards_def, &scope));
if let Some(next) = transient {
depth += 1;
if depth >= MAX_TRANSIENT_DEPTH {
eprintln!(
"[FsmActor] max transient depth ({}) reached at state '{}'",
MAX_TRANSIENT_DEPTH, target
);
break;
}
current = target;
target = next;
continue;
}
break;
}
let final_state = ctx
.get_pool("_fsm")
.into_iter()
.find(|(k, _)| k == "current")
.and_then(|(_, v)| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| target.clone());
out.insert(
"state".to_string(),
Message::String(final_state.clone().into()),
);
out.insert(
"transition".to_string(),
Message::object(EncodableValue::from(json!({
"from": from,
"to": final_state,
"event": event,
}))),
);
if !emit_values.is_empty() {
out.insert(
"emit".to_string(),
Message::object(EncodableValue::from(json!({
"id": final_state,
"data": Value::Object(emit_values.clone()),
}))),
);
out.insert(
"data".to_string(),
Message::object(EncodableValue::from(Value::Object(emit_values))),
);
}
let should_emit_ctx = states
.get(&final_state)
.and_then(|s| s.get("entry"))
.and_then(|a| a.get("emit_context"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
if should_emit_ctx {
let ctx_map: serde_json::Map<String, Value> =
ctx.get_pool("_fsm_ctx").into_iter().collect();
out.insert(
"context".to_string(),
Message::object(EncodableValue::from(Value::Object(ctx_map))),
);
}
Ok(out)
}
fn apply_action(
action: &Value,
ctx: &ActorContext,
emit_values: &mut serde_json::Map<String, Value>,
) {
if let Some(obj) = action.get("emit").and_then(|v| v.as_object()) {
for (k, v) in obj {
emit_values.insert(k.clone(), v.clone());
}
}
if let Some(obj) = action.get("assign").and_then(|v| v.as_object()) {
for (field, mutation) in obj {
let current: f64 = ctx
.get_pool("_fsm_ctx")
.into_iter()
.find(|(k, _)| k == field)
.and_then(|(_, v)| v.as_f64())
.unwrap_or(0.0);
if let Some(op_obj) = mutation.as_object() {
let op = op_obj.get("op").and_then(|v| v.as_str()).unwrap_or("set");
let val = op_obj.get("value").and_then(|v| v.as_f64()).unwrap_or(1.0);
let new_val = match op {
"set" => val,
"increment" => current + val,
"decrement" => current - val,
"add" => current + val,
"multiply" => current * val,
"toggle" => {
if current == 0.0 {
1.0
} else {
0.0
}
}
_ => current,
};
ctx.pool_upsert("_fsm_ctx", field, json!(new_val));
} else {
ctx.pool_upsert("_fsm_ctx", field, mutation.clone());
}
}
}
}