datasynth_audit_optimizer/shortest_path.rs
1//! Shortest-path analysis for audit procedure FSMs.
2//!
3//! Uses BFS on each procedure's `ProcedureAggregate` to find the minimum
4//! number of transitions required to move from `initial_state` to any
5//! terminal state (a state with no outgoing transitions).
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use serde::Serialize;
10
11use datasynth_audit_fsm::schema::AuditBlueprint;
12
13// ---------------------------------------------------------------------------
14// Report types
15// ---------------------------------------------------------------------------
16
17/// The minimum-transition path through a single procedure's FSM.
18#[derive(Debug, Clone, Serialize)]
19pub struct ProcedurePath {
20 /// Ordered sequence of states visited, from `initial_state` to terminal.
21 pub states: Vec<String>,
22 /// Number of transitions taken (= `states.len() - 1`).
23 pub transition_count: usize,
24 /// Commands used along the path (one per transition; empty string when the
25 /// transition has no associated command).
26 pub commands: Vec<String>,
27}
28
29/// Aggregate shortest-path report across all procedures in a blueprint.
30#[derive(Debug, Clone, Serialize)]
31pub struct ShortestPathReport {
32 /// Per-procedure minimum paths, keyed by procedure id.
33 pub procedure_paths: HashMap<String, ProcedurePath>,
34 /// Sum of `transition_count` across all procedures.
35 pub total_minimum_transitions: usize,
36}
37
38// ---------------------------------------------------------------------------
39// Analysis
40// ---------------------------------------------------------------------------
41
42/// Analyse every procedure in `blueprint` and return the shortest path from
43/// `initial_state` to any terminal state for each procedure that has a
44/// non-empty aggregate.
45///
46/// A terminal state is any state that has no outgoing transitions in the
47/// procedure's FSM. BFS guarantees the first path found to a terminal state
48/// is the shortest.
49pub fn analyze_shortest_paths(blueprint: &AuditBlueprint) -> ShortestPathReport {
50 let mut procedure_paths: HashMap<String, ProcedurePath> = HashMap::new();
51
52 for phase in &blueprint.phases {
53 for procedure in &phase.procedures {
54 let agg = &procedure.aggregate;
55
56 // Skip procedures with no FSM content.
57 if agg.transitions.is_empty() && agg.initial_state.is_empty() {
58 continue;
59 }
60
61 // ------------------------------------------------------------------
62 // Build adjacency list: state -> Vec<(to_state, command)>
63 // ------------------------------------------------------------------
64 let mut adj: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
65
66 // Ensure every declared state has an entry (even if it has no
67 // outgoing transitions) so we can detect terminal states correctly.
68 for state in &agg.states {
69 adj.entry(state.as_str()).or_default();
70 }
71
72 for transition in &agg.transitions {
73 // Also lazily create entries for states only referenced in
74 // transitions (not in the explicit states list).
75 adj.entry(transition.from_state.as_str()).or_default();
76 adj.entry(transition.to_state.as_str()).or_default();
77
78 let cmd = transition.command.as_deref().unwrap_or("");
79 adj.entry(transition.from_state.as_str())
80 .or_default()
81 .push((transition.to_state.as_str(), cmd));
82 }
83
84 let initial = agg.initial_state.as_str();
85
86 // Nothing to do if there is no initial state or it is not in the
87 // adjacency map.
88 if initial.is_empty() || !adj.contains_key(initial) {
89 continue;
90 }
91
92 // ------------------------------------------------------------------
93 // Identify terminal states (no outgoing transitions).
94 // ------------------------------------------------------------------
95 let terminal_states: HashSet<&str> = adj
96 .iter()
97 .filter(|(_, neighbours)| neighbours.is_empty())
98 .map(|(state, _)| *state)
99 .collect();
100
101 if terminal_states.is_empty() {
102 // No terminal state exists; skip this procedure.
103 continue;
104 }
105
106 // ------------------------------------------------------------------
107 // BFS from initial_state.
108 // Each queue entry: (current_state, path_of_states, path_of_commands)
109 // ------------------------------------------------------------------
110 // Store (state, Vec<state>, Vec<command>) per BFS node.
111 let mut visited: HashSet<&str> = HashSet::new();
112 let mut queue: VecDeque<(&str, Vec<&str>, Vec<&str>)> = VecDeque::new();
113
114 visited.insert(initial);
115 queue.push_back((initial, vec![initial], vec![]));
116
117 let mut best: Option<ProcedurePath> = None;
118
119 'bfs: while let Some((current, states_path, commands_path)) = queue.pop_front() {
120 if terminal_states.contains(current) {
121 let transition_count = commands_path.len();
122 best = Some(ProcedurePath {
123 states: states_path.iter().map(|s| s.to_string()).collect(),
124 transition_count,
125 commands: commands_path.iter().map(|c| c.to_string()).collect(),
126 });
127 break 'bfs;
128 }
129
130 if let Some(neighbours) = adj.get(current) {
131 for &(next_state, cmd) in neighbours {
132 if !visited.contains(next_state) {
133 visited.insert(next_state);
134 let mut new_states = states_path.clone();
135 new_states.push(next_state);
136 let mut new_cmds = commands_path.clone();
137 new_cmds.push(cmd);
138 queue.push_back((next_state, new_states, new_cmds));
139 }
140 }
141 }
142 }
143
144 if let Some(path) = best {
145 procedure_paths.insert(procedure.id.clone(), path);
146 }
147 }
148 }
149
150 let total_minimum_transitions = procedure_paths.values().map(|p| p.transition_count).sum();
151
152 ShortestPathReport {
153 procedure_paths,
154 total_minimum_transitions,
155 }
156}
157
158// ---------------------------------------------------------------------------
159// Tests
160// ---------------------------------------------------------------------------
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
166
167 #[test]
168 fn test_fsa_shortest_paths() {
169 let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
170 let report = analyze_shortest_paths(&bwp.blueprint);
171
172 assert!(
173 !report.procedure_paths.is_empty(),
174 "Expected at least one procedure path in the FSA blueprint, got none"
175 );
176
177 for (proc_id, path) in &report.procedure_paths {
178 assert!(
179 path.transition_count >= 2,
180 "Procedure '{}' path has {} transitions; expected >= 2",
181 proc_id,
182 path.transition_count
183 );
184 assert_eq!(
185 path.states.len(),
186 path.transition_count + 1,
187 "Procedure '{}': states.len() ({}) should equal transition_count + 1 ({})",
188 proc_id,
189 path.states.len(),
190 path.transition_count + 1,
191 );
192 assert_eq!(
193 path.commands.len(),
194 path.transition_count,
195 "Procedure '{}': commands.len() ({}) should equal transition_count ({})",
196 proc_id,
197 path.commands.len(),
198 path.transition_count,
199 );
200 }
201 }
202
203 #[test]
204 fn test_shortest_path_report_serializes() {
205 let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
206 let report = analyze_shortest_paths(&bwp.blueprint);
207 let json = serde_json::to_string(&report).expect("serialization should succeed");
208 assert!(
209 json.contains("procedure_paths"),
210 "Serialized JSON should contain 'procedure_paths'"
211 );
212 }
213}