1use crate::types::EventLog;
9use rustkernel_core::traits::GpuKernel;
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::Instant;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
21pub enum PredictionModelType {
22 #[default]
24 Markov1,
25 Markov2,
27 NGram,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PredictionConfig {
34 pub model_type: PredictionModelType,
36 pub n_gram_size: usize,
38 pub top_k: usize,
40 pub min_probability: f64,
42 pub laplace_smoothing: bool,
44}
45
46impl Default for PredictionConfig {
47 fn default() -> Self {
48 Self {
49 model_type: PredictionModelType::Markov1,
50 n_gram_size: 3,
51 top_k: 5,
52 min_probability: 0.01,
53 laplace_smoothing: true,
54 }
55 }
56}
57
58pub type TransitionMatrix = HashMap<String, HashMap<String, u64>>;
61
62pub type HigherOrderTransitions = HashMap<Vec<String>, HashMap<String, u64>>;
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct PredictionModel {
69 pub model_type: PredictionModelType,
71 pub transitions: TransitionMatrix,
73 pub higher_order: HigherOrderTransitions,
75 pub start_activities: HashMap<String, u64>,
77 pub end_activities: HashMap<String, u64>,
79 pub vocabulary: Vec<String>,
81 pub trace_count: u64,
83 pub event_count: u64,
85}
86
87impl Default for PredictionModel {
88 fn default() -> Self {
89 Self {
90 model_type: PredictionModelType::Markov1,
91 transitions: HashMap::new(),
92 higher_order: HashMap::new(),
93 start_activities: HashMap::new(),
94 end_activities: HashMap::new(),
95 vocabulary: Vec::new(),
96 trace_count: 0,
97 event_count: 0,
98 }
99 }
100}
101
102impl PredictionModel {
103 pub fn train(log: &EventLog, config: &PredictionConfig) -> Self {
105 let mut model = Self {
106 model_type: config.model_type,
107 ..Default::default()
108 };
109
110 let mut vocab_set = std::collections::HashSet::new();
111
112 for trace in log.traces.values() {
113 if trace.events.is_empty() {
114 continue;
115 }
116
117 model.trace_count += 1;
118 model.event_count += trace.events.len() as u64;
119
120 let activities: Vec<&str> = trace.events.iter().map(|e| e.activity.as_str()).collect();
121
122 if let Some(first) = activities.first() {
124 *model.start_activities.entry(first.to_string()).or_default() += 1;
125 }
126 if let Some(last) = activities.last() {
127 *model.end_activities.entry(last.to_string()).or_default() += 1;
128 }
129
130 for act in &activities {
132 vocab_set.insert(act.to_string());
133 }
134
135 for window in activities.windows(2) {
137 let from = window[0].to_string();
138 let to = window[1].to_string();
139 *model
140 .transitions
141 .entry(from)
142 .or_default()
143 .entry(to)
144 .or_default() += 1;
145 }
146
147 match config.model_type {
149 PredictionModelType::Markov2 => {
150 for window in activities.windows(3) {
151 let key = vec![window[0].to_string(), window[1].to_string()];
152 let next = window[2].to_string();
153 *model
154 .higher_order
155 .entry(key)
156 .or_default()
157 .entry(next)
158 .or_default() += 1;
159 }
160 }
161 PredictionModelType::NGram => {
162 let n = config.n_gram_size;
163 if activities.len() >= n {
164 for window in activities.windows(n) {
165 let key: Vec<String> =
166 window[..n - 1].iter().map(|s| s.to_string()).collect();
167 let next = window[n - 1].to_string();
168 *model
169 .higher_order
170 .entry(key)
171 .or_default()
172 .entry(next)
173 .or_default() += 1;
174 }
175 }
176 }
177 PredictionModelType::Markov1 => {}
178 }
179 }
180
181 model.vocabulary = vocab_set.into_iter().collect();
182 model.vocabulary.sort();
183
184 model
185 }
186
187 pub fn predict(
189 &self,
190 history: &[String],
191 config: &PredictionConfig,
192 ) -> Vec<ActivityPrediction> {
193 let vocab_size = self.vocabulary.len();
194 let smoothing = if config.laplace_smoothing { 1.0 } else { 0.0 };
195
196 let counts: Option<&HashMap<String, u64>> = match self.model_type {
198 PredictionModelType::Markov1 => {
199 history.last().and_then(|last| self.transitions.get(last))
200 }
201 PredictionModelType::Markov2 => {
202 if history.len() >= 2 {
203 let key = vec![
204 history[history.len() - 2].clone(),
205 history[history.len() - 1].clone(),
206 ];
207 self.higher_order.get(&key)
208 } else if history.len() == 1 {
209 self.transitions.get(&history[0])
211 } else {
212 None
213 }
214 }
215 PredictionModelType::NGram => {
216 let n = config.n_gram_size;
217 if history.len() >= n - 1 {
218 let key: Vec<String> = history[history.len() - (n - 1)..].to_vec();
219 self.higher_order.get(&key)
220 } else if !history.is_empty() {
221 self.transitions.get(&history[history.len() - 1])
223 } else {
224 None
225 }
226 }
227 };
228
229 let mut predictions: Vec<ActivityPrediction> = if let Some(counts) = counts {
231 let total: u64 = counts.values().sum();
232 let total_with_smoothing = total as f64 + smoothing * vocab_size as f64;
233
234 self.vocabulary
235 .iter()
236 .map(|activity| {
237 let count = counts.get(activity).copied().unwrap_or(0);
238 let prob = (count as f64 + smoothing) / total_with_smoothing;
239 ActivityPrediction {
240 activity: activity.clone(),
241 probability: prob,
242 confidence: if total > 10 { prob } else { prob * 0.5 },
243 is_end: self.end_activities.contains_key(activity),
244 }
245 })
246 .filter(|p| p.probability >= config.min_probability)
247 .collect()
248 } else if config.laplace_smoothing && !self.vocabulary.is_empty() {
249 let prob = 1.0 / vocab_size as f64;
251 self.vocabulary
252 .iter()
253 .map(|activity| ActivityPrediction {
254 activity: activity.clone(),
255 probability: prob,
256 confidence: 0.1, is_end: self.end_activities.contains_key(activity),
258 })
259 .collect()
260 } else {
261 Vec::new()
262 };
263
264 predictions.sort_by(|a, b| {
266 b.probability
267 .partial_cmp(&a.probability)
268 .unwrap_or(std::cmp::Ordering::Equal)
269 });
270 predictions.truncate(config.top_k);
271
272 predictions
273 }
274
275 pub fn predict_from_names(
277 &self,
278 history: &[&str],
279 config: &PredictionConfig,
280 ) -> Vec<ActivityPrediction> {
281 let history: Vec<String> = history.iter().map(|s| s.to_string()).collect();
282 self.predict(&history, config)
283 }
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct ActivityPrediction {
289 pub activity: String,
291 pub probability: f64,
293 pub confidence: f64,
295 pub is_end: bool,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct PredictionInput {
302 pub traces: Vec<TraceHistory>,
304 pub model: PredictionModel,
306 pub config: PredictionConfig,
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct TraceHistory {
313 pub case_id: String,
315 pub activities: Vec<String>,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct PredictionOutput {
322 pub predictions: Vec<TracePrediction>,
324 pub compute_time_us: u64,
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct TracePrediction {
331 pub case_id: String,
333 pub predictions: Vec<ActivityPrediction>,
335 pub expected_remaining: Option<f64>,
337}
338
339#[derive(Debug, Clone)]
344pub struct NextActivityPrediction {
345 metadata: KernelMetadata,
346}
347
348impl Default for NextActivityPrediction {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354impl NextActivityPrediction {
355 #[must_use]
357 pub fn new() -> Self {
358 Self {
359 metadata: KernelMetadata::batch("procint/next-activity", Domain::ProcessIntelligence)
360 .with_description("Markov/N-gram next activity prediction")
361 .with_throughput(100_000)
362 .with_latency_us(50.0),
363 }
364 }
365
366 pub fn train(log: &EventLog, config: &PredictionConfig) -> PredictionModel {
368 PredictionModel::train(log, config)
369 }
370
371 pub fn predict_batch(
373 traces: &[TraceHistory],
374 model: &PredictionModel,
375 config: &PredictionConfig,
376 ) -> Vec<TracePrediction> {
377 traces
378 .iter()
379 .map(|trace| {
380 let predictions = model.predict(&trace.activities, config);
381 TracePrediction {
382 case_id: trace.case_id.clone(),
383 predictions,
384 expected_remaining: None,
385 }
386 })
387 .collect()
388 }
389
390 pub fn compute(input: &PredictionInput) -> PredictionOutput {
392 let start = Instant::now();
393 let predictions = Self::predict_batch(&input.traces, &input.model, &input.config);
394 PredictionOutput {
395 predictions,
396 compute_time_us: start.elapsed().as_micros() as u64,
397 }
398 }
399}
400
401impl GpuKernel for NextActivityPrediction {
402 fn metadata(&self) -> &KernelMetadata {
403 &self.metadata
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::types::ProcessEvent;
411
412 fn create_test_log() -> EventLog {
413 let mut log = EventLog::new("test".to_string());
414
415 for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
417 log.add_event(ProcessEvent {
418 id: i as u64,
419 case_id: "trace1".to_string(),
420 activity: activity.to_string(),
421 timestamp: i as u64 * 100,
422 resource: None,
423 attributes: HashMap::new(),
424 });
425 }
426
427 for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
429 log.add_event(ProcessEvent {
430 id: (10 + i) as u64,
431 case_id: "trace2".to_string(),
432 activity: activity.to_string(),
433 timestamp: i as u64 * 100,
434 resource: None,
435 attributes: HashMap::new(),
436 });
437 }
438
439 for (i, activity) in ["A", "B", "E", "D"].iter().enumerate() {
441 log.add_event(ProcessEvent {
442 id: (20 + i) as u64,
443 case_id: "trace3".to_string(),
444 activity: activity.to_string(),
445 timestamp: i as u64 * 100,
446 resource: None,
447 attributes: HashMap::new(),
448 });
449 }
450
451 for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
453 log.add_event(ProcessEvent {
454 id: (30 + i) as u64,
455 case_id: "trace4".to_string(),
456 activity: activity.to_string(),
457 timestamp: i as u64 * 100,
458 resource: None,
459 attributes: HashMap::new(),
460 });
461 }
462
463 log
464 }
465
466 #[test]
467 fn test_next_activity_prediction_metadata() {
468 let kernel = NextActivityPrediction::new();
469 assert_eq!(kernel.metadata().id, "procint/next-activity");
470 assert_eq!(kernel.metadata().domain, Domain::ProcessIntelligence);
471 }
472
473 #[test]
474 fn test_model_training() {
475 let log = create_test_log();
476 let config = PredictionConfig::default();
477 let model = PredictionModel::train(&log, &config);
478
479 assert_eq!(model.trace_count, 4);
480 assert!(model.vocabulary.contains(&"A".to_string()));
481 assert!(model.vocabulary.contains(&"B".to_string()));
482 assert!(model.vocabulary.contains(&"C".to_string()));
483 assert!(model.vocabulary.contains(&"D".to_string()));
484 assert!(model.vocabulary.contains(&"E".to_string()));
485
486 assert!(model.transitions.contains_key("A"));
488 assert!(model.transitions.contains_key("B"));
489 }
490
491 #[test]
492 fn test_first_order_prediction() {
493 let log = create_test_log();
494 let config = PredictionConfig {
495 model_type: PredictionModelType::Markov1,
496 top_k: 3,
497 min_probability: 0.0,
498 laplace_smoothing: false,
499 ..Default::default()
500 };
501 let model = PredictionModel::train(&log, &config);
502
503 let predictions = model.predict_from_names(&["A"], &config);
505 assert!(!predictions.is_empty());
506 assert_eq!(predictions[0].activity, "B");
507 assert!(predictions[0].probability > 0.9);
508
509 let predictions = model.predict_from_names(&["B"], &config);
511 assert!(!predictions.is_empty());
512 assert_eq!(predictions[0].activity, "C");
513 }
514
515 #[test]
516 fn test_second_order_prediction() {
517 let log = create_test_log();
518 let config = PredictionConfig {
519 model_type: PredictionModelType::Markov2,
520 top_k: 3,
521 min_probability: 0.0,
522 laplace_smoothing: false,
523 ..Default::default()
524 };
525 let model = PredictionModel::train(&log, &config);
526
527 let predictions = model.predict_from_names(&["A", "B"], &config);
529 assert!(!predictions.is_empty());
530 assert_eq!(predictions[0].activity, "C");
532 }
533
534 #[test]
535 fn test_batch_prediction() {
536 let log = create_test_log();
537 let config = PredictionConfig::default();
538 let model = PredictionModel::train(&log, &config);
539
540 let traces = vec![
541 TraceHistory {
542 case_id: "test1".to_string(),
543 activities: vec!["A".to_string()],
544 },
545 TraceHistory {
546 case_id: "test2".to_string(),
547 activities: vec!["A".to_string(), "B".to_string()],
548 },
549 ];
550
551 let results = NextActivityPrediction::predict_batch(&traces, &model, &config);
552 assert_eq!(results.len(), 2);
553 assert_eq!(results[0].case_id, "test1");
554 assert_eq!(results[1].case_id, "test2");
555 }
556
557 #[test]
558 fn test_laplace_smoothing() {
559 let log = create_test_log();
560 let config_no_smooth = PredictionConfig {
561 laplace_smoothing: false,
562 top_k: 10,
563 min_probability: 0.0,
564 ..Default::default()
565 };
566 let config_smooth = PredictionConfig {
567 laplace_smoothing: true,
568 top_k: 10,
569 min_probability: 0.0,
570 ..Default::default()
571 };
572 let model = PredictionModel::train(&log, &config_no_smooth);
573
574 let pred_no_smooth = model.predict_from_names(&["D"], &config_no_smooth);
576 let _max_prob = pred_no_smooth.iter().map(|p| p.probability).sum::<f64>();
578
579 let pred_smooth = model.predict_from_names(&["D"], &config_smooth);
581 assert!(!pred_smooth.is_empty());
582 assert!(pred_smooth.iter().all(|p| p.probability > 0.0));
583 }
584
585 #[test]
586 fn test_start_end_activities() {
587 let log = create_test_log();
588 let config = PredictionConfig::default();
589 let model = PredictionModel::train(&log, &config);
590
591 assert!(model.start_activities.contains_key("A"));
593 assert_eq!(model.start_activities.get("A"), Some(&4));
594
595 assert!(model.end_activities.contains_key("D"));
597 assert_eq!(model.end_activities.get("D"), Some(&4));
598 }
599
600 #[test]
601 fn test_ngram_prediction() {
602 let log = create_test_log();
603 let config = PredictionConfig {
604 model_type: PredictionModelType::NGram,
605 n_gram_size: 3,
606 top_k: 3,
607 min_probability: 0.0,
608 laplace_smoothing: false,
609 };
610 let model = PredictionModel::train(&log, &config);
611
612 let predictions = model.predict_from_names(&["A", "B"], &config);
614 assert!(!predictions.is_empty());
615 }
616
617 #[test]
618 fn test_empty_history() {
619 let log = create_test_log();
620 let config = PredictionConfig {
621 laplace_smoothing: true,
622 ..Default::default()
623 };
624 let model = PredictionModel::train(&log, &config);
625
626 let predictions = model.predict(&[], &config);
628 assert!(!predictions.is_empty() || config.laplace_smoothing);
630 }
631
632 #[test]
633 fn test_compute_output() {
634 let log = create_test_log();
635 let config = PredictionConfig::default();
636 let model = PredictionModel::train(&log, &config);
637
638 let input = PredictionInput {
639 traces: vec![TraceHistory {
640 case_id: "test".to_string(),
641 activities: vec!["A".to_string(), "B".to_string()],
642 }],
643 model,
644 config,
645 };
646
647 let output = NextActivityPrediction::compute(&input);
648 assert_eq!(output.predictions.len(), 1);
649 assert!(output.compute_time_us < 1_000_000); }
651}