rosu_pattern_detector/mania/models/
pattern.rs

1use crate::mania::models::base::{Notes, ManiaMeasure};
2use std::collections::HashMap;
3
4#[derive(Debug, PartialEq, Clone,Hash,Eq)]
5pub enum Pattern
6{
7    Jack(JackPattern),
8    Handstream(HandstreamPattern),
9    Jumpstream(JumpstreamPattern),
10    Singlestream(SinglestreamPattern),
11    None,
12} 
13
14impl Pattern {
15    pub fn to_all(&self) -> Pattern {
16        match self {
17            Pattern::Jack(_) => Pattern::Jack(JackPattern::All),
18            Pattern::Handstream(_) => Pattern::Handstream(HandstreamPattern::All),
19            Pattern::Jumpstream(_) => Pattern::Jumpstream(JumpstreamPattern::All),
20            Pattern::Singlestream(_) => Pattern::Singlestream(SinglestreamPattern::All),
21            Pattern::None => Pattern::None,
22        }
23    }
24}
25
26impl std::fmt::Display for Pattern {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        let pattern = match self {
29            Pattern::Jack(pattern) => pattern.to_string(),
30            Pattern::Handstream(pattern) => pattern.to_string(),
31            Pattern::Jumpstream(pattern) => pattern.to_string(),
32            Pattern::Singlestream(_) => "SingleStream".to_string(),
33            Pattern::None => return write!(f, "None"),
34        };
35        write!(f, "{}", pattern)
36    }
37}
38
39#[derive(Debug, PartialEq, Clone, Hash, Eq)]
40pub enum JackPattern {
41    Chordjack,
42    DenseChordjack,
43    ChordStream,
44    Speedjack,
45    All,
46}
47impl JackPattern {
48    pub fn determine_jack_type(measure: &mut ManiaMeasure) -> JackPattern {
49        let mut pattern_count: HashMap<Notes, usize> = HashMap::new();
50    
51        for note in measure.notes.iter() {
52            *pattern_count.entry(note.get_pattern()).or_insert(0) += 1;
53        }
54    
55        let single = *pattern_count.get(&Notes::Single).unwrap_or(&0);
56        let jump = *pattern_count.get(&Notes::Jump).unwrap_or(&0);
57        let hand = *pattern_count.get(&Notes::Hand).unwrap_or(&0);
58        let quad = *pattern_count.get(&Notes::Quad).unwrap_or(&0);
59    
60        if hand > jump + single {
61            JackPattern::DenseChordjack
62        } 
63        else if quad > 0 && jump + hand > single 
64        {
65            JackPattern::Chordjack
66        } 
67        else
68        {
69            JackPattern::determine_jackspeed_or_chordstream(measure)
70        }
71    }    
72
73    fn determine_jackspeed_or_chordstream(measure: &mut ManiaMeasure) -> JackPattern {
74        let mut jack_count = 0;
75        for (i, note) in measure.notes.iter().enumerate() {
76            if i > 0 {
77                let prev = &measure.notes[i - 1];
78                if note.notes.iter().zip(prev.notes.iter()).any(|(n, p)| *n && *p) {
79                    jack_count += 1;
80                }
81            }
82        }
83        if jack_count <= 1 && measure.notes.len() > 6 {
84            JackPattern::ChordStream
85        } else {
86            JackPattern::Speedjack
87        }
88    }
89
90    pub fn to_string(&self) -> String {
91        match self {
92            JackPattern::Chordjack => "Chordjack".to_string(),
93            JackPattern::DenseChordjack => "DenseChordjack".to_string(),
94            JackPattern::ChordStream => "ChordStream".to_string(),
95            JackPattern::Speedjack => "Speedjack".to_string(),
96            JackPattern::All => "All".to_string(),
97        }
98    }
99}
100
101
102#[derive(Debug, PartialEq, Clone, Hash, Eq)]
103pub enum JumpstreamPattern {
104    LightJs,
105    AnchorJs,
106    JS,
107    JT,
108    All,
109}
110impl JumpstreamPattern {
111    fn has_two_consecutive_jumps(measure: &ManiaMeasure) -> bool {
112        let mut last_was_jump = false;
113    
114        for note in measure.notes.iter() {
115            let is_jump = matches!(note.get_pattern(), Notes::Jump);
116    
117            if is_jump && last_was_jump {
118                return true;
119            }
120    
121            last_was_jump = is_jump;
122        }
123    
124        false
125    }
126
127    pub fn determine_js_type(measure: &mut ManiaMeasure) -> JumpstreamPattern {
128        // Compte les différents types de patterns
129        if JumpstreamPattern::has_two_consecutive_jumps(measure) {
130            return JumpstreamPattern::JT;
131        }
132        
133        let mut pattern_count: HashMap<Notes, usize> = HashMap::new();
134        for note in measure.notes.iter() {
135            *pattern_count.entry(note.get_pattern()).or_insert(0) += 1;
136        }
137        let single = *pattern_count.get(&Notes::Single).unwrap_or(&0);
138        let jump = *pattern_count.get(&Notes::Jump).unwrap_or(&0);
139    
140        // Crée un vecteur pour compter les notes actives par colonne
141        let mut vect_int = vec![0; measure.notes[0].notes.len()];
142    
143        // Compte combien de fois chaque colonne est utilisée
144        for note in measure.notes.iter() {
145            for (i, &is_active) in note.notes.iter().enumerate() {
146                if is_active {
147                    vect_int[i] += 1;
148                }
149            }
150        }
151    
152        // Détermine le type de jumpstream basé sur les statistiques
153        if let Some(&max_value) = vect_int.iter().max() {
154            if max_value > 3 {
155                JumpstreamPattern::AnchorJs
156            } else if jump < single {
157                JumpstreamPattern::LightJs
158            } else {
159                JumpstreamPattern::JS
160            }
161        } else {
162            JumpstreamPattern::JS
163        }
164    }
165
166    pub fn to_string(&self) -> String {
167        match self {
168            JumpstreamPattern::LightJs => "LightJs".to_string(),
169            JumpstreamPattern::AnchorJs => "AnchorJs".to_string(),
170            JumpstreamPattern::JS => "JS".to_string(),
171            JumpstreamPattern::JT => "JT".to_string(),
172            JumpstreamPattern::All => "All".to_string(),
173        }
174    }
175}
176
177#[derive(Debug, PartialEq, Clone, Hash, Eq)]
178pub enum HandstreamPattern {
179    LightHs,
180    AnchorHs,
181    DenseHs,
182    HS,
183    All,
184}
185impl HandstreamPattern {
186    pub fn determine_hs_type(measure: &mut ManiaMeasure) -> HandstreamPattern {
187        let mut pattern_count: HashMap<Notes, usize> = HashMap::new();
188    
189        for note in measure.notes.iter() {
190            *pattern_count.entry(note.get_pattern()).or_insert(0) += 1;
191        }
192        let jump = *pattern_count.get(&Notes::Jump).unwrap_or(&0);
193    
194        if jump == 0 {
195            HandstreamPattern::LightHs
196        } else if jump > 0 {
197            HandstreamPattern::DenseHs
198        } else {
199            HandstreamPattern::HS
200        }
201    }
202
203    pub fn to_string(&self) -> String {
204        match self {
205            HandstreamPattern::LightHs => "LightHs".to_string(),
206            HandstreamPattern::AnchorHs => "AnchorHs".to_string(),
207            HandstreamPattern::DenseHs => "DenseHs".to_string(),
208            HandstreamPattern::HS => "HS".to_string(),
209            HandstreamPattern::All => "All".to_string(),
210        }
211    }
212}
213
214#[derive(Debug, PartialEq, Clone, Hash, Eq)]
215pub enum SinglestreamPattern {
216    Singlestream,
217    All,
218}
219// Used to calculate the weight of a pattern
220// Problem is vibro or jumptrill have way HIGHER NPM as a base so map could be detected wrongfully
221pub fn get_pattern_weight(pattern: &Pattern) -> f64 {
222    match pattern {
223        // Jack patterns
224        Pattern::Jack(jack_type) => match jack_type {
225            JackPattern::DenseChordjack => 0.8,
226            JackPattern::Speedjack => 0.9,
227            JackPattern::Chordjack => 1.0,
228            JackPattern::ChordStream => 1.1,
229            JackPattern::All => 1.0,
230        },
231        
232        // Handstream patterns
233        Pattern::Handstream(hs_type) => match hs_type {
234            HandstreamPattern::DenseHs => 0.8,
235            HandstreamPattern::AnchorHs => 1.1,
236            HandstreamPattern::HS => 1.0,
237            HandstreamPattern::LightHs => 1.1,
238            HandstreamPattern::All => 1.0,
239        },
240        
241        // Jumpstream patterns
242        Pattern::Jumpstream(js_type) => match js_type {
243            JumpstreamPattern::JT => 0.7,
244            JumpstreamPattern::AnchorJs => 1.1,
245            JumpstreamPattern::JS => 1.0,
246            JumpstreamPattern::LightJs => 1.1,
247            JumpstreamPattern::All => 1.0,
248        },
249        
250        // Singlestream patterns
251        Pattern::Singlestream(ss_type) => match ss_type {
252            SinglestreamPattern::Singlestream => 1.1,
253            SinglestreamPattern::All => 1.0,
254        },
255        
256        // None
257        Pattern::None => 0.0,
258    }
259}