rustkernel_procint/
prediction.rs

1//! Next activity prediction kernels.
2//!
3//! This module provides process activity prediction:
4//! - Markov chain-based prediction
5//! - N-gram model prediction
6//! - Batch inference for multiple traces
7
8use crate::types::EventLog;
9use rustkernel_core::traits::GpuKernel;
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::Instant;
14
15// ============================================================================
16// Next Activity Prediction Kernel
17// ============================================================================
18
19/// Prediction model type.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
21pub enum PredictionModelType {
22    /// First-order Markov chain (single previous activity).
23    #[default]
24    Markov1,
25    /// Second-order Markov chain (two previous activities).
26    Markov2,
27    /// N-gram model with configurable n.
28    NGram,
29}
30
31/// Configuration for prediction.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PredictionConfig {
34    /// Type of prediction model.
35    pub model_type: PredictionModelType,
36    /// N for N-gram model (ignored for Markov).
37    pub n_gram_size: usize,
38    /// Number of top predictions to return.
39    pub top_k: usize,
40    /// Minimum probability threshold.
41    pub min_probability: f64,
42    /// Use Laplace smoothing for unseen transitions.
43    pub laplace_smoothing: bool,
44}
45
46impl Default for PredictionConfig {
47    fn default() -> Self {
48        Self {
49            model_type: PredictionModelType::Markov1,
50            n_gram_size: 3,
51            top_k: 5,
52            min_probability: 0.01,
53            laplace_smoothing: true,
54        }
55    }
56}
57
58/// Transition matrix for first-order Markov model.
59/// Key: current activity, Value: map of next activity -> count
60pub type TransitionMatrix = HashMap<String, HashMap<String, u64>>;
61
62/// Higher-order transition matrix.
63/// Key: sequence of activities (as tuple), Value: map of next activity -> count
64pub type HigherOrderTransitions = HashMap<Vec<String>, HashMap<String, u64>>;
65
66/// A trained prediction model.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct PredictionModel {
69    /// Model type.
70    pub model_type: PredictionModelType,
71    /// First-order transitions (activity -> next -> count).
72    pub transitions: TransitionMatrix,
73    /// Higher-order transitions (for Markov2 and N-gram).
74    pub higher_order: HigherOrderTransitions,
75    /// Start activity frequencies.
76    pub start_activities: HashMap<String, u64>,
77    /// End activity frequencies.
78    pub end_activities: HashMap<String, u64>,
79    /// Activity vocabulary.
80    pub vocabulary: Vec<String>,
81    /// Total traces trained on.
82    pub trace_count: u64,
83    /// Total events trained on.
84    pub event_count: u64,
85}
86
87impl Default for PredictionModel {
88    fn default() -> Self {
89        Self {
90            model_type: PredictionModelType::Markov1,
91            transitions: HashMap::new(),
92            higher_order: HashMap::new(),
93            start_activities: HashMap::new(),
94            end_activities: HashMap::new(),
95            vocabulary: Vec::new(),
96            trace_count: 0,
97            event_count: 0,
98        }
99    }
100}
101
102impl PredictionModel {
103    /// Create a new model from an event log.
104    pub fn train(log: &EventLog, config: &PredictionConfig) -> Self {
105        let mut model = Self {
106            model_type: config.model_type,
107            ..Default::default()
108        };
109
110        let mut vocab_set = std::collections::HashSet::new();
111
112        for trace in log.traces.values() {
113            if trace.events.is_empty() {
114                continue;
115            }
116
117            model.trace_count += 1;
118            model.event_count += trace.events.len() as u64;
119
120            let activities: Vec<&str> = trace.events.iter().map(|e| e.activity.as_str()).collect();
121
122            // Record start/end activities
123            if let Some(first) = activities.first() {
124                *model.start_activities.entry(first.to_string()).or_default() += 1;
125            }
126            if let Some(last) = activities.last() {
127                *model.end_activities.entry(last.to_string()).or_default() += 1;
128            }
129
130            // Build vocabulary
131            for act in &activities {
132                vocab_set.insert(act.to_string());
133            }
134
135            // Build transition matrix
136            for window in activities.windows(2) {
137                let from = window[0].to_string();
138                let to = window[1].to_string();
139                *model
140                    .transitions
141                    .entry(from)
142                    .or_default()
143                    .entry(to)
144                    .or_default() += 1;
145            }
146
147            // Build higher-order transitions if needed
148            match config.model_type {
149                PredictionModelType::Markov2 => {
150                    for window in activities.windows(3) {
151                        let key = vec![window[0].to_string(), window[1].to_string()];
152                        let next = window[2].to_string();
153                        *model
154                            .higher_order
155                            .entry(key)
156                            .or_default()
157                            .entry(next)
158                            .or_default() += 1;
159                    }
160                }
161                PredictionModelType::NGram => {
162                    let n = config.n_gram_size;
163                    if activities.len() >= n {
164                        for window in activities.windows(n) {
165                            let key: Vec<String> =
166                                window[..n - 1].iter().map(|s| s.to_string()).collect();
167                            let next = window[n - 1].to_string();
168                            *model
169                                .higher_order
170                                .entry(key)
171                                .or_default()
172                                .entry(next)
173                                .or_default() += 1;
174                        }
175                    }
176                }
177                PredictionModelType::Markov1 => {}
178            }
179        }
180
181        model.vocabulary = vocab_set.into_iter().collect();
182        model.vocabulary.sort();
183
184        model
185    }
186
187    /// Predict next activities for a given sequence.
188    pub fn predict(
189        &self,
190        history: &[String],
191        config: &PredictionConfig,
192    ) -> Vec<ActivityPrediction> {
193        let vocab_size = self.vocabulary.len();
194        let smoothing = if config.laplace_smoothing { 1.0 } else { 0.0 };
195
196        // Get transition counts based on model type
197        let counts: Option<&HashMap<String, u64>> = match self.model_type {
198            PredictionModelType::Markov1 => {
199                history.last().and_then(|last| self.transitions.get(last))
200            }
201            PredictionModelType::Markov2 => {
202                if history.len() >= 2 {
203                    let key = vec![
204                        history[history.len() - 2].clone(),
205                        history[history.len() - 1].clone(),
206                    ];
207                    self.higher_order.get(&key)
208                } else if history.len() == 1 {
209                    // Fall back to first-order
210                    self.transitions.get(&history[0])
211                } else {
212                    None
213                }
214            }
215            PredictionModelType::NGram => {
216                let n = config.n_gram_size;
217                if history.len() >= n - 1 {
218                    let key: Vec<String> = history[history.len() - (n - 1)..].to_vec();
219                    self.higher_order.get(&key)
220                } else if !history.is_empty() {
221                    // Fall back to first-order
222                    self.transitions.get(&history[history.len() - 1])
223                } else {
224                    None
225                }
226            }
227        };
228
229        // Calculate probabilities
230        let mut predictions: Vec<ActivityPrediction> = if let Some(counts) = counts {
231            let total: u64 = counts.values().sum();
232            let total_with_smoothing = total as f64 + smoothing * vocab_size as f64;
233
234            self.vocabulary
235                .iter()
236                .map(|activity| {
237                    let count = counts.get(activity).copied().unwrap_or(0);
238                    let prob = (count as f64 + smoothing) / total_with_smoothing;
239                    ActivityPrediction {
240                        activity: activity.clone(),
241                        probability: prob,
242                        confidence: if total > 10 { prob } else { prob * 0.5 },
243                        is_end: self.end_activities.contains_key(activity),
244                    }
245                })
246                .filter(|p| p.probability >= config.min_probability)
247                .collect()
248        } else if config.laplace_smoothing && !self.vocabulary.is_empty() {
249            // Uniform distribution with smoothing for unseen context
250            let prob = 1.0 / vocab_size as f64;
251            self.vocabulary
252                .iter()
253                .map(|activity| ActivityPrediction {
254                    activity: activity.clone(),
255                    probability: prob,
256                    confidence: 0.1, // Low confidence for uniform
257                    is_end: self.end_activities.contains_key(activity),
258                })
259                .collect()
260        } else {
261            Vec::new()
262        };
263
264        // Sort by probability descending and take top_k
265        predictions.sort_by(|a, b| {
266            b.probability
267                .partial_cmp(&a.probability)
268                .unwrap_or(std::cmp::Ordering::Equal)
269        });
270        predictions.truncate(config.top_k);
271
272        predictions
273    }
274
275    /// Predict from activity names (convenience method).
276    pub fn predict_from_names(
277        &self,
278        history: &[&str],
279        config: &PredictionConfig,
280    ) -> Vec<ActivityPrediction> {
281        let history: Vec<String> = history.iter().map(|s| s.to_string()).collect();
282        self.predict(&history, config)
283    }
284}
285
286/// A predicted next activity.
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct ActivityPrediction {
289    /// Predicted activity name.
290    pub activity: String,
291    /// Probability of this activity.
292    pub probability: f64,
293    /// Confidence in the prediction (adjusted for data sparsity).
294    pub confidence: f64,
295    /// Whether this is commonly an end activity.
296    pub is_end: bool,
297}
298
299/// Input for batch prediction.
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct PredictionInput {
302    /// Traces to predict next activities for.
303    pub traces: Vec<TraceHistory>,
304    /// Trained model.
305    pub model: PredictionModel,
306    /// Configuration.
307    pub config: PredictionConfig,
308}
309
310/// A trace history for prediction.
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct TraceHistory {
313    /// Case/trace ID.
314    pub case_id: String,
315    /// Activity history (most recent last).
316    pub activities: Vec<String>,
317}
318
319/// Output from batch prediction.
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct PredictionOutput {
322    /// Predictions per trace.
323    pub predictions: Vec<TracePrediction>,
324    /// Compute time in microseconds.
325    pub compute_time_us: u64,
326}
327
328/// Predictions for a single trace.
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct TracePrediction {
331    /// Case/trace ID.
332    pub case_id: String,
333    /// Top-k predictions.
334    pub predictions: Vec<ActivityPrediction>,
335    /// Expected remaining activities (if model supports).
336    pub expected_remaining: Option<f64>,
337}
338
339/// Next activity prediction kernel.
340///
341/// Predicts the next activity in a business process using
342/// Markov chains or N-gram models trained on historical data.
343#[derive(Debug, Clone)]
344pub struct NextActivityPrediction {
345    metadata: KernelMetadata,
346}
347
348impl Default for NextActivityPrediction {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354impl NextActivityPrediction {
355    /// Create a new next activity prediction kernel.
356    #[must_use]
357    pub fn new() -> Self {
358        Self {
359            metadata: KernelMetadata::batch("procint/next-activity", Domain::ProcessIntelligence)
360                .with_description("Markov/N-gram next activity prediction")
361                .with_throughput(100_000)
362                .with_latency_us(50.0),
363        }
364    }
365
366    /// Train a model from an event log.
367    pub fn train(log: &EventLog, config: &PredictionConfig) -> PredictionModel {
368        PredictionModel::train(log, config)
369    }
370
371    /// Batch predict for multiple traces.
372    pub fn predict_batch(
373        traces: &[TraceHistory],
374        model: &PredictionModel,
375        config: &PredictionConfig,
376    ) -> Vec<TracePrediction> {
377        traces
378            .iter()
379            .map(|trace| {
380                let predictions = model.predict(&trace.activities, config);
381                TracePrediction {
382                    case_id: trace.case_id.clone(),
383                    predictions,
384                    expected_remaining: None,
385                }
386            })
387            .collect()
388    }
389
390    /// Compute batch predictions.
391    pub fn compute(input: &PredictionInput) -> PredictionOutput {
392        let start = Instant::now();
393        let predictions = Self::predict_batch(&input.traces, &input.model, &input.config);
394        PredictionOutput {
395            predictions,
396            compute_time_us: start.elapsed().as_micros() as u64,
397        }
398    }
399}
400
401impl GpuKernel for NextActivityPrediction {
402    fn metadata(&self) -> &KernelMetadata {
403        &self.metadata
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::types::ProcessEvent;
411
412    fn create_test_log() -> EventLog {
413        let mut log = EventLog::new("test".to_string());
414
415        // Trace 1: A -> B -> C -> D
416        for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
417            log.add_event(ProcessEvent {
418                id: i as u64,
419                case_id: "trace1".to_string(),
420                activity: activity.to_string(),
421                timestamp: i as u64 * 100,
422                resource: None,
423                attributes: HashMap::new(),
424            });
425        }
426
427        // Trace 2: A -> B -> C -> D (same pattern)
428        for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
429            log.add_event(ProcessEvent {
430                id: (10 + i) as u64,
431                case_id: "trace2".to_string(),
432                activity: activity.to_string(),
433                timestamp: i as u64 * 100,
434                resource: None,
435                attributes: HashMap::new(),
436            });
437        }
438
439        // Trace 3: A -> B -> E -> D (different path)
440        for (i, activity) in ["A", "B", "E", "D"].iter().enumerate() {
441            log.add_event(ProcessEvent {
442                id: (20 + i) as u64,
443                case_id: "trace3".to_string(),
444                activity: activity.to_string(),
445                timestamp: i as u64 * 100,
446                resource: None,
447                attributes: HashMap::new(),
448            });
449        }
450
451        // Trace 4: A -> B -> C -> D
452        for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
453            log.add_event(ProcessEvent {
454                id: (30 + i) as u64,
455                case_id: "trace4".to_string(),
456                activity: activity.to_string(),
457                timestamp: i as u64 * 100,
458                resource: None,
459                attributes: HashMap::new(),
460            });
461        }
462
463        log
464    }
465
466    #[test]
467    fn test_next_activity_prediction_metadata() {
468        let kernel = NextActivityPrediction::new();
469        assert_eq!(kernel.metadata().id, "procint/next-activity");
470        assert_eq!(kernel.metadata().domain, Domain::ProcessIntelligence);
471    }
472
473    #[test]
474    fn test_model_training() {
475        let log = create_test_log();
476        let config = PredictionConfig::default();
477        let model = PredictionModel::train(&log, &config);
478
479        assert_eq!(model.trace_count, 4);
480        assert!(model.vocabulary.contains(&"A".to_string()));
481        assert!(model.vocabulary.contains(&"B".to_string()));
482        assert!(model.vocabulary.contains(&"C".to_string()));
483        assert!(model.vocabulary.contains(&"D".to_string()));
484        assert!(model.vocabulary.contains(&"E".to_string()));
485
486        // Check transitions
487        assert!(model.transitions.contains_key("A"));
488        assert!(model.transitions.contains_key("B"));
489    }
490
491    #[test]
492    fn test_first_order_prediction() {
493        let log = create_test_log();
494        let config = PredictionConfig {
495            model_type: PredictionModelType::Markov1,
496            top_k: 3,
497            min_probability: 0.0,
498            laplace_smoothing: false,
499            ..Default::default()
500        };
501        let model = PredictionModel::train(&log, &config);
502
503        // After A, B should be predicted with high probability
504        let predictions = model.predict_from_names(&["A"], &config);
505        assert!(!predictions.is_empty());
506        assert_eq!(predictions[0].activity, "B");
507        assert!(predictions[0].probability > 0.9);
508
509        // After B, C should be most likely (3 traces), E second (1 trace)
510        let predictions = model.predict_from_names(&["B"], &config);
511        assert!(!predictions.is_empty());
512        assert_eq!(predictions[0].activity, "C");
513    }
514
515    #[test]
516    fn test_second_order_prediction() {
517        let log = create_test_log();
518        let config = PredictionConfig {
519            model_type: PredictionModelType::Markov2,
520            top_k: 3,
521            min_probability: 0.0,
522            laplace_smoothing: false,
523            ..Default::default()
524        };
525        let model = PredictionModel::train(&log, &config);
526
527        // After A, B -> C should be predicted (using 2nd order)
528        let predictions = model.predict_from_names(&["A", "B"], &config);
529        assert!(!predictions.is_empty());
530        // C appears after A,B in 3 traces, E in 1 trace
531        assert_eq!(predictions[0].activity, "C");
532    }
533
534    #[test]
535    fn test_batch_prediction() {
536        let log = create_test_log();
537        let config = PredictionConfig::default();
538        let model = PredictionModel::train(&log, &config);
539
540        let traces = vec![
541            TraceHistory {
542                case_id: "test1".to_string(),
543                activities: vec!["A".to_string()],
544            },
545            TraceHistory {
546                case_id: "test2".to_string(),
547                activities: vec!["A".to_string(), "B".to_string()],
548            },
549        ];
550
551        let results = NextActivityPrediction::predict_batch(&traces, &model, &config);
552        assert_eq!(results.len(), 2);
553        assert_eq!(results[0].case_id, "test1");
554        assert_eq!(results[1].case_id, "test2");
555    }
556
557    #[test]
558    fn test_laplace_smoothing() {
559        let log = create_test_log();
560        let config_no_smooth = PredictionConfig {
561            laplace_smoothing: false,
562            top_k: 10,
563            min_probability: 0.0,
564            ..Default::default()
565        };
566        let config_smooth = PredictionConfig {
567            laplace_smoothing: true,
568            top_k: 10,
569            min_probability: 0.0,
570            ..Default::default()
571        };
572        let model = PredictionModel::train(&log, &config_no_smooth);
573
574        // Without smoothing, unseen transition should have 0 probability
575        let pred_no_smooth = model.predict_from_names(&["D"], &config_no_smooth);
576        // D is end activity, so no transitions from it without smoothing
577        let _max_prob = pred_no_smooth.iter().map(|p| p.probability).sum::<f64>();
578
579        // With smoothing, should have non-zero probabilities
580        let pred_smooth = model.predict_from_names(&["D"], &config_smooth);
581        assert!(!pred_smooth.is_empty());
582        assert!(pred_smooth.iter().all(|p| p.probability > 0.0));
583    }
584
585    #[test]
586    fn test_start_end_activities() {
587        let log = create_test_log();
588        let config = PredictionConfig::default();
589        let model = PredictionModel::train(&log, &config);
590
591        // A should be start activity
592        assert!(model.start_activities.contains_key("A"));
593        assert_eq!(model.start_activities.get("A"), Some(&4));
594
595        // D should be end activity
596        assert!(model.end_activities.contains_key("D"));
597        assert_eq!(model.end_activities.get("D"), Some(&4));
598    }
599
600    #[test]
601    fn test_ngram_prediction() {
602        let log = create_test_log();
603        let config = PredictionConfig {
604            model_type: PredictionModelType::NGram,
605            n_gram_size: 3,
606            top_k: 3,
607            min_probability: 0.0,
608            laplace_smoothing: false,
609        };
610        let model = PredictionModel::train(&log, &config);
611
612        // With 3-gram: A, B -> C or E
613        let predictions = model.predict_from_names(&["A", "B"], &config);
614        assert!(!predictions.is_empty());
615    }
616
617    #[test]
618    fn test_empty_history() {
619        let log = create_test_log();
620        let config = PredictionConfig {
621            laplace_smoothing: true,
622            ..Default::default()
623        };
624        let model = PredictionModel::train(&log, &config);
625
626        // Empty history should return uniform or start distribution
627        let predictions = model.predict(&[], &config);
628        // With smoothing, should return something
629        assert!(!predictions.is_empty() || config.laplace_smoothing);
630    }
631
632    #[test]
633    fn test_compute_output() {
634        let log = create_test_log();
635        let config = PredictionConfig::default();
636        let model = PredictionModel::train(&log, &config);
637
638        let input = PredictionInput {
639            traces: vec![TraceHistory {
640                case_id: "test".to_string(),
641                activities: vec!["A".to_string(), "B".to_string()],
642            }],
643            model,
644            config,
645        };
646
647        let output = NextActivityPrediction::compute(&input);
648        assert_eq!(output.predictions.len(), 1);
649        assert!(output.compute_time_us < 1_000_000); // Should be fast
650    }
651}