redact_composer_core/render/context/
mod.rs

1use std::any::{type_name, TypeId};
2use std::collections::HashSet;
3use std::hash::{Hash, Hasher};
4use std::iter::successors;
5use std::marker::PhantomData;
6use std::ops::Bound::{Excluded, Included, Unbounded};
7use std::ops::{Bound, RangeBounds};
8
9use rand::{Rng, SeedableRng};
10use rand_chacha::ChaCha12Rng;
11use twox_hash::XxHash64;
12
13use crate::render::RenderSegment;
14use crate::render::{
15    tree::{Node, Tree},
16    Result,
17};
18use crate::timing::RangeOps;
19use crate::SegmentRef;
20use crate::{CompositionOptions, Element};
21
22use crate::error::RendererError::MissingContext;
23use crate::render::context::TimingRelation::*;
24
25#[cfg(test)]
26mod test;
27
28/// Provides access to common utilities, such as methods to lookup other composition tree nodes, or
29/// Rng.
30///
31/// This struct is provided as an argument to [`Renderer::render`](crate::render::Renderer::render).
32#[derive(Debug)]
33pub struct CompositionContext<'a> {
34    pub(crate) options: &'a CompositionOptions,
35    pub(crate) tree: &'a Tree<RenderSegment>,
36    pub(crate) start: &'a Node<RenderSegment>,
37    pub(crate) type_cache: Option<&'a Vec<HashSet<TypeId>>>,
38}
39
40impl Copy for CompositionContext<'_> {}
41
42impl Clone for CompositionContext<'_> {
43    fn clone(&self) -> Self {
44        *self
45    }
46}
47
48impl<'a> CompositionContext<'a> {
49    pub(crate) fn new(
50        options: &'a CompositionOptions,
51        tree: &'a Tree<RenderSegment>,
52        start: &'a Node<RenderSegment>,
53        type_cache: Option<&'a Vec<HashSet<TypeId>>>,
54    ) -> CompositionContext<'a> {
55        CompositionContext {
56            options,
57            tree,
58            start,
59            type_cache,
60        }
61    }
62
63    /// Search the in-progress composition tree for nodes of type `Element`.
64    /// Returns a [`CtxQuery`], allowing further specifications before running the search.
65    pub fn find<Element: crate::Element>(&self) -> CtxQuery<Element, impl Fn(&Element) -> bool> {
66        CtxQuery {
67            ctx: self,
68            timing: None,
69            scope: None,
70            where_fn: |_| true,
71            __: PhantomData,
72        }
73    }
74
75    /// Returns the composition beat length. A composition's tempo (BPM) is relative to this value.
76    pub fn beat_length(&self) -> i32 {
77        self.options.ticks_per_beat
78    }
79
80    /// Creates an [`Rng`] seeded from the currently rendering segment's seed.
81    pub fn rng(&self) -> impl Rng {
82        ChaCha12Rng::seed_from_u64(self.start.value.seed)
83    }
84
85    /// Creates an [`Rng`] seeded from a combination of the currently rendering segment's seed
86    /// as well as the provided seed.
87    pub fn rng_with_seed(&self, seed: impl Hash) -> impl Rng {
88        let mut hasher = XxHash64::default();
89        self.start.value.seed.hash(&mut hasher);
90        seed.hash(&mut hasher);
91
92        ChaCha12Rng::seed_from_u64(hasher.finish())
93    }
94
95    /// Search the in-progress composition tree for all [`Element`]s within the given
96    /// [`TimingConstraint`] and [`SearchScope`] criteria that match the provided closure. Returns
97    /// a vector of [`SegmentRef`]s referencing the matching [`Element`]s if any were found,
98    /// or else [`None`]. This is useful if the timing data is required.
99    fn get_all_segments_where<F: Element>(
100        &self,
101        where_clause: impl Fn(&F) -> bool,
102        relation: TimingConstraint,
103        scope: SearchScope,
104    ) -> Option<Vec<SegmentRef<F>>> {
105        let mut matching_segments: Vec<SegmentRef<F>> = vec![];
106
107        let search_start = (match scope {
108            SearchScope::WithinAncestor(t) => successors(Some(self.start), |node| {
109                node.parent.map(|idx| &self.tree[idx])
110            })
111            .filter(|node| {
112                successors(Some(&*node.value.segment.element), |&s| s.wrapped_element())
113                    .any(|target| target.as_any().type_id() == t)
114            })
115            .last(),
116            _ => None,
117        })
118        .unwrap_or(&self.tree[0]);
119
120        for node in CtxIter::new::<F>(search_start, self.tree, self.type_cache, relation) {
121            if self.is_in_scope(&scope, node)
122                && node
123                    .value
124                    .segment
125                    .element_as::<F>()
126                    .is_some_and(&where_clause)
127            {
128                if let Ok(segment) = (&node.value.segment).try_into() {
129                    matching_segments.insert(matching_segments.len(), segment);
130                }
131            }
132        }
133
134        if matching_segments.is_empty() {
135            None
136        } else {
137            Some(matching_segments)
138        }
139    }
140
141    fn is_in_scope(&self, scope: &SearchScope, node: &Node<RenderSegment>) -> bool {
142        match scope {
143            SearchScope::WithinAncestor(search_type) => {
144                let mut cursor = self.start.parent;
145                let mut opt_ancestor = None;
146
147                while let Some(cursor_node) = cursor.and_then(|p_idx| self.tree.get(p_idx)) {
148                    if successors(Some(&*cursor_node.value.segment.element), |&s| {
149                        s.wrapped_element()
150                    })
151                    .any(|s| s.as_any().type_id() == *search_type)
152                    {
153                        opt_ancestor = Some(cursor_node);
154                    }
155
156                    cursor = cursor_node.parent;
157                }
158
159                if let Some(ancestor) = opt_ancestor {
160                    cursor = Some(node.idx);
161                    while let Some(cursor_node) = cursor.and_then(|idx| self.tree.get(idx)) {
162                        if cursor_node.idx == ancestor.idx {
163                            return true;
164                        }
165                        cursor = cursor_node.parent;
166                    }
167                }
168
169                false
170            }
171            SearchScope::Within(search_type) => {
172                let mut cursor = Some(node.idx);
173
174                while let Some(ancestor) = cursor.and_then(|p_idx| self.tree.get(p_idx)) {
175                    if successors(Some(&*ancestor.value.segment.element), |&s| {
176                        s.wrapped_element()
177                    })
178                    .any(|s| s.as_any().type_id() == *search_type)
179                    {
180                        return true;
181                    }
182
183                    cursor = ancestor.parent;
184                }
185
186                false
187            }
188            SearchScope::Anywhere => true,
189        }
190    }
191}
192
193/// A context query builder. Initiate a query via [`CompositionContext::find`].
194#[derive(Debug)]
195pub struct CtxQuery<'a, S: Element, F: Fn(&S) -> bool> {
196    ctx: &'a CompositionContext<'a>,
197    timing: Option<TimingConstraint>,
198    scope: Option<SearchScope>,
199    where_fn: F,
200    __: PhantomData<S>,
201}
202
203impl<'a, S: Element, F: Fn(&S) -> bool> CtxQuery<'a, S, F> {
204    /// Restrict the search to segments matching a given [`TimingRelation`].
205    pub fn with_timing<R: RangeBounds<i32>>(mut self, relation: TimingRelation, timing: R) -> Self {
206        self.timing = Some(TimingConstraint::from((relation, timing)));
207
208        self
209    }
210
211    /// Restrict the search to descendent segments a given [`Element`] type. This does
212    /// not in itself impose any timing constraints for the search -- for that, use
213    /// [`with_timing`](Self::with_timing).
214    pub fn within<S2: Element>(mut self) -> Self {
215        self.scope = Some(SearchScope::Within(TypeId::of::<S2>()));
216
217        self
218    }
219
220    /// Restrict the search to segments generated within the initiator's ancestor of the
221    /// given [`Element`]. This does not in itself impose any timing constraints for the
222    /// search -- for that, use [`with_timing`](Self::with_timing).
223    pub fn within_ancestor<S2: Element>(mut self) -> Self {
224        self.scope = Some(SearchScope::WithinAncestor(TypeId::of::<S2>()));
225
226        self
227    }
228
229    /// Restrict the search to segments matching the supplied closure.
230    pub fn matching(self, where_fn: impl Fn(&S) -> bool) -> CtxQuery<'a, S, impl Fn(&S) -> bool> {
231        CtxQuery {
232            ctx: self.ctx,
233            timing: self.timing,
234            scope: self.scope,
235            where_fn,
236            __: self.__,
237        }
238    }
239
240    /// Runs the context query, and returns a single optional result, or [`None`] if none are found.
241    pub fn get(self) -> Option<SegmentRef<'a, S>> {
242        self.ctx
243            .get_all_segments_where::<S>(
244                self.where_fn,
245                self.timing.unwrap_or(TimingConstraint::from((
246                    During,
247                    self.ctx.start.value.segment.timing,
248                ))),
249                self.scope.unwrap_or(SearchScope::Anywhere),
250            )
251            .and_then(|mut v| {
252                if v.first().is_none() {
253                    None
254                } else {
255                    Some(v.swap_remove(0))
256                }
257            })
258    }
259
260    /// Runs the context query, and returns all results, or [`None`] if none are found.
261    pub fn get_all(self) -> Option<Vec<SegmentRef<'a, S>>> {
262        self.get_at_least(1)
263    }
264
265    /// Runs the context query. Returns all results if at least `min_requested` results are found,
266    /// otherwise [`None`] is returned.
267    pub fn get_at_least(self, min_requested: usize) -> Option<Vec<SegmentRef<'a, S>>> {
268        if let Some(results) = self.ctx.get_all_segments_where::<S>(
269            self.where_fn,
270            self.timing.unwrap_or(TimingConstraint::from((
271                Overlapping,
272                self.ctx.start.value.segment.timing,
273            ))),
274            self.scope.unwrap_or(SearchScope::Anywhere),
275        ) {
276            if results.len() >= min_requested {
277                return Some(results);
278            }
279        }
280
281        None
282    }
283
284    /// Runs the context query, and returns a single result, or [`MissingContext`] error if none are found.
285    pub fn require(self) -> Result<SegmentRef<'a, S>> {
286        self.get()
287            .ok_or(MissingContext(type_name::<S>().to_string()))
288    }
289
290    /// Runs the context query, and returns all results, or [`MissingContext`] error if none are found.
291    pub fn require_all(self) -> Result<Vec<SegmentRef<'a, S>>> {
292        self.require_at_least(1)
293    }
294
295    /// Runs the context query. If at least `min_requested` results are found they are returned,
296    /// otherwise a [`MissingContext`] error is returned.
297    pub fn require_at_least(self, min_requested: usize) -> Result<Vec<SegmentRef<'a, S>>> {
298        self.get_at_least(min_requested)
299            .ok_or(MissingContext(type_name::<S>().to_string()))
300    }
301}
302
303/// Describes a timing relationship to reference time range.
304#[derive(Debug)]
305pub enum TimingRelation {
306    /// Describes a relationship for a target whose time range fully includes the reference time range.
307    During,
308    /// Describes a relationship for a target whose time range shares any part of the reference time range.
309    Overlapping,
310    /// Describes a relationship for a target whose time range is fully enclosed within the reference time range.
311    Within,
312    /// Describes a relationship for a target whose time range begins within the reference time range.
313    BeginningWithin,
314    /// Describes a relationship for a target whose time range ends within the reference time range.
315    EndingWithin,
316    /// Describes a relationship for a target whose time range ends before/at the reference time range begin.
317    Before,
318    /// Describes a relationship for a target whose time range starts after/at the reference time range end.
319    After,
320}
321
322/// Used to describe which portions of a composition tree to search during a context lookup.
323#[derive(Debug)]
324enum SearchScope {
325    /// Describes the relationship for a target that is a descendent of a particular ancestor of the reference node type.
326    WithinAncestor(TypeId),
327    /// Describes the relationship for a target that is a descendent of a particular reference node type.
328    Within(TypeId),
329    /// Describes a scope that has no restrictions.
330    Anywhere,
331}
332
333/// Describes a relationship between a target and reference time range.
334#[derive(Debug)]
335struct TimingConstraint {
336    pub relation: TimingRelation,
337    pub ref_range: (Bound<i32>, Bound<i32>),
338}
339
340impl<R: RangeBounds<i32>> From<(TimingRelation, R)> for TimingConstraint {
341    fn from(value: (TimingRelation, R)) -> Self {
342        TimingConstraint {
343            relation: value.0,
344            ref_range: (value.1.start_bound().cloned(), value.1.end_bound().cloned()),
345        }
346    }
347}
348
349impl TimingConstraint {
350    // Determines if a target time range matches this relationship.
351    fn matches<T: RangeBounds<i32>>(&self, target_range: &T) -> bool {
352        match self.relation {
353            During => target_range.contains_range(&self.ref_range),
354            Overlapping => target_range.intersects(&self.ref_range),
355            Within => target_range.is_contained_by(&self.ref_range),
356            BeginningWithin => target_range.begins_within(&self.ref_range),
357            EndingWithin => target_range.ends_within(&self.ref_range),
358            Before => target_range.is_before(&self.ref_range),
359            After => target_range.is_after(&self.ref_range),
360        }
361    }
362
363    // Determines if a target time range could contain a matche for this relationship.
364    fn could_match_within<T: RangeBounds<i32>>(&self, target_range: &T) -> bool {
365        match self.relation {
366            During | Overlapping => self.matches(target_range),
367            Within | BeginningWithin | EndingWithin => self.ref_range.intersects(target_range),
368            Before => match self.ref_range.start_bound() {
369                Included(v) => target_range.intersects(&(Unbounded, Excluded(v))),
370                Excluded(v) => target_range.intersects(&(Unbounded, Included(v))),
371                Unbounded => false,
372            },
373            After => match self.ref_range.end_bound() {
374                Included(v) => target_range.intersects(&(Excluded(v), Unbounded)),
375                Excluded(v) => target_range.intersects(&(Included(v), Unbounded)),
376                Unbounded => false,
377            },
378        }
379    }
380}
381
382struct CtxIter<'a> {
383    tree: &'a Tree<RenderSegment>,
384    type_cache: Option<&'a Vec<HashSet<TypeId>>>,
385    idx: usize,
386    curr_nodes: Vec<&'a Node<RenderSegment>>,
387    next_nodes: Vec<&'a Node<RenderSegment>>,
388    time_relation: TimingConstraint,
389    search_type: TypeId,
390}
391
392impl<'a> Iterator for CtxIter<'a> {
393    type Item = &'a Node<RenderSegment>;
394
395    fn next(&mut self) -> Option<Self::Item> {
396        if let Some(node) = self.curr_nodes.get(self.idx) {
397            if self
398                .type_cache
399                .map_or(true, |cache| cache[node.idx].contains(&self.search_type))
400            {
401                let mut child_nodes: Vec<&Node<RenderSegment>> = node
402                    .children
403                    .iter()
404                    .map(|child_idx| &self.tree[*child_idx])
405                    .filter(|n| n.value.rendered && self.might_have_items(n))
406                    .collect();
407
408                self.next_nodes.append(&mut child_nodes);
409            }
410            self.idx += 1;
411
412            if self.time_relation.matches(&node.value.segment) {
413                Some(node)
414            } else {
415                self.next()
416            }
417        } else if self.next_nodes.is_empty() {
418            None
419        } else {
420            self.curr_nodes = vec![];
421            self.curr_nodes.append(&mut self.next_nodes);
422            self.idx = 0;
423
424            self.next()
425        }
426    }
427}
428
429impl<'a> CtxIter<'a> {
430    fn new<S: Element>(
431        node: &'a Node<RenderSegment>,
432        tree: &'a Tree<RenderSegment>,
433        type_cache: Option<&'a Vec<HashSet<TypeId>>>,
434        relation: TimingConstraint,
435    ) -> CtxIter<'a> {
436        CtxIter {
437            tree,
438            type_cache,
439            idx: 0,
440            curr_nodes: vec![node],
441            next_nodes: vec![],
442            time_relation: relation,
443            search_type: TypeId::of::<S>(),
444        }
445    }
446
447    fn might_have_items(&self, node: &Node<RenderSegment>) -> bool {
448        self.time_relation.could_match_within(&node.value.segment)
449    }
450}