use crate::discovery::{project_to_event_log, ACTIVITY_KEY};
use crate::types::AdmittedReceipt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use wasm4pm::ilp_discovery::discover_optimized_dfg_from_log;
use wasm4pm::models::DFG;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ActivityPrediction {
pub activity: String,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PredictionReport {
pub predictions: Vec<ActivityPrediction>,
pub context_length: usize,
pub model_type: String,
}
#[derive(Debug, thiserror::Error)]
pub enum PredictionError {
#[error("invalid top-k: {0} (must be at least 1)")]
InvalidTopK(usize),
#[error("wasm4pm discovery failed: {0}")]
Wasm4pm(String),
}
pub fn predict_next_with_model(
model: &AdmittedReceipt,
current_trace: &AdmittedReceipt,
top_k: usize,
) -> Result<PredictionReport, PredictionError> {
if top_k == 0 {
return Err(PredictionError::InvalidTopK(0));
}
let model_receipt = &model.value;
let trace_receipt = ¤t_trace.value;
let context_length = trace_receipt.events.len();
let log = project_to_event_log(model_receipt);
let dfg = discover_optimized_dfg_from_log(&log, ACTIVITY_KEY, 1.0, 1.0);
let candidates = if trace_receipt.events.is_empty() {
let total_starts: usize = dfg.start_activities.values().sum();
if total_starts > 0 {
dfg.start_activities
.iter()
.map(|(id, &freq)| {
let label = find_label(&dfg, id);
(label, freq as f64 / total_starts as f64)
})
.collect::<Vec<_>>()
} else {
vec![]
}
} else {
let last_event = trace_receipt
.events
.last()
.ok_or_else(|| PredictionError::Wasm4pm("Unexpected empty event list".to_string()))?;
let last_activity = &last_event.event_type;
let node_ids: Vec<&String> = dfg
.nodes
.iter()
.filter(|n| &n.label == last_activity)
.map(|n| &n.id)
.collect();
let mut next_freqs: HashMap<String, f64> = HashMap::new();
let mut total_node_freq = 0.0;
for &node_id in &node_ids {
if let Some(node) = dfg.nodes.iter().find(|n| &n.id == node_id) {
total_node_freq += node.frequency as f64;
}
for edge in &dfg.edges {
if &edge.from == node_id {
let target_label = find_label(&dfg, &edge.to);
*next_freqs.entry(target_label).or_insert(0.0) += edge.frequency as f64;
}
}
}
if total_node_freq > 0.0 {
next_freqs
.into_iter()
.map(|(label, freq)| (label, freq / total_node_freq))
.collect()
} else {
vec![]
}
};
let mut predictions: Vec<ActivityPrediction> = candidates
.into_iter()
.map(|(activity, confidence)| ActivityPrediction {
activity,
confidence,
})
.collect();
predictions.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.activity.cmp(&b.activity))
});
predictions.truncate(top_k);
Ok(PredictionReport {
predictions,
context_length,
model_type: "Directly-Follows Graph (DFG)".to_string(),
})
}
pub fn predict_next(
admitted: &AdmittedReceipt,
top_k: usize,
) -> Result<PredictionReport, PredictionError> {
predict_next_with_model(admitted, admitted, top_k)
}
fn find_label(dfg: &DFG, id: &str) -> String {
dfg.nodes
.iter()
.find(|n| n.id == id)
.map(|n| n.label.clone())
.unwrap_or_else(|| "unknown".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::admission::admit;
use crate::ocel::{build_event, object_ref, SeqCounter};
fn test_receipt(activities: &[&str]) -> AdmittedReceipt {
let mut asm = crate::chain::ChainAssembler::new();
let mut counter = SeqCounter::new();
for (i, &act) in activities.iter().enumerate() {
let ev = build_event(
act,
vec![object_ref(&format!("obj-{}", i), "artifact")],
act.as_bytes(),
&mut counter,
)
.expect("build_event failure in test");
asm.append(ev).expect("append failure in test");
}
admit(asm.finalize()).expect("test receipt must be admittable")
}
#[test]
fn predicts_sequential_activity_with_full_confidence() -> Result<(), Box<dyn std::error::Error>>
{
let model = test_receipt(&["A", "B", "C"]);
let prefix = test_receipt(&["A"]);
let report = predict_next_with_model(&model, &prefix, 5)?;
assert_eq!(report.predictions.len(), 1);
assert_eq!(report.predictions[0].activity, "B");
assert_eq!(report.predictions[0].confidence, 1.0);
Ok(())
}
#[test]
fn handles_multiple_branches_with_weighting() -> Result<(), Box<dyn std::error::Error>> {
let model = test_receipt(&["A", "B", "A", "C", "A", "B"]);
let prefix = test_receipt(&["A", "B", "A", "C", "A"]);
let report = predict_next_with_model(&model, &prefix, 5)?;
assert_eq!(report.predictions.len(), 2);
let b = report
.predictions
.iter()
.find(|p| p.activity == "B")
.ok_or("B not found")?;
let c = report
.predictions
.iter()
.find(|p| p.activity == "C")
.ok_or("C not found")?;
assert!((b.confidence - 0.6666).abs() < 0.001);
assert!((c.confidence - 0.3333).abs() < 0.001);
assert_eq!(report.predictions[0].activity, "B");
Ok(())
}
#[test]
fn returns_empty_predictions_for_unknown_state() -> Result<(), Box<dyn std::error::Error>> {
let model = test_receipt(&["A", "B"]);
let prefix = test_receipt(&["C"]);
let report = predict_next_with_model(&model, &prefix, 5)?;
assert!(report.predictions.is_empty());
Ok(())
}
#[test]
fn respects_top_k_limit() -> Result<(), Box<dyn std::error::Error>> {
let model = test_receipt(&["A", "B", "A", "C", "A", "D"]);
let prefix = test_receipt(&["A"]);
let report = predict_next_with_model(&model, &prefix, 2)?;
assert_eq!(report.predictions.len(), 2);
Ok(())
}
}