1use std::collections::{HashSet, VecDeque};
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
24pub struct StateMachine {
25 pub name: String,
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub display_name: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
32 pub description: Option<String>,
33 pub initial_state: String,
35 pub states: Vec<StateDef>,
37 pub transitions: Vec<Transition>,
39}
40
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
46pub struct StateDef {
47 pub name: String,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub display_name: Option<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub description: Option<String>,
55 #[serde(default)]
57 pub is_final: bool,
58 #[serde(default, skip_serializing_if = "Vec::is_empty")]
60 pub on_enter: Vec<String>,
61 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub on_exit: Vec<String>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub metadata: Option<serde_json::Value>,
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
74pub struct Transition {
75 pub from: String,
77 pub event: String,
79 pub to: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub guard: Option<String>,
84 #[serde(default, skip_serializing_if = "Vec::is_empty")]
86 pub actions: Vec<String>,
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub description: Option<String>,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
97#[serde(rename_all = "snake_case")]
98pub enum Warning {
99 UnreachableState(String),
101 DeadEndState(String),
103 NoFinalStates,
105 UnusedGuard(String),
107 TransitionTriggerWithoutStateMachine(String),
109 DuplicateRelationship(String),
111 ManyToManyWithForeignKey { relationship: String },
113 ConflictingIntentHints { intent: String },
115 MultiplePrimaryIntentHints,
117}
118
119impl StateMachine {
120 pub fn new(name: impl Into<String>) -> Self {
122 Self {
123 name: name.into(),
124 display_name: None,
125 description: None,
126 initial_state: String::new(),
127 states: Vec::new(),
128 transitions: Vec::new(),
129 }
130 }
131
132 pub fn display_name(mut self, name: impl Into<String>) -> Self {
134 self.display_name = Some(name.into());
135 self
136 }
137
138 pub fn description(mut self, desc: impl Into<String>) -> Self {
140 self.description = Some(desc.into());
141 self
142 }
143
144 pub fn initial(mut self, state: impl Into<String>) -> Self {
146 self.initial_state = state.into();
147 self
148 }
149
150 pub fn state(mut self, state: StateDef) -> Self {
152 self.states.push(state);
153 self
154 }
155
156 pub fn transition(mut self, transition: Transition) -> Self {
158 self.transitions.push(transition);
159 self
160 }
161
162 pub fn validate(&self) -> Result<Vec<Warning>, crate::Error> {
167 let mut warnings = Vec::new();
168 let state_names: HashSet<&str> = self.states.iter().map(|s| s.name.as_str()).collect();
169
170 if self.initial_state.is_empty() {
172 return Err(crate::Error::Validation("initial state not set".into()));
173 }
174
175 if !state_names.contains(self.initial_state.as_str()) {
177 return Err(crate::Error::Validation(format!(
178 "initial state '{}' not found in states",
179 self.initial_state
180 )));
181 }
182
183 for t in &self.transitions {
185 if !state_names.contains(t.from.as_str()) {
186 return Err(crate::Error::Validation(format!(
187 "transition source '{}' not found in states",
188 t.from
189 )));
190 }
191 if !state_names.contains(t.to.as_str()) {
192 return Err(crate::Error::Validation(format!(
193 "transition target '{}' not found in states",
194 t.to
195 )));
196 }
197 }
198
199 let mut reachable = HashSet::new();
201 let mut queue = VecDeque::new();
202 queue.push_back(self.initial_state.as_str());
203 reachable.insert(self.initial_state.as_str());
204
205 while let Some(current) = queue.pop_front() {
206 for t in &self.transitions {
207 if t.from == current && !reachable.contains(t.to.as_str()) {
208 reachable.insert(t.to.as_str());
209 queue.push_back(t.to.as_str());
210 }
211 }
212 }
213
214 for state in &self.states {
215 if !reachable.contains(state.name.as_str()) {
216 warnings.push(Warning::UnreachableState(state.name.clone()));
217 }
218 }
219
220 let states_with_outgoing: HashSet<&str> =
222 self.transitions.iter().map(|t| t.from.as_str()).collect();
223 for state in &self.states {
224 if !state.is_final && !states_with_outgoing.contains(state.name.as_str()) {
225 warnings.push(Warning::DeadEndState(state.name.clone()));
226 }
227 }
228
229 if !self.states.iter().any(|s| s.is_final) {
231 warnings.push(Warning::NoFinalStates);
232 }
233
234 Ok(warnings)
235 }
236
237 pub fn states_for_event(&self, event: &str) -> Vec<&Transition> {
239 self.transitions
240 .iter()
241 .filter(|t| t.event == event)
242 .collect()
243 }
244
245 pub fn events_from_state(&self, state: &str) -> Vec<&Transition> {
247 self.transitions
248 .iter()
249 .filter(|t| t.from == state)
250 .collect()
251 }
252}
253
254impl StateDef {
255 pub fn new(name: impl Into<String>) -> Self {
257 Self {
258 name: name.into(),
259 display_name: None,
260 description: None,
261 is_final: false,
262 on_enter: Vec::new(),
263 on_exit: Vec::new(),
264 metadata: None,
265 }
266 }
267
268 pub fn display_name(mut self, name: impl Into<String>) -> Self {
270 self.display_name = Some(name.into());
271 self
272 }
273
274 pub fn description(mut self, desc: impl Into<String>) -> Self {
276 self.description = Some(desc.into());
277 self
278 }
279
280 pub fn final_state(mut self) -> Self {
282 self.is_final = true;
283 self
284 }
285
286 pub fn on_enter(mut self, effects: Vec<impl Into<String>>) -> Self {
288 self.on_enter = effects.into_iter().map(Into::into).collect();
289 self
290 }
291
292 pub fn on_exit(mut self, effects: Vec<impl Into<String>>) -> Self {
294 self.on_exit = effects.into_iter().map(Into::into).collect();
295 self
296 }
297
298 pub fn metadata(mut self, metadata: serde_json::Value) -> Self {
300 self.metadata = Some(metadata);
301 self
302 }
303}
304
305impl Transition {
306 pub fn new(from: impl Into<String>, event: impl Into<String>, to: impl Into<String>) -> Self {
308 Self {
309 from: from.into(),
310 event: event.into(),
311 to: to.into(),
312 guard: None,
313 actions: Vec::new(),
314 description: None,
315 }
316 }
317
318 pub fn guard(mut self, guard: impl Into<String>) -> Self {
320 self.guard = Some(guard.into());
321 self
322 }
323
324 pub fn actions(mut self, actions: Vec<impl Into<String>>) -> Self {
326 self.actions = actions.into_iter().map(Into::into).collect();
327 self
328 }
329
330 pub fn description(mut self, desc: impl Into<String>) -> Self {
332 self.description = Some(desc.into());
333 self
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 fn sample_machine() -> StateMachine {
342 StateMachine::new("order_lifecycle")
343 .initial("draft")
344 .state(StateDef::new("draft").display_name("Draft"))
345 .state(
346 StateDef::new("pending")
347 .display_name("Pending")
348 .on_enter(vec!["notify_reviewer"]),
349 )
350 .state(StateDef::new("approved").display_name("Approved"))
351 .state(
352 StateDef::new("completed")
353 .display_name("Completed")
354 .final_state(),
355 )
356 .transition(Transition::new("draft", "submit", "pending").guard("has_required_fields"))
357 .transition(Transition::new("pending", "approve", "approved").guard("is_reviewer"))
358 .transition(Transition::new("approved", "complete", "completed"))
359 }
360
361 #[test]
362 fn state_machine_serde_round_trip() {
363 let machine = sample_machine();
364 let json = serde_json::to_string_pretty(&machine).unwrap();
365 let parsed: StateMachine = serde_json::from_str(&json).unwrap();
366
367 assert_eq!(machine.name, parsed.name);
368 assert_eq!(machine.initial_state, parsed.initial_state);
369 assert_eq!(machine.states.len(), parsed.states.len());
370 assert_eq!(machine.transitions.len(), parsed.transitions.len());
371 }
372
373 #[test]
374 fn state_def_serde_round_trip() {
375 let state = StateDef::new("pending")
376 .display_name("Pending Review")
377 .description("Awaiting reviewer approval")
378 .on_enter(vec!["notify_reviewer", "start_sla_timer"])
379 .on_exit(vec!["stop_sla_timer"])
380 .metadata(serde_json::json!({"color": "yellow"}));
381
382 let json = serde_json::to_string(&state).unwrap();
383 let parsed: StateDef = serde_json::from_str(&json).unwrap();
384
385 assert_eq!(state.name, parsed.name);
386 assert_eq!(state.display_name, parsed.display_name);
387 assert_eq!(state.description, parsed.description);
388 assert_eq!(state.is_final, parsed.is_final);
389 assert_eq!(state.on_enter, parsed.on_enter);
390 assert_eq!(state.on_exit, parsed.on_exit);
391 assert_eq!(state.metadata, parsed.metadata);
392 }
393
394 #[test]
395 fn transition_serde_round_trip() {
396 let transition = Transition::new("pending", "reject", "rejected")
397 .guard("is_reviewer")
398 .actions(vec!["log_rejection_reason", "notify_submitter"])
399 .description("Reviewer rejects the submission");
400
401 let json = serde_json::to_string(&transition).unwrap();
402 let parsed: Transition = serde_json::from_str(&json).unwrap();
403
404 assert_eq!(transition.from, parsed.from);
405 assert_eq!(transition.event, parsed.event);
406 assert_eq!(transition.to, parsed.to);
407 assert_eq!(transition.guard, parsed.guard);
408 assert_eq!(transition.actions, parsed.actions);
409 assert_eq!(transition.description, parsed.description);
410 }
411
412 #[test]
413 fn json_omits_empty_optional_fields() {
414 let state = StateDef::new("draft");
415 let json = serde_json::to_string(&state).unwrap();
416 assert!(!json.contains("display_name"));
417 assert!(!json.contains("description"));
418 assert!(!json.contains("on_enter"));
419 assert!(!json.contains("on_exit"));
420 assert!(!json.contains("metadata"));
421
422 let transition = Transition::new("a", "go", "b");
423 let json = serde_json::to_string(&transition).unwrap();
424 assert!(!json.contains("guard"));
425 assert!(!json.contains("actions"));
426 assert!(!json.contains("description"));
427
428 let machine = StateMachine::new("test").initial("a");
429 let json = serde_json::to_string(&machine).unwrap();
430 assert!(!json.contains("display_name"));
431 assert!(!json.contains("description"));
432 }
433
434 #[test]
435 fn validate_valid_machine() {
436 let machine = sample_machine();
437 let warnings = machine.validate().unwrap();
438 assert!(warnings.is_empty());
439 }
440
441 #[test]
442 fn validate_missing_initial_state() {
443 let machine = StateMachine::new("test").state(StateDef::new("a"));
444 assert!(machine.validate().is_err());
445 }
446
447 #[test]
448 fn validate_initial_state_not_in_states() {
449 let machine = StateMachine::new("test")
450 .initial("nonexistent")
451 .state(StateDef::new("a").final_state());
452 assert!(machine.validate().is_err());
453 }
454
455 #[test]
456 fn validate_invalid_transition_source() {
457 let machine = StateMachine::new("test")
458 .initial("a")
459 .state(StateDef::new("a").final_state())
460 .transition(Transition::new("missing", "go", "a"));
461 assert!(machine.validate().is_err());
462 }
463
464 #[test]
465 fn validate_invalid_transition_target() {
466 let machine = StateMachine::new("test")
467 .initial("a")
468 .state(StateDef::new("a").final_state())
469 .transition(Transition::new("a", "go", "missing"));
470 assert!(machine.validate().is_err());
471 }
472
473 #[test]
474 fn validate_unreachable_state() {
475 let machine = StateMachine::new("test")
476 .initial("a")
477 .state(StateDef::new("a").final_state())
478 .state(StateDef::new("orphan"));
479 let warnings = machine.validate().unwrap();
480 assert!(warnings.contains(&Warning::UnreachableState("orphan".into())));
481 }
482
483 #[test]
484 fn validate_dead_end_state() {
485 let machine = StateMachine::new("test")
486 .initial("a")
487 .state(StateDef::new("a"))
488 .state(StateDef::new("b"))
489 .transition(Transition::new("a", "go", "b"));
490 let warnings = machine.validate().unwrap();
491 assert!(warnings.contains(&Warning::DeadEndState("b".into())));
492 }
493
494 #[test]
495 fn validate_no_final_states() {
496 let machine = StateMachine::new("test")
497 .initial("a")
498 .state(StateDef::new("a"))
499 .state(StateDef::new("b"))
500 .transition(Transition::new("a", "go", "b"))
501 .transition(Transition::new("b", "back", "a"));
502 let warnings = machine.validate().unwrap();
503 assert!(warnings.contains(&Warning::NoFinalStates));
504 }
505
506 #[test]
507 fn states_for_event_returns_matching_transitions() {
508 let machine = sample_machine();
509 let submit_transitions = machine.states_for_event("submit");
510 assert_eq!(submit_transitions.len(), 1);
511 assert_eq!(submit_transitions[0].from, "draft");
512 assert_eq!(submit_transitions[0].to, "pending");
513 }
514
515 #[test]
516 fn states_for_event_returns_empty_for_unknown() {
517 let machine = sample_machine();
518 assert!(machine.states_for_event("nonexistent").is_empty());
519 }
520
521 #[test]
522 fn events_from_state_returns_outgoing() {
523 let machine = sample_machine();
524 let from_draft = machine.events_from_state("draft");
525 assert_eq!(from_draft.len(), 1);
526 assert_eq!(from_draft[0].event, "submit");
527 }
528
529 #[test]
530 fn events_from_state_returns_empty_for_final() {
531 let machine = sample_machine();
532 assert!(machine.events_from_state("completed").is_empty());
533 }
534
535 #[test]
536 fn state_machine_json_structure() {
537 let machine = sample_machine();
538 let json = serde_json::to_string(&machine).unwrap();
539 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
540
541 assert!(value.get("name").is_some());
542 assert!(value.get("initial_state").is_some());
543 assert!(value.get("states").is_some());
544 assert!(value.get("transitions").is_some());
545
546 let states = value["states"].as_array().unwrap();
547 assert_eq!(states.len(), 4);
548
549 let transitions = value["transitions"].as_array().unwrap();
550 assert_eq!(transitions.len(), 3);
551 }
552
553 #[test]
554 fn state_machine_builder_chain() {
555 let machine = StateMachine::new("workflow")
556 .display_name("Workflow")
557 .description("A test workflow")
558 .initial("start")
559 .state(StateDef::new("start"))
560 .state(StateDef::new("end").final_state())
561 .transition(Transition::new("start", "go", "end"));
562
563 assert_eq!(machine.name, "workflow");
564 assert_eq!(machine.display_name.as_deref(), Some("Workflow"));
565 assert_eq!(machine.description.as_deref(), Some("A test workflow"));
566 assert_eq!(machine.initial_state, "start");
567 assert_eq!(machine.states.len(), 2);
568 assert_eq!(machine.transitions.len(), 1);
569 }
570
571 #[test]
572 fn state_def_builder_chain() {
573 let state = StateDef::new("processing")
574 .display_name("Processing")
575 .description("Order is being processed")
576 .final_state()
577 .on_enter(vec!["start_timer", "notify"])
578 .on_exit(vec!["stop_timer"])
579 .metadata(serde_json::json!({"color": "blue", "icon": "gear"}));
580
581 assert_eq!(state.name, "processing");
582 assert_eq!(state.display_name.as_deref(), Some("Processing"));
583 assert_eq!(
584 state.description.as_deref(),
585 Some("Order is being processed")
586 );
587 assert!(state.is_final);
588 assert_eq!(state.on_enter, vec!["start_timer", "notify"]);
589 assert_eq!(state.on_exit, vec!["stop_timer"]);
590 assert!(state.metadata.is_some());
591 }
592
593 #[test]
594 fn transition_builder_chain() {
595 let transition = Transition::new("draft", "submit", "pending")
596 .guard("has_required_fields")
597 .actions(vec!["validate", "log_submission"])
598 .description("Submit draft for review");
599
600 assert_eq!(transition.from, "draft");
601 assert_eq!(transition.event, "submit");
602 assert_eq!(transition.to, "pending");
603 assert_eq!(transition.guard.as_deref(), Some("has_required_fields"));
604 assert_eq!(transition.actions, vec!["validate", "log_submission"]);
605 assert_eq!(
606 transition.description.as_deref(),
607 Some("Submit draft for review")
608 );
609 }
610
611 #[test]
612 fn state_def_defaults() {
613 let state = StateDef::new("x");
614 assert_eq!(state.name, "x");
615 assert!(!state.is_final);
616 assert!(state.on_enter.is_empty());
617 assert!(state.on_exit.is_empty());
618 assert!(state.display_name.is_none());
619 assert!(state.description.is_none());
620 assert!(state.metadata.is_none());
621 }
622
623 #[test]
624 fn validate_all_warnings_combined() {
625 let machine = StateMachine::new("test")
628 .initial("a")
629 .state(StateDef::new("a"))
630 .state(StateDef::new("b"))
631 .state(StateDef::new("orphan"))
632 .transition(Transition::new("a", "go", "b"));
633
634 let warnings = machine.validate().unwrap();
635 assert!(warnings.contains(&Warning::UnreachableState("orphan".into())));
636 assert!(warnings.contains(&Warning::DeadEndState("b".into())));
637 assert!(warnings.contains(&Warning::DeadEndState("orphan".into())));
638 assert!(warnings.contains(&Warning::NoFinalStates));
639 assert_eq!(warnings.len(), 4);
640 }
641
642 #[test]
643 fn full_order_lifecycle() {
644 let machine = StateMachine::new("order_lifecycle")
645 .display_name("Order Lifecycle")
646 .description("Tracks an order from creation to fulfillment")
647 .initial("draft")
648 .state(
649 StateDef::new("draft")
650 .display_name("Draft")
651 .description("Order is being prepared"),
652 )
653 .state(
654 StateDef::new("submitted")
655 .display_name("Submitted")
656 .on_enter(vec!["validate_inventory", "calculate_totals"]),
657 )
658 .state(
659 StateDef::new("processing")
660 .display_name("Processing")
661 .on_enter(vec!["charge_payment", "reserve_inventory"]),
662 )
663 .state(
664 StateDef::new("shipped")
665 .display_name("Shipped")
666 .on_enter(vec!["generate_tracking", "notify_customer"]),
667 )
668 .state(
669 StateDef::new("delivered")
670 .display_name("Delivered")
671 .final_state(),
672 )
673 .state(
674 StateDef::new("cancelled")
675 .display_name("Cancelled")
676 .final_state()
677 .on_enter(vec!["refund_payment", "release_inventory"]),
678 )
679 .transition(
680 Transition::new("draft", "submit", "submitted")
681 .guard("has_items")
682 .description("Customer submits the order"),
683 )
684 .transition(
685 Transition::new("submitted", "process", "processing")
686 .guard("payment_valid")
687 .actions(vec!["lock_prices"]),
688 )
689 .transition(
690 Transition::new("processing", "ship", "shipped").guard("inventory_fulfilled"),
691 )
692 .transition(Transition::new("shipped", "deliver", "delivered"))
693 .transition(Transition::new("draft", "cancel", "cancelled"))
694 .transition(
695 Transition::new("submitted", "cancel", "cancelled").guard("cancellation_allowed"),
696 )
697 .transition(
698 Transition::new("processing", "cancel", "cancelled")
699 .guard("cancellation_allowed")
700 .actions(vec!["reverse_payment"]),
701 );
702
703 let warnings = machine.validate().unwrap();
704 assert!(warnings.is_empty());
705
706 let cancel_transitions = machine.states_for_event("cancel");
708 assert_eq!(cancel_transitions.len(), 3);
709
710 let from_draft = machine.events_from_state("draft");
712 assert_eq!(from_draft.len(), 2);
713 }
714}