gluon_salsa/
runtime.rs

1use crate::durability::Durability;
2use crate::plumbing::CycleDetected;
3use crate::revision::{AtomicRevision, Revision};
4use crate::{Database, DatabaseKeyIndex, Event, EventKind, ForkState};
5use log::debug;
6use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
7use parking_lot::{Mutex, RwLock};
8use rustc_hash::{FxHashMap, FxHasher};
9use smallvec::SmallVec;
10use std::hash::{BuildHasherDefault, Hash};
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13
14pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
15pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
16
17mod local_state;
18use local_state::{ActiveQueryGuard, LocalState};
19
20/// The salsa runtime stores the storage for all queries as well as
21/// tracking the query stack and dependencies between cycles.
22///
23/// Each new runtime you create (e.g., via `Runtime::new` or
24/// `Runtime::default`) will have an independent set of query storage
25/// associated with it. Normally, therefore, you only do this once, at
26/// the start of your application.
27pub struct Runtime {
28    /// Our unique runtime id.
29    id: RuntimeId,
30
31    /// If this is a "forked" runtime, then the `revision_guard` will
32    /// be `Some`; this guard holds a read-lock on the global query
33    /// lock.
34    revision_guard: Option<RevisionGuard>,
35
36    /// Local state that is specific to this runtime (thread).
37    local_state: LocalState,
38
39    pub(super) parent: Option<ForkState>,
40
41    /// Shared state that is accessible via all runtimes.
42    shared_state: Arc<SharedState>,
43}
44
45impl Drop for Runtime {
46    fn drop(&mut self) {
47        if self.parent.is_some() {
48            self.unblock_queries_blocked_on_self(None);
49        }
50    }
51}
52
53impl Default for Runtime {
54    fn default() -> Self {
55        Runtime {
56            id: RuntimeId { counter: 0 },
57            revision_guard: None,
58            shared_state: Default::default(),
59            local_state: Default::default(),
60            parent: Default::default(),
61        }
62    }
63}
64
65impl std::fmt::Debug for Runtime {
66    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        fmt.debug_struct("Runtime")
68            .field("id", &self.id())
69            .field("forked", &self.revision_guard.is_some())
70            .field("shared_state", &self.shared_state)
71            .finish()
72    }
73}
74
75impl Runtime {
76    /// Create a new runtime; equivalent to `Self::default`. This is
77    /// used when creating a new database.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// See [`crate::storage::Storage::snapshot`].
83    pub fn snapshot(&self) -> Self {
84        if self.local_state.query_in_progress() {
85            panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
86        }
87
88        let revision_guard = RevisionGuard::new(&self.shared_state);
89
90        let id = RuntimeId {
91            counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
92        };
93
94        Runtime {
95            id,
96            revision_guard: Some(revision_guard),
97            shared_state: self.shared_state.clone(),
98            local_state: Default::default(),
99            parent: self.parent.clone(),
100        }
101    }
102
103    /// Returns a "forked" runtime, suitable to call concurrent queries.
104    pub fn fork(&self, state: ForkState) -> Self {
105        let revision_guard = RevisionGuard::new(&self.shared_state);
106
107        let id = RuntimeId {
108            counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
109        };
110
111        assert!(self.try_block_on_fork(id));
112
113        Runtime {
114            id,
115            revision_guard: Some(revision_guard),
116            shared_state: self.shared_state.clone(),
117            local_state: Default::default(),
118            parent: Some(state),
119        }
120    }
121
122    /// A "synthetic write" causes the system to act *as though* some
123    /// input of durability `durability` has changed. This is mostly
124    /// useful for profiling scenarios, but it also has interactions
125    /// with garbage collection. In general, a synthetic write to
126    /// durability level D will cause the system to fully trace all
127    /// queries of durability level D and below. When running a GC, then:
128    ///
129    /// - Synthetic writes will cause more derived values to be
130    ///   *retained*.  This is because derived values are only
131    ///   retained if they are traced, and a synthetic write can cause
132    ///   more things to be traced.
133    /// - Synthetic writes can cause more interned values to be
134    ///   *collected*. This is because interned values can only be
135    ///   collected if they were not yet traced in the current
136    ///   revision. Therefore, if you issue a synthetic write, execute
137    ///   some query Q, and then start collecting interned values, you
138    ///   will be able to recycle interned values not used in Q.
139    ///
140    /// In general, then, one can do a "full GC" that retains only
141    /// those things that are used by some query Q by (a) doing a
142    /// synthetic write at `Durability::HIGH`, (b) executing the query
143    /// Q and then (c) doing a sweep.
144    ///
145    /// **WARNING:** Just like an ordinary write, this method triggers
146    /// cancellation. If you invoke it while a snapshot exists, it
147    /// will block until that snapshot is dropped -- if that snapshot
148    /// is owned by the current thread, this could trigger deadlock.
149    pub fn synthetic_write(&mut self, durability: Durability) {
150        self.with_incremented_revision(&mut |_next_revision| Some(durability));
151    }
152
153    /// The unique identifier attached to this `SalsaRuntime`. Each
154    /// snapshotted runtime has a distinct identifier.
155    #[inline]
156    pub fn id(&self) -> RuntimeId {
157        self.id
158    }
159
160    /// The unique identifier attached to this `SalsaRuntime` and the ids of its parents.
161    /// Each snapshotted runtime has a distinct identifier.
162    pub fn ids<'a>(&'a self) -> impl Iterator<Item = RuntimeId> + 'a {
163        self.parent
164            .iter()
165            .flat_map(|state| state.0.parents.iter().cloned())
166            .chain(Some(self.id()))
167    }
168
169    /// Returns the database-key for the query that this thread is
170    /// actively executing (if any).
171    pub fn active_query(&self) -> Option<DatabaseKeyIndex> {
172        self.local_state.active_query()
173    }
174
175    /// Read current value of the revision counter.
176    #[inline]
177    pub(crate) fn current_revision(&self) -> Revision {
178        self.shared_state.revisions[0].load()
179    }
180
181    /// The revision in which values with durability `d` may have last
182    /// changed.  For D0, this is just the current revision. But for
183    /// higher levels of durability, this value may lag behind the
184    /// current revision. If we encounter a value of durability Di,
185    /// then, we can check this function to get a "bound" on when the
186    /// value may have changed, which allows us to skip walking its
187    /// dependencies.
188    #[inline]
189    pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
190        self.shared_state.revisions[d.index()].load()
191    }
192
193    /// Read current value of the revision counter.
194    #[inline]
195    fn pending_revision(&self) -> Revision {
196        self.shared_state.pending_revision.load()
197    }
198
199    /// Check if the current revision is canceled. If this method ever
200    /// returns true, the currently executing query is also marked as
201    /// having an *untracked read* -- this means that, in the next
202    /// revision, we will always recompute its value "as if" some
203    /// input had changed. This means that, if your revision is
204    /// canceled (which indicates that current query results will be
205    /// ignored) your query is free to shortcircuit and return
206    /// whatever it likes.
207    ///
208    /// This method is useful for implementing cancellation of queries.
209    /// You can do it in one of two ways, via `Result`s or via unwinding.
210    ///
211    /// The `Result` approach looks like this:
212    ///
213    ///   * Some queries invoke `is_current_revision_canceled` and
214    ///     return a special value, like `Err(Canceled)`, if it returns
215    ///     `true`.
216    ///   * Other queries propagate the special value using `?` operator.
217    ///   * API around top-level queries checks if the result is `Ok` or
218    ///     `Err(Canceled)`.
219    ///
220    /// The `panic` approach works in a similar way:
221    ///
222    ///   * Some queries invoke `is_current_revision_canceled` and
223    ///     panic with a special value, like `Canceled`, if it returns
224    ///     true.
225    ///   * The implementation of `Database` trait overrides
226    ///     `on_propagated_panic` to throw this special value as well.
227    ///     This way, panic gets propagated naturally through dependant
228    ///     queries, even across the threads.
229    ///   * API around top-level queries converts a `panic` into `Result` by
230    ///     catching the panic (using either `std::panic::catch_unwind` or
231    ///     threads) and downcasting the payload to `Canceled` (re-raising
232    ///     panic if downcast fails).
233    ///
234    /// Note that salsa is explicitly designed to be panic-safe, so cancellation
235    /// via unwinding is 100% valid approach to cancellation.
236    #[inline]
237    pub fn is_current_revision_canceled(&self) -> bool {
238        let current_revision = self.current_revision();
239        let pending_revision = self.pending_revision();
240        debug!(
241            "is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
242            current_revision, pending_revision
243        );
244        if pending_revision > current_revision {
245            self.report_untracked_read();
246            true
247        } else {
248            // Subtle: If the current revision is not canceled, we
249            // still report an **anonymous** read, which will bump up
250            // the revision number to be at least the last
251            // non-canceled revision. This is needed to ensure
252            // deterministic reads and avoid salsa-rs/salsa#66. The
253            // specific scenario we are trying to avoid is tested by
254            // `no_back_dating_in_cancellation`; it works like
255            // this. Imagine we have 3 queries, where Query3 invokes
256            // Query2 which invokes Query1. Then:
257            //
258            // - In Revision R1:
259            //   - Query1: Observes cancelation and returns sentinel S.
260            //     - Recorded inputs: Untracked, because we observed cancelation.
261            //   - Query2: Reads Query1 and propagates sentinel S.
262            //     - Recorded inputs: Query1, changed-at=R1
263            //   - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
264            //     - Recorded inputs: Query2, changed-at=R1
265            // - In Revision R2:
266            //   - Query1: Observes no cancelation. All of its inputs last changed in R0,
267            //     so it returns a valid value with "changed at" of R0.
268            //     - Recorded inputs: ..., changed-at=R0
269            //   - Query2: Recomputes its value and returns correct result.
270            //     - Recorded inputs: Query1, changed-at=R0 <-- key problem!
271            //   - Query3: sees that Query2's result last changed in R0, so it thinks it
272            //     can re-use its value from R1 (which is the sentinel value).
273            //
274            // The anonymous read here prevents that scenario: Query1
275            // winds up with a changed-at setting of R2, which is the
276            // "pending revision", and hence Query2 and Query3
277            // are recomputed.
278            assert_eq!(pending_revision, current_revision);
279            self.report_anon_read(pending_revision);
280            false
281        }
282    }
283
284    /// Acquires the **global query write lock** (ensuring that no queries are
285    /// executing) and then increments the current revision counter; invokes
286    /// `op` with the global query write lock still held.
287    ///
288    /// While we wait to acquire the global query write lock, this method will
289    /// also increment `pending_revision_increments`, thus signalling to queries
290    /// that their results are "canceled" and they should abort as expeditiously
291    /// as possible.
292    ///
293    /// The `op` closure should actually perform the writes needed. It is given
294    /// the new revision as an argument, and its return value indicates whether
295    /// any pre-existing value was modified:
296    ///
297    /// - returning `None` means that no pre-existing value was modified (this
298    ///   could occur e.g. when setting some key on an input that was never set
299    ///   before)
300    /// - returning `Some(d)` indicates that a pre-existing value was modified
301    ///   and it had the durability `d`. This will update the records for when
302    ///   values with each durability were modified.
303    ///
304    /// Note that, given our writer model, we can assume that only one thread is
305    /// attempting to increment the global revision at a time.
306    pub(crate) fn with_incremented_revision(
307        &mut self,
308        op: &mut dyn FnMut(Revision) -> Option<Durability>,
309    ) {
310        log::debug!("increment_revision()");
311
312        if !self.permits_increment() {
313            panic!("increment_revision invoked during a query computation");
314        }
315
316        // Set the `pending_revision` field so that people
317        // know current revision is canceled.
318        let current_revision = self.shared_state.pending_revision.fetch_then_increment();
319
320        // To modify the revision, we need the lock.
321        let shared_state = self.shared_state.clone();
322        let _lock = shared_state.query_lock.write();
323
324        let old_revision = self.shared_state.revisions[0].fetch_then_increment();
325        assert_eq!(current_revision, old_revision);
326
327        let new_revision = current_revision.next();
328
329        debug!("increment_revision: incremented to {:?}", new_revision);
330
331        if let Some(d) = op(new_revision) {
332            for rev in &self.shared_state.revisions[1..=d.index()] {
333                rev.store(new_revision);
334            }
335        }
336    }
337
338    pub(crate) fn permits_increment(&self) -> bool {
339        self.revision_guard.is_none() && !self.local_state.query_in_progress()
340    }
341
342    pub(crate) fn prepare_query_implementation<DB>(
343        db: &mut DB,
344        database_key_index: DatabaseKeyIndex,
345    ) -> ActiveQueryGuard<'_, DB>
346    where
347        DB: std::ops::Deref,
348        DB::Target: Database,
349    {
350        debug!(
351            "{:?}: execute_query_implementation invoked",
352            database_key_index
353        );
354
355        let runtime = db.salsa_runtime();
356        db.salsa_event(Event {
357            runtime_id: runtime.id(),
358            kind: EventKind::WillExecute {
359                database_key: database_key_index,
360            },
361        });
362
363        // Push the active query onto the stack.
364        let max_durability = Durability::MAX;
365        LocalState::push_query(db, database_key_index, max_durability)
366    }
367
368    pub(crate) fn complete_query<DB, V>(
369        active_query: ActiveQueryGuard<'_, DB>,
370        value: V,
371    ) -> ComputedQueryResult<V>
372    where
373        DB: std::ops::Deref,
374        DB::Target: Database,
375    {
376        let ActiveQuery {
377            dependencies,
378            changed_at,
379            durability,
380            cycle,
381            ..
382        } = active_query.complete();
383
384        ComputedQueryResult {
385            value,
386            durability,
387            changed_at,
388            dependencies,
389            cycle,
390        }
391    }
392
393    /// Reports that the currently active query read the result from
394    /// another query.
395    ///
396    /// # Parameters
397    ///
398    /// - `database_key`: the query whose result was read
399    /// - `changed_revision`: the last revision in which the result of that
400    ///   query had changed
401    pub(crate) fn report_query_read<'hack>(
402        &self,
403        input: DatabaseKeyIndex,
404        durability: Durability,
405        changed_at: Revision,
406    ) {
407        self.local_state
408            .report_query_read(input, durability, changed_at);
409    }
410
411    /// Reports that the query depends on some state unknown to salsa.
412    ///
413    /// Queries which report untracked reads will be re-executed in the next
414    /// revision.
415    pub fn report_untracked_read(&self) {
416        self.local_state
417            .report_untracked_read(self.current_revision());
418    }
419
420    /// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`.
421    ///
422    /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
423    pub fn report_synthetic_read(&self, durability: Durability) {
424        self.local_state.report_synthetic_read(durability);
425    }
426
427    /// An "anonymous" read is a read that doesn't come from executing
428    /// a query, but from some other internal operation. It just
429    /// modifies the "changed at" to be at least the given revision.
430    /// (It also does not disqualify a query from being considered
431    /// constant, since it is used for queries that don't give back
432    /// actual *data*.)
433    ///
434    /// This is used when queries check if they have been canceled.
435    fn report_anon_read(&self, revision: Revision) {
436        self.local_state.report_anon_read(revision)
437    }
438
439    /// Obviously, this should be user configurable at some point.
440    pub(crate) fn report_unexpected_cycle(
441        &self,
442        database_key_index: DatabaseKeyIndex,
443        error: CycleDetected,
444        changed_at: Revision,
445    ) -> crate::CycleError<DatabaseKeyIndex> {
446        debug!(
447            "report_unexpected_cycle(database_key={:?})",
448            database_key_index
449        );
450
451        let mut query_stack = self.local_state.borrow_query_stack_mut();
452
453        if error.from == error.to {
454            // All queries in the cycle is local
455            let start_index = query_stack
456                .iter()
457                .rposition(|active_query| active_query.database_key_index == database_key_index)
458                .expect("bug: query is not on the stack");
459            let cycle_participants = &mut query_stack[start_index..];
460            let cycle: Vec<_> = cycle_participants
461                .iter()
462                .map(|active_query| active_query.database_key_index)
463                .collect();
464
465            assert!(!cycle.is_empty());
466
467            for active_query in cycle_participants {
468                active_query.cycle = cycle.clone();
469            }
470
471            crate::CycleError {
472                cycle,
473                changed_at,
474                durability: Durability::MAX,
475            }
476        } else {
477            // Part of the cycle is on another thread so we need to lock and inspect the shared
478            // state
479            let dependency_graph = self.shared_state.dependency_graph.lock();
480
481            let mut cycle = Vec::new();
482            {
483                let cycle_iter = dependency_graph
484                    .get_cycle_path(
485                        &database_key_index,
486                        error.from,
487                        error.to,
488                        query_stack.iter().map(|query| &query.database_key_index),
489                    )
490                    .chain(Some(&database_key_index));
491
492                cycle.extend(cycle_iter.cloned());
493            }
494
495            assert!(!cycle.is_empty());
496
497            for active_query in query_stack
498                .iter_mut()
499                .filter(|query| cycle.iter().any(|key| *key == query.database_key_index))
500            {
501                active_query.cycle = cycle.clone();
502            }
503
504            crate::CycleError {
505                cycle,
506                changed_at,
507                durability: Durability::MAX,
508            }
509        }
510    }
511
512    pub(crate) fn mark_cycle_participants(&self, cycle: &[DatabaseKeyIndex]) {
513        for active_query in self
514            .local_state
515            .borrow_query_stack_mut()
516            .iter_mut()
517            .rev()
518            .take_while(|active_query| cycle.iter().any(|e| *e == active_query.database_key_index))
519        {
520            active_query.cycle = cycle.to_owned();
521        }
522    }
523
524    /// Try to make this runtime blocked on `other_id`. Returns true
525    /// upon success or false if `other_id` is already blocked on us.
526    pub(crate) fn try_block_on(&self, database_key: DatabaseKeyIndex, other_id: RuntimeId) -> bool {
527        let mut graph = self.shared_state.dependency_graph.lock();
528
529        graph.add_edge(
530            self.id(),
531            Some(&database_key),
532            other_id,
533            self.local_state
534                .borrow_query_stack()
535                .iter()
536                .map(|query| query.database_key_index),
537        )
538    }
539
540    pub(crate) fn try_block_on_fork(&self, other_id: RuntimeId) -> bool {
541        let mut graph = self.shared_state.dependency_graph.lock();
542
543        graph.add_edge(
544            self.id(),
545            None,
546            other_id,
547            self.local_state
548                .borrow_query_stack()
549                .iter()
550                .map(|query| query.database_key_index),
551        )
552    }
553
554    pub(crate) fn unblock_queries_blocked_on_self(
555        &self,
556        database_key_index: Option<DatabaseKeyIndex>,
557    ) {
558        self.shared_state
559            .dependency_graph
560            .lock()
561            .remove_edge(database_key_index.as_ref(), self.id())
562    }
563}
564
565/// State that will be common to all threads (when we support multiple threads)
566struct SharedState {
567    /// Stores the next id to use for a snapshotted runtime (starts at 1).
568    next_id: AtomicUsize,
569
570    /// Whenever derived queries are executing, they acquire this lock
571    /// in read mode. Mutating inputs (and thus creating a new
572    /// revision) requires a write lock (thus guaranteeing that no
573    /// derived queries are in progress). Note that this is not needed
574    /// to prevent **race conditions** -- the revision counter itself
575    /// is stored in an `AtomicUsize` so it can be cheaply read
576    /// without acquiring the lock.  Rather, the `query_lock` is used
577    /// to ensure a higher-level consistency property.
578    query_lock: RwLock<()>,
579
580    /// This is typically equal to `revision` -- set to `revision+1`
581    /// when a new revision is pending (which implies that the current
582    /// revision is canceled).
583    pending_revision: AtomicRevision,
584
585    /// Stores the "last change" revision for values of each duration.
586    /// This vector is always of length at least 1 (for Durability 0)
587    /// but its total length depends on the number of durations. The
588    /// element at index 0 is special as it represents the "current
589    /// revision".  In general, we have the invariant that revisions
590    /// in here are *declining* -- that is, `revisions[i] >=
591    /// revisions[i + 1]`, for all `i`. This is because when you
592    /// modify a value with durability D, that implies that values
593    /// with durability less than D may have changed too.
594    revisions: Vec<AtomicRevision>,
595
596    /// The dependency graph tracks which runtimes are blocked on one
597    /// another, waiting for queries to terminate.
598    dependency_graph: Mutex<DependencyGraph<DatabaseKeyIndex>>,
599}
600
601impl SharedState {
602    fn with_durabilities(durabilities: usize) -> Self {
603        SharedState {
604            next_id: AtomicUsize::new(1),
605            query_lock: Default::default(),
606            revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
607            pending_revision: AtomicRevision::start(),
608            dependency_graph: Default::default(),
609        }
610    }
611}
612
613impl std::panic::RefUnwindSafe for SharedState {}
614
615impl Default for SharedState {
616    fn default() -> Self {
617        Self::with_durabilities(Durability::LEN)
618    }
619}
620
621impl std::fmt::Debug for SharedState {
622    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623        let query_lock = if self.query_lock.try_write().is_some() {
624            "<unlocked>"
625        } else if self.query_lock.try_read().is_some() {
626            "<rlocked>"
627        } else {
628            "<wlocked>"
629        };
630        fmt.debug_struct("SharedState")
631            .field("query_lock", &query_lock)
632            .field("revisions", &self.revisions)
633            .field("pending_revision", &self.pending_revision)
634            .finish()
635    }
636}
637
638struct ActiveQuery {
639    /// What query is executing
640    database_key_index: DatabaseKeyIndex,
641
642    /// Minimum durability of inputs observed so far.
643    durability: Durability,
644
645    /// Maximum revision of all inputs observed. If we observe an
646    /// untracked read, this will be set to the most recent revision.
647    changed_at: Revision,
648
649    /// Set of subqueries that were accessed thus far, or `None` if
650    /// there was an untracked the read.
651    dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
652
653    /// Stores the entire cycle, if one is found and this query is part of it.
654    cycle: Vec<DatabaseKeyIndex>,
655}
656
657pub(crate) struct ComputedQueryResult<V> {
658    /// Final value produced
659    pub(crate) value: V,
660
661    /// Minimum durability of inputs observed so far.
662    pub(crate) durability: Durability,
663
664    /// Maximum revision of all inputs observed. If we observe an
665    /// untracked read, this will be set to the most recent revision.
666    pub(crate) changed_at: Revision,
667
668    /// Complete set of subqueries that were accessed, or `None` if
669    /// there was an untracked the read.
670    pub(crate) dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
671
672    /// The cycle if one occured while computing this value
673    pub(crate) cycle: Vec<DatabaseKeyIndex>,
674}
675
676impl ActiveQuery {
677    fn new(database_key_index: DatabaseKeyIndex, max_durability: Durability) -> Self {
678        ActiveQuery {
679            database_key_index,
680            durability: max_durability,
681            changed_at: Revision::start(),
682            dependencies: Some(FxIndexSet::default()),
683            cycle: Vec::new(),
684        }
685    }
686
687    fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) {
688        if let Some(set) = &mut self.dependencies {
689            set.insert(input);
690        }
691
692        self.durability = self.durability.min(durability);
693        self.changed_at = self.changed_at.max(revision);
694    }
695
696    fn add_untracked_read(&mut self, changed_at: Revision) {
697        self.dependencies = None;
698        self.durability = Durability::LOW;
699        self.changed_at = changed_at;
700    }
701
702    fn add_synthetic_read(&mut self, durability: Durability) {
703        self.durability = self.durability.min(durability);
704    }
705
706    fn add_anon_read(&mut self, changed_at: Revision) {
707        self.changed_at = self.changed_at.max(changed_at);
708    }
709}
710
711/// A unique identifier for a particular runtime. Each time you create
712/// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is
713/// complete, its `RuntimeId` may potentially be re-used.
714#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
715pub struct RuntimeId {
716    counter: usize,
717}
718
719#[derive(Clone, Debug)]
720pub(crate) struct StampedValue<V> {
721    pub(crate) value: V,
722    pub(crate) durability: Durability,
723    pub(crate) changed_at: Revision,
724}
725
726#[derive(Debug)]
727struct Edge<K> {
728    id: RuntimeId,
729    path: Vec<K>,
730}
731
732#[derive(Debug)]
733struct DependencyGraph<K: Hash + Eq> {
734    /// A `(K -> V)` pair in this map indicates that the the runtime
735    /// `K` is blocked on some query executing in the runtime `V`.
736    /// This encodes a graph that must be acyclic (or else deadlock
737    /// will result).
738    edges: FxHashMap<RuntimeId, SmallVec<[Edge<K>; 1]>>,
739    labels: FxHashMap<K, SmallVec<[RuntimeId; 4]>>,
740    forks: FxHashMap<RuntimeId, SmallVec<[RuntimeId; 4]>>,
741}
742
743impl<K> Default for DependencyGraph<K>
744where
745    K: Hash + Eq,
746{
747    fn default() -> Self {
748        DependencyGraph {
749            edges: Default::default(),
750            labels: Default::default(),
751            forks: Default::default(),
752        }
753    }
754}
755
756impl<K> DependencyGraph<K>
757where
758    K: Hash + Eq + Clone,
759{
760    fn can_add_edge(&self, from_id: RuntimeId, to_id: RuntimeId) -> bool {
761        !self.find_edge(from_id, to_id, &mut |_| ())
762    }
763
764    fn find_edge(
765        &self,
766        from_id: RuntimeId,
767        to_id: RuntimeId,
768        f: &mut impl FnMut(RuntimeId),
769    ) -> bool {
770        // First: walk the chain of things that `to_id` depends on,
771        // looking for us.
772        if from_id == to_id {
773            return true;
774        }
775        if let Some(qs) = self.edges.get(&to_id) {
776            return qs.iter().any(|q| {
777                if self.find_edge(from_id, q.id, f) {
778                    f(q.id);
779                    true
780                } else {
781                    false
782                }
783            });
784        }
785        false
786    }
787
788    /// Attempt to add an edge `from_id -> to_id` into the result graph.
789    fn add_edge(
790        &mut self,
791        from_id: RuntimeId,
792        database_key: Option<&K>,
793        to_id: RuntimeId,
794        path: impl IntoIterator<Item = K>,
795    ) -> bool {
796        assert_ne!(from_id, to_id);
797
798        if !self.can_add_edge(from_id, to_id) {
799            return false;
800        }
801
802        self.edges.entry(from_id).or_default().push(Edge {
803            id: to_id,
804            path: path.into_iter().chain(database_key.cloned()).collect(),
805        });
806
807        if let Some(database_key) = database_key.cloned() {
808            self.labels.entry(database_key).or_default().push(from_id);
809        } else {
810            self.forks.entry(to_id).or_default().push(from_id);
811        }
812        true
813    }
814
815    fn remove_edge(&mut self, database_key: Option<&K>, to_id: RuntimeId) {
816        let vec = match database_key {
817            Some(database_key) => self.labels.remove(database_key).unwrap_or_default(),
818            None => self.forks.remove(&to_id).unwrap_or_default(),
819        };
820
821        for from_id in &vec {
822            use std::collections::hash_map::Entry;
823            match self.edges.entry(*from_id) {
824                Entry::Occupied(mut entry) => {
825                    let edges = entry.get_mut();
826                    let i = edges
827                        .iter()
828                        .position(|edge| edge.id == to_id)
829                        .expect("Tried to remove edge which did not exist in the edge list");
830                    edges.swap_remove(i);
831
832                    if edges.is_empty() {
833                        entry.remove();
834                    }
835                }
836                Entry::Vacant(_) => unreachable!(),
837            }
838        }
839    }
840
841    fn get_cycle_path<'a>(
842        &'a self,
843        database_key: &'a K,
844        from: RuntimeId,
845        to: RuntimeId,
846        local_path: impl IntoIterator<Item = &'a K>,
847    ) -> impl Iterator<Item = &'a K>
848    where
849        K: std::fmt::Debug,
850    {
851        let mut vec = Vec::new();
852        assert!(self.find_edge(from, to, &mut |id| vec.push(id)));
853        vec.push(to);
854
855        let mut current = Some(std::slice::from_ref(database_key));
856        let mut last = None;
857        let mut local_path = Some(local_path);
858        let mut vec_iter = vec.into_iter().rev().peekable();
859        std::iter::from_fn(move || match current.take() {
860            Some(path) => {
861                let id = vec_iter.next()?;
862                let link_key = path.last().unwrap();
863
864                current = self.edges.get(&id).and_then(|out_edges| {
865                    let next_id = vec_iter.peek()?;
866                    let edge = out_edges.iter().find(|edge| edge.id == *next_id)?;
867
868                    Some(
869                        edge.path
870                            .iter()
871                            .rposition(|p| p == link_key)
872                            .map(|i| &edge.path[i + 1..])
873                            .unwrap_or_else(|| &edge.path[..]),
874                    )
875                });
876
877                if current.is_none() {
878                    last = local_path.take().map(|local_path| {
879                        local_path
880                            .into_iter()
881                            .skip_while(move |p| *p != link_key)
882                            .skip(1)
883                    });
884                }
885
886                Some(path)
887            }
888            None => match &mut last {
889                Some(iter) => iter.next().map(std::slice::from_ref),
890                None => None,
891            },
892        })
893        .flat_map(|x| x)
894    }
895}
896
897struct RevisionGuard {
898    shared_state: Arc<SharedState>,
899}
900
901impl RevisionGuard {
902    fn new(shared_state: &Arc<SharedState>) -> Self {
903        // Subtle: we use a "recursive" lock here so that it is not an
904        // error to acquire a read-lock when one is already held (this
905        // happens when a query uses `snapshot` to spawn off parallel
906        // workers, for example).
907        //
908        // This has the side-effect that we are responsible to ensure
909        // that people contending for the write lock do not starve,
910        // but this is what we achieve via the cancellation mechanism.
911        //
912        // (In particular, since we only ever have one "mutating
913        // handle" to the database, the only contention for the global
914        // query lock occurs when there are "futures" evaluating
915        // queries in parallel, and those futures hold a read-lock
916        // already, so the starvation problem is more about them bring
917        // themselves to a close, versus preventing other people from
918        // *starting* work).
919        unsafe {
920            shared_state.query_lock.raw().lock_shared_recursive();
921        }
922
923        Self {
924            shared_state: shared_state.clone(),
925        }
926    }
927}
928
929impl Drop for RevisionGuard {
930    fn drop(&mut self) {
931        // Release our read-lock without using RAII. As documented in
932        // `Snapshot::new` above, this requires the unsafe keyword.
933        unsafe {
934            self.shared_state.query_lock.raw().unlock_shared();
935        }
936    }
937}
938
939#[cfg(test)]
940mod tests {
941    use super::*;
942
943    #[test]
944    fn dependency_graph_path1() {
945        let mut graph = DependencyGraph::default();
946        let a = RuntimeId { counter: 0 };
947        let b = RuntimeId { counter: 1 };
948        assert!(graph.add_edge(a, Some(&2), b, vec![1]));
949        // assert!(graph.add_edge(b, &1, a, vec![3, 2]));
950        assert_eq!(
951            graph
952                .get_cycle_path(&1, b, a, &[3, 2][..])
953                .cloned()
954                .collect::<Vec<i32>>(),
955            vec![1, 2]
956        );
957    }
958
959    #[test]
960    fn dependency_graph_path2() {
961        let mut graph = DependencyGraph::default();
962        let a = RuntimeId { counter: 0 };
963        let b = RuntimeId { counter: 1 };
964        let c = RuntimeId { counter: 2 };
965        assert!(graph.add_edge(a, Some(&3), b, vec![1]));
966        assert!(graph.add_edge(b, Some(&4), c, vec![2, 3]));
967        // assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7]));
968        assert_eq!(
969            graph
970                .get_cycle_path(&1, c, a, &[5, 6, 4, 7][..])
971                .cloned()
972                .collect::<Vec<i32>>(),
973            vec![1, 3, 4, 7]
974        );
975    }
976}