use {
crate::{
Ident,
connect::lsp::{ClientId, request::WorkDoneProgressCreate},
database::{PartitionWriteContextRef, Partitions},
partitions::work_done_progress::ProgressStage,
protocol::{
jsonrpc::Message,
lsp::{
LanguageServer, NumberOrString, ProgressParams, ProgressParamsValue,
WorkDoneProgress, WorkDoneProgressBegin, WorkDoneProgressCreateParams,
WorkDoneProgressEnd, WorkDoneProgressReport,
},
},
record::LaburnumRecordRef,
scheduler::{key_watcher::WatcherResult, task::TaskContext},
},
async_channel::Sender,
dashmap::DashMap,
std::{
future::Future,
pin::Pin,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProgressState {
Open,
Closed,
}
#[derive(Debug, Clone)]
struct ProgressEntry {
state: ProgressState,
token: String,
counter: u64,
title: String,
}
#[derive(Debug)]
pub struct ProgressTracker {
entries: DashMap<Ident, ProgressEntry>,
senders: DashMap<ClientId, Sender<Message>>,
}
impl ProgressTracker {
pub fn new(client_id: ClientId, sender: Sender<Message>) -> Self {
let senders = DashMap::new();
senders.insert(client_id, sender);
Self {
entries: DashMap::new(),
senders,
}
}
pub fn new_disconnected() -> Self {
Self {
entries: DashMap::new(),
senders: DashMap::new(),
}
}
pub fn register_client(&self, client_id: ClientId, sender: Sender<Message>) {
self.senders.insert(client_id, sender);
}
pub fn unregister_client(&self, client_id: ClientId) {
self.senders.remove(&client_id);
}
pub fn has_sender(&self) -> bool {
!self.senders.is_empty()
}
fn generate_token(progress_id: Ident, counter: u64) -> String {
format!("laburnum/progress/{progress_id}/{counter}")
}
fn send_notification(&self, token: &str, progress: WorkDoneProgress) {
if self.senders.is_empty() {
return;
}
let params = ProgressParams {
token: NumberOrString::String(token.to_string()),
value: ProgressParamsValue::WorkDone(progress),
};
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
let notification =
crate::protocol::jsonrpc::Notification::build("$/progress")
.params(params_value)
.finish();
let message = Message::Notification(notification);
for sender in self.senders.iter() {
let _ = sender.value().send_blocking(message.clone());
}
}
fn send_create_request(&self, token: &str) {
use {
crate::protocol::lsp::WorkDoneProgressCreateParams,
std::sync::atomic::{AtomicI64, Ordering},
};
static REQUEST_ID_COUNTER: AtomicI64 = AtomicI64::new(1_000_000);
if self.senders.is_empty() {
return;
}
let params = WorkDoneProgressCreateParams {
token: NumberOrString::String(token.to_string()),
};
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
for sender in self.senders.iter() {
let id = crate::protocol::jsonrpc::Id::Number(
REQUEST_ID_COUNTER.fetch_add(1, Ordering::Relaxed),
);
let request = crate::protocol::jsonrpc::Request::build(
"window/workDoneProgress/create",
id,
)
.params(params_value.clone())
.finish();
let message = Message::Request(request);
let _ = sender.value().send_blocking(message);
}
}
pub fn has_progress(&self, progress_id: Ident) -> bool {
self.entries.contains_key(&progress_id)
}
pub fn begin(&self, progress_id: Ident, title: &str) {
use dashmap::Entry;
let token = match self.entries.entry(progress_id) {
| Entry::Occupied(mut entry) => {
let e = entry.get_mut();
if e.state == ProgressState::Open {
return;
}
let new_counter = e.counter + 1;
e.state = ProgressState::Open;
e.token = Self::generate_token(progress_id, new_counter);
e.counter = new_counter;
e.title = title.to_string();
e.token.clone()
},
| Entry::Vacant(entry) => {
let token = Self::generate_token(progress_id, 0);
entry.insert(ProgressEntry {
state: ProgressState::Open,
token: token.clone(),
counter: 0,
title: title.to_string(),
});
token
},
};
self.send_create_request(&token);
self.send_notification(
&token,
WorkDoneProgress::Begin(WorkDoneProgressBegin {
title: title.to_string(),
cancellable: Some(false),
message: None,
percentage: None,
}),
);
}
pub fn report(
&self,
progress_id: Ident,
message: &str,
percentage: Option<u32>,
) {
let token = {
let Some(entry) = self.entries.get(&progress_id) else {
return;
};
if entry.state != ProgressState::Open {
return;
}
entry.token.clone()
};
self.send_notification(
&token,
WorkDoneProgress::Report(WorkDoneProgressReport {
cancellable: Some(false),
message: Some(message.to_string()),
percentage,
}),
);
}
pub fn on_idle(&self) {
let mut ended: Vec<String> = Vec::new();
for mut entry in self.entries.iter_mut() {
if entry.state == ProgressState::Open {
entry.state = ProgressState::Closed;
ended.push(entry.token.clone());
}
}
for token in ended {
self.send_notification(
&token,
WorkDoneProgress::End(WorkDoneProgressEnd {
message: Some("Complete".to_string()),
}),
);
}
}
pub fn end(&self, progress_id: Ident, message: &str) {
let Some(token) = self.close(progress_id) else {
return;
};
self.send_notification(
&token,
WorkDoneProgress::End(WorkDoneProgressEnd {
message: Some(message.to_string()),
}),
);
}
pub fn is_open(&self, progress_id: Ident) -> bool {
self
.entries
.get(&progress_id)
.map(|e| e.state == ProgressState::Open)
.unwrap_or(false)
}
pub fn try_open(&self, progress_id: Ident) -> Option<String> {
use dashmap::Entry;
match self.entries.entry(progress_id) {
| Entry::Occupied(mut entry) => {
let e = entry.get_mut();
if e.state == ProgressState::Open {
None
} else {
let new_counter = e.counter + 1;
let token = Self::generate_token(progress_id, new_counter);
e.state = ProgressState::Open;
e.token = token.clone();
e.counter = new_counter;
Some(token)
}
},
| Entry::Vacant(entry) => {
let token = Self::generate_token(progress_id, 0);
entry.insert(ProgressEntry {
state: ProgressState::Open,
token: token.clone(),
counter: 0,
title: String::new(),
});
Some(token)
},
}
}
pub fn get_token(&self, progress_id: Ident) -> Option<String> {
self
.entries
.get(&progress_id)
.filter(|e| e.state == ProgressState::Open)
.map(|e| e.token.clone())
}
pub fn close(&self, progress_id: Ident) -> Option<String> {
use dashmap::Entry;
match self.entries.entry(progress_id) {
| Entry::Occupied(mut entry) => {
let e = entry.get_mut();
if e.state == ProgressState::Open {
e.state = ProgressState::Closed;
Some(e.token.clone())
} else {
None
}
},
| Entry::Vacant(_) => None,
}
}
}
struct ProgressAction {
key_ident: Ident,
stage: ProgressStage,
lsp_progress: WorkDoneProgress,
}
pub fn work_done_progress_watcher<'a, P, T>(
ctx: &'a mut TaskContext<P, T>,
_writer: &'a mut PartitionWriteContextRef<'a, P>,
) -> Pin<Box<dyn Future<Output = WatcherResult<P, T>> + Send + 'a>>
where
P: Partitions,
T: LanguageServer<P>,
{
Box::pin(async move {
let updated_keys = ctx.matched_keys_updated().to_vec();
let progress_tracker = ctx.progress_tracker();
let mut actions = Vec::new();
for key in &updated_keys {
let key_ident = key.key_ident();
let results = key.get_record(ctx.query_client()).await;
for record_meta in results.records.iter() {
let Some(record) = results.get(record_meta) else {
continue;
};
let Some(progress) = record.as_work_done_progress() else {
continue;
};
actions.push(ProgressAction {
key_ident,
stage: progress.stage(),
lsp_progress: progress.to_lsp_progress(),
});
}
}
for action in actions {
let progress_id = action.key_ident;
match action.stage {
| ProgressStage::Begin => {
let Some(token) = progress_tracker.try_open(progress_id) else {
continue;
};
let lsp_token = NumberOrString::String(token);
ctx
.send_request_fire_and_forget::<WorkDoneProgressCreate>(
WorkDoneProgressCreateParams {
token: lsp_token.clone(),
},
)
.await;
ctx
.send_notification::<crate::connect::lsp::notification::Progress>(
ProgressParams {
token: lsp_token,
value: ProgressParamsValue::WorkDone(action.lsp_progress),
},
None,
)
.await;
},
| ProgressStage::Report => {
let Some(token) = progress_tracker.get_token(progress_id) else {
continue;
};
let lsp_token = NumberOrString::String(token);
ctx
.send_notification::<crate::connect::lsp::notification::Progress>(
ProgressParams {
token: lsp_token,
value: ProgressParamsValue::WorkDone(action.lsp_progress),
},
None,
)
.await;
},
| ProgressStage::End => {
let Some(token) = progress_tracker.close(progress_id) else {
continue;
};
let lsp_token = NumberOrString::String(token);
ctx
.send_notification::<crate::connect::lsp::notification::Progress>(
ProgressParams {
token: lsp_token,
value: ProgressParamsValue::WorkDone(action.lsp_progress),
},
None,
)
.await;
},
}
}
WatcherResult::empty()
})
}
#[cfg(test)]
mod tests {
use {super::*, async_channel::Receiver, std::sync::Arc};
fn test_channel() -> (Sender<Message>, Receiver<Message>) {
async_channel::unbounded()
}
fn collect_messages(receiver: &Receiver<Message>) -> Vec<Message> {
let mut messages = Vec::new();
while let Ok(msg) = receiver.try_recv() {
messages.push(msg);
}
messages
}
fn extract_progress_kind(msg: &Message) -> Option<&'static str> {
match msg {
| Message::Notification(n) if n.method() == "$/progress" => {
let params = n.params()?;
let value = params.get("value")?;
let kind = value.get("kind")?.as_str()?;
Some(match kind {
| "begin" => "begin",
| "report" => "report",
| "end" => "end",
| _ => return None,
})
},
| Message::Request(r) if r.method() == "window/workDoneProgress/create" => {
Some("create")
},
| _ => None,
}
}
fn extract_report_message(msg: &Message) -> Option<String> {
match msg {
| Message::Notification(n) if n.method() == "$/progress" => {
let params = n.params()?;
let value = params.get("value")?;
if value.get("kind")?.as_str()? == "report" {
return value.get("message")?.as_str().map(|s| s.to_string());
}
None
},
| _ => None,
}
}
const PROJECT: Ident = Ident::new_const("project");
const PROGRESS_A: Ident = Ident::new_const("progress_a");
const PROGRESS_B: Ident = Ident::new_const("progress_b");
const PROGRESS_C: Ident = Ident::new_const("progress_c");
const TEST_PROGRESS: Ident = Ident::new_const("test_progress");
fn test_client_id() -> ClientId {
ClientId::INTERNAL
}
#[test]
fn test_try_open_returns_token_first_time() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let token = tracker.try_open(PROJECT);
assert!(token.is_some());
assert!(token.unwrap().starts_with("laburnum/progress/"));
}
#[test]
fn test_try_open_returns_none_when_already_open() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let first = tracker.try_open(PROJECT);
assert!(first.is_some());
let second = tracker.try_open(PROJECT);
assert!(second.is_none());
}
#[test]
fn test_get_token_returns_token_when_open() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let expected = tracker.try_open(PROJECT).unwrap();
let actual = tracker.get_token(PROJECT);
assert_eq!(actual, Some(expected));
}
#[test]
fn test_get_token_returns_none_when_closed() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
tracker.try_open(PROJECT);
tracker.close(PROJECT);
let token = tracker.get_token(PROJECT);
assert!(token.is_none());
}
#[test]
fn test_close_returns_token_and_marks_closed() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let expected = tracker.try_open(PROJECT).unwrap();
let closed = tracker.close(PROJECT);
assert_eq!(closed, Some(expected));
assert!(!tracker.is_open(PROJECT));
}
#[test]
fn test_can_reopen_after_close() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let first_token = tracker.try_open(PROJECT).unwrap();
tracker.close(PROJECT);
let second_token = tracker.try_open(PROJECT).unwrap();
assert_ne!(second_token, first_token);
}
#[test]
fn test_tokens_are_unique() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let t1 = tracker.try_open(PROGRESS_A).unwrap();
let t2 = tracker.try_open(PROGRESS_B).unwrap();
let t3 = tracker.try_open(PROGRESS_C).unwrap();
assert_ne!(t1, t2);
assert_ne!(t2, t3);
assert_ne!(t1, t3);
}
#[test]
fn test_begin_then_idle_closes_progress() {
let (sender, _) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.begin(TEST_PROGRESS, "Testing");
assert!(tracker.is_open(TEST_PROGRESS));
tracker.on_idle();
assert!(!tracker.is_open(TEST_PROGRESS));
}
#[test]
fn test_begin_is_idempotent_while_open() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.begin(TEST_PROGRESS, "Testing");
let _ = collect_messages(&receiver);
tracker.begin(TEST_PROGRESS, "Testing");
let kinds: Vec<_> = collect_messages(&receiver)
.iter()
.filter_map(extract_progress_kind)
.collect();
assert!(kinds.is_empty(), "begin while open should be a no-op");
}
#[test]
fn test_begin_sends_create_and_begin_notifications() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.begin(TEST_PROGRESS, "Testing");
let kinds: Vec<_> = collect_messages(&receiver)
.iter()
.filter_map(extract_progress_kind)
.collect();
assert_eq!(
kinds,
vec!["create", "begin"],
"begin should send Create request then Begin notification"
);
}
#[test]
fn test_report_sends_report_notification() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.begin(TEST_PROGRESS, "Testing");
let _ = collect_messages(&receiver);
tracker.report(TEST_PROGRESS, "Parsing", None);
let messages = collect_messages(&receiver);
let kinds: Vec<_> =
messages.iter().filter_map(extract_progress_kind).collect();
assert_eq!(
kinds,
vec!["report"],
"report should send Report notification"
);
assert_eq!(
extract_report_message(&messages[0]),
Some("Parsing".to_string()),
"Report should carry the message text"
);
}
#[test]
fn test_report_ignored_when_not_open() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.report(TEST_PROGRESS, "Parsing", None);
let kinds: Vec<_> = collect_messages(&receiver)
.iter()
.filter_map(extract_progress_kind)
.collect();
assert!(kinds.is_empty(), "report without begin should do nothing");
}
#[test]
fn test_full_progress_lifecycle_notifications() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.begin(TEST_PROGRESS, "Processing");
tracker.report(TEST_PROGRESS, "Parsing", None);
tracker.report(TEST_PROGRESS, "Analyzing", None);
tracker.on_idle();
let kinds: Vec<_> = collect_messages(&receiver)
.iter()
.filter_map(extract_progress_kind)
.collect();
assert_eq!(
kinds,
vec!["create", "begin", "report", "report", "end"],
"Full lifecycle: create, begin, report, report, end"
);
}
#[test]
fn test_reopen_after_idle_uses_new_token() {
let (sender, _) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
tracker.begin(TEST_PROGRESS, "Testing");
let first = tracker.get_token(TEST_PROGRESS).unwrap();
tracker.on_idle();
tracker.begin(TEST_PROGRESS, "Testing");
let second = tracker.get_token(TEST_PROGRESS).unwrap();
assert_ne!(first, second, "reopen after idle should mint a new token");
}
}