use crate::intelligent_behavior::{sub_scenario::SubScenario, visual_layout::VisualLayout};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct ConsistencyRule {
pub name: String,
pub description: Option<String>,
pub condition: String,
pub action: RuleAction,
#[serde(default)]
pub priority: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum RuleAction {
Error {
status: u16,
message: String,
},
Transform {
description: String,
},
ExecuteChain {
chain_id: String,
},
RequireAuth {
message: String,
},
StateTransition {
resource_type: String,
transition: String,
},
}
impl ConsistencyRule {
pub fn new(name: impl Into<String>, condition: impl Into<String>, action: RuleAction) -> Self {
Self {
name: name.into(),
description: None,
condition: condition.into(),
action,
priority: 0,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn matches(&self, method: &str, path: &str) -> bool {
if self.condition.contains("path starts_with") {
if let Some(prefix) = self.condition.split('\'').nth(1) {
return path.starts_with(prefix);
}
}
if self.condition.contains("method ==") {
if let Some(expected_method) = self.condition.split('\'').nth(1) {
return method.eq_ignore_ascii_case(expected_method);
}
}
false
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct StateMachine {
pub resource_type: String,
pub states: Vec<String>,
pub initial_state: String,
pub transitions: Vec<StateTransition>,
#[serde(default)]
pub sub_scenarios: Vec<SubScenario>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub visual_layout: Option<VisualLayout>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl StateMachine {
pub fn new(
resource_type: impl Into<String>,
states: Vec<String>,
initial_state: impl Into<String>,
) -> Self {
Self {
resource_type: resource_type.into(),
states,
initial_state: initial_state.into(),
transitions: Vec::new(),
sub_scenarios: Vec::new(),
visual_layout: None,
metadata: HashMap::new(),
}
}
pub fn add_transition(mut self, transition: StateTransition) -> Self {
self.transitions.push(transition);
self
}
pub fn add_transitions(mut self, transitions: Vec<StateTransition>) -> Self {
self.transitions.extend(transitions);
self
}
pub fn can_transition(&self, from: &str, to: &str) -> bool {
self.transitions.iter().any(|t| t.from_state == from && t.to_state == to)
}
pub fn next_states(&self, current: &str) -> Vec<String> {
self.transitions
.iter()
.filter(|t| t.from_state == current)
.map(|t| t.to_state.clone())
.collect()
}
pub fn select_next_state(&self, current: &str) -> Option<String> {
let candidates: Vec<&StateTransition> =
self.transitions.iter().filter(|t| t.from_state == current).collect();
if candidates.is_empty() {
return None;
}
let total_probability: f64 = candidates.iter().map(|t| t.probability).sum();
let mut cumulative = 0.0;
let random = rand::random::<f64>() * total_probability;
for transition in &candidates {
cumulative += transition.probability;
if random <= cumulative {
return Some(transition.to_state.clone());
}
}
Some(candidates[0].to_state.clone())
}
pub fn add_sub_scenario(mut self, sub_scenario: SubScenario) -> Self {
self.sub_scenarios.push(sub_scenario);
self
}
pub fn with_visual_layout(mut self, layout: VisualLayout) -> Self {
self.visual_layout = Some(layout);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn get_sub_scenario(&self, id: &str) -> Option<&SubScenario> {
self.sub_scenarios.iter().find(|s| s.id == id)
}
pub fn get_sub_scenario_mut(&mut self, id: &str) -> Option<&mut SubScenario> {
self.sub_scenarios.iter_mut().find(|s| s.id == id)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct StateTransition {
#[serde(rename = "from")]
pub from_state: String,
#[serde(rename = "to")]
pub to_state: String,
#[serde(default = "default_probability")]
pub probability: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub condition: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub side_effects: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub condition_expression: Option<String>,
#[serde(skip)]
pub condition_ast: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sub_scenario_ref: Option<String>,
}
impl StateTransition {
pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
Self {
from_state: from.into(),
to_state: to.into(),
probability: default_probability(),
condition: None,
side_effects: None,
condition_expression: None,
condition_ast: None,
sub_scenario_ref: None,
}
}
pub fn with_probability(mut self, probability: f64) -> Self {
self.probability = probability.clamp(0.0, 1.0);
self
}
pub fn with_condition(mut self, condition: impl Into<String>) -> Self {
self.condition = Some(condition.into());
self
}
pub fn with_side_effect(mut self, effect: impl Into<String>) -> Self {
let mut effects = self.side_effects.unwrap_or_default();
effects.push(effect.into());
self.side_effects = Some(effects);
self
}
pub fn with_condition_expression(mut self, expression: impl Into<String>) -> Self {
self.condition_expression = Some(expression.into());
self
}
pub fn with_sub_scenario_ref(mut self, sub_scenario_id: impl Into<String>) -> Self {
self.sub_scenario_ref = Some(sub_scenario_id.into());
self
}
}
fn default_probability() -> f64 {
1.0
}
#[derive(Debug, Clone)]
pub struct EvaluationContext {
pub method: String,
pub path: String,
pub headers: HashMap<String, String>,
pub session_state: HashMap<String, serde_json::Value>,
}
impl EvaluationContext {
pub fn new(method: impl Into<String>, path: impl Into<String>) -> Self {
Self {
method: method.into(),
path: path.into(),
headers: HashMap::new(),
session_state: HashMap::new(),
}
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = headers;
self
}
pub fn with_session_state(mut self, state: HashMap<String, serde_json::Value>) -> Self {
self.session_state = state;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consistency_rule_matches() {
let rule = ConsistencyRule::new(
"require_auth",
"path starts_with '/api/cart'",
RuleAction::RequireAuth {
message: "Authentication required".to_string(),
},
);
assert!(rule.matches("GET", "/api/cart"));
assert!(rule.matches("POST", "/api/cart/items"));
assert!(!rule.matches("GET", "/api/products"));
}
#[test]
fn test_state_machine_transitions() {
let machine = StateMachine::new(
"order",
vec![
"pending".to_string(),
"processing".to_string(),
"shipped".to_string(),
"delivered".to_string(),
],
"pending",
)
.add_transition(StateTransition::new("pending", "processing").with_probability(0.8))
.add_transition(StateTransition::new("processing", "shipped").with_probability(0.9))
.add_transition(StateTransition::new("shipped", "delivered").with_probability(1.0));
assert!(machine.can_transition("pending", "processing"));
assert!(machine.can_transition("processing", "shipped"));
assert!(!machine.can_transition("pending", "shipped")); }
#[test]
fn test_state_machine_next_states() {
let machine = StateMachine::new(
"order",
vec![
"pending".to_string(),
"processing".to_string(),
"cancelled".to_string(),
],
"pending",
)
.add_transition(StateTransition::new("pending", "processing"))
.add_transition(StateTransition::new("pending", "cancelled"));
let next = machine.next_states("pending");
assert_eq!(next.len(), 2);
assert!(next.contains(&"processing".to_string()));
assert!(next.contains(&"cancelled".to_string()));
}
#[test]
fn test_rule_action_serialization() {
let action = RuleAction::Error {
status: 401,
message: "Unauthorized".to_string(),
};
let json = serde_json::to_string(&action).unwrap();
assert!(json.contains("\"type\":\"error\""));
assert!(json.contains("401"));
let deserialized: RuleAction = serde_json::from_str(&json).unwrap();
match deserialized {
RuleAction::Error { status, message } => {
assert_eq!(status, 401);
assert_eq!(message, "Unauthorized");
}
_ => panic!("Unexpected action type"),
}
}
#[test]
fn test_state_transition_probability() {
let transition = StateTransition::new("pending", "processing").with_probability(0.75);
assert_eq!(transition.probability, 0.75);
let transition_clamped = StateTransition::new("a", "b").with_probability(1.5);
assert_eq!(transition_clamped.probability, 1.0);
}
}