rosu_pattern_detector/mania/
detector.rs

1use crate::mania::models::base::ManiaMeasure;
2use crate::mania::models::pattern::Pattern;
3use std::collections::HashMap;
4
5// needed for the analyzer in order to do calc on it
6// i32 is the timestamp of the start of the measure
7// ManiaMeasure is the measure of the hit object
8pub struct HitObjects(pub HashMap<i32, ManiaMeasure>);
9
10pub struct PatternsValues(pub HashMap<Pattern, f64>);
11
12impl std::fmt::Display for PatternsValues {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        write!(f, "PatternsValues {{")?;
15        for (pattern, value) in self.0.iter() {
16            write!(f, "{}: {}, ", pattern, value)?;
17        }
18        write!(f, "}}")
19    }
20}
21impl PatternsValues {
22    fn add_pattern(&mut self, pattern: Pattern, value: f64) {
23        *self.0.entry(pattern).or_insert(0.0) += value;
24    }
25    
26    pub fn ordered_print(&self) {
27        let mut sorted: Vec<(Pattern, f64)> = self.0.iter()
28            .map(|(pattern, &value)| (pattern.clone(), value))
29            .collect();
30        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
31        
32        println!("Patterns (sorted by value):");
33        for (pattern, value) in sorted {
34            println!("  {:?}: {:.2}", pattern, value);
35        }
36    }
37}
38
39impl HitObjects {
40    pub fn get_npm(&self) -> f64 {
41        self.0.values().map(|measure| measure.measure.npm as f64).sum::<f64>() / self.0.len() as f64
42    }
43
44    pub fn get_patterns_values(&self) -> PatternsValues {
45        let mut patterns = PatternsValues(HashMap::new());
46        for (_, measure) in self.0.iter() {
47            patterns.add_pattern(measure.pattern.clone(), measure.value);        
48        }
49        patterns
50    }
51}
52
53
54pub(crate) fn analyze_patterns(hit_objects: &mut HitObjects)
55{
56    let average_npm = hit_objects.get_npm();
57
58    for (_, measure) in hit_objects.0.iter_mut() {
59        measure.pattern = measure.detect_pattern();
60        measure.value = measure.get_pattern_weight_modifier(average_npm);
61    }
62}