Skip to main content

datasynth_audit_optimizer/
shortest_path.rs

1//! Shortest-path analysis for audit procedure FSMs.
2//!
3//! Uses BFS on each procedure's `ProcedureAggregate` to find the minimum
4//! number of transitions required to move from `initial_state` to any
5//! terminal state (a state with no outgoing transitions).
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use serde::Serialize;
10
11use datasynth_audit_fsm::schema::AuditBlueprint;
12
13// ---------------------------------------------------------------------------
14// Report types
15// ---------------------------------------------------------------------------
16
17/// The minimum-transition path through a single procedure's FSM.
18#[derive(Debug, Clone, Serialize)]
19pub struct ProcedurePath {
20    /// Ordered sequence of states visited, from `initial_state` to terminal.
21    pub states: Vec<String>,
22    /// Number of transitions taken (= `states.len() - 1`).
23    pub transition_count: usize,
24    /// Commands used along the path (one per transition; empty string when the
25    /// transition has no associated command).
26    pub commands: Vec<String>,
27}
28
29/// Aggregate shortest-path report across all procedures in a blueprint.
30#[derive(Debug, Clone, Serialize)]
31pub struct ShortestPathReport {
32    /// Per-procedure minimum paths, keyed by procedure id.
33    pub procedure_paths: HashMap<String, ProcedurePath>,
34    /// Sum of `transition_count` across all procedures.
35    pub total_minimum_transitions: usize,
36}
37
38// ---------------------------------------------------------------------------
39// Analysis
40// ---------------------------------------------------------------------------
41
42/// Analyse every procedure in `blueprint` and return the shortest path from
43/// `initial_state` to any terminal state for each procedure that has a
44/// non-empty aggregate.
45///
46/// A terminal state is any state that has no outgoing transitions in the
47/// procedure's FSM.  BFS guarantees the first path found to a terminal state
48/// is the shortest.
49pub fn analyze_shortest_paths(blueprint: &AuditBlueprint) -> ShortestPathReport {
50    let mut procedure_paths: HashMap<String, ProcedurePath> = HashMap::new();
51
52    for phase in &blueprint.phases {
53        for procedure in &phase.procedures {
54            let agg = &procedure.aggregate;
55
56            // Skip procedures with no FSM content.
57            if agg.transitions.is_empty() && agg.initial_state.is_empty() {
58                continue;
59            }
60
61            // ------------------------------------------------------------------
62            // Build adjacency list: state -> Vec<(to_state, command)>
63            // ------------------------------------------------------------------
64            let mut adj: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
65
66            // Ensure every declared state has an entry (even if it has no
67            // outgoing transitions) so we can detect terminal states correctly.
68            for state in &agg.states {
69                adj.entry(state.as_str()).or_default();
70            }
71
72            for transition in &agg.transitions {
73                // Also lazily create entries for states only referenced in
74                // transitions (not in the explicit states list).
75                adj.entry(transition.from_state.as_str()).or_default();
76                adj.entry(transition.to_state.as_str()).or_default();
77
78                let cmd = transition.command.as_deref().unwrap_or("");
79                adj.entry(transition.from_state.as_str())
80                    .or_default()
81                    .push((transition.to_state.as_str(), cmd));
82            }
83
84            let initial = agg.initial_state.as_str();
85
86            // Nothing to do if there is no initial state or it is not in the
87            // adjacency map.
88            if initial.is_empty() || !adj.contains_key(initial) {
89                continue;
90            }
91
92            // ------------------------------------------------------------------
93            // Identify terminal states (no outgoing transitions).
94            // ------------------------------------------------------------------
95            let terminal_states: HashSet<&str> = adj
96                .iter()
97                .filter(|(_, neighbours)| neighbours.is_empty())
98                .map(|(state, _)| *state)
99                .collect();
100
101            if terminal_states.is_empty() {
102                // No terminal state exists; skip this procedure.
103                continue;
104            }
105
106            // ------------------------------------------------------------------
107            // BFS from initial_state.
108            // Each queue entry: (current_state, path_of_states, path_of_commands)
109            // ------------------------------------------------------------------
110            // Store (state, Vec<state>, Vec<command>) per BFS node.
111            let mut visited: HashSet<&str> = HashSet::new();
112            let mut queue: VecDeque<(&str, Vec<&str>, Vec<&str>)> = VecDeque::new();
113
114            visited.insert(initial);
115            queue.push_back((initial, vec![initial], vec![]));
116
117            let mut best: Option<ProcedurePath> = None;
118
119            'bfs: while let Some((current, states_path, commands_path)) = queue.pop_front() {
120                if terminal_states.contains(current) {
121                    let transition_count = commands_path.len();
122                    best = Some(ProcedurePath {
123                        states: states_path.iter().map(|s| s.to_string()).collect(),
124                        transition_count,
125                        commands: commands_path.iter().map(|c| c.to_string()).collect(),
126                    });
127                    break 'bfs;
128                }
129
130                if let Some(neighbours) = adj.get(current) {
131                    for &(next_state, cmd) in neighbours {
132                        if !visited.contains(next_state) {
133                            visited.insert(next_state);
134                            let mut new_states = states_path.clone();
135                            new_states.push(next_state);
136                            let mut new_cmds = commands_path.clone();
137                            new_cmds.push(cmd);
138                            queue.push_back((next_state, new_states, new_cmds));
139                        }
140                    }
141                }
142            }
143
144            if let Some(path) = best {
145                procedure_paths.insert(procedure.id.clone(), path);
146            }
147        }
148    }
149
150    let total_minimum_transitions = procedure_paths.values().map(|p| p.transition_count).sum();
151
152    ShortestPathReport {
153        procedure_paths,
154        total_minimum_transitions,
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Tests
160// ---------------------------------------------------------------------------
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
166
167    #[test]
168    fn test_fsa_shortest_paths() {
169        let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
170        let report = analyze_shortest_paths(&bwp.blueprint);
171
172        assert!(
173            !report.procedure_paths.is_empty(),
174            "Expected at least one procedure path in the FSA blueprint, got none"
175        );
176
177        for (proc_id, path) in &report.procedure_paths {
178            assert!(
179                path.transition_count >= 2,
180                "Procedure '{}' path has {} transitions; expected >= 2",
181                proc_id,
182                path.transition_count
183            );
184            assert_eq!(
185                path.states.len(),
186                path.transition_count + 1,
187                "Procedure '{}': states.len() ({}) should equal transition_count + 1 ({})",
188                proc_id,
189                path.states.len(),
190                path.transition_count + 1,
191            );
192            assert_eq!(
193                path.commands.len(),
194                path.transition_count,
195                "Procedure '{}': commands.len() ({}) should equal transition_count ({})",
196                proc_id,
197                path.commands.len(),
198                path.transition_count,
199            );
200        }
201    }
202
203    #[test]
204    fn test_shortest_path_report_serializes() {
205        let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
206        let report = analyze_shortest_paths(&bwp.blueprint);
207        let json = serde_json::to_string(&report).expect("serialization should succeed");
208        assert!(
209            json.contains("procedure_paths"),
210            "Serialized JSON should contain 'procedure_paths'"
211        );
212    }
213}