s2-lite 0.30.3

Lightweight server implementation of S2, the durable streams API, backed by object storage
Documentation
use std::{collections::VecDeque, sync::Arc};

use parking_lot::Mutex;
use slatedb::{CloseReason, DbStatus};
use tokio::sync::watch;

type Callback = Box<dyn FnOnce(Result<u64, CloseReason>) + Send + 'static>;

#[derive(Clone)]
pub(super) struct DurabilityNotifier {
    state: Arc<Mutex<State>>,
}

#[derive(Default)]
struct State {
    closed_reason: Option<CloseReason>,
    last_durable_seq: u64,
    waiters: VecDeque<Waiter>,
}

struct Waiter {
    durable_seq: u64,
    callback: Callback,
}

fn waiters_are_sorted(waiters: &VecDeque<Waiter>) -> bool {
    let mut prev = None;
    for waiter in waiters {
        if let Some(prev_target) = prev
            && prev_target > waiter.durable_seq
        {
            return false;
        }
        prev = Some(waiter.durable_seq);
    }
    true
}

impl DurabilityNotifier {
    pub fn spawn(db: &slatedb::Db) -> Self {
        let status_rx = db.subscribe();
        let initial_status = status_rx.borrow().clone();
        let state = Arc::new(Mutex::new(State {
            closed_reason: initial_status.close_reason,
            last_durable_seq: initial_status.durable_seq,
            waiters: VecDeque::new(),
        }));
        if initial_status.close_reason.is_none() {
            tokio::spawn(run_notifier(status_rx, state.clone()));
        }
        Self { state }
    }

    pub fn subscribe(
        &self,
        target_durable_seq: u64,
        callback: impl FnOnce(Result<u64, CloseReason>) + Send + 'static,
    ) {
        let mut state = self.state.lock();
        if let Some(reason) = state.closed_reason {
            drop(state);
            callback(Err(reason));
            return;
        }
        if state.last_durable_seq >= target_durable_seq {
            let durable_seq = state.last_durable_seq;
            drop(state);
            callback(Ok(durable_seq));
            return;
        }
        let waiter = Waiter {
            durable_seq: target_durable_seq,
            callback: Box::new(callback),
        };
        let insert_pos = state
            .waiters
            .iter()
            .rposition(|w| w.durable_seq <= target_durable_seq)
            .map_or(0, |idx| idx + 1);
        state.waiters.insert(insert_pos, waiter);
        debug_assert!(waiters_are_sorted(&state.waiters));
    }
}

async fn run_notifier(mut status_rx: watch::Receiver<DbStatus>, state: Arc<Mutex<State>>) {
    loop {
        if status_rx.changed().await.is_err() {
            close_with_reason(state.as_ref(), CloseReason::Clean);
            return;
        }
        let status = status_rx.borrow().clone();
        notify_waiters(state.as_ref(), status.durable_seq);
        if let Some(close_reason) = status.close_reason {
            close_with_reason(state.as_ref(), close_reason);
            return;
        }
    }
}

fn notify_waiters(state: &Mutex<State>, durable_seq: u64) {
    let ready = {
        let mut state = state.lock();
        if durable_seq <= state.last_durable_seq {
            return;
        }
        state.last_durable_seq = durable_seq;
        debug_assert!(waiters_are_sorted(&state.waiters));
        let split = state
            .waiters
            .partition_point(|w| w.durable_seq <= durable_seq);
        state.waiters.drain(..split).collect::<Vec<_>>()
    };
    for waiter in ready {
        (waiter.callback)(Ok(durable_seq));
    }
}

fn close_with_reason(state: &Mutex<State>, reason: CloseReason) {
    let pending = {
        let mut state = state.lock();
        let prev = state.closed_reason.replace(reason);
        assert!(prev.is_none());
        std::mem::take(&mut state.waiters)
    };
    for waiter in pending {
        (waiter.callback)(Err(reason));
    }
}

#[cfg(test)]
mod tests {
    use std::{
        sync::{
            Arc,
            mpsc::{self, TryRecvError},
        },
        time::Duration,
    };

    use slatedb::{Db, object_store::memory::InMemory};

    use super::*;

    fn test_notifier(last_durable_seq: u64) -> DurabilityNotifier {
        DurabilityNotifier {
            state: Arc::new(Mutex::new(State {
                closed_reason: None,
                last_durable_seq,
                waiters: VecDeque::new(),
            })),
        }
    }

    #[test]
    fn subscribe_immediate_when_target_already_durable() {
        let notifier = test_notifier(42);
        let (tx, rx) = mpsc::channel();
        notifier.subscribe(7, move |res| {
            tx.send(res).expect("send callback result");
        });
        let res = rx
            .recv_timeout(Duration::from_millis(100))
            .expect("callback should run");
        assert_eq!(res.expect("durable result"), 42);
        assert!(notifier.state.lock().waiters.is_empty());
    }

    #[test]
    fn notify_waiters_releases_only_ready_targets() {
        let state = Arc::new(Mutex::new(State {
            closed_reason: None,
            last_durable_seq: 0,
            waiters: VecDeque::new(),
        }));

        let (tx5, rx5) = mpsc::channel();
        state.lock().waiters.push_back(Waiter {
            durable_seq: 5,
            callback: Box::new(move |res| {
                tx5.send(res).expect("send callback result");
            }),
        });

        let (tx8, rx8) = mpsc::channel();
        state.lock().waiters.push_back(Waiter {
            durable_seq: 8,
            callback: Box::new(move |res| {
                tx8.send(res).expect("send callback result");
            }),
        });

        notify_waiters(state.as_ref(), 5);

        let res5 = rx5
            .recv_timeout(Duration::from_millis(100))
            .expect("ready waiter should run");
        assert_eq!(res5.expect("durable result"), 5);
        assert!(matches!(rx8.try_recv(), Err(TryRecvError::Empty)));
        assert_eq!(
            state
                .lock()
                .waiters
                .iter()
                .map(|w| w.durable_seq)
                .collect::<Vec<_>>(),
            vec![8]
        );
    }

    #[test]
    fn waiters_are_kept_sorted_when_insertions_arrive_out_of_order() {
        let notifier = test_notifier(0);

        notifier.subscribe(10, |_| {});
        notifier.subscribe(5, |_| {});
        notifier.subscribe(7, |_| {});

        assert_eq!(
            notifier
                .state
                .lock()
                .waiters
                .iter()
                .map(|w| w.durable_seq)
                .collect::<Vec<_>>(),
            vec![5, 7, 10]
        );
    }

    #[test]
    fn close_with_reason_fails_all_pending_waiters() {
        let state = Arc::new(Mutex::new(State {
            closed_reason: None,
            last_durable_seq: 0,
            waiters: VecDeque::new(),
        }));

        let (tx5, rx5) = mpsc::channel();
        state.lock().waiters.push_back(Waiter {
            durable_seq: 5,
            callback: Box::new(move |res| {
                tx5.send(res).expect("send callback result");
            }),
        });

        let (tx8, rx8) = mpsc::channel();
        state.lock().waiters.push_back(Waiter {
            durable_seq: 8,
            callback: Box::new(move |res| {
                tx8.send(res).expect("send callback result");
            }),
        });

        close_with_reason(state.as_ref(), CloseReason::Clean);

        let err5 = rx5
            .recv_timeout(Duration::from_millis(100))
            .expect("callback should run")
            .expect_err("close should fail waiter");
        let err8 = rx8
            .recv_timeout(Duration::from_millis(100))
            .expect("callback should run")
            .expect_err("close should fail waiter");
        assert_eq!(err5, CloseReason::Clean);
        assert_eq!(err8, CloseReason::Clean);
        let state = state.lock();
        assert!(state.waiters.is_empty());
        assert_eq!(state.closed_reason, Some(CloseReason::Clean));
    }

    #[test]
    fn subscribe_after_close_returns_error_immediately() {
        let notifier = test_notifier(0);
        close_with_reason(notifier.state.as_ref(), CloseReason::Clean);

        let (tx, rx) = mpsc::channel();
        notifier.subscribe(1, move |res| {
            tx.send(res).expect("send callback result");
        });

        let err = rx
            .recv_timeout(Duration::from_millis(100))
            .expect("callback should run")
            .expect_err("closed notifier should fail immediately");
        assert_eq!(err, CloseReason::Clean);
        assert!(notifier.state.lock().waiters.is_empty());
    }

    #[tokio::test(flavor = "current_thread")]
    async fn spawn_on_closed_db_fails_subscribers_immediately() {
        let object_store: Arc<dyn slatedb::object_store::ObjectStore> = Arc::new(InMemory::new());
        let db = Db::builder("test", object_store)
            .build()
            .await
            .expect("build test db");
        db.close().await.expect("close test db");

        let notifier = DurabilityNotifier::spawn(&db);
        let (tx, rx) = mpsc::channel();
        notifier.subscribe(0, move |res| {
            tx.send(res).expect("send callback result");
        });

        let err = rx
            .recv_timeout(Duration::from_millis(100))
            .expect("callback should run")
            .expect_err("closed db should fail immediately");
        assert_eq!(err, CloseReason::Clean);
        assert_eq!(
            notifier.state.lock().closed_reason,
            Some(CloseReason::Clean)
        );
    }
}