rosu_pattern_detector/mania/models/
pattern.rs1use crate::mania::models::base::{Notes, ManiaMeasure};
2use std::collections::HashMap;
3
4#[derive(Debug, PartialEq, Clone,Hash,Eq)]
5pub enum Pattern
6{
7 Jack(JackPattern),
8 Handstream(HandstreamPattern),
9 Jumpstream(JumpstreamPattern),
10 Singlestream(SinglestreamPattern),
11 None,
12}
13
14impl Pattern {
15 pub fn to_all(&self) -> Pattern {
16 match self {
17 Pattern::Jack(_) => Pattern::Jack(JackPattern::All),
18 Pattern::Handstream(_) => Pattern::Handstream(HandstreamPattern::All),
19 Pattern::Jumpstream(_) => Pattern::Jumpstream(JumpstreamPattern::All),
20 Pattern::Singlestream(_) => Pattern::Singlestream(SinglestreamPattern::All),
21 Pattern::None => Pattern::None,
22 }
23 }
24}
25
26impl std::fmt::Display for Pattern {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 let pattern = match self {
29 Pattern::Jack(pattern) => pattern.to_string(),
30 Pattern::Handstream(pattern) => pattern.to_string(),
31 Pattern::Jumpstream(pattern) => pattern.to_string(),
32 Pattern::Singlestream(_) => "SingleStream".to_string(),
33 Pattern::None => return write!(f, "None"),
34 };
35 write!(f, "{}", pattern)
36 }
37}
38
39#[derive(Debug, PartialEq, Clone, Hash, Eq)]
40pub enum JackPattern {
41 Chordjack,
42 DenseChordjack,
43 ChordStream,
44 Speedjack,
45 All,
46}
47impl JackPattern {
48 pub fn determine_jack_type(measure: &mut ManiaMeasure) -> JackPattern {
49 let mut pattern_count: HashMap<Notes, usize> = HashMap::new();
50
51 for note in measure.notes.iter() {
52 *pattern_count.entry(note.get_pattern()).or_insert(0) += 1;
53 }
54
55 let single = *pattern_count.get(&Notes::Single).unwrap_or(&0);
56 let jump = *pattern_count.get(&Notes::Jump).unwrap_or(&0);
57 let hand = *pattern_count.get(&Notes::Hand).unwrap_or(&0);
58 let quad = *pattern_count.get(&Notes::Quad).unwrap_or(&0);
59
60 if hand > jump + single {
61 JackPattern::DenseChordjack
62 }
63 else if quad > 0 && jump + hand > single
64 {
65 JackPattern::Chordjack
66 }
67 else
68 {
69 JackPattern::determine_jackspeed_or_chordstream(measure)
70 }
71 }
72
73 fn determine_jackspeed_or_chordstream(measure: &mut ManiaMeasure) -> JackPattern {
74 let mut jack_count = 0;
75 for (i, note) in measure.notes.iter().enumerate() {
76 if i > 0 {
77 let prev = &measure.notes[i - 1];
78 if note.notes.iter().zip(prev.notes.iter()).any(|(n, p)| *n && *p) {
79 jack_count += 1;
80 }
81 }
82 }
83 if jack_count <= 1 && measure.notes.len() > 6 {
84 JackPattern::ChordStream
85 } else {
86 JackPattern::Speedjack
87 }
88 }
89
90 pub fn to_string(&self) -> String {
91 match self {
92 JackPattern::Chordjack => "Chordjack".to_string(),
93 JackPattern::DenseChordjack => "DenseChordjack".to_string(),
94 JackPattern::ChordStream => "ChordStream".to_string(),
95 JackPattern::Speedjack => "Speedjack".to_string(),
96 JackPattern::All => "All".to_string(),
97 }
98 }
99}
100
101
102#[derive(Debug, PartialEq, Clone, Hash, Eq)]
103pub enum JumpstreamPattern {
104 LightJs,
105 AnchorJs,
106 JS,
107 JT,
108 All,
109}
110impl JumpstreamPattern {
111 fn has_two_consecutive_jumps(measure: &ManiaMeasure) -> bool {
112 let mut last_was_jump = false;
113
114 for note in measure.notes.iter() {
115 let is_jump = matches!(note.get_pattern(), Notes::Jump);
116
117 if is_jump && last_was_jump {
118 return true;
119 }
120
121 last_was_jump = is_jump;
122 }
123
124 false
125 }
126
127 pub fn determine_js_type(measure: &mut ManiaMeasure) -> JumpstreamPattern {
128 if JumpstreamPattern::has_two_consecutive_jumps(measure) {
130 return JumpstreamPattern::JT;
131 }
132
133 let mut pattern_count: HashMap<Notes, usize> = HashMap::new();
134 for note in measure.notes.iter() {
135 *pattern_count.entry(note.get_pattern()).or_insert(0) += 1;
136 }
137 let single = *pattern_count.get(&Notes::Single).unwrap_or(&0);
138 let jump = *pattern_count.get(&Notes::Jump).unwrap_or(&0);
139
140 let mut vect_int = vec![0; measure.notes[0].notes.len()];
142
143 for note in measure.notes.iter() {
145 for (i, &is_active) in note.notes.iter().enumerate() {
146 if is_active {
147 vect_int[i] += 1;
148 }
149 }
150 }
151
152 if let Some(&max_value) = vect_int.iter().max() {
154 if max_value > 3 {
155 JumpstreamPattern::AnchorJs
156 } else if jump < single {
157 JumpstreamPattern::LightJs
158 } else {
159 JumpstreamPattern::JS
160 }
161 } else {
162 JumpstreamPattern::JS
163 }
164 }
165
166 pub fn to_string(&self) -> String {
167 match self {
168 JumpstreamPattern::LightJs => "LightJs".to_string(),
169 JumpstreamPattern::AnchorJs => "AnchorJs".to_string(),
170 JumpstreamPattern::JS => "JS".to_string(),
171 JumpstreamPattern::JT => "JT".to_string(),
172 JumpstreamPattern::All => "All".to_string(),
173 }
174 }
175}
176
177#[derive(Debug, PartialEq, Clone, Hash, Eq)]
178pub enum HandstreamPattern {
179 LightHs,
180 AnchorHs,
181 DenseHs,
182 HS,
183 All,
184}
185impl HandstreamPattern {
186 pub fn determine_hs_type(measure: &mut ManiaMeasure) -> HandstreamPattern {
187 let mut pattern_count: HashMap<Notes, usize> = HashMap::new();
188
189 for note in measure.notes.iter() {
190 *pattern_count.entry(note.get_pattern()).or_insert(0) += 1;
191 }
192 let jump = *pattern_count.get(&Notes::Jump).unwrap_or(&0);
193
194 if jump == 0 {
195 HandstreamPattern::LightHs
196 } else if jump > 0 {
197 HandstreamPattern::DenseHs
198 } else {
199 HandstreamPattern::HS
200 }
201 }
202
203 pub fn to_string(&self) -> String {
204 match self {
205 HandstreamPattern::LightHs => "LightHs".to_string(),
206 HandstreamPattern::AnchorHs => "AnchorHs".to_string(),
207 HandstreamPattern::DenseHs => "DenseHs".to_string(),
208 HandstreamPattern::HS => "HS".to_string(),
209 HandstreamPattern::All => "All".to_string(),
210 }
211 }
212}
213
214#[derive(Debug, PartialEq, Clone, Hash, Eq)]
215pub enum SinglestreamPattern {
216 Singlestream,
217 All,
218}
219pub fn get_pattern_weight(pattern: &Pattern) -> f64 {
222 match pattern {
223 Pattern::Jack(jack_type) => match jack_type {
225 JackPattern::DenseChordjack => 0.8,
226 JackPattern::Speedjack => 0.9,
227 JackPattern::Chordjack => 1.0,
228 JackPattern::ChordStream => 1.1,
229 JackPattern::All => 1.0,
230 },
231
232 Pattern::Handstream(hs_type) => match hs_type {
234 HandstreamPattern::DenseHs => 0.8,
235 HandstreamPattern::AnchorHs => 1.1,
236 HandstreamPattern::HS => 1.0,
237 HandstreamPattern::LightHs => 1.1,
238 HandstreamPattern::All => 1.0,
239 },
240
241 Pattern::Jumpstream(js_type) => match js_type {
243 JumpstreamPattern::JT => 0.7,
244 JumpstreamPattern::AnchorJs => 1.1,
245 JumpstreamPattern::JS => 1.0,
246 JumpstreamPattern::LightJs => 1.1,
247 JumpstreamPattern::All => 1.0,
248 },
249
250 Pattern::Singlestream(ss_type) => match ss_type {
252 SinglestreamPattern::Singlestream => 1.1,
253 SinglestreamPattern::All => 1.0,
254 },
255
256 Pattern::None => 0.0,
258 }
259}