Skip to main content

rhythm_open_exchange/codec/formats/sm/
encoder.rs

1#![allow(
2    clippy::doc_markdown,
3    clippy::cast_precision_loss,
4    clippy::cast_possible_truncation,
5    clippy::cast_sign_loss,
6    clippy::cast_lossless,
7    clippy::needless_range_loop,
8    clippy::match_same_arms,
9    clippy::redundant_closure_for_method_calls,
10    clippy::collapsible_if
11)]
12//! Encoder for converting `RoxChart` to StepMania (`.sm`) format.
13
14use std::fmt::Write;
15
16use crate::codec::Encoder;
17use crate::error::RoxResult;
18use crate::model::{NoteType, RoxChart};
19
20/// Encoder for StepMania (`.sm`) beatmaps.
21pub struct SmEncoder;
22
23impl Encoder for SmEncoder {
24    fn encode(chart: &RoxChart) -> RoxResult<Vec<u8>> {
25        let mut output = String::new();
26
27        // Metadata
28        let _ = writeln!(output, "#TITLE:{};", chart.metadata.title);
29        let _ = writeln!(output, "#SUBTITLE:;");
30        let _ = writeln!(output, "#ARTIST:{};", chart.metadata.artist);
31        let _ = writeln!(output, "#TITLETRANSLIT:;");
32        let _ = writeln!(output, "#ARTISTTRANSLIT:;");
33        let _ = writeln!(output, "#GENRE:;");
34        let _ = writeln!(output, "#CREDIT:{};", chart.metadata.creator);
35        let _ = writeln!(output, "#BANNER:;");
36        if let Some(bg) = &chart.metadata.background_file {
37            let _ = writeln!(output, "#BACKGROUND:{bg};");
38        } else {
39            let _ = writeln!(output, "#BACKGROUND:;");
40        }
41        let _ = writeln!(output, "#LYRICSPATH:;");
42        let _ = writeln!(output, "#CDTITLE:;");
43        let _ = writeln!(output, "#MUSIC:{};", chart.metadata.audio_file);
44
45        // Determine Sync Point (Beat 0 location)
46        // SM expects Offset to be the time of the first beat.
47        // We use the time of the first uninherited timing point.
48        let first_bpm_time = chart
49            .timing_points
50            .iter()
51            .find(|tp| !tp.is_inherited)
52            .map_or(0, |tp| tp.time_us);
53
54        // Offset (SM uses "Time where Beat 0 begins" in seconds)
55        // So if beat 0 is at -0.030s, Offset should be -0.030.
56        let offset_seconds = first_bpm_time as f64 / 1_000_000.0;
57        let _ = writeln!(output, "#OFFSET:{offset_seconds:.6};");
58
59        // Sample start/length
60        #[allow(clippy::cast_precision_loss)]
61        let sample_start = chart.metadata.preview_time_us as f64 / 1_000_000.0;
62        #[allow(clippy::cast_precision_loss)]
63        let sample_length = chart.metadata.preview_duration_us as f64 / 1_000_000.0;
64        let _ = writeln!(output, "#SAMPLESTART:{sample_start:.3};");
65        let _ = writeln!(output, "#SAMPLELENGTH:{sample_length:.3};");
66
67        let _ = writeln!(output, "#SELECTABLE:YES;");
68
69        // BPMs
70        output.push_str("#BPMS:");
71        let bpm_points: Vec<_> = chart
72            .timing_points
73            .iter()
74            .filter(|tp| !tp.is_inherited)
75            .collect();
76
77        for (i, tp) in bpm_points.iter().enumerate() {
78            // Calculate beat relative to the sync point (first_bpm_time)
79            // Note: Since we set offset based on first_bpm_time, beat 0 matches that time.
80            let beat = us_to_beat(tp.time_us, &bpm_points, first_bpm_time);
81            if i > 0 {
82                output.push(',');
83            }
84            // Format beat: if integer, use integer format, else float
85            if (beat - beat.round()).abs() < 0.001 {
86                let _ = write!(output, "{:.0}={:.3}", beat, tp.bpm);
87            } else {
88                let _ = write!(output, "{:.3}={:.3}", beat, tp.bpm);
89            }
90        }
91        let _ = writeln!(output, ";");
92
93        // Stops (empty for now)
94        let _ = writeln!(output, "#STOPS:;");
95        let _ = writeln!(output);
96
97        // Notes section
98        let stepstype = match chart.key_count() {
99            4 => "dance-single",
100            6 => "dance-solo",
101            8 => "dance-double",
102            _ => "dance-single",
103        };
104
105        let _ = writeln!(output, "#NOTES:");
106        let _ = writeln!(output, "     {stepstype}:");
107        let _ = writeln!(output, "     :");
108        // Force Difficulty to "Hard" or "Challenge" to ensure Etterna/SM sees it validly.
109        // "1.0x" is not a standard difficulty name.
110        let difficulty_name = match chart.metadata.difficulty_name.as_str() {
111            "Beginner" | "Easy" | "Medium" | "Hard" | "Challenge" | "Edit" => {
112                &chart.metadata.difficulty_name
113            }
114            _ => "Hard", // Fallback for numeric versions like "1.0x"
115        };
116        let _ = writeln!(output, "     {difficulty_name}:");
117        let _ = writeln!(
118            output,
119            "     {}:",
120            chart.metadata.difficulty_value.unwrap_or(1.0) as u32
121        );
122        // Correct format for radar values
123        // Revert to simple integer format as per working 4k.sm example
124        let _ = writeln!(output, "     0,0,0,0,0:");
125
126        // Generate measures
127        let bpms_tuple: Vec<_> = chart
128            .timing_points
129            .iter()
130            .filter(|tp| !tp.is_inherited)
131            .map(|tp| (tp.time_us, tp.bpm))
132            .collect();
133
134        encode_measures(&mut output, chart, &bpms_tuple, first_bpm_time);
135
136        let _ = writeln!(output, ";");
137
138        Ok(output.into_bytes())
139    }
140}
141
142/// Convert microseconds to beat position.
143/// `start_time_us` is the time where beat count starts (beat 0).
144fn us_to_beat(time_us: i64, bpm_points: &[&crate::model::TimingPoint], start_time_us: i64) -> f64 {
145    if bpm_points.is_empty() {
146        return 0.0;
147    }
148
149    let mut current_time_us = start_time_us;
150    let mut current_beat: f64 = 0.0;
151    let mut current_bpm = bpm_points[0].bpm; // Default to first BPM
152
153    for i in 1..bpm_points.len() {
154        let tp = bpm_points[i];
155        if tp.time_us > time_us {
156            break;
157        }
158
159        let elapsed_us = tp.time_us - current_time_us;
160        let elapsed_beats = us_to_beats_at_bpm(elapsed_us, current_bpm);
161        current_beat += elapsed_beats;
162        current_time_us = tp.time_us;
163        current_bpm = tp.bpm;
164    }
165
166    let remaining_us = time_us - current_time_us;
167    current_beat + us_to_beats_at_bpm(remaining_us, current_bpm)
168}
169
170fn us_to_beats_at_bpm(us: i64, bpm: f32) -> f64 {
171    let seconds = us as f64 / 1_000_000.0;
172    seconds * f64::from(bpm) / 60.0
173}
174
175/// Encode all notes into SM measure format.
176#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
177fn encode_measures(output: &mut String, chart: &RoxChart, bpms: &[(i64, f32)], start_time_us: i64) {
178    if chart.notes.is_empty() {
179        // Empty chart - just one empty measure
180        for _ in 0..4 {
181            let _ = writeln!(output, "{}", "0".repeat(chart.key_count() as usize));
182        }
183        return;
184    }
185
186    // Find the total duration
187    let max_time = chart
188        .notes
189        .iter()
190        .map(|n| n.end_time_us())
191        .max()
192        .unwrap_or(0);
193
194    // Calculate number of measures needed
195    let total_beats = us_to_beat_simple(max_time, bpms, start_time_us);
196
197    let total_measures = if total_beats > 0.0 {
198        (total_beats / 4.0).ceil() as usize + 1
199    } else {
200        1
201    };
202
203    // Create note events: (time_us, column, char)
204    let mut events: Vec<(i64, u8, char)> = Vec::new();
205
206    for note in &chart.notes {
207        match &note.note_type {
208            NoteType::Tap => {
209                events.push((note.time_us, note.column, '1'));
210            }
211            NoteType::Hold { duration_us } => {
212                events.push((note.time_us, note.column, '2'));
213                events.push((note.time_us + duration_us, note.column, '3'));
214            }
215            NoteType::Burst { duration_us } => {
216                events.push((note.time_us, note.column, '4'));
217                events.push((note.time_us + duration_us, note.column, '3'));
218            }
219            NoteType::Mine => {
220                events.push((note.time_us, note.column, 'M'));
221            }
222        }
223    }
224
225    // Sort events by time
226    events.sort_by_key(|(t, _, _)| *t);
227
228    // Resolve collisions: Abutting notes (Tail overwrites Next Note or vice versa)
229    // If a Tail ('3') is at the same time/col as a Start ('1', '2', '4', 'M'),
230    // StepMania cannot handle it (Tail requires release, Start requires press).
231    // We convert the Hold into a Tap to prevent "hanging head" crashes.
232    let len = events.len();
233
234    // We iterate and mark modifications.
235    // Note: events is sorted by time.
236    // If times are equal, stable sort keeps relative order?
237    // We generated Heads then Tails.
238    // If Note A (Hold) ends at T, and Note B (Tap) starts at T.
239    // events: [..., (T, A, 3), (T, B, 1), ...] (assuming A processed before B in chart? No)
240    // chart.notes is sorted by time. A starts before B.
241    // So (A_start) < (B_start).
242    // So A processed first.
243    // So (T, A, 3) pushed, then (T, B, 1) pushed.
244    // Sort stable maintains order.
245    // So events[i] = 3, events[i+1] = 1.
246
247    for i in 0..len.saturating_sub(1) {
248        let (t1, c1, ch1) = events[i];
249        let (t2, c2, ch2) = events[i + 1];
250
251        // Check for collision
252        if t1 == t2 && c1 == c2 {
253            // Case: Tail ('3') followed by Start ('1', '2', '4', 'M')
254            if ch1 == '3' && (ch2 == '1' || ch2 == '2' || ch2 == '4' || ch2 == 'M') {
255                // Collision!
256                // Convert the Hold (Head at '2'/'4') to Tap ('1').
257                // Mark Tail ('3') for removal.
258
259                // Find the corresponding Head
260                let mut head_found = false;
261                for j in (0..i).rev() {
262                    if events[j].1 == c1 {
263                        if events[j].2 == '2' || events[j].2 == '4' {
264                            // Found head. Convert to Tap.
265                            events[j].2 = '1';
266                            head_found = true;
267                            break;
268                        } else if events[j].2 == '3' {
269                            // Another tail? Nested holds? Shouldn't happen in flat list unless logic wrong.
270                            // Stop if we hit another tail, it means we missed the head or interleaved.
271                            break;
272                        }
273                    }
274                }
275
276                if head_found {
277                    events[i].2 = '0'; // Mark tail for removal (we'll filter '0' out output logic or now)
278                }
279            }
280        }
281    }
282
283    // Group events by measure
284    // Map: measure_index -> Vec<(beat_in_measure, col, char)>
285    let mut measure_events: Vec<Vec<(f64, u8, char)>> = vec![Vec::new(); total_measures];
286
287    for (time_us, col, ch) in events {
288        if ch == '0' {
289            continue;
290        } // Skip removed tails
291
292        let raw_beat = us_to_beat_simple(time_us, bpms, start_time_us);
293
294        // If beat is negative, it's before the start. Skip or warn?
295        if raw_beat < 0.0 {
296            continue; // Cannot represent in SM M0
297        }
298
299        // Snap to grid (48th notes / 192 per measure) to handle floating point jitter
300        #[allow(clippy::items_after_statements)]
301        const GRID_RESOLUTION: f64 = 48.0;
302        let mut beat = (raw_beat * GRID_RESOLUTION).round() / GRID_RESOLUTION;
303
304        // If the snapped beat is effectively an integer + epsilon, make sure it behaves
305        if (beat - beat.round()).abs() < 1e-6 {
306            beat = beat.round();
307        }
308
309        let measure_idx = (beat / 4.0).floor() as usize;
310        let beat_in_measure = beat % 4.0;
311
312        if measure_idx < measure_events.len() {
313            measure_events[measure_idx].push((beat_in_measure, col, ch));
314        } else {
315            // Extend if needed
316            if measure_idx >= measure_events.len() {
317                measure_events.resize(measure_idx + 1, Vec::new());
318            }
319            measure_events[measure_idx].push((beat_in_measure, col, ch));
320        }
321    }
322
323    // Generate each measure
324    for (measure_num, events) in measure_events.iter().enumerate() {
325        if measure_num > 0 {
326            let _ = writeln!(output, ",");
327        }
328
329        // Try standard SM divisors
330        let divisors = [4, 8, 12, 16, 24, 32, 48, 64, 96, 192];
331        let mut best_divisor = 192;
332
333        'divisor_loop: for &div in &divisors {
334            // Check if all events align with this divisor
335            for (beat_in_measure, _, _) in events {
336                // Ideal position in lines for this divisor
337                let ideal_line = beat_in_measure * (div as f64) / 4.0;
338                let snapped_line = ideal_line.round();
339
340                // If deviation is too high, this divisor is invalid
341                if (ideal_line - snapped_line).abs() > 0.001 {
342                    continue 'divisor_loop;
343                }
344            }
345
346            // If we get here, all events aligned
347            best_divisor = div;
348            break;
349        }
350
351        // If measure is empty, force 4 lines to save space
352        if events.is_empty() {
353            best_divisor = 4;
354        }
355
356        let lines_per_measure = best_divisor;
357        for i in 0..lines_per_measure {
358            // Collect events on this line
359            let mut line_chars: Vec<char> = vec!['0'; chart.key_count() as usize];
360
361            for (beat_in_measure, col, ch) in events {
362                // Check if this event belongs to this line
363                // We use the same epsilon logic as above to match
364                let event_line_pos = beat_in_measure * (lines_per_measure as f64) / 4.0;
365                if (event_line_pos - i as f64).abs() < 0.001 {
366                    if (*col as usize) < line_chars.len() {
367                        line_chars[*col as usize] = *ch;
368                    }
369                }
370            }
371
372            let line_str: String = line_chars.into_iter().collect();
373            let _ = writeln!(output, "{line_str}");
374        }
375    }
376}
377
378fn us_to_beat_simple(time_us: i64, bpms: &[(i64, f32)], start_time_us: i64) -> f64 {
379    if bpms.is_empty() {
380        return (time_us - start_time_us) as f64 / 1_000_000.0 * 120.0 / 60.0;
381    }
382
383    let mut current_time_us = start_time_us;
384    let mut current_beat: f64 = 0.0;
385    let mut current_bpm = bpms[0].1;
386
387    for i in 1..bpms.len() {
388        let (bpm_time, new_bpm) = bpms[i];
389        if bpm_time > time_us {
390            break;
391        }
392
393        let elapsed = bpm_time - current_time_us;
394        current_beat += us_to_beats_at_bpm(elapsed, current_bpm);
395        current_time_us = bpm_time;
396        current_bpm = new_bpm;
397    }
398
399    current_beat + us_to_beats_at_bpm(time_us - current_time_us, current_bpm)
400}
401
402#[cfg(test)]
403mod tests {
404
405    #[test]
406    #[cfg(feature = "analysis")]
407    fn test_roundtrip() {
408        use crate::analysis::RoxAnalysis;
409        use crate::codec::Decoder;
410        use crate::codec::Encoder;
411        use crate::codec::formats::sm::SmDecoder;
412        use crate::codec::formats::sm::SmEncoder;
413        let data = crate::test_utils::get_test_asset("stepmania/4k.sm");
414        let chart1 = <SmDecoder as Decoder>::decode(&data).unwrap();
415        let encoded = SmEncoder::encode(&chart1).unwrap();
416        let chart2 = <SmDecoder as Decoder>::decode(&encoded).unwrap();
417
418        assert_eq!(chart1.key_count(), chart2.key_count());
419
420        // SM roundtrip might be tricky due to floating point and grid snapping.
421        // Let's see if hashes match.
422        assert_eq!(
423            chart1.notes_hash(),
424            chart2.notes_hash(),
425            "Notes hash mismatch"
426        );
427        assert_eq!(
428            chart1.timings_hash(),
429            chart2.timings_hash(),
430            "Timings hash mismatch"
431        );
432    }
433}