1use alloc::collections::BinaryHeap;
11use alloc::collections::VecDeque;
12use alloc::vec::Vec;
13use core::cmp::Reverse;
14use core::hash::Hash;
15
16use hashbrown::HashMap;
17use hashbrown::hash_map::Entry;
18
19use crate::DrainBuilder;
20use crate::channel::Channel;
21use crate::graph::InvalidationGraph;
22use crate::scratch::TraversalScratch;
23use crate::trace::InvalidationTrace;
24
25pub trait DenseKey: Copy {
40 fn index(self) -> usize;
42}
43
44impl DenseKey for u32 {
45 #[inline]
46 fn index(self) -> usize {
47 self as usize
48 }
49}
50
51impl DenseKey for usize {
52 #[inline]
53 fn index(self) -> usize {
54 self
55 }
56}
57
58const DENSE_SENTINEL: u32 = u32::MAX;
60
61#[inline]
62pub(crate) fn prepare_dense_growth<T>(vec: &mut Vec<T>, idx: usize, storage: &str) -> usize {
63 let target_len = idx.checked_add(1).unwrap_or_else(|| {
64 panic!("DenseKey index {idx} overflows addressable capacity for {storage}")
65 });
66
67 if target_len > vec.len() {
68 vec.try_reserve_exact(target_len - vec.len()).unwrap_or_else(|err| {
69 panic!(
70 "DenseKey index {idx} requires growing {storage} to length {target_len}: {err:?}; use a compact dense key space or intern::Interner"
71 )
72 });
73 }
74
75 target_len
76}
77
78#[derive(Copy, Clone, Debug, PartialEq, Eq)]
80pub enum DrainCompletion {
81 Complete,
83 Stalled {
85 remaining: usize,
87 },
88}
89
90#[derive(Debug)]
146pub struct DrainSorted<'a, K>
147where
148 K: Copy + Eq + Hash + DenseKey,
149{
150 graph: &'a InvalidationGraph<K>,
151 channel: Channel,
152 queue: VecDeque<K>,
154 in_degree: HashMap<K, usize>,
156 stalled: bool,
157}
158
159#[derive(Debug)]
167pub struct DrainSortedDeterministic<'a, K>
168where
169 K: Copy + Eq + Hash + Ord + DenseKey,
170{
171 graph: &'a InvalidationGraph<K>,
172 channel: Channel,
173 ready: BinaryHeap<Reverse<K>>,
175 in_degree: Vec<u32>,
178 remaining: usize,
180 stalled: bool,
181}
182
183impl<'a, K> DrainSorted<'a, K>
184where
185 K: Copy + Eq + Hash + DenseKey,
186{
187 pub(crate) fn from_iter_with_capacity<I>(
188 invalidated_keys: I,
189 cap: usize,
190 graph: &'a InvalidationGraph<K>,
191 channel: Channel,
192 ) -> Self
193 where
194 I: Iterator<Item = K>,
195 {
196 let mut in_degree: HashMap<K, usize> = HashMap::with_capacity(cap);
198 let mut unique_keys = Vec::with_capacity(cap);
199 for key in invalidated_keys {
200 if let Entry::Vacant(e) = in_degree.entry(key) {
201 e.insert(0);
202 unique_keys.push(key);
203 }
204 }
205
206 for &key in &unique_keys {
209 for dep in graph.dependencies(key, channel) {
210 if in_degree.contains_key(&dep) {
211 *in_degree.get_mut(&key).expect("key is in in_degree") += 1;
212 }
213 }
214 }
215
216 let mut queue = VecDeque::with_capacity(in_degree.len());
218 queue.extend(
219 unique_keys
220 .into_iter()
221 .filter(|&k| in_degree.get(&k).is_some_and(|°| deg == 0)),
222 );
223
224 Self {
225 graph,
226 channel,
227 queue,
228 in_degree,
229 stalled: false,
230 }
231 }
232
233 #[must_use]
239 pub fn is_empty(&self) -> bool {
240 self.queue.is_empty()
241 }
242
243 #[must_use]
245 pub fn remaining(&self) -> usize {
246 self.in_degree.len()
247 }
248
249 #[must_use]
256 pub fn is_stalled(&self) -> bool {
257 self.stalled
258 }
259
260 #[must_use]
268 pub fn completion(&self) -> DrainCompletion {
269 if self.stalled {
270 DrainCompletion::Stalled {
271 remaining: self.remaining(),
272 }
273 } else {
274 DrainCompletion::Complete
275 }
276 }
277
278 #[must_use]
280 pub fn collect_with_completion(mut self) -> (Vec<K>, DrainCompletion) {
281 let mut out = Vec::with_capacity(self.in_degree.len());
282 out.extend(&mut self);
283 let completion = self.completion();
284 (out, completion)
285 }
286}
287
288impl<'a, K> DrainSortedDeterministic<'a, K>
289where
290 K: Copy + Eq + Hash + Ord + DenseKey,
291{
292 pub(crate) fn from_iter_with_capacity<I>(
293 invalidated_keys: I,
294 cap: usize,
295 graph: &'a InvalidationGraph<K>,
296 channel: Channel,
297 ) -> Self
298 where
299 I: Iterator<Item = K>,
300 {
301 let mut in_degree: Vec<u32> = Vec::new();
303 let mut unique_keys = Vec::with_capacity(cap);
304 for key in invalidated_keys {
305 let idx = key.index();
306 if idx >= in_degree.len() {
307 let target_len =
308 prepare_dense_growth(&mut in_degree, idx, "deterministic drain in-degree");
309 in_degree.resize(target_len, DENSE_SENTINEL);
310 }
311 if in_degree[idx] == DENSE_SENTINEL {
312 in_degree[idx] = 0;
313 unique_keys.push(key);
314 }
315 }
316
317 for &key in &unique_keys {
319 for dep in graph.dependencies(key, channel) {
320 let dep_idx = dep.index();
321 if dep_idx < in_degree.len() && in_degree[dep_idx] != DENSE_SENTINEL {
322 in_degree[key.index()] += 1;
323 }
324 }
325 }
326
327 let remaining = unique_keys.len();
328
329 let mut ready = BinaryHeap::with_capacity(remaining);
331 for key in unique_keys {
332 if in_degree[key.index()] == 0 {
333 ready.push(Reverse(key));
334 }
335 }
336
337 Self {
338 graph,
339 channel,
340 ready,
341 in_degree,
342 remaining,
343 stalled: false,
344 }
345 }
346
347 #[must_use]
352 pub fn is_empty(&self) -> bool {
353 self.ready.is_empty()
354 }
355
356 #[must_use]
358 pub fn remaining(&self) -> usize {
359 self.remaining
360 }
361
362 #[must_use]
369 pub fn is_stalled(&self) -> bool {
370 self.stalled
371 }
372
373 #[must_use]
375 pub fn completion(&self) -> DrainCompletion {
376 if self.stalled {
377 DrainCompletion::Stalled {
378 remaining: self.remaining(),
379 }
380 } else {
381 DrainCompletion::Complete
382 }
383 }
384
385 #[must_use]
387 pub fn collect_with_completion(mut self) -> (Vec<K>, DrainCompletion) {
388 let mut out = Vec::with_capacity(self.remaining);
389 out.extend(&mut self);
390 let completion = self.completion();
391 (out, completion)
392 }
393}
394
395impl<K> Iterator for DrainSorted<'_, K>
396where
397 K: Copy + Eq + Hash + DenseKey,
398{
399 type Item = K;
400
401 fn next(&mut self) -> Option<Self::Item> {
402 let Some(key) = self.queue.pop_front() else {
403 if !self.in_degree.is_empty() {
404 self.stalled = true;
405 }
406 return None;
407 };
408
409 self.in_degree.remove(&key);
411
412 for dependent in self.graph.dependents(key, self.channel) {
414 if let Some(deg) = self.in_degree.get_mut(&dependent) {
415 *deg -= 1;
416 if *deg == 0 {
417 self.queue.push_back(dependent);
418 }
419 }
420 }
421
422 Some(key)
423 }
424
425 fn size_hint(&self) -> (usize, Option<usize>) {
426 let remaining = self.in_degree.len();
427 (remaining, Some(remaining))
428 }
429}
430
431impl<K> Iterator for DrainSortedDeterministic<'_, K>
432where
433 K: Copy + Eq + Hash + Ord + DenseKey,
434{
435 type Item = K;
436
437 fn next(&mut self) -> Option<Self::Item> {
438 let Some(Reverse(key)) = self.ready.pop() else {
439 if self.remaining > 0 {
440 self.stalled = true;
441 }
442 return None;
443 };
444
445 self.in_degree[key.index()] = DENSE_SENTINEL;
446 self.remaining -= 1;
447
448 for dependent in self.graph.dependents(key, self.channel) {
449 let idx = dependent.index();
450 if idx < self.in_degree.len() && self.in_degree[idx] != DENSE_SENTINEL {
451 self.in_degree[idx] -= 1;
452 if self.in_degree[idx] == 0 {
453 self.ready.push(Reverse(dependent));
454 }
455 }
456 }
457
458 Some(key)
459 }
460
461 fn size_hint(&self) -> (usize, Option<usize>) {
462 (self.remaining, Some(self.remaining))
463 }
464}
465
466pub fn drain_sorted<'a, K>(
498 invalidated: &mut crate::InvalidationSet<K>,
499 graph: &'a InvalidationGraph<K>,
500 channel: Channel,
501) -> DrainSorted<'a, K>
502where
503 K: Copy + Eq + Hash + DenseKey,
504{
505 DrainBuilder::new(invalidated, graph, channel)
506 .invalidated_only()
507 .run()
508}
509
510pub fn drain_sorted_deterministic<'a, K>(
515 invalidated: &mut crate::InvalidationSet<K>,
516 graph: &'a InvalidationGraph<K>,
517 channel: Channel,
518) -> DrainSortedDeterministic<'a, K>
519where
520 K: Copy + Eq + Hash + Ord + DenseKey,
521{
522 DrainBuilder::new(invalidated, graph, channel)
523 .invalidated_only()
524 .deterministic()
525 .run()
526}
527
528pub fn drain_affected_sorted<'a, K>(
570 invalidated: &mut crate::InvalidationSet<K>,
571 graph: &'a InvalidationGraph<K>,
572 channel: Channel,
573) -> DrainSorted<'a, K>
574where
575 K: Copy + Eq + Hash + DenseKey,
576{
577 DrainBuilder::new(invalidated, graph, channel)
578 .affected()
579 .run()
580}
581
582pub fn drain_affected_sorted_with_trace<'a, K, T>(
593 invalidated: &mut crate::InvalidationSet<K>,
594 graph: &'a InvalidationGraph<K>,
595 channel: Channel,
596 scratch: &mut TraversalScratch<K>,
597 trace: &mut T,
598) -> DrainSorted<'a, K>
599where
600 K: Copy + Eq + Hash + DenseKey,
601 T: InvalidationTrace<K>,
602{
603 DrainBuilder::new(invalidated, graph, channel)
604 .affected()
605 .trace(scratch, trace)
606 .run()
607}
608
609pub fn drain_affected_sorted_deterministic<'a, K>(
615 invalidated: &mut crate::InvalidationSet<K>,
616 graph: &'a InvalidationGraph<K>,
617 channel: Channel,
618) -> DrainSortedDeterministic<'a, K>
619where
620 K: Copy + Eq + Hash + Ord + DenseKey,
621{
622 DrainBuilder::new(invalidated, graph, channel)
623 .affected()
624 .deterministic()
625 .run()
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631 use alloc::vec;
632 use alloc::vec::Vec;
633
634 use crate::TraversalScratch;
635 use crate::graph::CycleHandling;
636 use crate::set::InvalidationSet;
637 use crate::trace::OneParentRecorder;
638
639 const LAYOUT: Channel = Channel::new(0);
640
641 #[test]
642 fn topological_order_chain() {
643 let mut graph = InvalidationGraph::<u32>::new();
644 graph
646 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
647 .unwrap();
648 graph
649 .add_dependency(3, 2, LAYOUT, CycleHandling::Error)
650 .unwrap();
651 graph
652 .add_dependency(4, 3, LAYOUT, CycleHandling::Error)
653 .unwrap();
654
655 let invalidated_keys = vec![4, 2, 1, 3]; let cap = invalidated_keys.len();
657 let sorted: Vec<_> =
658 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
659 .collect();
660
661 assert_eq!(sorted, vec![1, 2, 3, 4]);
663 }
664
665 #[test]
666 fn topological_order_diamond() {
667 let mut graph = InvalidationGraph::<u32>::new();
668 graph
670 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
671 .unwrap();
672 graph
673 .add_dependency(3, 1, LAYOUT, CycleHandling::Error)
674 .unwrap();
675 graph
676 .add_dependency(4, 2, LAYOUT, CycleHandling::Error)
677 .unwrap();
678 graph
679 .add_dependency(4, 3, LAYOUT, CycleHandling::Error)
680 .unwrap();
681
682 let invalidated_keys = vec![4, 3, 2, 1];
683 let cap = invalidated_keys.len();
684 let sorted: Vec<_> =
685 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
686 .collect();
687
688 assert_eq!(sorted[0], 1);
690 assert_eq!(sorted[3], 4);
691 assert!(sorted[1] == 2 || sorted[1] == 3);
692 assert!(sorted[2] == 2 || sorted[2] == 3);
693 }
694
695 #[test]
696 fn partial_invalidated_set() {
697 let mut graph = InvalidationGraph::<u32>::new();
698 graph
700 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
701 .unwrap();
702 graph
703 .add_dependency(3, 2, LAYOUT, CycleHandling::Error)
704 .unwrap();
705
706 let invalidated_keys = vec![3, 2];
708 let cap = invalidated_keys.len();
709 let sorted: Vec<_> =
710 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
711 .collect();
712
713 assert_eq!(sorted, vec![2, 3]);
715 }
716
717 #[test]
718 fn no_dependencies() {
719 let graph = InvalidationGraph::<u32>::new();
720 let invalidated_keys = vec![3, 1, 2];
721 let cap = invalidated_keys.len();
722 let sorted: Vec<_> =
723 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
724 .collect();
725
726 assert_eq!(sorted.len(), 3);
728 }
729
730 #[test]
731 fn drain_sorted_function() {
732 let mut graph = InvalidationGraph::<u32>::new();
733 graph
734 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
735 .unwrap();
736
737 let mut invalidated = InvalidationSet::new();
738 invalidated.mark(1, LAYOUT);
739 invalidated.mark(2, LAYOUT);
740
741 let sorted: Vec<_> = drain_sorted(&mut invalidated, &graph, LAYOUT).collect();
742 assert_eq!(sorted, vec![1, 2]);
743
744 assert!(!invalidated.has_invalidated(LAYOUT));
746 }
747
748 #[test]
749 fn empty_invalidated_set() {
750 let graph = InvalidationGraph::<u32>::new();
751 let invalidated_keys: Vec<u32> = vec![];
752 let cap = invalidated_keys.len();
753 let sorted: Vec<_> =
754 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
755 .collect();
756 assert!(sorted.is_empty());
757 }
758
759 #[test]
760 fn size_hint_accurate() {
761 let mut graph = InvalidationGraph::<u32>::new();
762 graph
763 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
764 .unwrap();
765
766 let invalidated_keys = vec![1, 2];
767 let cap = invalidated_keys.len();
768 let mut drain =
769 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT);
770
771 assert_eq!(drain.size_hint(), (2, Some(2)));
772 assert_eq!(drain.remaining(), 2);
773
774 let _ = drain.next();
775 assert_eq!(drain.size_hint(), (1, Some(1)));
776
777 let _ = drain.next();
778 assert_eq!(drain.size_hint(), (0, Some(0)));
779 assert!(drain.is_empty());
780 }
781
782 #[test]
783 fn duplicate_keys_deduplicated() {
784 let mut graph = InvalidationGraph::<u32>::new();
785 graph
786 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
787 .unwrap();
788
789 let invalidated_keys = vec![1, 2, 1, 2, 1];
791 let cap = invalidated_keys.len();
792 let sorted: Vec<_> =
793 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
794 .collect();
795
796 assert_eq!(sorted.len(), 2);
798 assert_eq!(sorted, vec![1, 2]);
799 }
800
801 #[test]
802 fn cycles_stall_drain() {
803 let mut graph = InvalidationGraph::<u32>::new();
804 graph
806 .add_dependency(2, 1, LAYOUT, CycleHandling::Allow)
807 .unwrap();
808 graph
809 .add_dependency(3, 2, LAYOUT, CycleHandling::Allow)
810 .unwrap();
811 graph
812 .add_dependency(1, 3, LAYOUT, CycleHandling::Allow)
813 .unwrap();
814
815 let invalidated_keys = vec![1, 2, 3];
817 let cap = invalidated_keys.len();
818 let mut drain =
819 DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT);
820 let sorted: Vec<_> = drain.by_ref().collect();
821
822 assert!(
824 sorted.is_empty(),
825 "cycle should prevent any keys from being yielded"
826 );
827 assert!(drain.is_stalled());
828 assert_eq!(
829 drain.completion(),
830 DrainCompletion::Stalled { remaining: 3 }
831 );
832 }
833
834 #[test]
835 fn cycles_stall_drain_collect_with_completion() {
836 let mut graph = InvalidationGraph::<u32>::new();
837 graph
838 .add_dependency(2, 1, LAYOUT, CycleHandling::Allow)
839 .unwrap();
840 graph
841 .add_dependency(3, 2, LAYOUT, CycleHandling::Allow)
842 .unwrap();
843 graph
844 .add_dependency(1, 3, LAYOUT, CycleHandling::Allow)
845 .unwrap();
846
847 let mut invalidated = InvalidationSet::new();
848 invalidated.mark(1, LAYOUT);
849 invalidated.mark(2, LAYOUT);
850 invalidated.mark(3, LAYOUT);
851
852 let (sorted, completion) =
853 drain_sorted(&mut invalidated, &graph, LAYOUT).collect_with_completion();
854 assert!(sorted.is_empty());
855 assert_eq!(completion, DrainCompletion::Stalled { remaining: 3 });
856 }
857
858 #[test]
859 fn drain_affected_sorted_expands_dependents() {
860 let mut graph = InvalidationGraph::<u32>::new();
861 graph
863 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
864 .unwrap();
865 graph
866 .add_dependency(3, 2, LAYOUT, CycleHandling::Error)
867 .unwrap();
868 graph
869 .add_dependency(4, 3, LAYOUT, CycleHandling::Error)
870 .unwrap();
871
872 let mut invalidated = InvalidationSet::new();
873 invalidated.mark(1, LAYOUT);
875
876 let sorted: Vec<_> = drain_affected_sorted(&mut invalidated, &graph, LAYOUT).collect();
878 assert_eq!(sorted, vec![1, 2, 3, 4]);
879
880 assert!(!invalidated.has_invalidated(LAYOUT));
882 }
883
884 #[test]
885 fn drain_affected_sorted_multiple_roots() {
886 let mut graph = InvalidationGraph::<u32>::new();
887 graph
889 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
890 .unwrap();
891 graph
892 .add_dependency(4, 3, LAYOUT, CycleHandling::Error)
893 .unwrap();
894
895 let mut invalidated = InvalidationSet::new();
896 invalidated.mark(1, LAYOUT);
897 invalidated.mark(3, LAYOUT);
898
899 let sorted: Vec<_> = drain_affected_sorted(&mut invalidated, &graph, LAYOUT).collect();
900 assert_eq!(sorted.len(), 4);
902 }
903
904 #[test]
905 fn deterministic_topological_order_diamond_is_total() {
906 let mut graph = InvalidationGraph::<u32>::new();
907 graph
909 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
910 .unwrap();
911 graph
912 .add_dependency(3, 1, LAYOUT, CycleHandling::Error)
913 .unwrap();
914 graph
915 .add_dependency(4, 2, LAYOUT, CycleHandling::Error)
916 .unwrap();
917 graph
918 .add_dependency(4, 3, LAYOUT, CycleHandling::Error)
919 .unwrap();
920
921 let invalidated_keys: Vec<u32> = vec![4, 3, 2, 1];
922 let cap = invalidated_keys.len();
923 let sorted: Vec<_> = DrainSortedDeterministic::from_iter_with_capacity(
924 invalidated_keys.into_iter(),
925 cap,
926 &graph,
927 LAYOUT,
928 )
929 .collect();
930
931 assert_eq!(sorted, vec![1, 2, 3, 4]);
933 }
934
935 #[test]
936 #[should_panic(expected = "DenseKey index")]
937 fn deterministic_drain_rejects_sparse_key_space() {
938 let graph = InvalidationGraph::<usize>::new();
939 let mut invalidated = InvalidationSet::new();
940 invalidated.mark(usize::MAX, LAYOUT);
941
942 let _: Vec<_> = drain_sorted_deterministic(&mut invalidated, &graph, LAYOUT).collect();
943 }
944
945 #[test]
946 fn affected_sorted_with_trace_records_one_path() {
947 let mut graph = InvalidationGraph::<u32>::new();
948 graph
950 .add_dependency(2, 1, LAYOUT, CycleHandling::Error)
951 .unwrap();
952 graph
953 .add_dependency(3, 2, LAYOUT, CycleHandling::Error)
954 .unwrap();
955
956 let mut invalidated = InvalidationSet::new();
957 invalidated.mark(1, LAYOUT);
958
959 let mut scratch = TraversalScratch::new();
960 let mut rec = OneParentRecorder::new();
961 let sorted: Vec<_> = drain_affected_sorted_with_trace(
962 &mut invalidated,
963 &graph,
964 LAYOUT,
965 &mut scratch,
966 &mut rec,
967 )
968 .collect();
969
970 assert_eq!(sorted, vec![1, 2, 3]);
971 assert_eq!(rec.explain_path(3, LAYOUT).unwrap(), vec![1, 2, 3]);
972 }
973}