redact_composer_core/render/context/
mod.rs1use std::any::{type_name, TypeId};
2use std::collections::HashSet;
3use std::hash::{Hash, Hasher};
4use std::iter::successors;
5use std::marker::PhantomData;
6use std::ops::Bound::{Excluded, Included, Unbounded};
7use std::ops::{Bound, RangeBounds};
8
9use rand::{Rng, SeedableRng};
10use rand_chacha::ChaCha12Rng;
11use twox_hash::XxHash64;
12
13use crate::render::RenderSegment;
14use crate::render::{
15 tree::{Node, Tree},
16 Result,
17};
18use crate::timing::RangeOps;
19use crate::SegmentRef;
20use crate::{CompositionOptions, Element};
21
22use crate::error::RendererError::MissingContext;
23use crate::render::context::TimingRelation::*;
24
25#[cfg(test)]
26mod test;
27
28#[derive(Debug)]
33pub struct CompositionContext<'a> {
34 pub(crate) options: &'a CompositionOptions,
35 pub(crate) tree: &'a Tree<RenderSegment>,
36 pub(crate) start: &'a Node<RenderSegment>,
37 pub(crate) type_cache: Option<&'a Vec<HashSet<TypeId>>>,
38}
39
40impl Copy for CompositionContext<'_> {}
41
42impl Clone for CompositionContext<'_> {
43 fn clone(&self) -> Self {
44 *self
45 }
46}
47
48impl<'a> CompositionContext<'a> {
49 pub(crate) fn new(
50 options: &'a CompositionOptions,
51 tree: &'a Tree<RenderSegment>,
52 start: &'a Node<RenderSegment>,
53 type_cache: Option<&'a Vec<HashSet<TypeId>>>,
54 ) -> CompositionContext<'a> {
55 CompositionContext {
56 options,
57 tree,
58 start,
59 type_cache,
60 }
61 }
62
63 pub fn find<Element: crate::Element>(&self) -> CtxQuery<Element, impl Fn(&Element) -> bool> {
66 CtxQuery {
67 ctx: self,
68 timing: None,
69 scope: None,
70 where_fn: |_| true,
71 __: PhantomData,
72 }
73 }
74
75 pub fn beat_length(&self) -> i32 {
77 self.options.ticks_per_beat
78 }
79
80 pub fn rng(&self) -> impl Rng {
82 ChaCha12Rng::seed_from_u64(self.start.value.seed)
83 }
84
85 pub fn rng_with_seed(&self, seed: impl Hash) -> impl Rng {
88 let mut hasher = XxHash64::default();
89 self.start.value.seed.hash(&mut hasher);
90 seed.hash(&mut hasher);
91
92 ChaCha12Rng::seed_from_u64(hasher.finish())
93 }
94
95 fn get_all_segments_where<F: Element>(
100 &self,
101 where_clause: impl Fn(&F) -> bool,
102 relation: TimingConstraint,
103 scope: SearchScope,
104 ) -> Option<Vec<SegmentRef<F>>> {
105 let mut matching_segments: Vec<SegmentRef<F>> = vec![];
106
107 let search_start = (match scope {
108 SearchScope::WithinAncestor(t) => successors(Some(self.start), |node| {
109 node.parent.map(|idx| &self.tree[idx])
110 })
111 .filter(|node| {
112 successors(Some(&*node.value.segment.element), |&s| s.wrapped_element())
113 .any(|target| target.as_any().type_id() == t)
114 })
115 .last(),
116 _ => None,
117 })
118 .unwrap_or(&self.tree[0]);
119
120 for node in CtxIter::new::<F>(search_start, self.tree, self.type_cache, relation) {
121 if self.is_in_scope(&scope, node)
122 && node
123 .value
124 .segment
125 .element_as::<F>()
126 .is_some_and(&where_clause)
127 {
128 if let Ok(segment) = (&node.value.segment).try_into() {
129 matching_segments.insert(matching_segments.len(), segment);
130 }
131 }
132 }
133
134 if matching_segments.is_empty() {
135 None
136 } else {
137 Some(matching_segments)
138 }
139 }
140
141 fn is_in_scope(&self, scope: &SearchScope, node: &Node<RenderSegment>) -> bool {
142 match scope {
143 SearchScope::WithinAncestor(search_type) => {
144 let mut cursor = self.start.parent;
145 let mut opt_ancestor = None;
146
147 while let Some(cursor_node) = cursor.and_then(|p_idx| self.tree.get(p_idx)) {
148 if successors(Some(&*cursor_node.value.segment.element), |&s| {
149 s.wrapped_element()
150 })
151 .any(|s| s.as_any().type_id() == *search_type)
152 {
153 opt_ancestor = Some(cursor_node);
154 }
155
156 cursor = cursor_node.parent;
157 }
158
159 if let Some(ancestor) = opt_ancestor {
160 cursor = Some(node.idx);
161 while let Some(cursor_node) = cursor.and_then(|idx| self.tree.get(idx)) {
162 if cursor_node.idx == ancestor.idx {
163 return true;
164 }
165 cursor = cursor_node.parent;
166 }
167 }
168
169 false
170 }
171 SearchScope::Within(search_type) => {
172 let mut cursor = Some(node.idx);
173
174 while let Some(ancestor) = cursor.and_then(|p_idx| self.tree.get(p_idx)) {
175 if successors(Some(&*ancestor.value.segment.element), |&s| {
176 s.wrapped_element()
177 })
178 .any(|s| s.as_any().type_id() == *search_type)
179 {
180 return true;
181 }
182
183 cursor = ancestor.parent;
184 }
185
186 false
187 }
188 SearchScope::Anywhere => true,
189 }
190 }
191}
192
193#[derive(Debug)]
195pub struct CtxQuery<'a, S: Element, F: Fn(&S) -> bool> {
196 ctx: &'a CompositionContext<'a>,
197 timing: Option<TimingConstraint>,
198 scope: Option<SearchScope>,
199 where_fn: F,
200 __: PhantomData<S>,
201}
202
203impl<'a, S: Element, F: Fn(&S) -> bool> CtxQuery<'a, S, F> {
204 pub fn with_timing<R: RangeBounds<i32>>(mut self, relation: TimingRelation, timing: R) -> Self {
206 self.timing = Some(TimingConstraint::from((relation, timing)));
207
208 self
209 }
210
211 pub fn within<S2: Element>(mut self) -> Self {
215 self.scope = Some(SearchScope::Within(TypeId::of::<S2>()));
216
217 self
218 }
219
220 pub fn within_ancestor<S2: Element>(mut self) -> Self {
224 self.scope = Some(SearchScope::WithinAncestor(TypeId::of::<S2>()));
225
226 self
227 }
228
229 pub fn matching(self, where_fn: impl Fn(&S) -> bool) -> CtxQuery<'a, S, impl Fn(&S) -> bool> {
231 CtxQuery {
232 ctx: self.ctx,
233 timing: self.timing,
234 scope: self.scope,
235 where_fn,
236 __: self.__,
237 }
238 }
239
240 pub fn get(self) -> Option<SegmentRef<'a, S>> {
242 self.ctx
243 .get_all_segments_where::<S>(
244 self.where_fn,
245 self.timing.unwrap_or(TimingConstraint::from((
246 During,
247 self.ctx.start.value.segment.timing,
248 ))),
249 self.scope.unwrap_or(SearchScope::Anywhere),
250 )
251 .and_then(|mut v| {
252 if v.first().is_none() {
253 None
254 } else {
255 Some(v.swap_remove(0))
256 }
257 })
258 }
259
260 pub fn get_all(self) -> Option<Vec<SegmentRef<'a, S>>> {
262 self.get_at_least(1)
263 }
264
265 pub fn get_at_least(self, min_requested: usize) -> Option<Vec<SegmentRef<'a, S>>> {
268 if let Some(results) = self.ctx.get_all_segments_where::<S>(
269 self.where_fn,
270 self.timing.unwrap_or(TimingConstraint::from((
271 Overlapping,
272 self.ctx.start.value.segment.timing,
273 ))),
274 self.scope.unwrap_or(SearchScope::Anywhere),
275 ) {
276 if results.len() >= min_requested {
277 return Some(results);
278 }
279 }
280
281 None
282 }
283
284 pub fn require(self) -> Result<SegmentRef<'a, S>> {
286 self.get()
287 .ok_or(MissingContext(type_name::<S>().to_string()))
288 }
289
290 pub fn require_all(self) -> Result<Vec<SegmentRef<'a, S>>> {
292 self.require_at_least(1)
293 }
294
295 pub fn require_at_least(self, min_requested: usize) -> Result<Vec<SegmentRef<'a, S>>> {
298 self.get_at_least(min_requested)
299 .ok_or(MissingContext(type_name::<S>().to_string()))
300 }
301}
302
303#[derive(Debug)]
305pub enum TimingRelation {
306 During,
308 Overlapping,
310 Within,
312 BeginningWithin,
314 EndingWithin,
316 Before,
318 After,
320}
321
322#[derive(Debug)]
324enum SearchScope {
325 WithinAncestor(TypeId),
327 Within(TypeId),
329 Anywhere,
331}
332
333#[derive(Debug)]
335struct TimingConstraint {
336 pub relation: TimingRelation,
337 pub ref_range: (Bound<i32>, Bound<i32>),
338}
339
340impl<R: RangeBounds<i32>> From<(TimingRelation, R)> for TimingConstraint {
341 fn from(value: (TimingRelation, R)) -> Self {
342 TimingConstraint {
343 relation: value.0,
344 ref_range: (value.1.start_bound().cloned(), value.1.end_bound().cloned()),
345 }
346 }
347}
348
349impl TimingConstraint {
350 fn matches<T: RangeBounds<i32>>(&self, target_range: &T) -> bool {
352 match self.relation {
353 During => target_range.contains_range(&self.ref_range),
354 Overlapping => target_range.intersects(&self.ref_range),
355 Within => target_range.is_contained_by(&self.ref_range),
356 BeginningWithin => target_range.begins_within(&self.ref_range),
357 EndingWithin => target_range.ends_within(&self.ref_range),
358 Before => target_range.is_before(&self.ref_range),
359 After => target_range.is_after(&self.ref_range),
360 }
361 }
362
363 fn could_match_within<T: RangeBounds<i32>>(&self, target_range: &T) -> bool {
365 match self.relation {
366 During | Overlapping => self.matches(target_range),
367 Within | BeginningWithin | EndingWithin => self.ref_range.intersects(target_range),
368 Before => match self.ref_range.start_bound() {
369 Included(v) => target_range.intersects(&(Unbounded, Excluded(v))),
370 Excluded(v) => target_range.intersects(&(Unbounded, Included(v))),
371 Unbounded => false,
372 },
373 After => match self.ref_range.end_bound() {
374 Included(v) => target_range.intersects(&(Excluded(v), Unbounded)),
375 Excluded(v) => target_range.intersects(&(Included(v), Unbounded)),
376 Unbounded => false,
377 },
378 }
379 }
380}
381
382struct CtxIter<'a> {
383 tree: &'a Tree<RenderSegment>,
384 type_cache: Option<&'a Vec<HashSet<TypeId>>>,
385 idx: usize,
386 curr_nodes: Vec<&'a Node<RenderSegment>>,
387 next_nodes: Vec<&'a Node<RenderSegment>>,
388 time_relation: TimingConstraint,
389 search_type: TypeId,
390}
391
392impl<'a> Iterator for CtxIter<'a> {
393 type Item = &'a Node<RenderSegment>;
394
395 fn next(&mut self) -> Option<Self::Item> {
396 if let Some(node) = self.curr_nodes.get(self.idx) {
397 if self
398 .type_cache
399 .map_or(true, |cache| cache[node.idx].contains(&self.search_type))
400 {
401 let mut child_nodes: Vec<&Node<RenderSegment>> = node
402 .children
403 .iter()
404 .map(|child_idx| &self.tree[*child_idx])
405 .filter(|n| n.value.rendered && self.might_have_items(n))
406 .collect();
407
408 self.next_nodes.append(&mut child_nodes);
409 }
410 self.idx += 1;
411
412 if self.time_relation.matches(&node.value.segment) {
413 Some(node)
414 } else {
415 self.next()
416 }
417 } else if self.next_nodes.is_empty() {
418 None
419 } else {
420 self.curr_nodes = vec![];
421 self.curr_nodes.append(&mut self.next_nodes);
422 self.idx = 0;
423
424 self.next()
425 }
426 }
427}
428
429impl<'a> CtxIter<'a> {
430 fn new<S: Element>(
431 node: &'a Node<RenderSegment>,
432 tree: &'a Tree<RenderSegment>,
433 type_cache: Option<&'a Vec<HashSet<TypeId>>>,
434 relation: TimingConstraint,
435 ) -> CtxIter<'a> {
436 CtxIter {
437 tree,
438 type_cache,
439 idx: 0,
440 curr_nodes: vec![node],
441 next_nodes: vec![],
442 time_relation: relation,
443 search_type: TypeId::of::<S>(),
444 }
445 }
446
447 fn might_have_items(&self, node: &Node<RenderSegment>) -> bool {
448 self.time_relation.could_match_within(&node.value.segment)
449 }
450}