use {
crate::{
Ident,
Partitions,
TRACER,
connect::{
ipc::Connection,
lsp::{
ClientId,
notification::Notification,
request::Request,
},
},
database::{
Database,
GenerationEpoch,
RecordKey,
chunk::RecordWriter,
query::QueryClient,
},
progress::{
Progress,
ProgressBuilder,
ProgressTracker,
},
protocol::{
jsonrpc,
lsp::LanguageServer,
otel::TraceContext,
},
scheduler::{
Scheduler,
lanes::Lane,
},
},
parking_lot::Mutex,
std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::{
Arc,
atomic::{
AtomicU64,
Ordering,
},
},
task::{
Context,
Poll,
},
},
};
static TASK_COUNTER: AtomicU64 = AtomicU64::new(0);
fn generate_execution_id() -> Ident {
let id = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
Ident::from_hash(id)
}
#[allow(clippy::type_complexity)]
pub(crate) struct LaburnumTask<P: Partitions, T: LanguageServer<P>> {
future: Mutex<
Option<Pin<Box<dyn Future<Output = Option<RecordWriter<P>>> + Send>>>,
>,
pub(crate) lane: Lane,
pub(crate) scheduler: Arc<Scheduler<P, T>>,
trace_context: TraceContext,
execution_id: Ident,
birth_epoch: Mutex<Option<GenerationEpoch>>,
_phantom: PhantomData<T>,
}
impl<P: Partitions, T: LanguageServer<P>> LaburnumTask<P, T> {
pub(crate) fn poll_once(self: &Arc<Self>) -> Poll<()> {
let _guard = self.trace_context.attach();
{
let mut epoch_slot = self.birth_epoch.lock();
if epoch_slot.is_none() {
let epoch = self.scheduler.db.get_current_epoch();
*epoch_slot = Some(epoch);
self.scheduler.register_active_epoch(epoch);
}
}
let mut future_slot = self.future.lock();
if let Some(mut future) = future_slot.take() {
let waker = futures::task::waker_ref(self);
let mut cx = Context::from_waker(&waker);
match future.as_mut().poll(&mut cx) {
| Poll::Ready(Some(chunk_builder)) => {
let mut chunk = chunk_builder.build();
if !chunk.is_empty() || !chunk.clear_prefixes().is_empty() || chunk.has_pending_sources() {
otel::span!("laburnum.task.complete", in |cx|{
let task_id = chunk.task_id();
let db = self.scheduler.db.clone();
let error_source_keys = chunk.error_diagnostic_source_keys();
let pending_sources = chunk.take_sources();
let source_cache_reader = self.scheduler.source_cache.read().reader();
let mut result = db.commit_chunk(chunk, &source_cache_reader);
let deferred_decrements = std::mem::take(&mut result.deferred_decrements);
if !deferred_decrements.is_empty() {
self.scheduler.reaper.queue_decrements(deferred_decrements);
}
if self.scheduler.gc.is_marking() {
self.scheduler.gc.add_to_gray(result.new_hashes.iter().copied());
}
if !pending_sources.is_empty() {
let mut source_cache = self.scheduler.source_cache.write();
for (source_key, source) in pending_sources {
source_cache.complete_version(source_key, source);
}
}
if !error_source_keys.is_empty() {
let mut source_cache = self.scheduler.source_cache.write();
for source_key in &error_source_keys {
source_cache.set_has_errors(*source_key, true);
}
}
if result.has_changes() {
self.scheduler.on_new_chunk(task_id, result);
}
});
}
self.deregister_epoch();
self
.scheduler
.progress_tracker
.on_task_complete(self.execution_id);
Poll::Ready(())
},
| Poll::Ready(None) => {
self.deregister_epoch();
self
.scheduler
.progress_tracker
.on_task_complete(self.execution_id);
Poll::Ready(())
},
| Poll::Pending => {
*future_slot = Some(future);
Poll::Pending
},
}
} else {
Poll::Ready(())
}
}
fn deregister_epoch(&self) {
if let Some(epoch) = self.birth_epoch.lock().take() {
self.scheduler.deregister_active_epoch(epoch);
}
}
pub(crate) fn new<F, Fut>(
scheduler: Arc<Scheduler<P, T>>,
task_fn: F,
lane: Lane,
client_id: ClientId,
) -> Arc<Self>
where
F: FnOnce(TaskContext<P, T>) -> Fut + Send + 'static,
Fut: Future<Output = Option<RecordWriter<P>>> + Send + 'static,
{
Self::new_with_parent(scheduler, task_fn, lane, None, client_id)
}
pub(crate) fn new_with_parent<F, Fut>(
scheduler: Arc<Scheduler<P, T>>,
task_fn: F,
lane: Lane,
parent_task_id: Option<crate::Ident>,
client_id: ClientId,
) -> Arc<Self>
where
F: FnOnce(TaskContext<P, T>) -> Fut + Send + 'static,
Fut: Future<Output = Option<RecordWriter<P>>> + Send + 'static,
{
let trace_context = TraceContext::from_current_span();
let execution_id = generate_execution_id();
let ctx = TaskContext::new(
scheduler.db.clone(),
scheduler.clone(),
client_id,
)
.with_execution_id(execution_id);
let ctx = if let Some(parent_id) = parent_task_id {
ctx.with_task_id(parent_id)
} else {
ctx
};
Arc::new(Self {
future: Mutex::new(Some(Box::pin(task_fn(ctx)))),
lane,
scheduler,
trace_context,
execution_id,
birth_epoch: Mutex::new(None),
_phantom: PhantomData,
})
}
}
impl<P: Partitions, T: LanguageServer<P>> futures::task::ArcWake
for LaburnumTask<P, T>
{
fn wake_by_ref(arc_self: &Arc<Self>) {
use crate::scheduler::lanes::is_rpc_lane;
if is_rpc_lane(arc_self.lane) {
arc_self.scheduler.queue_rpc_task(arc_self.clone());
} else {
arc_self.scheduler.queue_task(arc_self.clone());
}
}
}
pub struct TaskContext<P: Partitions, T: LanguageServer<P>> {
db: Database<P>,
scheduler: Arc<Scheduler<P, T>>,
connection: Connection,
query_clients: Vec<QueryClient<P>>,
matched_keys_updated: Vec<RecordKey>,
matched_keys_deleted: Vec<RecordKey>,
task_id: Option<Ident>,
execution_id: Ident,
client_id: ClientId,
_phantom: PhantomData<T>,
}
impl<P: Partitions, T: LanguageServer<P>> TaskContext<P, T> {
pub(crate) fn new(
db: Database<P>,
scheduler: Arc<Scheduler<P, T>>,
client_id: ClientId,
) -> Self {
let connection = scheduler.connection.clone();
Self {
db,
scheduler,
connection,
query_clients: Vec::new(),
matched_keys_updated: Vec::new(),
matched_keys_deleted: Vec::new(),
task_id: None,
execution_id: Ident::new("unset"),
client_id,
_phantom: PhantomData,
}
}
pub(crate) fn with_execution_id(mut self, execution_id: Ident) -> Self {
self.execution_id = execution_id;
self
}
pub(crate) fn with_task_id(mut self, task_id: Ident) -> Self {
self.task_id = Some(task_id);
self
}
pub fn client_id(&self) -> ClientId {
self.client_id
}
pub fn position_encoding(&self) -> crate::protocol::lsp::PositionEncodingKind {
self
.scheduler
.registry()
.get(self.client_id)
.map(|c| c.position_encoding().clone())
.unwrap_or(crate::protocol::lsp::PositionEncodingKind::DEFAULT)
}
pub(crate) fn new_record_writer(&self, task_id: Ident) -> RecordWriter<P> {
RecordWriter::new(task_id).with_parent_task_id(self.task_id)
}
pub fn query_client(&mut self) -> &mut QueryClient<P> {
let client = QueryClient::new(self.db.clone());
self.query_clients.push(client);
let index = self.query_clients.len() - 1;
&mut self.query_clients[index]
}
pub(crate) fn spawn_task<F, Fut>(
&self,
task_fn: F,
lane: Lane,
) -> Arc<LaburnumTask<P, T>>
where
F: FnOnce(TaskContext<P, T>) -> Fut + Send + 'static,
Fut: std::future::Future<
Output = Option<crate::database::chunk::RecordWriter<P>>,
> + Send
+ 'static,
{
self.spawn_task_for_client(task_fn, lane, self.client_id)
}
pub(crate) fn spawn_task_for_client<F, Fut>(
&self,
task_fn: F,
lane: Lane,
client_id: ClientId,
) -> Arc<LaburnumTask<P, T>>
where
F: FnOnce(TaskContext<P, T>) -> Fut + Send + 'static,
Fut: std::future::Future<
Output = Option<crate::database::chunk::RecordWriter<P>>,
> + Send
+ 'static,
{
let parent_task_id = self.task_id;
let task = LaburnumTask::new_with_parent(
self.scheduler.clone(),
task_fn,
lane,
parent_task_id,
client_id,
);
self.scheduler.queue_task(task.clone());
task
}
pub fn filesystems(&self) -> Arc<parking_lot::RwLock<Vec<crate::fs::FS>>> {
self.scheduler.filesystems.clone()
}
pub fn source_cache(
&self,
) -> Arc<parking_lot::RwLock<crate::SourceCache<P, T>>> {
self.scheduler.source_cache.clone()
}
pub fn source_cache_reader(
&self,
) -> crate::source::cache::reporter::SourceCacheReader {
let guard = self.scheduler.source_cache.read();
guard.reader()
}
pub fn scheduler(&self) -> Arc<Scheduler<P, T>> {
self.scheduler.clone()
}
pub fn server(&self) -> Arc<T> {
self.scheduler.server.clone()
}
pub async fn send_notification<N: Notification>(
&self,
params: N::Params,
client_id: Option<ClientId>,
) {
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
let notification = jsonrpc::Notification::build(N::METHOD)
.params(params_value)
.finish();
let message = jsonrpc::Message::Notification(notification);
if let Some(id) = client_id {
if id == ClientId::INTERNAL {
if let Err(e) = self.connection.sender.send(message).await {
otel::error!(
"notification_send_failed",
format!("Failed to send notification {} to INTERNAL client: {:?}", N::METHOD, e)
);
}
} else {
if let Err(e) = self.scheduler.registry().send_to(id, message).await {
otel::error!(
"notification_send_failed",
format!("Failed to send notification {} to client {:?}: {:?}", N::METHOD, id, e)
);
}
}
} else {
if let Err(e) = self.connection.sender.send(message).await {
otel::error!(
"notification_send_failed",
format!("Failed to send notification {}: {:?}", N::METHOD, e)
);
}
}
}
pub async fn broadcast_notification<N: Notification>(
&self,
topic: impl crate::connect::lsp::Topic,
params: N::Params,
) {
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
let notification = jsonrpc::Notification::build(N::METHOD)
.params(params_value)
.finish();
self
.scheduler
.registry()
.broadcast(topic, notification)
.await;
}
pub(crate) async fn send_request_fire_and_forget<R: Request>(
&self,
params: R::Params,
) {
use std::sync::atomic::{
AtomicI64,
Ordering,
};
static REQUEST_ID_COUNTER: AtomicI64 = AtomicI64::new(0);
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
let id =
jsonrpc::Id::Number(REQUEST_ID_COUNTER.fetch_add(1, Ordering::Relaxed));
let request = jsonrpc::Request::build(R::METHOD, id)
.params(params_value)
.finish();
let message = jsonrpc::Message::Request(request);
let _ = self.connection.sender.send(message).await;
}
pub fn progress_tracker(&self) -> Arc<ProgressTracker> {
self.scheduler.progress_tracker.clone()
}
pub fn progress(&self, progress_id: Ident, title: &str) -> ProgressBuilder {
ProgressBuilder::new(
self.scheduler.progress_tracker.clone(),
progress_id,
title.to_string(),
self.execution_id,
)
}
pub fn get_progress(&self, progress_id: Ident) -> Option<Progress> {
if self.scheduler.progress_tracker.has_progress(progress_id) {
Some(Progress::new(
self.scheduler.progress_tracker.clone(),
progress_id,
self.execution_id,
))
} else {
None
}
}
pub fn matched_keys_updated(&mut self) -> Vec<RecordKey> {
std::mem::take(&mut self.matched_keys_updated)
}
pub fn matched_keys_deleted(&mut self) -> Vec<RecordKey> {
std::mem::take(&mut self.matched_keys_deleted)
}
pub(crate) fn set_matched_keys(
&mut self,
updated: Vec<RecordKey>,
deleted: Vec<RecordKey>,
) {
self.matched_keys_updated = updated;
self.matched_keys_deleted = deleted;
}
pub fn check_source_available(
&self,
source_key: crate::SourceKey,
) -> Result<(), crate::LaburnumError> {
let source_cache_reader = self.source_cache_reader();
if source_cache_reader.get_source(source_key).is_none() {
Err(crate::LaburnumError::TaskCancelled)
} else {
Ok(())
}
}
pub fn client_count(&self) -> usize {
self.scheduler.registry().client_count()
}
pub fn client_ids(&self) -> Vec<crate::connect::lsp::ClientId> {
self.scheduler.registry().client_ids()
}
pub fn idle_duration(&self) -> Option<std::time::Duration> {
self.scheduler.registry().idle_duration()
}
pub fn connected_clients(&self) -> Vec<crate::connect::lsp::ConnectedClientInfo> {
self.scheduler.registry().connected_clients()
}
pub fn request_shutdown(&self) {
self.scheduler.request_shutdown();
}
}
#[cfg(test)]
mod tests {
use {
super::*,
crate::{
connect::lsp::ClientId,
database::tests::storage::TestPartitions,
protocol::lsp::LspVersion,
server::LaburnumLanguageServer,
},
};
fn test_client_id() -> ClientId {
ClientId::from_raw(1)
}
fn create_test_context() -> TaskContext<TestPartitions, LaburnumLanguageServer> {
use crate::{
connect::lsp::ClientId,
database::Database,
scheduler::Scheduler,
};
let (server_conn, _client_conn) = crate::connect::ipc::Connection::memory();
let filesystems = std::sync::Arc::new(parking_lot::RwLock::new(Vec::new()));
let source_cache =
std::sync::Arc::new(parking_lot::RwLock::new(crate::SourceCache::new()));
let scheduler = Scheduler::new_with_worker_count(
server_conn,
Arc::new(LaburnumLanguageServer),
filesystems,
source_cache,
1,
);
TaskContext::new(Database::new(), scheduler, ClientId::INTERNAL)
}
#[test]
fn test_task_context_default_empty_matched_keys() {
let mut ctx = create_test_context();
assert_eq!(ctx.matched_keys_updated().len(), 0);
assert_eq!(ctx.matched_keys_deleted().len(), 0);
}
#[test]
fn test_set_matched_keys() {
use crate::database::RecordKey;
let mut ctx = create_test_context();
let updated = vec![RecordKey::new(Ident::new("pk"), "sk1".to_string())];
let deleted = vec![RecordKey::new(Ident::new("pk"), "sk2".to_string())];
ctx.set_matched_keys(updated.clone(), deleted.clone());
assert_eq!(ctx.matched_keys_updated(), updated);
assert_eq!(ctx.matched_keys_deleted(), deleted);
}
#[test]
fn test_check_source_available_returns_error_when_not_found() {
let ctx = create_test_context();
let key = crate::SourceKey::new(1, 0);
let result = ctx.check_source_available(key);
assert!(matches!(result, Err(crate::LaburnumError::TaskCancelled)));
}
#[test]
fn test_check_source_available_returns_ok_when_found() {
let ctx = create_test_context();
let uri = crate::Uri::parse("file:///test.txt").unwrap();
let source_cache_arc = ctx.source_cache();
let key = {
let mut source_cache = source_cache_arc.write();
let (key, _) = source_cache
.upsert_with_version(uri, "content".to_string(), LspVersion::new(0), test_client_id())
.unwrap();
source_cache.complete_pending_for_test(key);
key
};
let result = ctx.check_source_available(key);
assert!(result.is_ok());
}
}