1use crate::storage::filter_stats::FilterStats;
2use crate::{
3 DBData, DBWeight, NumEntries,
4 algebra::{NegByRef, ZRingValue},
5 circuit::checkpointer::Checkpoint,
6 dynamic::{
7 DataTrait, DynDataTyped, DynPair, DynUnit, DynVec, DynWeightedPairs, Erase, Factory,
8 LeanVec, WeightTrait, WeightTraitTyped, WithFactory,
9 },
10 trace::{
11 Batch, BatchFactories, BatchReader, BatchReaderFactories, Builder, Cursor, DbspSerializer,
12 Deserializer, Filter, GroupFilter, MergeCursor, VecKeyBatch, WeightedItem,
13 cursor::Position,
14 deserialize_wset,
15 layers::{Cursor as _, Leaf, LeafCursor, LeafFactories, Trie},
16 ord::merge_batcher::MergeBatcher,
17 serialize_wset,
18 },
19 utils::Tup2,
20};
21use crate::{DynZWeight, ZWeight};
22use itertools::{EitherOrBoth, Itertools};
23use rand::Rng;
24use rkyv::{Archive, Deserialize, Serialize};
25use size_of::SizeOf;
26use std::any::TypeId;
27use std::fmt::{self, Debug, Display};
28
29pub struct VecWSetFactories<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> {
30 pub layer_factories: LeafFactories<K, R>,
31 item_factory: &'static dyn Factory<DynPair<K, DynUnit>>,
32 weighted_item_factory: &'static dyn Factory<WeightedItem<K, DynUnit, R>>,
33 weighted_items_factory: &'static dyn Factory<DynWeightedPairs<DynPair<K, DynUnit>, R>>,
34 weighted_vals_factory: &'static dyn Factory<DynWeightedPairs<DynUnit, R>>,
35 }
37
38impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Clone for VecWSetFactories<K, R> {
39 fn clone(&self) -> Self {
40 Self {
41 layer_factories: self.layer_factories.clone(),
42 item_factory: self.item_factory,
43 weighted_item_factory: self.weighted_item_factory,
44 weighted_items_factory: self.weighted_items_factory,
45 weighted_vals_factory: self.weighted_vals_factory,
46 }
47 }
48}
49
50impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> BatchReaderFactories<K, DynUnit, (), R>
51 for VecWSetFactories<K, R>
52{
53 fn new<KType, VType, RType>() -> Self
54 where
55 KType: DBData + Erase<K>,
56 VType: DBData + Erase<DynUnit>,
57 RType: DBWeight + Erase<R>,
58 {
59 Self {
60 layer_factories: LeafFactories::new::<KType, RType>(),
61 item_factory: WithFactory::<Tup2<KType, ()>>::FACTORY,
62 weighted_item_factory: WithFactory::<Tup2<Tup2<KType, ()>, RType>>::FACTORY,
63 weighted_items_factory: WithFactory::<LeanVec<Tup2<Tup2<KType, ()>, RType>>>::FACTORY,
64 weighted_vals_factory: WithFactory::<LeanVec<Tup2<(), RType>>>::FACTORY,
65 }
66 }
67
68 fn key_factory(&self) -> &'static dyn Factory<K> {
69 self.layer_factories.key
70 }
71
72 fn keys_factory(&self) -> &'static dyn Factory<DynVec<K>> {
73 self.layer_factories.keys
74 }
75
76 fn val_factory(&self) -> &'static dyn Factory<DynUnit> {
77 WithFactory::<()>::FACTORY
78 }
79
80 fn weight_factory(&self) -> &'static dyn Factory<R> {
81 self.layer_factories.diff
82 }
83}
84
85impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> BatchFactories<K, DynUnit, (), R>
86 for VecWSetFactories<K, R>
87{
88 fn item_factory(&self) -> &'static dyn Factory<DynPair<K, DynUnit>> {
91 self.item_factory
92 }
93
94 fn weighted_item_factory(&self) -> &'static dyn Factory<WeightedItem<K, DynUnit, R>> {
95 self.weighted_item_factory
96 }
97
98 fn weighted_items_factory(
99 &self,
100 ) -> &'static dyn Factory<DynWeightedPairs<DynPair<K, DynUnit>, R>> {
101 self.weighted_items_factory
102 }
103
104 fn weighted_vals_factory(&self) -> &'static dyn Factory<DynWeightedPairs<DynUnit, R>> {
105 self.weighted_vals_factory
106 }
107
108 fn time_diffs_factory(
109 &self,
110 ) -> Option<&'static dyn Factory<DynWeightedPairs<DynDataTyped<()>, R>>> {
111 None
112 }
113}
114
115pub struct VecWSet<K, R>
117where
118 K: DataTrait + ?Sized,
119 R: WeightTrait + ?Sized,
120{
121 #[doc(hidden)]
122 pub layer: Leaf<K, R>,
123 factories: VecWSetFactories<K, R>,
124 negative_weight_count: u64,
125}
126
127impl<K, R> SizeOf for VecWSet<K, R>
128where
129 K: DataTrait + ?Sized,
130 R: WeightTrait + ?Sized,
131{
132 fn size_of_children(&self, context: &mut size_of::Context) {
133 context.add(self.approximate_byte_size());
136 }
137}
138
139impl<K, R> VecWSet<K, R>
140where
141 K: DataTrait + ?Sized,
142 R: WeightTrait + ?Sized,
143{
144 pub fn from_parts(
145 factories: VecWSetFactories<K, R>,
146 keys: Box<DynVec<K>>,
147 diffs: Box<DynVec<R>>,
148 ) -> Self {
149 Self {
150 layer: Leaf::from_parts(&factories.layer_factories, keys, diffs),
151 factories,
152 negative_weight_count: 0,
153 }
154 }
155}
156
157impl<K, R> PartialEq for VecWSet<K, R>
158where
159 K: DataTrait + ?Sized,
160 R: WeightTrait + ?Sized,
161{
162 fn eq(&self, other: &Self) -> bool {
163 self.layer == other.layer
164 }
165}
166
167impl<K, R> Checkpoint for VecWSet<K, R>
168where
169 K: DataTrait + ?Sized,
170 R: WeightTrait + ?Sized,
171{
172 fn checkpoint(&self) -> Result<Vec<u8>, crate::Error> {
173 Ok(serialize_wset(self))
174 }
175
176 fn restore(&mut self, data: &[u8]) -> Result<(), crate::Error> {
177 *self = deserialize_wset(&self.factories, data);
178 Ok(())
179 }
180}
181
182impl<K, R> Eq for VecWSet<K, R>
183where
184 K: DataTrait + ?Sized,
185 R: WeightTrait + ?Sized,
186{
187}
188
189impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Debug for VecWSet<K, R> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_struct("VecWSet")
192 .field("layer", &self.layer)
193 .finish()
194 }
195}
196
197impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Clone for VecWSet<K, R> {
198 fn clone(&self) -> Self {
199 Self {
200 layer: self.layer.clone(),
201 factories: self.factories.clone(),
202 negative_weight_count: self.negative_weight_count,
203 }
204 }
205}
206
207impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> VecWSet<K, R> {
208 #[inline]
209 pub fn len(&self) -> usize {
210 self.layer.len()
211 }
212
213 #[inline]
214 pub fn is_empty(&self) -> bool {
215 self.layer.is_empty()
216 }
217
218 }
236
237impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Deserialize<VecWSet<K, R>, Deserializer>
238 for ()
239{
240 fn deserialize(
241 &self,
242 _deserializer: &mut Deserializer,
243 ) -> Result<VecWSet<K, R>, <Deserializer as rkyv::Fallible>::Error> {
244 todo!()
245 }
246}
247
248impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Archive for VecWSet<K, R> {
249 type Archived = ();
250 type Resolver = ();
251
252 unsafe fn resolve(&self, _pos: usize, _resolver: Self::Resolver, _out: *mut Self::Archived) {
253 todo!()
254 }
255}
256impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Serialize<DbspSerializer<'_>>
257 for VecWSet<K, R>
258{
259 fn serialize(
260 &self,
261 _serializer: &mut DbspSerializer,
262 ) -> Result<Self::Resolver, <DbspSerializer<'_> as rkyv::Fallible>::Error> {
263 todo!()
264 }
265}
266
267impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Display for VecWSet<K, R> {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 writeln!(
270 f,
271 "layer:\n{}",
272 textwrap::indent(&self.layer.to_string(), " ")
273 )
274 }
275}
276
277impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> NumEntries for VecWSet<K, R> {
290 const CONST_NUM_ENTRIES: Option<usize> = <Leaf<K, R>>::CONST_NUM_ENTRIES;
291
292 fn num_entries_shallow(&self) -> usize {
293 self.layer.num_entries_shallow()
294 }
295
296 fn num_entries_deep(&self) -> usize {
297 self.layer.num_entries_deep()
298 }
299}
300
301impl<K: DataTrait + ?Sized, R: WeightTraitTyped + ?Sized> NegByRef for VecWSet<K, R>
302where
303 R::Type: DBWeight + ZRingValue + Erase<R>,
304{
305 fn neg_by_ref(&self) -> Self {
306 Self {
307 layer: self.layer.neg_by_ref(),
308 factories: self.factories.clone(),
309 negative_weight_count: self.negative_weight_count,
310 }
311 }
312}
313
314impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> BatchReader for VecWSet<K, R> {
315 type Key = K;
316 type Val = DynUnit;
317 type Time = ();
318 type R = R;
319 type Cursor<'s> = VecWSetCursor<'s, K, R>;
320 type Factories = VecWSetFactories<K, R>;
321 #[inline]
324 fn factories(&self) -> Self::Factories {
325 self.factories.clone()
326 }
327
328 #[inline]
329 fn cursor(&self) -> Self::Cursor<'_> {
330 VecWSetCursor {
331 valid: true,
332 cursor: self.layer.cursor(),
333 }
334 }
335
336 fn consuming_cursor(
337 &mut self,
338 key_filter: Option<Filter<Self::Key>>,
339 value_filter: Option<GroupFilter<Self::Val>>,
340 ) -> Box<dyn crate::trace::MergeCursor<Self::Key, Self::Val, Self::Time, Self::R> + Send + '_>
341 {
342 if key_filter.is_none() && value_filter.is_none() {
343 Box::new(VecWSetConsumingCursor::new(self))
344 } else {
345 self.merge_cursor(key_filter, value_filter)
346 }
347 }
348
349 #[inline]
358 fn key_count(&self) -> usize {
359 Trie::keys(&self.layer)
360 }
361
362 #[inline]
363 fn len(&self) -> usize {
364 self.layer.tuples()
365 }
366
367 #[inline]
368 fn approximate_byte_size(&self) -> usize {
369 self.layer.approximate_byte_size()
370 }
371
372 fn membership_filter_stats(&self) -> FilterStats {
373 FilterStats::default()
374 }
375
376 fn sample_keys<RG>(&self, rng: &mut RG, sample_size: usize, sample: &mut DynVec<Self::Key>)
377 where
378 RG: Rng,
379 {
380 self.layer.sample_keys(rng, sample_size, sample);
381 }
382
383 fn keys(&self) -> Option<&DynVec<Self::Key>> {
384 Some(&*self.layer.keys)
385 }
386}
387
388impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Batch for VecWSet<K, R> {
389 type Timed<T: crate::Timestamp> = VecKeyBatch<K, T, R>;
390 type Batcher = MergeBatcher<Self>;
391 type Builder = VecWSetBuilder<K, R>;
392
393 fn negative_weight_count(&self) -> Option<u64> {
394 Some(self.negative_weight_count)
395 }
396}
397
398#[derive(Debug, SizeOf)]
400pub struct VecWSetCursor<'s, K: DataTrait + ?Sized, R: WeightTrait + ?Sized> {
401 valid: bool,
402 pub(crate) cursor: LeafCursor<'s, K, R>,
403}
404
405impl<K, R> Clone for VecWSetCursor<'_, K, R>
406where
407 K: DataTrait + ?Sized,
408 R: WeightTrait + ?Sized,
409{
410 fn clone(&self) -> Self {
411 Self {
412 valid: self.valid,
413 cursor: self.cursor.clone(),
414 }
415 }
416}
417
418impl<K: DataTrait + ?Sized, R: WeightTrait + ?Sized> Cursor<K, DynUnit, (), R>
419 for VecWSetCursor<'_, K, R>
420{
421 fn weight_factory(&self) -> &'static dyn Factory<R> {
430 self.cursor.storage.factories.diff
431 }
432
433 fn key(&self) -> &K {
434 self.cursor.current_key()
435 }
436
437 fn val(&self) -> &DynUnit {
438 &()
439 }
440
441 fn map_times(&mut self, logic: &mut dyn FnMut(&(), &R)) {
442 if self.cursor.valid() {
443 logic(&(), self.cursor.current_diff())
444 }
445 }
446
447 fn map_times_through(&mut self, _upper: &(), logic: &mut dyn FnMut(&(), &R)) {
448 self.map_times(logic)
449 }
450
451 fn weight(&mut self) -> &R {
452 debug_assert!(&self.cursor.valid());
453 self.cursor.current_diff()
454 }
455
456 fn weight_checked(&mut self) -> &R {
457 self.weight()
458 }
459
460 fn map_values(&mut self, logic: &mut dyn FnMut(&DynUnit, &R)) {
461 if self.val_valid() {
462 logic(&(), self.cursor.current_diff())
463 }
464 }
465
466 fn key_valid(&self) -> bool {
467 self.cursor.valid()
468 }
469
470 fn val_valid(&self) -> bool {
471 self.valid
472 }
473
474 fn step_key(&mut self) {
475 self.cursor.step();
476 self.valid = true;
477 }
478
479 fn step_key_reverse(&mut self) {
480 self.cursor.step_reverse();
481 self.valid = true;
482 }
483
484 fn seek_key(&mut self, key: &K) {
485 self.cursor.seek(key);
486 self.valid = true;
487 }
488
489 fn seek_key_exact(&mut self, key: &K, _hash: Option<u64>) -> bool {
490 self.seek_key(key);
491 self.key_valid() && self.key().eq(key)
492 }
493
494 fn seek_key_with(&mut self, predicate: &dyn Fn(&K) -> bool) {
495 self.cursor.seek_key_with(predicate);
496 self.valid = true;
497 }
498
499 fn seek_key_with_reverse(&mut self, predicate: &dyn Fn(&K) -> bool) {
500 self.cursor.seek_key_with_reverse(predicate);
501 self.valid = true;
502 }
503
504 fn seek_key_reverse(&mut self, key: &K) {
505 self.cursor.seek_reverse(key);
506 self.valid = true;
507 }
508
509 fn step_val(&mut self) {
510 self.valid = false;
511 }
512
513 fn seek_val(&mut self, _val: &DynUnit) {}
514
515 fn seek_val_with(&mut self, predicate: &dyn Fn(&DynUnit) -> bool) {
516 if !predicate(&()) {
517 self.valid = false;
518 }
519 }
520
521 fn rewind_keys(&mut self) {
522 self.cursor.rewind();
523 self.valid = true;
524 }
525
526 fn fast_forward_keys(&mut self) {
527 self.cursor.fast_forward();
528 self.valid = true;
529 }
530
531 fn rewind_vals(&mut self) {
532 self.valid = true;
533 }
534
535 fn step_val_reverse(&mut self) {
536 self.valid = false;
537 }
538
539 fn seek_val_reverse(&mut self, _val: &DynUnit) {}
540
541 fn seek_val_with_reverse(&mut self, predicate: &dyn Fn(&DynUnit) -> bool) {
542 if !predicate(&()) {
543 self.valid = false;
544 }
545 }
546
547 fn fast_forward_vals(&mut self) {
548 self.valid = true;
549 }
550
551 fn position(&self) -> Option<Position> {
552 Some(Position {
553 total: self.cursor.keys() as u64,
554 offset: self.cursor.position() as u64,
555 })
556 }
557}
558
559#[derive(SizeOf)]
561pub struct VecWSetBuilder<K, R>
562where
563 K: DataTrait + ?Sized,
564 R: WeightTrait + ?Sized,
565{
566 #[size_of(skip)]
567 factories: VecWSetFactories<K, R>,
568 keys: Box<DynVec<K>>,
569 val: bool,
570 diffs: Box<DynVec<R>>,
571 negative_weight_count: u64,
572}
573
574impl<K, R> VecWSetBuilder<K, R>
575where
576 K: DataTrait + ?Sized,
577 R: WeightTrait + ?Sized,
578{
579 fn pushed_key(&mut self) {
580 #[cfg(debug_assertions)]
581 {
582 debug_assert!(self.val, "every key must have exactly one value");
583 debug_assert_eq!(
584 self.keys.len(),
585 self.diffs.len(),
586 "every key must have exactly one diff"
587 );
588 }
589 self.val = false;
590
591 debug_assert!(
592 {
593 let n = self.keys.len();
594 n == 1 || self.keys[n - 2] < self.keys[n - 1]
595 },
596 "keys must be strictly monotonically increasing but {:?} >= {:?}",
597 &self.keys[self.keys.len() - 2],
598 &self.keys[self.keys.len() - 1]
599 );
600 }
601
602 fn pushed_diff(&self) {
603 #[cfg(debug_assertions)]
604 debug_assert!(!self.val, "every val must have exactly one key");
605 debug_assert_eq!(
606 self.keys.len() + 1,
607 self.diffs.len(),
608 "every diff must have exactly one key"
609 );
610 }
611
612 pub fn copy_to_builder<B, BO>(&self, dst: &mut B)
618 where
619 B: Builder<BO>,
620 BO: Batch<Key = K, Val = DynUnit, R = R, Time = ()>,
621 {
622 for key_diff in self.keys.dyn_iter().zip_longest(self.diffs.dyn_iter()) {
623 match key_diff {
624 EitherOrBoth::Both(key, diff) => {
625 dst.push_val_diff(&(), diff);
626 dst.push_key(key);
627 }
628 EitherOrBoth::Left(_) => unreachable!(),
629 EitherOrBoth::Right(diff) => {
630 dst.push_diff(diff);
631 if self.val {
632 dst.push_val(&());
633 }
634 }
635 }
636 }
637 }
638
639 fn update_total_weight(&mut self, weight: &R) {
640 if TypeId::of::<R>() == TypeId::of::<DynZWeight>() {
641 let weight = unsafe { weight.downcast::<ZWeight>() };
642 if !weight.ge0() {
643 self.negative_weight_count += 1;
644 }
645 }
646 }
647}
648
649impl<K, R> Builder<VecWSet<K, R>> for VecWSetBuilder<K, R>
650where
651 Self: SizeOf,
652 K: DataTrait + ?Sized,
653 R: WeightTrait + ?Sized,
654{
655 fn with_capacity(
656 factories: &VecWSetFactories<K, R>,
657 key_capacity: usize,
658 _value_capacity: usize,
659 ) -> Self {
660 let mut keys = factories.layer_factories.keys.default_box();
661 keys.reserve_exact(key_capacity);
662
663 let mut diffs = factories.layer_factories.diffs.default_box();
664 diffs.reserve_exact(key_capacity);
665 Self {
666 factories: factories.clone(),
667 keys,
668 val: false,
669 diffs,
670 negative_weight_count: 0,
671 }
672 }
673
674 fn reserve(&mut self, additional: usize) {
675 self.keys.reserve(additional);
676 self.diffs.reserve(additional);
677 }
678
679 fn push_key(&mut self, key: &K) {
680 self.keys.push_ref(key);
681 self.pushed_key();
682 }
683
684 fn push_key_mut(&mut self, key: &mut K) {
685 self.keys.push_val(key);
686 self.pushed_key();
687 }
688
689 fn push_val(&mut self, _val: &DynUnit) {
690 #[cfg(debug_assertions)]
691 {
692 debug_assert!(!self.val);
693 debug_assert_eq!(
694 self.diffs.len(),
695 self.keys.len() + 1,
696 "every value must have exactly one diff"
697 );
698 }
699
700 self.val = true;
701 }
702
703 fn push_time_diff(&mut self, _time: &(), weight: &R) {
704 debug_assert!(!weight.is_zero());
705 self.update_total_weight(weight);
706 self.diffs.push_ref(weight);
707 self.pushed_diff();
708 }
709
710 fn push_time_diff_mut(&mut self, _time: &mut (), weight: &mut R) {
711 debug_assert!(!weight.is_zero());
712 self.update_total_weight(weight);
713 self.diffs.push_val(weight);
714 self.pushed_diff();
715 }
716
717 fn done(self) -> VecWSet<K, R> {
718 debug_assert_eq!(self.keys.len(), self.diffs.len());
719 VecWSet {
720 layer: Leaf::from_parts(&self.factories.layer_factories, self.keys, self.diffs),
721 factories: self.factories,
722 negative_weight_count: self.negative_weight_count,
723 }
724 }
725
726 fn num_keys(&self) -> usize {
727 self.keys.len()
728 }
729
730 fn num_tuples(&self) -> usize {
731 self.diffs.len()
732 }
733}
734
735struct VecWSetConsumingCursor<'a, K, R>
737where
738 K: DataTrait + ?Sized,
739 R: WeightTrait + ?Sized,
740{
741 wset: &'a mut VecWSet<K, R>,
742 index: usize,
743 val_valid: bool,
744 value: Box<DynUnit>,
745}
746
747impl<'a, K, R> VecWSetConsumingCursor<'a, K, R>
748where
749 K: DataTrait + ?Sized,
750 R: WeightTrait + ?Sized,
751{
752 fn new(wset: &'a mut VecWSet<K, R>) -> Self {
753 let val_valid = !wset.is_empty();
754 let value = wset.factories.val_factory().default_box();
755 Self {
756 wset,
757 index: 0,
758 val_valid,
759 value,
760 }
761 }
762}
763
764impl<K, R> MergeCursor<K, DynUnit, (), R> for VecWSetConsumingCursor<'_, K, R>
765where
766 K: DataTrait + ?Sized,
767 R: WeightTrait + ?Sized,
768{
769 fn key_valid(&self) -> bool {
770 self.index < self.wset.layer.keys.len()
771 }
772 fn val_valid(&self) -> bool {
773 self.val_valid
774 }
775 fn key(&self) -> &K {
776 self.wset.layer.keys.index(self.index)
777 }
778
779 fn val(&self) -> &DynUnit {
780 ().erase()
781 }
782
783 fn map_times(&mut self, logic: &mut dyn FnMut(&(), &R)) {
784 logic(&(), &self.wset.layer.diffs[self.index])
785 }
786
787 fn weight(&mut self) -> &R {
788 &self.wset.layer.diffs[self.index]
789 }
790
791 fn has_mut(&self) -> bool {
792 true
793 }
794
795 fn key_mut(&mut self) -> &mut K {
796 &mut self.wset.layer.keys[self.index]
797 }
798
799 fn val_mut(&mut self) -> &mut DynUnit {
800 &mut *self.value
801 }
802
803 fn weight_mut(&mut self) -> &mut R {
804 &mut self.wset.layer.diffs[self.index]
805 }
806
807 fn step_key(&mut self) {
808 self.index += 1;
809 self.val_valid = self.key_valid();
810 }
811
812 fn step_val(&mut self) {
813 self.val_valid = false;
814 }
815}
816
817