1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
pub(crate) use maybe_changed_after::VerifyResult;
pub(crate) use sync::{ClaimGuard, ClaimResult, Reentrancy, SyncGuard, SyncOwner, SyncTable};
use std::any::Any;
use std::fmt;
use std::ptr::NonNull;
use std::sync::OnceLock;
use std::sync::atomic::Ordering;
use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus};
use crate::database::RawDatabase;
use crate::function::delete::DeletedEntries;
use crate::hash::{FxHashSet, FxIndexSet};
use crate::ingredient::{Ingredient, WaitForResult};
use crate::key::DatabaseKeyIndex;
use crate::plumbing::{self, MemoIngredientMap};
use crate::salsa_struct::SalsaStructInDb;
use crate::sync::Arc;
use crate::table::Table;
use crate::table::memo::MemoTableTypes;
use crate::views::DatabaseDownCaster;
use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa};
use crate::zalsa_local::{QueryEdge, QueryOriginRef};
use crate::{Cycle, Id, Revision};
#[cfg(feature = "accumulator")]
mod accumulated;
mod backdate;
mod delete;
mod diff_outputs;
mod eviction;
mod execute;
mod fetch;
mod inputs;
mod maybe_changed_after;
mod memo;
mod specify;
mod sync;
pub use eviction::{EvictionPolicy, HasCapacity, Lru, NoopEviction};
pub type Memo<C> = memo::Memo<'static, C>;
pub trait Configuration: Any {
const DEBUG_NAME: &'static str;
const LOCATION: crate::ingredient::Location;
const PERSIST: bool;
/// The database that this function is associated with.
type DbView: ?Sized + crate::Database;
/// The "salsa struct type" that this function is associated with.
/// This can be just `salsa::Id` for functions that intern their arguments
/// and are not clearly associated with any one salsa struct.
type SalsaStruct<'db>: SalsaStructInDb;
/// The input to the function
type Input<'db>: Send + Sync;
/// The value computed by the function.
type Output<'db>: Send + Sync;
/// The eviction policy for this function's memoized values.
type Eviction: EvictionPolicy;
/// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how).
const CYCLE_STRATEGY: CycleRecoveryStrategy;
/// Invokes after a new result `new_value` has been computed for which an older memoized value
/// existed `old_value`, or in fixpoint iteration. Returns true if the new value is equal to
/// the older one.
///
/// This invokes user code in form of the `Eq` impl.
fn values_equal<'db>(old_value: &Self::Output<'db>, new_value: &Self::Output<'db>) -> bool;
/// Convert from the id used internally to the value that execute is expecting.
/// This is a no-op if the input to the function is a salsa struct.
fn id_to_input(zalsa: &Zalsa, key: Id) -> Self::Input<'_>;
/// Returns the size of any heap allocations in the output value, in bytes.
fn heap_size(_value: &Self::Output<'_>) -> Option<usize> {
None
}
/// Invoked when we need to compute the value for the given key, either because we've never
/// computed it before or because the old one relied on inputs that have changed.
///
/// This invokes the function the user wrote.
fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;
/// Get the cycle recovery initial value.
fn cycle_initial<'db>(
db: &'db Self::DbView,
id: Id,
input: Self::Input<'db>,
) -> Self::Output<'db>;
/// Decide what value to use for this cycle iteration. Takes ownership of the new value
/// and returns an owned value to use.
///
/// The function is called for every iteration of the cycle head, regardless of whether the cycle
/// has converged (the values are equal).
///
/// # Id
///
/// The id can be used to uniquely identify the query instance. This can be helpful
/// if the cycle function has to re-identify a value it returned previously.
///
/// # Values
///
/// The `last_provisional_value` is the value from the previous iteration of this cycle
/// and `value` is the new value that was computed in the current iteration.
///
/// # Iteration count
///
/// The `iteration` parameter isn't guaranteed to start from zero or to be contiguous:
///
/// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that
/// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case,
/// `iteration` continues from the nested cycle's iteration count rather than resetting to zero.
/// * **Non-contiguous values**: The iteration count can be non-contigious for cycle heads
/// that are only conditionally part of a cycle.
///
/// # Return value
///
/// The function should return the value to use for this iteration. This can be the `value`
/// that was computed, or a different value (e.g., a fallback value). This cycle will continue
/// iterating until the returned value equals the previous iteration's value.
fn recover_from_cycle<'db>(
db: &'db Self::DbView,
cycle: &Cycle,
last_provisional_value: &Self::Output<'db>,
value: Self::Output<'db>,
input: Self::Input<'db>,
) -> Self::Output<'db>;
/// Serialize the output type using `serde`.
///
/// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`.
fn serialize<S>(value: &Self::Output<'_>, serializer: S) -> Result<S::Ok, S::Error>
where
S: plumbing::serde::Serializer;
/// Deserialize the output type using `serde`.
///
/// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`.
fn deserialize<'de, D>(deserializer: D) -> Result<Self::Output<'static>, D::Error>
where
D: plumbing::serde::Deserializer<'de>;
}
/// Function ingredients are the "workhorse" of salsa.
///
/// They are used for tracked functions, for the "value" fields of tracked structs, and for the fields of input structs.
/// The function ingredient is fairly complex and so its code is spread across multiple modules, typically one per method.
/// The main entry points are:
///
/// * the `fetch` method, which is invoked when the function is called by the user's code;
/// it will return a memoized value if one exists, or execute the function otherwise.
/// * the `specify` method, which can only be used when the key is an entity created by the active query.
/// It sets the value of the function imperatively, so that when later fetches occur, they'll return this value.
/// * the `store` method, which can only be invoked with an `&mut` reference, and is to set input fields.
pub struct IngredientImpl<C: Configuration> {
/// The ingredient index we were assigned in the database.
/// Used to construct `DatabaseKeyIndex` values.
index: IngredientIndex,
/// The index for the memo/sync tables
///
/// This may be a [`crate::memo_ingredient_indices::MemoIngredientSingletonIndex`] or a
/// [`crate::memo_ingredient_indices::MemoIngredientIndices`], depending on whether the
/// tracked function's struct is a plain salsa struct or an enum `#[derive(Supertype)]`.
memo_ingredient_indices: <C::SalsaStruct<'static> as SalsaStructInDb>::MemoIngredientMap,
/// Eviction policy - type determined by Configuration.
/// Used to find memos to throw out when we have too many memoized values.
eviction: C::Eviction,
/// An downcaster to `C::DbView`.
///
/// # Safety
///
/// The supplied database must be be the same as the database used to construct the [`Views`]
/// instances that this downcaster was derived from.
view_caster: OnceLock<DatabaseDownCaster<C::DbView>>,
sync_table: SyncTable,
/// When `fetch` and friends executes, they return a reference to the
/// value stored in the memo that is extended to live as long as the `&self`
/// reference we start with. This means that whenever we remove something
/// from `memo_map` with an `&self` reference, there *could* be references to its
/// internals still in use. Therefore we push the memo into this queue and
/// only *actually* free up memory when a new revision starts (which means
/// we have an `&mut` reference to self).
///
/// You might think that we could do this only if the memo was verified in the
/// current revision: you would be right, but we are being defensive, because
/// we don't know that we can trust the database to give us the same runtime
/// everytime and so forth.
deleted_entries: DeletedEntries<C>,
}
impl<C> IngredientImpl<C>
where
C: Configuration,
{
pub fn new(
index: IngredientIndex,
memo_ingredient_indices: <C::SalsaStruct<'static> as SalsaStructInDb>::MemoIngredientMap,
eviction_capacity: usize,
) -> Self {
Self {
index,
memo_ingredient_indices,
eviction: C::Eviction::new(eviction_capacity),
deleted_entries: Default::default(),
view_caster: OnceLock::new(),
sync_table: SyncTable::new(index),
}
}
/// Set the view-caster for this tracked function ingredient, if it has
/// not already been initialized.
#[inline]
pub fn get_or_init(
&self,
view_caster: impl FnOnce() -> DatabaseDownCaster<C::DbView>,
) -> &Self {
// Note that we must set this lazily as we don't have access to the database
// type when ingredients are registered into the `Zalsa`.
self.view_caster.get_or_init(view_caster);
self
}
#[inline]
pub fn database_key_index(&self, key: Id) -> DatabaseKeyIndex {
DatabaseKeyIndex::new(self.index, key)
}
/// Set eviction capacity. Only available when eviction policy supports it.
pub fn set_capacity(&mut self, capacity: usize)
where
C::Eviction: HasCapacity,
{
self.eviction.set_capacity(capacity);
}
/// Returns a reference to the memo value that lives as long as self.
/// This is UNSAFE: the caller is responsible for ensuring that the
/// memo will not be released so long as the `&self` is valid.
/// This is done by (a) ensuring the memo is present in the memo-map
/// when this function is called and (b) ensuring that any entries
/// removed from the memo-map are added to `deleted_entries`, which is
/// only cleared with `&mut self`.
unsafe fn extend_memo_lifetime<'this>(
&'this self,
memo: &memo::Memo<'this, C>,
) -> &'this memo::Memo<'this, C> {
// SAFETY: the caller must guarantee that the memo will not be released before `&self`
unsafe { std::mem::transmute(memo) }
}
fn insert_memo<'db>(
&'db self,
zalsa: &'db Zalsa,
id: Id,
mut memo: memo::Memo<'db, C>,
memo_ingredient_index: MemoIngredientIndex,
) -> &'db memo::Memo<'db, C> {
if let Some(tracked_struct_ids) = memo.revisions.tracked_struct_ids_mut() {
tracked_struct_ids.shrink_to_fit();
}
// We convert to a `NonNull` here as soon as possible because we are going to alias
// into the `Box`, which is a `noalias` type.
// FIXME: Use `Box::into_non_null` once stable
let memo = NonNull::from(Box::leak(Box::new(memo)));
if let Some(old_value) =
self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index)
{
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
//
// SAFETY: Once the revision starts, there will be no outstanding borrows to the
// memo contents, and so it will be safe to free.
unsafe { self.deleted_entries.push(old_value) };
}
// SAFETY: memo has been inserted into the table
unsafe { self.extend_memo_lifetime(memo.as_ref()) }
}
#[inline]
fn memo_ingredient_index(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex {
self.memo_ingredient_indices.get_zalsa_id(zalsa, id)
}
fn view_caster(&self) -> &DatabaseDownCaster<C::DbView> {
self.view_caster
.get()
.expect("tracked function ingredients cannot be accessed before calling `init`")
}
}
impl<C> Ingredient for IngredientImpl<C>
where
C: Configuration,
{
fn location(&self) -> &'static crate::ingredient::Location {
&C::LOCATION
}
fn ingredient_index(&self) -> IngredientIndex {
self.index
}
unsafe fn maybe_changed_after(
&self,
_zalsa: &Zalsa,
db: RawDatabase<'_>,
input: Id,
revision: Revision,
) -> VerifyResult {
// SAFETY: The `db` belongs to the ingredient as per caller invariant
let db = unsafe { self.view_caster().downcast_unchecked(db) };
self.maybe_changed_after(db, input, revision)
}
fn collect_minimum_serialized_edges(
&self,
zalsa: &Zalsa,
edge: QueryEdge,
serialized_edges: &mut FxIndexSet<QueryEdge>,
visited_edges: &mut FxHashSet<QueryEdge>,
) {
let input = edge.key().key_index();
let Some(memo) =
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))
else {
return;
};
let origin = memo.revisions.origin.as_ref();
visited_edges.insert(edge);
// Collect the minimum dependency tree.
for edge in origin.edges() {
// Avoid forming cycles.
if visited_edges.contains(edge) {
continue;
}
// Avoid flattening edges that we're going to serialize directly.
if serialized_edges.contains(edge) {
continue;
}
let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index());
dependency.collect_minimum_serialized_edges(
zalsa,
*edge,
serialized_edges,
visited_edges,
)
}
}
/// Returns `final` if the memo has the `verified_final` flag set.
///
/// Otherwise, the value is still provisional. For both final and provisional, it also
/// returns the iteration in which this memo was created (always 0 except for cycle heads).
fn provisional_status<'db>(
&self,
zalsa: &'db Zalsa,
input: Id,
) -> Option<ProvisionalStatus<'db>> {
let memo =
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))?;
let iteration = memo.revisions.iteration();
let verified_final = memo.revisions.verified_final.load(Ordering::Relaxed);
Some(if verified_final {
ProvisionalStatus::Final {
iteration,
verified_at: memo.verified_at.load(),
}
} else {
ProvisionalStatus::Provisional {
iteration,
verified_at: memo.verified_at.load(),
cycle_heads: memo.cycle_heads(),
}
})
}
fn set_cycle_iteration_count(&self, zalsa: &Zalsa, input: Id, iteration_count: IterationCount) {
let Some(memo) =
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))
else {
return;
};
memo.revisions
.set_iteration_count(Self::database_key_index(self, input), iteration_count);
}
fn finalize_cycle_head(&self, zalsa: &Zalsa, input: Id) {
let Some(memo) =
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))
else {
return;
};
memo.revisions.verified_final.store(true, Ordering::Release);
}
fn flatten_cycle_head_dependencies(
&self,
zalsa: &Zalsa,
id: Id,
flattened_input_outputs: &mut FxIndexSet<QueryEdge>,
seen: &mut FxHashSet<DatabaseKeyIndex>,
) {
let memo_index = self.memo_ingredient_index(zalsa, id);
let Some(memo) = self.get_memo_from_table_for(zalsa, id, memo_index) else {
return;
};
let database_key_index = self.database_key_index(id);
// Only flatten dependencies of provisional queries, because only those
// contain cyclic dependencies.
if !memo.may_be_provisional() {
flattened_input_outputs.insert(QueryEdge::input(database_key_index));
return;
}
// There's nothing to do if we've visited this query before.
if !seen.insert(database_key_index) {
return;
}
let inputs = memo.revisions.origin.as_ref().inputs();
match C::CYCLE_STRATEGY {
// For queries with cycle handling, simply extend the input/outputs, because
// they already flattened their own dependencies when completing the query.
CycleRecoveryStrategy::FallbackImmediate | CycleRecoveryStrategy::Fixpoint => {
flattened_input_outputs.extend(inputs.map(QueryEdge::input));
}
// For regular queries, recurse
CycleRecoveryStrategy::Panic => {
for input in inputs {
let ingredient = zalsa.lookup_ingredient(input.ingredient_index());
ingredient.flatten_cycle_head_dependencies(
zalsa,
input.key_index(),
flattened_input_outputs,
seen,
);
}
}
}
}
fn cycle_converged(&self, zalsa: &Zalsa, input: Id) -> bool {
let Some(memo) =
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))
else {
return true;
};
memo.revisions.cycle_converged()
}
fn mark_as_transfer_target(&self, key_index: Id) -> Option<SyncOwner> {
self.sync_table.mark_as_transfer_target(key_index)
}
/// Attempts to claim `key_index` without blocking.
///
/// * [`WaitForResult::Running`] if the `key_index` is running on another thread. It's up to the caller to block on the other thread
/// to wait until the result becomes available.
/// * [`WaitForResult::Available`] It is (or at least was) possible to claim the `key_index`
/// * [`WaitResult::Cycle`] Claiming the `key_index` results in a cycle because it's on the current's thread query stack or
/// running on another thread that is blocked on this thread.
fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> {
match self
.sync_table
.peek_claim(zalsa, key_index, Reentrancy::Deny)
{
ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on),
ClaimResult::Cycle { inner } => WaitForResult::Cycle { inner },
ClaimResult::Claimed(()) => WaitForResult::Available,
}
}
fn origin<'db>(&self, zalsa: &'db Zalsa, key: Id) -> Option<QueryOriginRef<'db>> {
self.origin(zalsa, key)
}
fn mark_validated_output(
&self,
zalsa: &Zalsa,
executor: DatabaseKeyIndex,
output_key: crate::Id,
) {
self.validate_specified_value(zalsa, executor, output_key);
}
fn remove_stale_output(
&self,
_zalsa: &Zalsa,
_executor: DatabaseKeyIndex,
_stale_output_key: crate::Id,
) {
// This function is invoked when a query Q specifies the value for `stale_output_key` in rev 1,
// but not in rev 2. We don't do anything in this case, we just leave the (now stale) memo.
// Since its `verified_at` field has not changed, it will be considered dirty if it is invoked.
}
fn requires_reset_for_new_revision(&self) -> bool {
true
}
fn reset_for_new_revision(&mut self, table: &mut Table) {
self.eviction.for_each_evicted(|evict| {
let ingredient_index = table.ingredient_index(evict);
Self::evict_value_from_memo_for(
table.memos_mut(evict),
self.memo_ingredient_indices.get(ingredient_index),
)
});
self.deleted_entries.clear();
}
fn debug_name(&self) -> &'static str {
C::DEBUG_NAME
}
fn jar_kind(&self) -> JarKind {
JarKind::TrackedFn
}
fn memo_table_types(&self) -> &Arc<MemoTableTypes> {
unreachable!("function does not allocate pages")
}
fn memo_table_types_mut(&mut self) -> &mut Arc<MemoTableTypes> {
unreachable!("function does not allocate pages")
}
#[cfg(feature = "accumulator")]
unsafe fn accumulated<'db>(
&'db self,
db: RawDatabase<'db>,
key_index: Id,
) -> (
Option<&'db crate::accumulator::accumulated_map::AccumulatedMap>,
crate::accumulator::accumulated_map::InputAccumulatedValues,
) {
// SAFETY: The `db` belongs to the ingredient as per caller invariant
let db = unsafe { self.view_caster().downcast_unchecked(db) };
self.accumulated_map(db, key_index)
}
fn is_persistable(&self) -> bool {
C::PERSIST
}
fn should_serialize(&self, zalsa: &Zalsa) -> bool {
if !C::PERSIST {
return false;
}
// We only serialize the query if there are any memos associated with it.
for entry in <C::SalsaStruct<'_> as SalsaStructInDb>::entries(zalsa) {
let memo_ingredient_index = self.memo_ingredient_indices.get(entry.ingredient_index());
let memo =
self.get_memo_from_table_for(zalsa, entry.key_index(), memo_ingredient_index);
if memo.is_some_and(|memo| memo.should_serialize()) {
return true;
}
}
false
}
#[cfg(feature = "persistence")]
unsafe fn serialize<'db>(
&'db self,
zalsa: &'db Zalsa,
f: &mut dyn FnMut(&dyn erased_serde::Serialize),
) {
f(&persistence::SerializeIngredient {
zalsa,
ingredient: self,
})
}
#[cfg(feature = "persistence")]
fn deserialize(
&mut self,
zalsa: &mut Zalsa,
deserializer: &mut dyn erased_serde::Deserializer,
) -> Result<(), erased_serde::Error> {
let deserialize = persistence::DeserializeIngredient {
zalsa,
ingredient: self,
};
serde::de::DeserializeSeed::deserialize(deserialize, deserializer)
}
}
impl<C> std::fmt::Debug for IngredientImpl<C>
where
C: Configuration,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(std::any::type_name::<Self>())
.field("index", &self.index)
.finish()
}
}
#[cfg(feature = "persistence")]
mod persistence {
use super::{Configuration, IngredientImpl, Memo};
use crate::hash::{FxHashSet, FxIndexSet};
use crate::plumbing::{MemoIngredientMap, SalsaStructInDb};
use crate::zalsa::Zalsa;
use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryOriginRef};
use crate::{Id, IngredientIndex};
use serde::de;
use serde::ser::SerializeMap;
use std::ptr::NonNull;
pub struct SerializeIngredient<'db, C>
where
C: Configuration,
{
pub zalsa: &'db Zalsa,
pub ingredient: &'db IngredientImpl<C>,
}
impl<C> serde::Serialize for SerializeIngredient<'_, C>
where
C: Configuration,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let Self { ingredient, zalsa } = self;
let count = <C::SalsaStruct<'_> as SalsaStructInDb>::entries(zalsa)
.filter(|entry| {
let memo_ingredient_index = ingredient
.memo_ingredient_indices
.get(entry.ingredient_index());
let memo = ingredient.get_memo_from_table_for(
zalsa,
entry.key_index(),
memo_ingredient_index,
);
memo.is_some_and(|memo| memo.should_serialize())
})
.count();
let mut map = serializer.serialize_map(Some(count))?;
let mut visited_edges = FxHashSet::default();
let mut flattened_edges = FxIndexSet::default();
for entry in <C::SalsaStruct<'_> as SalsaStructInDb>::entries(zalsa) {
let memo_ingredient_index = ingredient
.memo_ingredient_indices
.get(entry.ingredient_index());
let memo = ingredient.get_memo_from_table_for(
zalsa,
entry.key_index(),
memo_ingredient_index,
);
if let Some(memo) = memo.filter(|memo| memo.should_serialize()) {
// Flatten the dependencies of this query down to the base inputs.
let flattened_origin = match memo.revisions.origin.as_ref() {
QueryOriginRef::Derived(edges) => {
collect_minimum_serialized_edges(
zalsa,
edges,
&mut visited_edges,
&mut flattened_edges,
);
QueryOrigin::derived(flattened_edges.drain(..).collect())
}
QueryOriginRef::DerivedUntracked(edges) => {
collect_minimum_serialized_edges(
zalsa,
edges,
&mut visited_edges,
&mut flattened_edges,
);
QueryOrigin::derived_untracked(flattened_edges.drain(..).collect())
}
QueryOriginRef::Assigned(key) => {
let dependency = zalsa.lookup_ingredient(key.ingredient_index());
assert!(
dependency.is_persistable(),
"specified query `{}` must be persistable",
dependency.debug_name()
);
QueryOrigin::assigned(key)
}
};
let memo = memo.with_origin(flattened_origin);
// TODO: Group structs by ingredient index into a nested map.
let key = format!(
"{}:{}",
entry.ingredient_index().as_u32(),
entry.key_index().as_bits()
);
map.serialize_entry(&key, &memo)?;
visited_edges.clear();
}
}
map.end()
}
}
// Flatten the dependency edges before serialization.
fn collect_minimum_serialized_edges(
zalsa: &Zalsa,
edges: &[QueryEdge],
visited_edges: &mut FxHashSet<QueryEdge>,
flattened_edges: &mut FxIndexSet<QueryEdge>,
) {
for &edge in edges {
let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index());
if dependency.is_persistable() {
// If the dependency will be serialized, we can serialize the edge directly.
flattened_edges.insert(edge);
} else {
// Otherwise, serialize the minimum edges necessary to cover the dependency.
dependency.collect_minimum_serialized_edges(
zalsa,
edge,
flattened_edges,
visited_edges,
);
}
}
}
pub struct DeserializeIngredient<'db, C>
where
C: Configuration,
{
pub zalsa: &'db Zalsa,
pub ingredient: &'db mut IngredientImpl<C>,
}
impl<'de, C> de::DeserializeSeed<'de> for DeserializeIngredient<'_, C>
where
C: Configuration,
{
type Value = ();
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_map(self)
}
}
impl<'de, C> de::Visitor<'de> for DeserializeIngredient<'_, C>
where
C: Configuration,
{
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a map")
}
fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: de::MapAccess<'de>,
{
let DeserializeIngredient { zalsa, ingredient } = self;
while let Some((key, memo)) = access.next_entry::<&str, Memo<C>>()? {
let (ingredient_index, id) = key
.split_once(':')
.ok_or_else(|| de::Error::custom("invalid database key"))?;
let ingredient_index = IngredientIndex::new(
ingredient_index.parse::<u32>().map_err(de::Error::custom)?,
);
let id = Id::from_bits(id.parse::<u64>().map_err(de::Error::custom)?);
let memo_ingredient_index =
ingredient.memo_ingredient_indices.get(ingredient_index);
// SAFETY: We provide the current revision.
let memo_table = unsafe { zalsa.table().dyn_memos(id, zalsa.current_revision()) };
memo_table.insert(
memo_ingredient_index,
// FIXME: Use `Box::into_non_null` once stable.
NonNull::from(Box::leak(Box::new(memo))),
);
}
Ok(())
}
}
}