qbice 0.6.5

The Query-Based Incremental Computation Engine
Documentation
use std::{
    sync::{Arc, Weak},
    thread,
};

use crossbeam::deque::{Injector, Stealer, Worker};
use dashmap::DashSet;
use parking_lot::RwLock;
use tokio::sync::{Mutex, Notify};
use tracing::instrument;

use crate::{
    Engine, ExecutionStyle,
    config::{Config, WriteTransaction},
    engine::computation_graph::{
        QueryKind,
        database::{Database, Edge},
        dirty_worker::task::{Batch, DirtyTask, StrippedBuffer},
        statistic::Statistic,
    },
    query::QueryID,
};

mod task;

/// Work-stealing dirty propagation worker pool.
///
/// Uses a global [`Injector`] queue paired with per-worker local
/// deques. Workers batch-steal from the global queue to amortize
/// synchronization overhead, and can steal from sibling workers' local
/// deques when both the local deque and global queue are empty.
pub struct DirtyWorker<C: Config> {
    injector: Arc<Injector<DirtyTask<C>>>,
    notify: Arc<Notify>,
    shutdown: Arc<RwLock<bool>>,
}

impl<C: Config> DirtyWorker<C> {
    pub fn new(
        database: &Arc<Database<C>>,
        stats: &Arc<Statistic>,
        dirtied_queries: &Arc<DashSet<QueryID, C::BuildHasher>>,
    ) -> Self {
        let parallelism = thread::available_parallelism()
            .map(std::num::NonZero::get)
            .unwrap_or(8);

        let injector = Arc::new(Injector::new());
        let notify = Arc::new(Notify::new());
        let shutdown = Arc::new(RwLock::new(false));

        // Create per-worker local deques and their stealers.
        let mut workers = Vec::with_capacity(parallelism);
        let mut stealers = Vec::with_capacity(parallelism);

        for _ in 0..parallelism {
            let w = Worker::new_fifo();
            stealers.push(w.stealer());
            workers.push(w);
        }

        let stealers: Arc<[Stealer<DirtyTask<C>>]> = Arc::from(stealers);

        // Spawn worker tasks.
        for worker in workers {
            let injector = injector.clone();
            let stealers = stealers.clone();
            let notify = notify.clone();
            let shutdown = shutdown.clone();
            let database = Arc::downgrade(database);
            let stats = Arc::downgrade(stats);
            let dirtied_queries = Arc::downgrade(dirtied_queries);

            tokio::spawn(async move {
                Self::worker_loop(
                    worker,
                    &injector,
                    &stealers,
                    &notify,
                    &shutdown,
                    &database,
                    &stats,
                    &dirtied_queries,
                )
                .await;
            });
        }

        Self { injector, notify, shutdown }
    }

    /// Submit a dirty-propagation task to the global injector queue.
    pub fn submit_task(&self, dirty_task: DirtyTask<C>) {
        self.injector.push(dirty_task);
        self.notify.notify_one();
    }

    /// Attempt to find a task using the three-tier strategy:
    /// 1. Pop from the thread-local deque (zero contention).
    /// 2. Batch-steal from the global [`Injector`].
    /// 3. Steal from a sibling worker's deque.
    fn find_task(
        local: &Worker<DirtyTask<C>>,
        injector: &Injector<DirtyTask<C>>,
        stealers: &[Stealer<DirtyTask<C>>],
    ) -> Option<DirtyTask<C>> {
        // Pop a task from the local queue, if not empty.
        local.pop().or_else(|| {
            // Otherwise, we need to look for a task elsewhere.
            std::iter::repeat_with(|| {
                // Try stealing a batch of tasks from the global queue.
                injector
                    .steal_batch_and_pop(local)
                    // Or try stealing a task from one of the other threads.
                    .or_else(|| {
                        stealers
                            .iter()
                            .map(crossbeam::deque::Stealer::steal)
                            .collect()
                    })
            })
            // Loop while no task was stolen and any steal operation needs to be
            // retried.
            .find(|s| !s.is_retry())
            // Extract the stolen task, if there is one.
            .and_then(crossbeam::deque::Steal::success)
        })
    }

    #[allow(clippy::too_many_arguments, clippy::await_holding_lock)]
    async fn worker_loop(
        local: Worker<DirtyTask<C>>,
        injector: &Injector<DirtyTask<C>>,
        stealers: &[Stealer<DirtyTask<C>>],
        notify: &Notify,
        shutdown: &RwLock<bool>,
        database: &Weak<Database<C>>,
        statistic: &Weak<Statistic>,
        dirtied_queries: &Weak<DashSet<QueryID, C::BuildHasher>>,
    ) {
        loop {
            // Drain all available work before parking.
            let mut count = 0;
            while let Some(task) = Self::find_task(&local, injector, stealers) {
                Self::process_task(
                    task,
                    injector,
                    notify,
                    database,
                    statistic,
                    dirtied_queries,
                )
                .await;

                // every 32 tasks, yield to allow other tasks to run
                count += 1;
                if count >= 32 {
                    count = 0;
                    tokio::task::yield_now().await;
                }
            }

            let shutdown_guard = shutdown.read();

            if *shutdown_guard {
                break;
            }

            // IMPORTANT: has to hold the lock until we have acquired
            // notified
            let mut notified = std::pin::pin!(notify.notified());
            notified.as_mut().enable();

            drop(shutdown_guard);

            // Double-check: work may have arrived between the while-
            // loop exit and the enable() call above.
            if let Some(task) = Self::find_task(&local, injector, stealers) {
                Self::process_task(
                    task,
                    injector,
                    notify,
                    database,
                    statistic,
                    dirtied_queries,
                )
                .await;

                // work was found, reset backoff
                continue;
            }

            notified.await;
        }
    }

    async fn process_task(
        task: DirtyTask<C>,
        injector: &Injector<DirtyTask<C>>,
        notify: &Notify,
        database: &Weak<Database<C>>,
        statistic: &Weak<Statistic>,
        dirtied_queries: &Weak<DashSet<QueryID, C::BuildHasher>>,
    ) {
        let dirtied_queries = dirtied_queries.upgrade().unwrap();
        let statistic = statistic.upgrade().unwrap();
        let database = database.upgrade().unwrap();

        if !dirtied_queries.insert(*task.query_id()) {
            return;
        }

        let query_id = *task.query_id();
        let mut counter = 0;
        let mut pushed = false;

        for caller in
            unsafe { database.get_backward_edges_unchecked(&query_id).await }
        {
            counter += 1;
            // every 16 edges, yield to allow other tasks to run
            if counter >= 16 {
                counter = 0;
                tokio::task::yield_now().await;
            }

            {
                // opportunisitically try to use the write transaction
                if let Some(mut write_tx) = task.try_load_write_tx() {
                    database
                        .mark_dirty_forward_edge(
                            caller,
                            *task.query_id(),
                            &mut *write_tx,
                        )
                        .await;

                    // maintenance the remaining buffer edges
                    for edge in task.drain_limited() {
                        database
                            .mark_dirty_forward_edge_from(edge, &mut *write_tx)
                            .await;
                    }
                } else {
                    // couldn't get the write transaction, push to the buffer
                    task.push_to_buffer(Edge::new(caller, *task.query_id()));
                }

                statistic.add_dirtied_edge_count();
            }

            let query_kind = database.get_query_kind(&caller).await;

            if matches!(
                query_kind,
                QueryKind::Executable(
                    ExecutionStyle::Projection | ExecutionStyle::Firewall
                )
            ) {
                // don't continue propagation through firewall or
                // projection nodes
                continue;
            }

            injector.push(task.propagate_to(caller));

            if !pushed {
                notify.notify_one();
                pushed = true;
            }
        }

        drop(task);
    }
}

impl<C: Config> Drop for DirtyWorker<C> {
    fn drop(&mut self) {
        *self.shutdown.write() = true;
        self.notify.notify_waiters();
    }
}

impl<C: Config> Engine<C> {
    #[instrument(
        skip(self, query_id, trasnaction),
        level = "debug",
        name = "dirty_propagation",
        target = "qbice"
    )]
    pub(super) async fn dirty_propagate_from_batch(
        self: &Arc<Self>,
        query_id: impl IntoIterator<Item = QueryID>,
        trasnaction: WriteTransaction<C>,
    ) -> WriteTransaction<C> {
        let write_tx = Arc::new(Mutex::new(trasnaction));
        let stripped_buffer = Arc::new(StrippedBuffer::new());

        let batch = Batch::new(write_tx.clone(), stripped_buffer.clone());
        let notified = batch.notified_owned();

        for query_id in query_id {
            self.computation_graph
                .dirty_worker
                .submit_task(batch.new_task(query_id));
        }

        drop(batch);

        // wait for all tasks to complete
        notified.await;

        let mut write_tx = Arc::try_unwrap(write_tx)
            .unwrap_or_else(|_| {
                panic!("should be unique, notified system is broken")
            })
            .into_inner();

        for remaining_edge in stripped_buffer.drain_all() {
            self.computation_graph
                .database
                .mark_dirty_forward_edge_from(remaining_edge, &mut write_tx)
                .await;
        }

        write_tx
    }

    pub(super) fn clear_dirtied_queries(&self) {
        self.computation_graph.dirtied_queries.clear();
    }
}