use std::collections::{HashMap, HashSet, VecDeque};
use serde::Serialize;
use datasynth_audit_fsm::schema::AuditBlueprint;
#[derive(Debug, Clone, Serialize)]
pub struct ProcedurePath {
pub states: Vec<String>,
pub transition_count: usize,
pub commands: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ShortestPathReport {
pub procedure_paths: HashMap<String, ProcedurePath>,
pub total_minimum_transitions: usize,
}
pub fn analyze_shortest_paths(blueprint: &AuditBlueprint) -> ShortestPathReport {
let mut procedure_paths: HashMap<String, ProcedurePath> = HashMap::new();
for phase in &blueprint.phases {
for procedure in &phase.procedures {
let agg = &procedure.aggregate;
if agg.transitions.is_empty() && agg.initial_state.is_empty() {
continue;
}
let mut adj: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
for state in &agg.states {
adj.entry(state.as_str()).or_default();
}
for transition in &agg.transitions {
adj.entry(transition.from_state.as_str()).or_default();
adj.entry(transition.to_state.as_str()).or_default();
let cmd = transition.command.as_deref().unwrap_or("");
adj.entry(transition.from_state.as_str())
.or_default()
.push((transition.to_state.as_str(), cmd));
}
let initial = agg.initial_state.as_str();
if initial.is_empty() || !adj.contains_key(initial) {
continue;
}
let terminal_states: HashSet<&str> = adj
.iter()
.filter(|(_, neighbours)| neighbours.is_empty())
.map(|(state, _)| *state)
.collect();
if terminal_states.is_empty() {
continue;
}
let mut visited: HashSet<&str> = HashSet::new();
let mut queue: VecDeque<(&str, Vec<&str>, Vec<&str>)> = VecDeque::new();
visited.insert(initial);
queue.push_back((initial, vec![initial], vec![]));
let mut best: Option<ProcedurePath> = None;
'bfs: while let Some((current, states_path, commands_path)) = queue.pop_front() {
if terminal_states.contains(current) {
let transition_count = commands_path.len();
best = Some(ProcedurePath {
states: states_path.iter().map(|s| s.to_string()).collect(),
transition_count,
commands: commands_path.iter().map(|c| c.to_string()).collect(),
});
break 'bfs;
}
if let Some(neighbours) = adj.get(current) {
for &(next_state, cmd) in neighbours {
if !visited.contains(next_state) {
visited.insert(next_state);
let mut new_states = states_path.clone();
new_states.push(next_state);
let mut new_cmds = commands_path.clone();
new_cmds.push(cmd);
queue.push_back((next_state, new_states, new_cmds));
}
}
}
}
if let Some(path) = best {
procedure_paths.insert(procedure.id.clone(), path);
}
}
}
let total_minimum_transitions = procedure_paths.values().map(|p| p.transition_count).sum();
ShortestPathReport {
procedure_paths,
total_minimum_transitions,
}
}
#[cfg(test)]
mod tests {
use super::*;
use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
#[test]
fn test_fsa_shortest_paths() {
let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
let report = analyze_shortest_paths(&bwp.blueprint);
assert!(
!report.procedure_paths.is_empty(),
"Expected at least one procedure path in the FSA blueprint, got none"
);
for (proc_id, path) in &report.procedure_paths {
assert!(
path.transition_count >= 2,
"Procedure '{}' path has {} transitions; expected >= 2",
proc_id,
path.transition_count
);
assert_eq!(
path.states.len(),
path.transition_count + 1,
"Procedure '{}': states.len() ({}) should equal transition_count + 1 ({})",
proc_id,
path.states.len(),
path.transition_count + 1,
);
assert_eq!(
path.commands.len(),
path.transition_count,
"Procedure '{}': commands.len() ({}) should equal transition_count ({})",
proc_id,
path.commands.len(),
path.transition_count,
);
}
}
#[test]
fn test_shortest_path_report_serializes() {
let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
let report = analyze_shortest_paths(&bwp.blueprint);
let json = serde_json::to_string(&report).expect("serialization should succeed");
assert!(
json.contains("procedure_paths"),
"Serialized JSON should contain 'procedure_paths'"
);
}
}