use {
crate::{
connect::lsp::{ClientId, ClientRegistry},
daemon::{
DaemonConfig, daemon_task::DaemonTask, idle_monitor::idle_monitor_task,
},
server::IpcServer,
},
concurrent_queue::ConcurrentQueue,
std::{
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::Duration,
},
};
mod gc;
pub mod key_watcher;
pub mod lanes;
pub mod task;
mod worker;
use {
crate::{
Ident, Partitions, TRACER,
connect::ipc::Connection,
database::{
Database, GenerationEpoch, chunk::RecordWriter, gc::GarbageCollector,
query::SortKeyCondition, reaper::Reaper,
},
progress::ProgressTracker,
protocol::{lsp::LanguageServer, task::RpcTask},
scheduler::{
key_watcher::{WatcherResult, dispatch_builtin_watcher},
lanes::{Lane, lane_priority},
task::{LaburnumTask, TaskContext},
},
},
std::{
collections::{BTreeMap, HashMap},
future::Future,
},
};
type ReDispatch = Box<dyn FnOnce() + Send>;
type PendingRedispatch<P> =
HashMap<Ident, Vec<(SortKeyCondition<<P as Partitions>::SortKey>, ReDispatch)>>;
type WatcherHandlerFn<P, T> = for<'a> fn(
&'a mut TaskContext<P, T>,
&'a mut crate::database::PartitionWriteContextRef<'a, P>,
) -> std::pin::Pin<
Box<dyn Future<Output = WatcherResult<P, T>> + Send + 'a>,
>;
#[derive(Debug, Clone)]
pub struct SchedulerConfiguration {
pub rpc_response_capacity: usize,
pub enable_periodic_gc: bool,
pub idle_debounce: Duration,
}
impl Default for SchedulerConfiguration {
fn default() -> Self {
Self {
rpc_response_capacity: 100,
#[cfg(feature = "test")]
enable_periodic_gc: false,
#[cfg(not(feature = "test"))]
enable_periodic_gc: true,
idle_debounce: Duration::from_millis(10),
}
}
}
pub struct Scheduler<P: Partitions, T: LanguageServer<P>> {
db: Database<P>,
pub(crate) connection: Connection,
filesystems: Arc<parking_lot::RwLock<Vec<crate::fs::FS>>>,
source_cache: Arc<parking_lot::RwLock<crate::SourceCache<P, T>>>,
lane_queues: [ConcurrentQueue<Arc<LaburnumTask<P, T>>>; 31],
rpc_rotation_lock: parking_lot::Mutex<()>,
worker_threads: parking_lot::RwLock<Vec<std::thread::JoinHandle<()>>>,
worker_count: usize,
pub server: Arc<T>,
pub shutdown_flag: Arc<AtomicBool>,
config: SchedulerConfiguration,
shutdown_requested: Arc<AtomicBool>,
pub(crate) progress_tracker: Arc<ProgressTracker>,
pub(crate) work_in_flight: AtomicUsize,
pub(crate) idle_debounce_armed: AtomicBool,
pub(crate) pending_redispatch: parking_lot::Mutex<PendingRedispatch<P>>,
pub(crate) pending_redispatch_count: AtomicUsize,
pub(crate) registry: Arc<ClientRegistry>,
pub(crate) reaper: Reaper<P>,
pub(crate) gc: GarbageCollector,
pub(crate) active_epochs:
parking_lot::Mutex<BTreeMap<GenerationEpoch, usize>>,
}
impl<P: Partitions, T: LanguageServer<P>> Scheduler<P, T> {
pub fn new(
connection: Connection,
server: Arc<T>,
filesystems: Arc<parking_lot::RwLock<Vec<crate::fs::FS>>>,
source_cache: Arc<parking_lot::RwLock<crate::SourceCache<P, T>>>,
) -> Arc<Self> {
let worker_count = num_cpus::get().saturating_sub(1).max(1);
Self::new_with_config(
connection,
server,
filesystems,
source_cache,
worker_count,
SchedulerConfiguration::default(),
)
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn new_with_worker_count(
connection: Connection,
server: Arc<T>,
filesystems: Arc<parking_lot::RwLock<Vec<crate::fs::FS>>>,
source_cache: Arc<parking_lot::RwLock<crate::SourceCache<P, T>>>,
worker_count: usize,
) -> Arc<Self> {
Self::new_with_config(
connection,
server,
filesystems,
source_cache,
worker_count,
SchedulerConfiguration::default(),
)
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn new_with_config(
connection: Connection,
server: Arc<T>,
filesystems: Arc<parking_lot::RwLock<Vec<crate::fs::FS>>>,
source_cache: Arc<parking_lot::RwLock<crate::SourceCache<P, T>>>,
worker_count: usize,
config: SchedulerConfiguration,
) -> Arc<Self>
where
T: crate::hooks::LaburnumHooks<P, T>,
{
Self::new_inner(
connection,
server,
filesystems,
source_cache,
worker_count,
config,
Arc::new(ClientRegistry::new()),
)
}
pub fn new_daemon(
server: Arc<T>,
filesystems: Arc<parking_lot::RwLock<Vec<crate::fs::FS>>>,
source_cache: Arc<parking_lot::RwLock<crate::SourceCache<P, T>>>,
worker_count: usize,
config: SchedulerConfiguration,
registry: Arc<ClientRegistry>,
) -> Arc<Self>
where
T: crate::hooks::LaburnumHooks<P, T>,
{
let (placeholder_sender, placeholder_receiver) = async_channel::unbounded();
let placeholder_connection = Connection {
sender: placeholder_sender,
receiver: placeholder_receiver,
};
Self::new_inner(
placeholder_connection,
server,
filesystems,
source_cache,
worker_count,
config,
registry,
)
}
fn new_inner(
connection: Connection,
server: Arc<T>,
filesystems: Arc<parking_lot::RwLock<Vec<crate::fs::FS>>>,
source_cache: Arc<parking_lot::RwLock<crate::SourceCache<P, T>>>,
worker_count: usize,
config: SchedulerConfiguration,
registry: Arc<ClientRegistry>,
) -> Arc<Self>
where
T: crate::hooks::LaburnumHooks<P, T>,
{
otel::span!("laburnum.scheduler.new");
let shutdown_flag = Arc::new(AtomicBool::new(false));
let progress_tracker = Arc::new(ProgressTracker::new_disconnected());
progress_tracker
.register_client(ClientId::INTERNAL, connection.sender.clone());
let db = Database::new();
let reaper = Reaper::new(db.cas.stores_arc());
let gc = GarbageCollector::new();
let s = Arc::new(Self {
db,
connection: connection.clone(),
filesystems,
source_cache,
lane_queues: std::array::from_fn(|_| ConcurrentQueue::unbounded()),
rpc_rotation_lock: parking_lot::Mutex::new(()),
worker_threads: parking_lot::RwLock::new(Vec::new()),
worker_count,
server: server.clone(),
shutdown_flag: shutdown_flag.clone(),
config: config.clone(),
shutdown_requested: Arc::new(AtomicBool::new(false)),
progress_tracker,
work_in_flight: AtomicUsize::new(0),
idle_debounce_armed: AtomicBool::new(false),
pending_redispatch: parking_lot::Mutex::new(HashMap::new()),
pending_redispatch_count: AtomicUsize::new(0),
registry,
reaper,
gc,
active_epochs: parking_lot::Mutex::new(BTreeMap::new()),
});
s.source_cache.write().set_scheduler(Arc::downgrade(&s));
s
}
pub fn registry(&self) -> &Arc<ClientRegistry> {
&self.registry
}
pub(crate) fn source_cache(
&self,
) -> &Arc<parking_lot::RwLock<crate::SourceCache<P, T>>> {
&self.source_cache
}
pub fn request_shutdown(&self) {
self.shutdown_requested.store(true, Ordering::Release);
}
pub fn is_shutdown_requested(&self) -> bool {
self.shutdown_requested.load(Ordering::Acquire)
}
pub(crate) fn create_rpc_task_for_client(
self: &Arc<Self>,
connection: Connection,
client_id: ClientId,
shutdown_flag: Arc<AtomicBool>,
) -> Arc<LaburnumTask<P, T>> {
RpcTask::create(
(*self).clone(),
connection,
client_id,
self.server.clone(),
shutdown_flag,
self.config.rpc_response_capacity,
)
}
pub fn queue_client_rpc_task(
self: &Arc<Self>,
connection: Connection,
client_id: ClientId,
shutdown_flag: Arc<AtomicBool>,
) {
let task =
self.create_rpc_task_for_client(connection, client_id, shutdown_flag);
self.queue_rpc_task(task);
}
pub fn progress_tracker(&self) -> &Arc<ProgressTracker> {
&self.progress_tracker
}
pub(crate) fn user_task_started(&self) {
self.work_in_flight.fetch_add(1, Ordering::AcqRel);
}
pub(crate) fn user_task_finished(self: &Arc<Self>) {
if self.work_in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
self.arm_idle_debounce();
}
}
fn arm_idle_debounce(self: &Arc<Self>) {
if self
.idle_debounce_armed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let delay = self.config.idle_debounce;
let scheduler = self.clone();
let task = LaburnumTask::new_system_with_parent(
self.clone(),
move |_ctx| {
Box::pin(async move {
smol::Timer::after(delay).await;
scheduler.idle_debounce_armed.store(false, Ordering::Release);
if scheduler.work_in_flight.load(Ordering::Acquire) == 0 {
scheduler.progress_tracker.on_idle();
}
None
})
},
lanes::IDLE_LANE,
None,
ClientId::INTERNAL,
);
self.queue_task(task);
}
pub(crate) fn register_pending_redispatch(
&self,
partition: Ident,
condition: SortKeyCondition<P::SortKey>,
thunk: ReDispatch,
) {
self
.pending_redispatch
.lock()
.entry(partition)
.or_default()
.push((condition, thunk));
self.pending_redispatch_count.fetch_add(1, Ordering::AcqRel);
}
fn fire_pending_redispatch(&self, result: &crate::database::CommitResult<P>) {
if self.pending_redispatch_count.load(Ordering::Acquire) == 0 {
return;
}
let mut to_fire: Vec<ReDispatch> = Vec::new();
{
let mut registry = self.pending_redispatch.lock();
for pk in result.affected_partition_keys() {
let Some(entries) = registry.remove(&pk) else {
continue;
};
let mut kept = Vec::new();
for (condition, thunk) in entries {
let inserted = result.inserted_keys.get(&pk);
let deleted = result.deleted_keys.get(&pk);
let any_match = inserted
.into_iter()
.flat_map(|keys| keys.iter())
.chain(deleted.into_iter().flat_map(|keys| keys.iter()))
.any(|rk| condition.matches(rk.sort_key()));
if any_match {
to_fire.push(thunk);
} else {
kept.push((condition, thunk));
}
}
if !kept.is_empty() {
registry.insert(pk, kept);
}
}
}
if !to_fire.is_empty() {
self
.pending_redispatch_count
.fetch_sub(to_fire.len(), Ordering::AcqRel);
}
for thunk in to_fire {
thunk();
}
}
pub fn run_daemon(
self: &Arc<Self>,
ipc_server: IpcServer,
config: DaemonConfig,
) where
T: crate::hooks::LaburnumHooks<P, T>,
{
otel::span!("laburnum.scheduler.run_daemon");
let idle_triggered = Arc::new(AtomicBool::new(false));
if let Some(idle_timeout) = config.idle_timeout {
self.queue_task(LaburnumTask::new_system_with_parent(
self.clone(),
idle_monitor_task(
self.shutdown_flag.clone(),
idle_triggered.clone(),
idle_timeout,
),
lanes::IDLE_LANE,
None,
ClientId::INTERNAL,
));
}
self.queue_task(DaemonTask::create(
self.clone(),
ipc_server,
config,
idle_triggered,
));
if self.config.enable_periodic_gc {
self.queue_task(LaburnumTask::new_system_with_parent(
self.clone(),
gc::periodic_gc_task(self.shutdown_flag.clone()),
lanes::IDLE_LANE,
None,
ClientId::INTERNAL,
));
}
self.spawn_workers();
while !self.shutdown_flag.load(Ordering::Acquire) {
std::thread::park_timeout(Duration::from_millis(100));
}
self.notify_workers();
let handles = {
let mut threads = self.worker_threads.write();
std::mem::take(&mut *threads)
};
for handle in handles {
if let Err(e) = handle.join() {
otel::error!(
"worker_thread_join_failed",
format!("Failed to join worker thread: {:?}", e)
);
}
}
}
pub fn spawn_workers(self: &Arc<Self>) {
let trace_context =
crate::protocol::otel::TraceContext::from_current_span();
let mut threads = Vec::with_capacity(self.worker_count);
for id in 0..self.worker_count {
let handle =
worker::Worker::spawn(id, self.clone(), trace_context.clone());
threads.push(handle);
}
*self.worker_threads.write() = threads;
}
pub fn queue<F, Fut>(self: &Arc<Self>, task_fn: F, lane: Lane)
where
F: FnOnce(TaskContext<P, T>) -> Fut + Send + 'static,
Fut: Future<Output = Option<RecordWriter<P>>> + Send + 'static,
{
self.queue_task(LaburnumTask::new(
self.clone(),
task_fn,
lane,
ClientId::INTERNAL,
));
}
pub(crate) fn queue_task(&self, task: Arc<LaburnumTask<P, T>>) {
let mut lane_idx = lane_priority(task.lane) as usize;
if lane_idx > 31 {
eprintln!("unable to push task onto queue: lane out of bounds");
return;
}
while lane_idx > 0 {
if let Some(lane) = self.lane_queues.get(lane_idx) {
match lane.push(task.clone()) {
| Ok(_) => {
break;
},
| Err(_) => {
otel::error!(
"scheduler.lane_push_failed",
"lane queue is full",
"lane_idx" = lane_idx as i64
);
},
}
}
lane_idx -= 1;
}
if lane_idx == 0
&& let Some(lane) = self.lane_queues.get(lane_idx)
&& let Err(_) = lane.push(task)
{
otel::error!("scheduler.lowest_lane_push_failed", "lane queue is full");
}
self.notify_workers();
}
pub(crate) fn queue_rpc_task(&self, task: Arc<LaburnumTask<P, T>>) {
use lanes::{RPC_LANE_HIGH_IDX, RPC_LANE_LOW_IDX};
let _guard = self.rpc_rotation_lock.lock();
for to_idx in RPC_LANE_HIGH_IDX..RPC_LANE_LOW_IDX {
let from_idx = to_idx + 1;
while let Ok(t) = self.lane_queues[from_idx].pop() {
let _ = self.lane_queues[to_idx].push(t);
}
}
let _ = self.lane_queues[RPC_LANE_LOW_IDX].push(task);
self.notify_workers();
}
pub(crate) fn add_initial_tasks(self: &Arc<Self>)
where
T: crate::hooks::LaburnumHooks<P, T>,
{
self.queue_task(RpcTask::create(
(*self).clone(),
self.connection.clone(),
ClientId::INTERNAL,
self.server.clone(),
self.shutdown_flag.clone(),
self.config.rpc_response_capacity,
));
if self.config.enable_periodic_gc {
self.queue_task(LaburnumTask::new_system_with_parent(
self.clone(),
gc::periodic_gc_task(self.shutdown_flag.clone()),
lanes::IDLE_LANE,
None,
ClientId::INTERNAL,
));
}
}
pub fn run(self: &Arc<Self>)
where
T: crate::hooks::LaburnumHooks<P, T>,
{
otel::span!("laburnum.scheduler.run");
self.add_initial_tasks();
self.spawn_workers();
while !self.shutdown_flag.load(Ordering::Acquire) {
std::thread::park_timeout(Duration::from_millis(100));
}
self.notify_workers();
let handles = {
let mut threads = self.worker_threads.write();
std::mem::take(&mut *threads)
};
for handle in handles {
if let Err(e) = handle.join() {
otel::error!(
"worker_thread_join_failed",
format!("Failed to join worker thread: {:?}", e)
);
}
}
}
}
fn spawn_watcher_task<P: Partitions, T: LanguageServer<P>>(
scheduler: Arc<Scheduler<P, T>>,
task_pk: crate::Ident,
updated: Vec<crate::database::RecordKey<P>>,
deleted: Vec<crate::database::RecordKey<P>>,
handler_fn: WatcherHandlerFn<P, T>,
parent_task_id: Option<crate::Ident>,
) {
let body_scheduler = scheduler.clone();
let task = LaburnumTask::new_with_parent(
scheduler.clone(),
move |mut ctx: TaskContext<P, T>| {
let scheduler = body_scheduler;
async move {
ctx.set_matched_keys(updated.clone(), deleted.clone());
let mut writer = RecordWriter::new(task_pk);
let mut writer_ctx =
crate::database::PartitionWriteContextRef::new_for_watcher(
&mut writer,
task_pk,
);
let result = handler_fn(&mut ctx, &mut writer_ctx).await;
for follow_up in result.follow_ups {
let sched = scheduler.clone();
scheduler.queue_task(LaburnumTask::new(
sched.clone(),
move |mut ctx| async move {
let mut writer = RecordWriter::new(task_pk);
let mut writer_ctx =
crate::database::PartitionWriteContextRef::new_for_watcher(
&mut writer,
task_pk,
);
(follow_up.task_fn)(&mut ctx, &mut writer_ctx).await;
Some(writer)
},
follow_up.lane,
ClientId::INTERNAL,
));
}
for (dep_pk, condition) in ctx.take_pending_deps() {
let sched = scheduler.clone();
let updated = updated.clone();
let deleted = deleted.clone();
scheduler.register_pending_redispatch(
dep_pk,
condition,
Box::new(move || {
spawn_watcher_task(
sched,
task_pk,
updated,
deleted,
handler_fn,
parent_task_id,
);
}),
);
}
Some(writer)
}
},
lanes::DEFAULT_LANE,
parent_task_id,
ClientId::INTERNAL,
);
scheduler.queue_task(task);
}
impl<P: Partitions, T: LanguageServer<P>> Scheduler<P, T> {
pub fn server(&self) -> Arc<T> {
self.server.clone()
}
pub(crate) fn on_new_chunk(
self: &Arc<Self>,
task_id: crate::Ident,
result: crate::database::CommitResult<P>,
) {
let inserted_count: usize =
result.inserted_keys.values().map(|v| v.len()).sum();
let deleted_count: usize =
result.deleted_keys.values().map(|v| v.len()).sum();
otel::span!(
"laburnum.scheduler.on_new_chunk",
"inserted_keys.count" = inserted_count as i64,
"deleted_keys.count" = deleted_count as i64
);
self.fire_pending_redispatch(&result);
for pk in result.affected_partition_keys() {
let updated: Vec<crate::database::RecordKey<P>> = result
.inserted_keys
.get(&pk)
.cloned()
.unwrap_or_default();
let deleted: Vec<crate::database::RecordKey<P>> = result
.deleted_keys
.get(&pk)
.cloned()
.unwrap_or_default();
dispatch_builtin_watcher(pk, updated.clone(), deleted.clone(), {
let scheduler = self.clone();
move |task_pk, filtered_updated, filtered_deleted, handler_fn| {
spawn_watcher_task(
scheduler.clone(),
task_pk,
filtered_updated,
filtered_deleted,
handler_fn,
Some(task_id),
);
}
});
T::dispatch_watcher(pk, updated, deleted, {
let scheduler = self.clone();
move |task_pk, filtered_updated, filtered_deleted, handler_fn| {
spawn_watcher_task(
scheduler.clone(),
task_pk,
filtered_updated,
filtered_deleted,
handler_fn,
None,
);
}
});
}
}
pub(crate) fn register_active_epoch(&self, epoch: GenerationEpoch) {
let mut epochs = self.active_epochs.lock();
*epochs.entry(epoch).or_insert(0) += 1;
}
pub(crate) fn deregister_active_epoch(&self, epoch: GenerationEpoch) {
let mut epochs = self.active_epochs.lock();
if let Some(count) = epochs.get_mut(&epoch) {
*count -= 1;
if *count == 0 {
epochs.remove(&epoch);
}
}
}
pub(crate) fn oldest_running_epoch(&self) -> GenerationEpoch {
let epochs = self.active_epochs.lock();
epochs
.keys()
.next()
.copied()
.unwrap_or_else(|| self.db.get_current_epoch())
}
fn notify_workers(&self) {
self.worker_threads.read().iter().for_each(|handle| {
handle.thread().unpark();
});
}
}
#[cfg(test)]
pub mod tests;