Skip to main content

rhythm_open_exchange/codec/formats/osu/
encoder.rs

1//! Encoder for converting `RoxChart` to .osu format.
2
3use std::fmt::Write;
4
5use crate::codec::Encoder;
6use crate::error::RoxResult;
7use crate::model::RoxChart;
8
9/// Encoder for osu!mania beatmaps.
10pub struct OsuEncoder;
11
12impl Encoder for OsuEncoder {
13    fn encode(chart: &RoxChart) -> RoxResult<Vec<u8>> {
14        let mut output = String::new();
15
16        // Format version
17        output.push_str("osu file format v14\n\n");
18
19        write_general_section(&mut output, chart);
20        write_editor_section(&mut output);
21        write_metadata_section(&mut output, chart);
22        write_difficulty_section(&mut output, chart);
23        write_events_section(&mut output, chart);
24        write_timing_points_section(&mut output, chart);
25        write_hit_objects_section(&mut output, chart);
26
27        Ok(output.into_bytes())
28    }
29}
30
31/// Write the [General] section.
32fn write_general_section(output: &mut String, chart: &RoxChart) {
33    output.push_str("[General]\n");
34    let _ = writeln!(output, "AudioFilename: {}", chart.metadata.audio_file);
35    let _ = writeln!(
36        output,
37        "AudioLeadIn: {}",
38        chart.metadata.audio_offset_us / 1000
39    );
40    let _ = writeln!(
41        output,
42        "PreviewTime: {}",
43        chart.metadata.preview_time_us / 1000
44    );
45    output.push_str("Countdown: 0\n");
46    output.push_str("SampleSet: Normal\n");
47    output.push_str("StackLeniency: 0.7\n");
48    output.push_str("Mode: 3\n");
49    output.push_str("LetterboxInBreaks: 0\n");
50    output.push_str("SpecialStyle: 0\n");
51    output.push_str("WidescreenStoryboard: 0\n\n");
52}
53
54/// Write the [Editor] section.
55fn write_editor_section(output: &mut String) {
56    output.push_str("[Editor]\n");
57    output.push_str("DistanceSpacing: 1\n");
58    output.push_str("BeatDivisor: 4\n");
59    output.push_str("GridSize: 4\n");
60    output.push_str("TimelineZoom: 1\n\n");
61}
62
63/// Write the [Metadata] section.
64fn write_metadata_section(output: &mut String, chart: &RoxChart) {
65    output.push_str("[Metadata]\n");
66    let _ = writeln!(output, "Title:{}", chart.metadata.title);
67    let _ = writeln!(output, "TitleUnicode:{}", chart.metadata.title);
68    let _ = writeln!(output, "Artist:{}", chart.metadata.artist);
69    let _ = writeln!(output, "ArtistUnicode:{}", chart.metadata.artist);
70    let _ = writeln!(output, "Creator:{}", chart.metadata.creator);
71    let _ = writeln!(output, "Version:{}", chart.metadata.difficulty_name);
72    if let Some(source) = &chart.metadata.source {
73        let _ = writeln!(output, "Source:{source}");
74    }
75    if !chart.metadata.tags.is_empty() {
76        let _ = writeln!(output, "Tags:{}", chart.metadata.tags.join(" "));
77    }
78    // Export chart IDs (default to 0/-1 if not set)
79    let _ = writeln!(output, "BeatmapID:{}", chart.metadata.chart_id.unwrap_or(0));
80    // Safe: osu format uses -1 for missing set ID
81    #[allow(clippy::cast_possible_wrap)]
82    let _ = writeln!(
83        output,
84        "BeatmapSetID:{}",
85        chart.metadata.chartset_id.map_or(-1, |id| id as i64)
86    );
87    output.push('\n');
88}
89
90/// Write the [Difficulty] section.
91fn write_difficulty_section(output: &mut String, chart: &RoxChart) {
92    output.push_str("[Difficulty]\n");
93    output.push_str("HPDrainRate:8\n");
94    let _ = writeln!(output, "CircleSize:{}", chart.key_count());
95    let _ = writeln!(
96        output,
97        "OverallDifficulty:{}",
98        chart.metadata.difficulty_value.unwrap_or(8.0)
99    );
100    output.push_str("ApproachRate:5\n");
101    output.push_str("SliderMultiplier:1.4\n");
102    output.push_str("SliderTickRate:1\n\n");
103}
104
105/// Write the [Events] section.
106fn write_events_section(output: &mut String, chart: &RoxChart) {
107    output.push_str("[Events]\n");
108    output.push_str("//Background and Video events\n");
109    if let Some(bg) = &chart.metadata.background_file {
110        let _ = writeln!(output, "0,0,\"{bg}\",0,0");
111    }
112    output.push_str("//Break Periods\n");
113    output.push_str("//Storyboard Layer 0 (Background)\n");
114    output.push_str("//Storyboard Layer 1 (Fail)\n");
115    output.push_str("//Storyboard Layer 2 (Pass)\n");
116    output.push_str("//Storyboard Layer 3 (Foreground)\n");
117    output.push_str("//Storyboard Sound Samples\n\n");
118}
119
120/// Write the [`TimingPoints`] section.
121fn write_timing_points_section(output: &mut String, chart: &RoxChart) {
122    output.push_str("[TimingPoints]\n");
123    for tp in &chart.timing_points {
124        #[allow(clippy::cast_precision_loss)]
125        let time_ms = tp.time_us as f64 / 1000.0;
126
127        if tp.is_inherited {
128            // SV point: beatLength = -100 / sv
129            let beat_length = -100.0 / f64::from(tp.scroll_speed);
130            let _ = writeln!(output, "{time_ms},{beat_length},4,1,0,100,0,0");
131        } else {
132            // BPM point: beatLength = 60000 / bpm
133            let beat_length = 60000.0 / f64::from(tp.bpm);
134            let _ = writeln!(
135                output,
136                "{},{},{},1,0,100,1,0",
137                time_ms, beat_length, tp.signature
138            );
139        }
140    }
141    output.push_str("\n\n");
142}
143
144/// Write the [`HitObjects`] section.
145fn write_hit_objects_section(output: &mut String, chart: &RoxChart) {
146    output.push_str("[HitObjects]\n");
147    for note in &chart.notes {
148        // Safe: time_us / 1000 fits in i32 for typical beatmaps
149        #[allow(clippy::cast_possible_truncation)]
150        let time_ms = (note.time_us / 1000) as i32;
151        let x = column_to_x(note.column, chart.key_count());
152
153        match &note.note_type {
154            crate::model::NoteType::Tap => {
155                // x,y,time,type,hitSound,extras
156                let _ = writeln!(output, "{x},192,{time_ms},1,0,0:0:0:0:");
157            }
158            crate::model::NoteType::Hold { duration_us } => {
159                #[allow(clippy::cast_possible_truncation)]
160                let end_time = time_ms + (*duration_us / 1000) as i32;
161                // x,y,time,type,hitSound,endTime:extras
162                let _ = writeln!(output, "{x},192,{time_ms},128,0,{end_time}:0:0:0:0:");
163            }
164            crate::model::NoteType::Burst { .. } | crate::model::NoteType::Mine => {
165                // Burst and Mine - convert to tap for osu
166                let _ = writeln!(output, "{x},192,{time_ms},1,0,0:0:0:0:");
167            }
168        }
169    }
170}
171
172/// Convert column index to X position for osu.
173/// For 7K: 36, 109, 182, 256, 329, 402, 475
174#[must_use]
175pub fn column_to_x(column: u8, key_count: u8) -> i32 {
176    // Formula: center of column = (2*column + 1) * 256 / key_count
177    // Use integer arithmetic to avoid floating-point precision errors
178    let column = i32::from(column);
179    let key_count = i32::from(key_count);
180    (2 * column + 1) * 256 / key_count
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::model::{Note, TimingPoint};
187
188    /// Helper to verify all columns for a key count
189    fn verify_columns(key_count: u8, expected: &[i32]) {
190        assert_eq!(
191            expected.len(),
192            key_count as usize,
193            "Wrong number of expected values for {}K",
194            key_count
195        );
196        for (col, &expected_x) in expected.iter().enumerate() {
197            let actual = column_to_x(col as u8, key_count);
198            assert_eq!(
199                actual, expected_x,
200                "{}K column {} failed: expected {}, got {}",
201                key_count, col, expected_x, actual
202            );
203        }
204    }
205
206    #[test]
207    fn test_column_to_x_4k() {
208        verify_columns(4, &[64, 192, 320, 448]);
209    }
210
211    #[test]
212    fn test_column_to_x_5k() {
213        verify_columns(5, &[51, 153, 256, 358, 460]);
214    }
215
216    #[test]
217    fn test_column_to_x_6k() {
218        verify_columns(6, &[42, 128, 213, 298, 384, 469]);
219    }
220
221    #[test]
222    fn test_column_to_x_7k() {
223        verify_columns(7, &[36, 109, 182, 256, 329, 402, 475]);
224    }
225
226    #[test]
227    fn test_column_to_x_8k() {
228        verify_columns(8, &[32, 96, 160, 224, 288, 352, 416, 480]);
229    }
230
231    #[test]
232    fn test_column_to_x_9k() {
233        verify_columns(9, &[28, 85, 142, 199, 256, 312, 369, 426, 483]);
234    }
235
236    #[test]
237    fn test_column_to_x_10k() {
238        verify_columns(10, &[25, 76, 128, 179, 230, 281, 332, 384, 435, 486]);
239    }
240
241    #[test]
242    fn test_column_to_x_12k() {
243        verify_columns(
244            12,
245            &[21, 64, 106, 149, 192, 234, 277, 320, 362, 405, 448, 490],
246        );
247    }
248
249    #[test]
250    fn test_column_to_x_14k() {
251        verify_columns(
252            14,
253            &[
254                18, 54, 91, 128, 164, 201, 237, 274, 310, 347, 384, 420, 457, 493,
255            ],
256        );
257    }
258
259    #[test]
260    fn test_column_to_x_16k() {
261        verify_columns(
262            16,
263            &[
264                16, 48, 80, 112, 144, 176, 208, 240, 272, 304, 336, 368, 400, 432, 464, 496,
265            ],
266        );
267    }
268
269    #[test]
270    fn test_column_to_x_18k() {
271        verify_columns(
272            18,
273            &[
274                14, 42, 71, 99, 128, 156, 184, 213, 241, 270, 298, 327, 355, 384, 412, 440, 469,
275                497,
276            ],
277        );
278    }
279
280    #[test]
281    fn test_column_roundtrip() {
282        for key_count in [4, 5, 6, 7, 8, 9, 10] {
283            for col in 0..key_count {
284                let x = column_to_x(col, key_count);
285                #[allow(clippy::cast_possible_truncation)]
286                let decoded_col = ((x * i32::from(key_count)) / 512) as u8;
287                assert_eq!(
288                    decoded_col, col,
289                    "Roundtrip failed for {}K column {}",
290                    key_count, col
291                );
292            }
293        }
294    }
295
296    #[test]
297    fn test_encode_basic() {
298        let mut chart = RoxChart::new(7);
299        chart.metadata.title = "Test".into();
300        chart.metadata.artist = "Artist".into();
301        chart.metadata.creator = "Mapper".into();
302        chart.metadata.difficulty_name = "Hard".into();
303        chart.metadata.audio_file = "audio.mp3".into();
304        chart.timing_points.push(TimingPoint::bpm(0, 180.0));
305        chart.notes.push(Note::tap(1_000_000, 0));
306        chart.notes.push(Note::tap(1_500_000, 3));
307        chart.notes.push(Note::hold(2_000_000, 500_000, 6));
308
309        let encoded = OsuEncoder::encode(&chart).unwrap();
310        let output = String::from_utf8_lossy(&encoded);
311
312        assert!(output.contains("osu file format v14"));
313        assert!(output.contains("Mode: 3"));
314        assert!(output.contains("CircleSize:7"));
315    }
316
317    #[test]
318    #[cfg(feature = "analysis")]
319    fn test_roundtrip() {
320        use crate::analysis::RoxAnalysis;
321        use crate::codec::formats::osu::OsuDecoder;
322        use crate::codec::Decoder;
323        let data = crate::test_utils::get_test_asset("osu/mania_7k.osu");
324        let chart1 = <OsuDecoder as Decoder>::decode(&data).unwrap();
325        let encoded = OsuEncoder::encode(&chart1).unwrap();
326        let chart2 = <OsuDecoder as Decoder>::decode(&encoded).unwrap();
327
328        assert_eq!(chart1.key_count(), chart2.key_count());
329
330        // Compare using hashes
331        assert_eq!(
332            chart1.notes_hash(),
333            chart2.notes_hash(),
334            "Notes hash mismatch"
335        );
336        assert_eq!(
337            chart1.timings_hash(),
338            chart2.timings_hash(),
339            "Timings hash mismatch"
340        );
341    }
342}