use {
crate::{
Uri,
connect::lsp::{ClientId, notification::PublishDiagnostics},
database::{
PartitionWriteContextRef,
Partitions,
},
protocol::lsp::{
self,
LanguageServer,
PositionEncodingKind,
PublishDiagnosticsParams,
},
record::LaburnumRecordRef,
scheduler::{key_watcher::WatcherResult, task::TaskContext},
source::SourceKey,
},
std::{
collections::HashMap,
future::Future,
pin::Pin,
},
};
type ClientDiagnostics = HashMap<Uri, (Option<SourceKey>, Vec<lsp::Diagnostic>)>;
pub fn diagnostic_watcher<'a, P, T>(
ctx: &'a mut TaskContext<P, T>,
_writer: &'a mut PartitionWriteContextRef<'a, P>,
) -> Pin<Box<dyn Future<Output = WatcherResult<P, T>> + Send + 'a>>
where
P: Partitions,
T: LanguageServer<P>,
{
Box::pin(async move {
let updated_keys = ctx.matched_keys_updated().to_vec();
let deleted_keys = ctx.matched_keys_deleted().to_vec();
eprintln!(
"DEBUG: diagnostic_watcher called - updated: {}, deleted: {}",
updated_keys.len(),
deleted_keys.len()
);
let source_cache_reader = ctx.source_cache_reader();
let mut per_client: HashMap<ClientId, ClientDiagnostics> = HashMap::new();
let mut broadcast: ClientDiagnostics = HashMap::new();
for key in &updated_keys {
let results = key.get_record(ctx.query_client()).await;
for record_meta in results.records.iter() {
let Some(record) = results.get(record_meta) else {
continue;
};
let Some(diagnostic) = record.as_dyn_diagnostic() else {
continue;
};
let Some(source_key) = diagnostic.source_key() else {
continue;
};
let Some(source) = source_cache_reader.get_source(source_key) else {
continue;
};
let uri = source.uri().clone();
let clients = source_cache_reader.clients_for(source_key);
if clients.is_empty() {
let lsp_diag = diagnostic.to_lsp_diagnostic(
&source_cache_reader,
&PositionEncodingKind::DEFAULT,
);
let entry = broadcast.entry(uri).or_default();
entry.0 = Some(source_key);
entry.1.push(lsp_diag);
} else {
for client_id in clients {
let encoding = ctx
.scheduler()
.registry()
.get(client_id)
.map(|c| c.position_encoding().clone())
.unwrap_or(PositionEncodingKind::DEFAULT);
let lsp_diag =
diagnostic.to_lsp_diagnostic(&source_cache_reader, &encoding);
let client_map = per_client.entry(client_id).or_default();
let entry = client_map.entry(uri.clone()).or_default();
entry.0 = Some(source_key);
entry.1.push(lsp_diag);
}
}
}
}
let mut notified_uris: std::collections::HashSet<Uri> =
std::collections::HashSet::new();
for (client_id, uri_map) in per_client {
for (uri, (source_key, diagnostics)) in uri_map {
notified_uris.insert(uri.clone());
let version =
source_key.and_then(|sk| source_cache_reader.get_lsp_version(sk));
let params = PublishDiagnosticsParams::new(uri, diagnostics, version);
ctx
.send_notification::<PublishDiagnostics>(params, Some(client_id))
.await;
}
}
for (uri, (source_key, diagnostics)) in broadcast {
notified_uris.insert(uri.clone());
let version =
source_key.and_then(|sk| source_cache_reader.get_lsp_version(sk));
let params = PublishDiagnosticsParams::new(uri, diagnostics, version);
ctx
.send_notification::<PublishDiagnostics>(params, None)
.await;
}
let mut deleted_uris: HashMap<Uri, Option<i32>> = HashMap::new();
for key in &deleted_keys {
let Some(source_key) = key.source_key() else {
continue;
};
let Some(source) = source_cache_reader.get_source(source_key) else {
continue;
};
let uri = source.uri().clone();
let version = source_cache_reader.get_lsp_version(source_key);
deleted_uris.insert(uri, version);
}
for (uri, version) in deleted_uris {
if !notified_uris.contains(&uri) {
let params = PublishDiagnosticsParams::new(uri, vec![], version);
ctx.send_notification::<PublishDiagnostics>(params, None).await;
}
}
WatcherResult::empty()
})
}