midi_toolkit/sequence/event/
merge_events.rs

1use crate::gen_iter::GenIter;
2
3use crate::{
4    events::MIDIDelta,
5    num::MIDINum,
6    sequence::{grouped_multithreaded_merge, MergableStreams},
7    unwrap, yield_error,
8};
9
10enum Sequence<D: MIDINum, E: MIDIDelta<D>, Err, I: Iterator<Item = Result<E, Err>> + Sized> {
11    HasNext {
12        iter: I,
13        next: E,
14        _phantom: std::marker::PhantomData<D>,
15    },
16    Empty,
17}
18
19impl<D: MIDINum, E: MIDIDelta<D>, Err, I: Iterator<Item = Result<E, Err>> + Sized>
20    Sequence<D, E, Err, I>
21{
22    fn new(iter: I) -> Result<Self, Err> {
23        let mut iter = iter;
24        match iter.next() {
25            Some(Ok(next)) => Ok(Sequence::HasNext {
26                iter,
27                next,
28                _phantom: std::marker::PhantomData,
29            }),
30            Some(Err(e)) => Err(e),
31            None => Ok(Sequence::Empty),
32        }
33    }
34
35    fn next(&mut self) -> Result<Option<E>, Err> {
36        match self {
37            Sequence::HasNext { iter, next, .. } => match iter.next() {
38                Some(Ok(mut iter_next)) => {
39                    let new_time = next.delta().saturating_add(iter_next.delta());
40                    iter_next.set_delta(new_time);
41                    let old_next = std::mem::replace(next, iter_next);
42                    Ok(Some(old_next))
43                }
44                Some(Err(e)) => Err(e),
45                None => {
46                    let old = std::mem::replace(self, Sequence::Empty);
47
48                    let old_next = match old {
49                        Sequence::HasNext { next, .. } => next,
50                        _ => unreachable!(),
51                    };
52
53                    Ok(Some(old_next))
54                }
55            },
56            Sequence::Empty => Ok(None),
57        }
58    }
59
60    fn time(&self) -> Option<D> {
61        match self {
62            Sequence::HasNext { next, .. } => Some(next.delta()),
63            Sequence::Empty => None,
64        }
65    }
66
67    fn is_empty(&self) -> bool {
68        match self {
69            Sequence::HasNext { .. } => false,
70            Sequence::Empty => true,
71        }
72    }
73}
74
75struct BinaryTreeSequenceMerge<
76    D: MIDINum,
77    E: MIDIDelta<D>,
78    Err,
79    I: Iterator<Item = Result<E, Err>> + Sized,
80> {
81    sequences: Vec<Sequence<D, E, Err, I>>,
82    heap: Vec<Option<D>>,
83}
84
85impl<D: MIDINum, E: MIDIDelta<D>, Err, I: Iterator<Item = Result<E, Err>> + Sized>
86    BinaryTreeSequenceMerge<D, E, Err, I>
87{
88    fn new(iters: impl Iterator<Item = I>) -> Result<Self, Err> {
89        let mut sequences = vec![];
90
91        for iter in iters {
92            let seq = Sequence::new(iter)?;
93            if !seq.is_empty() {
94                sequences.push(seq);
95            }
96        }
97
98        if sequences.is_empty() {
99            sequences.push(Sequence::Empty);
100        }
101
102        let heap = vec![None; sequences.len() - 1];
103
104        let mut tree = Self { heap, sequences };
105
106        for i in (0..tree.heap.len()).rev() {
107            tree.update_time_from_children_for(i);
108        }
109
110        Ok(tree)
111    }
112
113    fn get_time_for(&self, index: usize) -> Option<D> {
114        if index >= self.heap.len() {
115            let index = index - self.heap.len();
116            self.sequences.get(index).and_then(|x| x.time())
117        } else {
118            self.heap.get(index).and_then(|x| *x)
119        }
120    }
121
122    fn calculate_time_from_children_for(&self, index: usize) -> Option<D> {
123        let left = index * 2 + 1;
124        let right = index * 2 + 2;
125
126        let left_time = self.get_time_for(left);
127        let right_time = self.get_time_for(right);
128
129        match (left_time, right_time) {
130            (Some(left_time), Some(right_time)) => {
131                if left_time < right_time {
132                    Some(left_time)
133                } else {
134                    Some(right_time)
135                }
136            }
137            (Some(left_time), None) => Some(left_time),
138            (None, Some(right_time)) => Some(right_time),
139            (None, None) => None,
140        }
141    }
142
143    fn update_time_from_children_for(&mut self, index: usize) -> Option<D> {
144        let time = self.calculate_time_from_children_for(index);
145        self.heap[index] = time;
146        time
147    }
148
149    fn propagate_time_change_from(&mut self, index: usize) {
150        let mut index = index;
151        while index > 0 {
152            index = (index - 1) / 2;
153            self.update_time_from_children_for(index);
154        }
155    }
156
157    fn find_smallest_sequence_index(&self) -> Option<usize> {
158        #[allow(clippy::question_mark)]
159        // Empty if the root time is None
160        if self.get_time_for(0).is_none() {
161            return None;
162        }
163
164        let mut index = 0;
165        loop {
166            if index >= self.heap.len() {
167                return Some(index - self.heap.len());
168            }
169
170            let left = index * 2 + 1;
171            let right = index * 2 + 2;
172
173            let left_time = self.get_time_for(left);
174            let right_time = self.get_time_for(right);
175
176            let next_index = match (left_time, right_time) {
177                (Some(left_time), Some(right_time)) => {
178                    if left_time < right_time {
179                        left
180                    } else {
181                        right
182                    }
183                }
184                (Some(_), None) => left,
185                (None, Some(_)) => right,
186                (None, None) => unreachable!(),
187            };
188
189            index = next_index;
190        }
191    }
192
193    fn next(&mut self) -> Result<Option<E>, Err> {
194        let index = self.find_smallest_sequence_index();
195        if let Some(index) = index {
196            let sequence = &mut self.sequences[index];
197            let item = sequence.next()?.unwrap();
198
199            self.propagate_time_change_from(index + self.heap.len());
200
201            Ok(Some(item))
202        } else {
203            Ok(None)
204        }
205    }
206
207    #[inline(always)]
208    fn assert_all_empty(&self) {
209        for seq in &self.sequences {
210            debug_assert!(seq.is_empty());
211        }
212    }
213}
214
215/// Merge an array of event iterators together into one iterator.
216pub fn merge_events_array<
217    D: MIDINum,
218    E: MIDIDelta<D>,
219    Err,
220    I: Iterator<Item = Result<E, Err>> + Sized,
221>(
222    array: Vec<I>,
223) -> impl Iterator<Item = Result<E, Err>> {
224    GenIter(
225        #[coroutine]
226        move || {
227            let tree = BinaryTreeSequenceMerge::new(array.into_iter());
228            match tree {
229                Err(e) => yield_error!(Err(e)),
230                Ok(mut tree) => {
231                    let mut time = D::zero();
232                    while let Some(mut e) = unwrap!(tree.next()) {
233                        let new_time = e.delta();
234                        e.set_delta(e.delta() - time);
235                        time = new_time;
236                        yield Ok(e);
237                    }
238                    tree.assert_all_empty();
239                }
240            }
241        },
242    )
243}
244
245struct SeqTime<D: MIDINum, E: MIDIDelta<D>, Err, I: Iterator<Item = Result<E, Err>> + Sized> {
246    iter: I,
247    time: D,
248    next: Option<E>,
249}
250
251/// Merge a pair of two different event iterators together into one iterator.
252pub fn merge_events<
253    D: MIDINum,
254    E: MIDIDelta<D>,
255    Err,
256    I1: Iterator<Item = Result<E, Err>> + Sized,
257    I2: Iterator<Item = Result<E, Err>> + Sized,
258>(
259    iter1: I1,
260    iter2: I2,
261) -> impl Iterator<Item = Result<E, Err>> {
262    fn seq_from_iter<
263        D: MIDINum,
264        E: MIDIDelta<D>,
265        Err,
266        I: Iterator<Item = Result<E, Err>> + Sized,
267    >(
268        mut iter: I,
269    ) -> Result<SeqTime<D, E, Err, I>, Err> {
270        let first = iter.next();
271        match first {
272            None => Ok(SeqTime {
273                iter,
274                time: D::zero(),
275                next: None,
276            }),
277            Some(e) => match e {
278                Err(e) => Err(e),
279                Ok(e) => Ok(SeqTime {
280                    iter,
281                    time: e.delta(),
282                    next: Some(e),
283                }),
284            },
285        }
286    }
287
288    fn move_next<D: MIDINum, E: MIDIDelta<D>, Err, I: Iterator<Item = Result<E, Err>> + Sized>(
289        seq: &mut SeqTime<D, E, Err, I>,
290    ) -> Result<(), Err> {
291        let next = seq.iter.next();
292        let next = match next {
293            None => None,
294            Some(e) => match e {
295                Err(e) => return Err(e),
296                Ok(e) => {
297                    seq.time += e.delta();
298                    Some(e)
299                }
300            },
301        };
302        seq.next = next;
303        Ok(())
304    }
305
306    GenIter(
307        #[coroutine]
308        move || {
309            let mut seq1 = unwrap!(seq_from_iter(iter1));
310            let mut seq2 = unwrap!(seq_from_iter(iter2));
311
312            let mut time = D::zero();
313
314            macro_rules! yield_event {
315                ($ev:ident, $time:expr) => {
316                    $ev.set_delta($time - time);
317                    time = $time;
318                    yield Ok($ev);
319                };
320            }
321
322            macro_rules! flush_seq_and_return {
323                ($seq:ident) => {
324                    while let Some(mut ev) = $seq.next.take() {
325                        yield_event!(ev, $seq.time);
326                        unwrap!(move_next(&mut $seq));
327                    }
328                    return;
329                };
330            }
331
332            loop {
333                if seq1.next.is_none() {
334                    if seq2.next.is_none() {
335                        break;
336                    } else {
337                        flush_seq_and_return!(seq2);
338                    }
339                }
340                if seq2.next.is_none() {
341                    flush_seq_and_return!(seq1);
342                }
343
344                if seq1.time < seq2.time {
345                    let mut ev = seq1.next.take().unwrap();
346                    yield_event!(ev, seq1.time);
347                    unwrap!(move_next(&mut seq1));
348                } else {
349                    let mut ev = seq2.next.take().unwrap();
350                    yield_event!(ev, seq2.time);
351                    unwrap!(move_next(&mut seq2));
352                }
353            }
354        },
355    )
356}
357
358struct EventMerger<D: 'static + MIDINum, E: 'static + MIDIDelta<D> + Send, Err: 'static + Send> {
359    _phantom: std::marker::PhantomData<(D, E, Err)>,
360}
361impl<D: 'static + MIDINum, E: 'static + MIDIDelta<D> + Send, Err: 'static + Send> MergableStreams
362    for EventMerger<D, E, Err>
363{
364    type Item = Result<E, Err>;
365
366    fn merge_two(
367        iter1: impl Iterator<Item = Self::Item> + Send + 'static,
368        iter2: impl Iterator<Item = Self::Item> + Send + 'static,
369    ) -> impl Iterator<Item = Self::Item> + Send + 'static {
370        merge_events(iter1, iter2)
371    }
372
373    fn merge_array(
374        array: Vec<impl Iterator<Item = Self::Item> + Send + 'static>,
375    ) -> impl Iterator<Item = Self::Item> + Send + 'static {
376        merge_events_array(array)
377    }
378}
379
380/// Group tracks into separate threads and merge them together
381pub fn grouped_multithreaded_merge_event_arrays<
382    D: 'static + MIDINum,
383    E: 'static + MIDIDelta<D> + Send,
384    Err: 'static + Send,
385    I: 'static + Iterator<Item = Result<E, Err>> + Sized + Send,
386>(
387    array: Vec<I>,
388) -> impl Iterator<Item = Result<E, Err>> {
389    grouped_multithreaded_merge::<EventMerger<D, E, Err>>(array)
390}