aedb 0.1.11

Embedded Rust storage engine with transactional commits, WAL durability, and snapshot-consistent reads
Documentation
use crate::catalog::Catalog;
use crate::commit::apply::apply_mutation;
use crate::commit::validation::Mutation;
use crate::config::PrimaryIndexBackend;
use crate::error::AedbError;
use crate::permission::CallerContext;
use crate::storage::keyspace::{Keyspace, Namespace, NamespaceId};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::{self as std_mpsc, Receiver, Sender};

#[derive(Clone)]
pub(super) struct ParallelApplyRuntime {
    workers: Arc<Vec<Sender<ParallelTask>>>,
    queued_tasks: Arc<AtomicUsize>,
    next_worker: Arc<AtomicUsize>,
}

pub(super) struct ParallelTask {
    pub(super) namespace_id: NamespaceId,
    pub(super) base_namespace: Namespace,
    pub(super) mutations: Vec<Mutation>,
    pub(super) commit_seq: u64,
    pub(super) backend: PrimaryIndexBackend,
    pub(super) catalog: Arc<Catalog>,
    pub(super) max_scan_rows: usize,
    pub(super) caller: Option<CallerContext>,
    pub(super) cancel: Arc<AtomicBool>,
    pub(super) response_tx: Sender<Result<(NamespaceId, Namespace), AedbError>>,
}

impl ParallelApplyRuntime {
    pub(super) fn new(worker_threads: usize) -> Self {
        let worker_count = worker_threads.max(1);
        let mut workers = Vec::with_capacity(worker_count);
        let queued_tasks = Arc::new(AtomicUsize::new(0));
        for _ in 0..worker_count {
            let (tx, rx) = std_mpsc::channel::<ParallelTask>();
            let q = Arc::clone(&queued_tasks);
            std::thread::spawn(move || run_worker(rx, q));
            workers.push(tx);
        }
        Self {
            workers: Arc::new(workers),
            queued_tasks,
            next_worker: Arc::new(AtomicUsize::new(0)),
        }
    }

    pub(super) fn submit(&self, task: ParallelTask) -> Result<(), AedbError> {
        let worker = self.next_worker.fetch_add(1, Ordering::Relaxed) % self.workers.len();
        self.queued_tasks.fetch_add(1, Ordering::Relaxed);
        if let Err(e) = self.workers[worker].send(task) {
            self.queued_tasks.fetch_sub(1, Ordering::Relaxed);
            return Err(AedbError::Validation(format!(
                "parallel runtime unavailable: {e}"
            )));
        }
        Ok(())
    }

    pub(super) fn queued_tasks(&self) -> usize {
        self.queued_tasks.load(Ordering::Relaxed)
    }
}

fn run_worker(rx: Receiver<ParallelTask>, queued_tasks: Arc<AtomicUsize>) {
    while let Ok(task) = rx.recv() {
        queued_tasks.fetch_sub(1, Ordering::Relaxed);
        let panic_response = task.response_tx.clone();
        let response =
            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| execute_task(task)));
        if response.is_err() {
            let _ = panic_response.send(Err(AedbError::ParallelApplyWorkerPanicked));
        }
    }
}

fn execute_task(task: ParallelTask) -> Result<(), AedbError> {
    if task.cancel.load(Ordering::Relaxed) {
        let _ = task
            .response_tx
            .send(Err(AedbError::ParallelApplyCancelled));
        return Ok(());
    }
    let mut local_catalog = (*task.catalog).clone();
    let mut local_keyspace = Keyspace::with_backend(task.backend);
    local_keyspace.insert_namespace(task.namespace_id.clone(), task.base_namespace);
    for mutation in &task.mutations {
        if task.cancel.load(Ordering::Relaxed) {
            let _ = task
                .response_tx
                .send(Err(AedbError::ParallelApplyCancelled));
            return Ok(());
        }
        super::parallel_worker_test_hook_for_mutation(mutation);
        if let Err(e) = apply_mutation(
            &mut local_catalog,
            &mut local_keyspace,
            mutation.clone(),
            task.commit_seq,
            Some(task.max_scan_rows),
            task.caller.as_ref(),
        ) {
            let _ = task.response_tx.send(Err(e));
            return Ok(());
        }
    }
    let namespace = local_keyspace
        .namespaces
        .get(&task.namespace_id)
        .cloned()
        .ok_or_else(|| AedbError::Validation("parallel apply namespace missing".into()))?;
    let _ = task.response_tx.send(Ok((task.namespace_id, namespace)));
    Ok(())
}