1use crate::types::EventLog;
9use rustkernel_core::traits::GpuKernel;
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::Instant;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21pub enum PredictionModelType {
22 Markov1,
24 Markov2,
26 NGram,
28}
29
30impl Default for PredictionModelType {
31 fn default() -> Self {
32 Self::Markov1
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PredictionConfig {
39 pub model_type: PredictionModelType,
41 pub n_gram_size: usize,
43 pub top_k: usize,
45 pub min_probability: f64,
47 pub laplace_smoothing: bool,
49}
50
51impl Default for PredictionConfig {
52 fn default() -> Self {
53 Self {
54 model_type: PredictionModelType::Markov1,
55 n_gram_size: 3,
56 top_k: 5,
57 min_probability: 0.01,
58 laplace_smoothing: true,
59 }
60 }
61}
62
63pub type TransitionMatrix = HashMap<String, HashMap<String, u64>>;
66
67pub type HigherOrderTransitions = HashMap<Vec<String>, HashMap<String, u64>>;
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct PredictionModel {
74 pub model_type: PredictionModelType,
76 pub transitions: TransitionMatrix,
78 pub higher_order: HigherOrderTransitions,
80 pub start_activities: HashMap<String, u64>,
82 pub end_activities: HashMap<String, u64>,
84 pub vocabulary: Vec<String>,
86 pub trace_count: u64,
88 pub event_count: u64,
90}
91
92impl Default for PredictionModel {
93 fn default() -> Self {
94 Self {
95 model_type: PredictionModelType::Markov1,
96 transitions: HashMap::new(),
97 higher_order: HashMap::new(),
98 start_activities: HashMap::new(),
99 end_activities: HashMap::new(),
100 vocabulary: Vec::new(),
101 trace_count: 0,
102 event_count: 0,
103 }
104 }
105}
106
107impl PredictionModel {
108 pub fn train(log: &EventLog, config: &PredictionConfig) -> Self {
110 let mut model = Self {
111 model_type: config.model_type,
112 ..Default::default()
113 };
114
115 let mut vocab_set = std::collections::HashSet::new();
116
117 for trace in log.traces.values() {
118 if trace.events.is_empty() {
119 continue;
120 }
121
122 model.trace_count += 1;
123 model.event_count += trace.events.len() as u64;
124
125 let activities: Vec<&str> = trace.events.iter().map(|e| e.activity.as_str()).collect();
126
127 if let Some(first) = activities.first() {
129 *model.start_activities.entry(first.to_string()).or_default() += 1;
130 }
131 if let Some(last) = activities.last() {
132 *model.end_activities.entry(last.to_string()).or_default() += 1;
133 }
134
135 for act in &activities {
137 vocab_set.insert(act.to_string());
138 }
139
140 for window in activities.windows(2) {
142 let from = window[0].to_string();
143 let to = window[1].to_string();
144 *model
145 .transitions
146 .entry(from)
147 .or_default()
148 .entry(to)
149 .or_default() += 1;
150 }
151
152 match config.model_type {
154 PredictionModelType::Markov2 => {
155 for window in activities.windows(3) {
156 let key = vec![window[0].to_string(), window[1].to_string()];
157 let next = window[2].to_string();
158 *model
159 .higher_order
160 .entry(key)
161 .or_default()
162 .entry(next)
163 .or_default() += 1;
164 }
165 }
166 PredictionModelType::NGram => {
167 let n = config.n_gram_size;
168 if activities.len() >= n {
169 for window in activities.windows(n) {
170 let key: Vec<String> =
171 window[..n - 1].iter().map(|s| s.to_string()).collect();
172 let next = window[n - 1].to_string();
173 *model
174 .higher_order
175 .entry(key)
176 .or_default()
177 .entry(next)
178 .or_default() += 1;
179 }
180 }
181 }
182 PredictionModelType::Markov1 => {}
183 }
184 }
185
186 model.vocabulary = vocab_set.into_iter().collect();
187 model.vocabulary.sort();
188
189 model
190 }
191
192 pub fn predict(
194 &self,
195 history: &[String],
196 config: &PredictionConfig,
197 ) -> Vec<ActivityPrediction> {
198 let vocab_size = self.vocabulary.len();
199 let smoothing = if config.laplace_smoothing { 1.0 } else { 0.0 };
200
201 let counts: Option<&HashMap<String, u64>> = match self.model_type {
203 PredictionModelType::Markov1 => {
204 history.last().and_then(|last| self.transitions.get(last))
205 }
206 PredictionModelType::Markov2 => {
207 if history.len() >= 2 {
208 let key = vec![
209 history[history.len() - 2].clone(),
210 history[history.len() - 1].clone(),
211 ];
212 self.higher_order.get(&key)
213 } else if history.len() == 1 {
214 self.transitions.get(&history[0])
216 } else {
217 None
218 }
219 }
220 PredictionModelType::NGram => {
221 let n = config.n_gram_size;
222 if history.len() >= n - 1 {
223 let key: Vec<String> = history[history.len() - (n - 1)..].to_vec();
224 self.higher_order.get(&key)
225 } else if history.len() >= 1 {
226 self.transitions.get(&history[history.len() - 1])
228 } else {
229 None
230 }
231 }
232 };
233
234 let mut predictions: Vec<ActivityPrediction> = if let Some(counts) = counts {
236 let total: u64 = counts.values().sum();
237 let total_with_smoothing = total as f64 + smoothing * vocab_size as f64;
238
239 self.vocabulary
240 .iter()
241 .map(|activity| {
242 let count = counts.get(activity).copied().unwrap_or(0);
243 let prob = (count as f64 + smoothing) / total_with_smoothing;
244 ActivityPrediction {
245 activity: activity.clone(),
246 probability: prob,
247 confidence: if total > 10 { prob } else { prob * 0.5 },
248 is_end: self.end_activities.contains_key(activity),
249 }
250 })
251 .filter(|p| p.probability >= config.min_probability)
252 .collect()
253 } else if config.laplace_smoothing && !self.vocabulary.is_empty() {
254 let prob = 1.0 / vocab_size as f64;
256 self.vocabulary
257 .iter()
258 .map(|activity| ActivityPrediction {
259 activity: activity.clone(),
260 probability: prob,
261 confidence: 0.1, is_end: self.end_activities.contains_key(activity),
263 })
264 .collect()
265 } else {
266 Vec::new()
267 };
268
269 predictions.sort_by(|a, b| {
271 b.probability
272 .partial_cmp(&a.probability)
273 .unwrap_or(std::cmp::Ordering::Equal)
274 });
275 predictions.truncate(config.top_k);
276
277 predictions
278 }
279
280 pub fn predict_from_names(
282 &self,
283 history: &[&str],
284 config: &PredictionConfig,
285 ) -> Vec<ActivityPrediction> {
286 let history: Vec<String> = history.iter().map(|s| s.to_string()).collect();
287 self.predict(&history, config)
288 }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ActivityPrediction {
294 pub activity: String,
296 pub probability: f64,
298 pub confidence: f64,
300 pub is_end: bool,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct PredictionInput {
307 pub traces: Vec<TraceHistory>,
309 pub model: PredictionModel,
311 pub config: PredictionConfig,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct TraceHistory {
318 pub case_id: String,
320 pub activities: Vec<String>,
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct PredictionOutput {
327 pub predictions: Vec<TracePrediction>,
329 pub compute_time_us: u64,
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize)]
335pub struct TracePrediction {
336 pub case_id: String,
338 pub predictions: Vec<ActivityPrediction>,
340 pub expected_remaining: Option<f64>,
342}
343
344#[derive(Debug, Clone)]
349pub struct NextActivityPrediction {
350 metadata: KernelMetadata,
351}
352
353impl Default for NextActivityPrediction {
354 fn default() -> Self {
355 Self::new()
356 }
357}
358
359impl NextActivityPrediction {
360 #[must_use]
362 pub fn new() -> Self {
363 Self {
364 metadata: KernelMetadata::batch("procint/next-activity", Domain::ProcessIntelligence)
365 .with_description("Markov/N-gram next activity prediction")
366 .with_throughput(100_000)
367 .with_latency_us(50.0),
368 }
369 }
370
371 pub fn train(log: &EventLog, config: &PredictionConfig) -> PredictionModel {
373 PredictionModel::train(log, config)
374 }
375
376 pub fn predict_batch(
378 traces: &[TraceHistory],
379 model: &PredictionModel,
380 config: &PredictionConfig,
381 ) -> Vec<TracePrediction> {
382 traces
383 .iter()
384 .map(|trace| {
385 let predictions = model.predict(&trace.activities, config);
386 TracePrediction {
387 case_id: trace.case_id.clone(),
388 predictions,
389 expected_remaining: None,
390 }
391 })
392 .collect()
393 }
394
395 pub fn compute(input: &PredictionInput) -> PredictionOutput {
397 let start = Instant::now();
398 let predictions = Self::predict_batch(&input.traces, &input.model, &input.config);
399 PredictionOutput {
400 predictions,
401 compute_time_us: start.elapsed().as_micros() as u64,
402 }
403 }
404}
405
406impl GpuKernel for NextActivityPrediction {
407 fn metadata(&self) -> &KernelMetadata {
408 &self.metadata
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::types::ProcessEvent;
416
417 fn create_test_log() -> EventLog {
418 let mut log = EventLog::new("test".to_string());
419
420 for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
422 log.add_event(ProcessEvent {
423 id: i as u64,
424 case_id: "trace1".to_string(),
425 activity: activity.to_string(),
426 timestamp: i as u64 * 100,
427 resource: None,
428 attributes: HashMap::new(),
429 });
430 }
431
432 for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
434 log.add_event(ProcessEvent {
435 id: (10 + i) as u64,
436 case_id: "trace2".to_string(),
437 activity: activity.to_string(),
438 timestamp: i as u64 * 100,
439 resource: None,
440 attributes: HashMap::new(),
441 });
442 }
443
444 for (i, activity) in ["A", "B", "E", "D"].iter().enumerate() {
446 log.add_event(ProcessEvent {
447 id: (20 + i) as u64,
448 case_id: "trace3".to_string(),
449 activity: activity.to_string(),
450 timestamp: i as u64 * 100,
451 resource: None,
452 attributes: HashMap::new(),
453 });
454 }
455
456 for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
458 log.add_event(ProcessEvent {
459 id: (30 + i) as u64,
460 case_id: "trace4".to_string(),
461 activity: activity.to_string(),
462 timestamp: i as u64 * 100,
463 resource: None,
464 attributes: HashMap::new(),
465 });
466 }
467
468 log
469 }
470
471 #[test]
472 fn test_next_activity_prediction_metadata() {
473 let kernel = NextActivityPrediction::new();
474 assert_eq!(kernel.metadata().id, "procint/next-activity");
475 assert_eq!(kernel.metadata().domain, Domain::ProcessIntelligence);
476 }
477
478 #[test]
479 fn test_model_training() {
480 let log = create_test_log();
481 let config = PredictionConfig::default();
482 let model = PredictionModel::train(&log, &config);
483
484 assert_eq!(model.trace_count, 4);
485 assert!(model.vocabulary.contains(&"A".to_string()));
486 assert!(model.vocabulary.contains(&"B".to_string()));
487 assert!(model.vocabulary.contains(&"C".to_string()));
488 assert!(model.vocabulary.contains(&"D".to_string()));
489 assert!(model.vocabulary.contains(&"E".to_string()));
490
491 assert!(model.transitions.contains_key("A"));
493 assert!(model.transitions.contains_key("B"));
494 }
495
496 #[test]
497 fn test_first_order_prediction() {
498 let log = create_test_log();
499 let config = PredictionConfig {
500 model_type: PredictionModelType::Markov1,
501 top_k: 3,
502 min_probability: 0.0,
503 laplace_smoothing: false,
504 ..Default::default()
505 };
506 let model = PredictionModel::train(&log, &config);
507
508 let predictions = model.predict_from_names(&["A"], &config);
510 assert!(!predictions.is_empty());
511 assert_eq!(predictions[0].activity, "B");
512 assert!(predictions[0].probability > 0.9);
513
514 let predictions = model.predict_from_names(&["B"], &config);
516 assert!(!predictions.is_empty());
517 assert_eq!(predictions[0].activity, "C");
518 }
519
520 #[test]
521 fn test_second_order_prediction() {
522 let log = create_test_log();
523 let config = PredictionConfig {
524 model_type: PredictionModelType::Markov2,
525 top_k: 3,
526 min_probability: 0.0,
527 laplace_smoothing: false,
528 ..Default::default()
529 };
530 let model = PredictionModel::train(&log, &config);
531
532 let predictions = model.predict_from_names(&["A", "B"], &config);
534 assert!(!predictions.is_empty());
535 assert_eq!(predictions[0].activity, "C");
537 }
538
539 #[test]
540 fn test_batch_prediction() {
541 let log = create_test_log();
542 let config = PredictionConfig::default();
543 let model = PredictionModel::train(&log, &config);
544
545 let traces = vec![
546 TraceHistory {
547 case_id: "test1".to_string(),
548 activities: vec!["A".to_string()],
549 },
550 TraceHistory {
551 case_id: "test2".to_string(),
552 activities: vec!["A".to_string(), "B".to_string()],
553 },
554 ];
555
556 let results = NextActivityPrediction::predict_batch(&traces, &model, &config);
557 assert_eq!(results.len(), 2);
558 assert_eq!(results[0].case_id, "test1");
559 assert_eq!(results[1].case_id, "test2");
560 }
561
562 #[test]
563 fn test_laplace_smoothing() {
564 let log = create_test_log();
565 let config_no_smooth = PredictionConfig {
566 laplace_smoothing: false,
567 top_k: 10,
568 min_probability: 0.0,
569 ..Default::default()
570 };
571 let config_smooth = PredictionConfig {
572 laplace_smoothing: true,
573 top_k: 10,
574 min_probability: 0.0,
575 ..Default::default()
576 };
577 let model = PredictionModel::train(&log, &config_no_smooth);
578
579 let pred_no_smooth = model.predict_from_names(&["D"], &config_no_smooth);
581 let _max_prob = pred_no_smooth.iter().map(|p| p.probability).sum::<f64>();
583
584 let pred_smooth = model.predict_from_names(&["D"], &config_smooth);
586 assert!(!pred_smooth.is_empty());
587 assert!(pred_smooth.iter().all(|p| p.probability > 0.0));
588 }
589
590 #[test]
591 fn test_start_end_activities() {
592 let log = create_test_log();
593 let config = PredictionConfig::default();
594 let model = PredictionModel::train(&log, &config);
595
596 assert!(model.start_activities.contains_key("A"));
598 assert_eq!(model.start_activities.get("A"), Some(&4));
599
600 assert!(model.end_activities.contains_key("D"));
602 assert_eq!(model.end_activities.get("D"), Some(&4));
603 }
604
605 #[test]
606 fn test_ngram_prediction() {
607 let log = create_test_log();
608 let config = PredictionConfig {
609 model_type: PredictionModelType::NGram,
610 n_gram_size: 3,
611 top_k: 3,
612 min_probability: 0.0,
613 laplace_smoothing: false,
614 ..Default::default()
615 };
616 let model = PredictionModel::train(&log, &config);
617
618 let predictions = model.predict_from_names(&["A", "B"], &config);
620 assert!(!predictions.is_empty());
621 }
622
623 #[test]
624 fn test_empty_history() {
625 let log = create_test_log();
626 let config = PredictionConfig {
627 laplace_smoothing: true,
628 ..Default::default()
629 };
630 let model = PredictionModel::train(&log, &config);
631
632 let predictions = model.predict(&[], &config);
634 assert!(!predictions.is_empty() || config.laplace_smoothing);
636 }
637
638 #[test]
639 fn test_compute_output() {
640 let log = create_test_log();
641 let config = PredictionConfig::default();
642 let model = PredictionModel::train(&log, &config);
643
644 let input = PredictionInput {
645 traces: vec![TraceHistory {
646 case_id: "test".to_string(),
647 activities: vec!["A".to_string(), "B".to_string()],
648 }],
649 model,
650 config,
651 };
652
653 let output = NextActivityPrediction::compute(&input);
654 assert_eq!(output.predictions.len(), 1);
655 assert!(output.compute_time_us < 1_000_000); }
657}