use {
crate::{
Ident, Uri,
connect::lsp::{ClientId, request::WorkDoneProgressCreate},
database::{PartitionWriteContextRef, Partitions},
partitions::work_done_progress::ProgressStage,
protocol::{
jsonrpc::Message,
lsp::{
LanguageServer, NumberOrString, ProgressParams, ProgressParamsValue,
WorkDoneProgress, WorkDoneProgressBegin, WorkDoneProgressCreateParams,
WorkDoneProgressEnd, WorkDoneProgressReport,
},
},
record::LaburnumRecordRef,
scheduler::{key_watcher::WatcherResult, task::TaskContext},
},
async_channel::Sender,
dashmap::DashMap,
std::{
collections::{HashMap, HashSet},
future::Future,
pin::Pin,
sync::Arc,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProgressState {
Open,
Closed,
}
#[derive(Debug, Clone)]
struct StageInfo {
task_id: Ident,
}
#[derive(Debug, Clone)]
struct ProgressEntry {
state: ProgressState,
token: String,
counter: u64,
title: String,
declared_stages: HashSet<Ident>,
active_stages: HashMap<Ident, StageInfo>,
completed_stages: HashSet<Ident>,
}
#[derive(Debug)]
pub struct ProgressTracker {
entries: DashMap<Ident, ProgressEntry>,
senders: DashMap<ClientId, Sender<Message>>,
}
pub struct ProgressBuilder {
tracker: Arc<ProgressTracker>,
progress_id: Ident,
title: String,
stages: Vec<Ident>,
task_id: Ident,
}
impl ProgressBuilder {
pub fn new(
tracker: Arc<ProgressTracker>,
progress_id: Ident,
title: String,
task_id: Ident,
) -> Self {
Self {
tracker,
progress_id,
title,
stages: Vec::new(),
task_id,
}
}
pub fn with_stages<I>(mut self, stages: I) -> Self
where
I: IntoIterator<Item = Ident>,
{
self.stages = stages.into_iter().collect();
self
}
pub fn build(self) -> Progress {
self.tracker.create_progress(
self.progress_id,
&self.title,
self.stages,
self.task_id,
)
}
}
pub struct Progress {
tracker: Arc<ProgressTracker>,
progress_id: Ident,
task_id: Ident,
}
impl Progress {
pub fn new(
tracker: Arc<ProgressTracker>,
progress_id: Ident,
task_id: Ident,
) -> Self {
Self {
tracker,
progress_id,
task_id,
}
}
pub fn stage(&self, stage: Ident, name: &str) -> StageBuilder<'_> {
StageBuilder {
progress: self,
stage,
name: name.to_string(),
uri: None,
}
}
pub fn add_stage(&self, stage: Ident) {
self.tracker.add_dynamic_stage(self.progress_id, stage);
}
}
pub struct StageBuilder<'a> {
progress: &'a Progress,
stage: Ident,
name: String,
uri: Option<Uri>,
}
impl<'a> StageBuilder<'a> {
pub fn with_uri(mut self, uri: &Uri) -> Self {
self.uri = Some(uri.clone());
self
}
pub fn build(self) -> StageHandle {
let message = match &self.uri {
| Some(uri) => {
let filename = uri
.path_segments()
.and_then(|mut s| s.next_back())
.unwrap_or("");
format!("{} {}", self.name, filename)
},
| None => self.name.clone(),
};
self.progress.tracker.begin_stage_and_notify(
self.progress.progress_id,
self.stage,
self.progress.task_id,
&message,
);
StageHandle {
tracker: self.progress.tracker.clone(),
progress_id: self.progress.progress_id,
stage: self.stage,
}
}
}
pub struct StageHandle {
tracker: Arc<ProgressTracker>,
progress_id: Ident,
stage: Ident,
}
impl Drop for StageHandle {
fn drop(&mut self) {
self.tracker.complete_stage(self.progress_id, self.stage);
}
}
impl ProgressTracker {
pub fn new(client_id: ClientId, sender: Sender<Message>) -> Self {
let senders = DashMap::new();
senders.insert(client_id, sender);
Self {
entries: DashMap::new(),
senders,
}
}
pub fn new_disconnected() -> Self {
Self {
entries: DashMap::new(),
senders: DashMap::new(),
}
}
pub fn register_client(&self, client_id: ClientId, sender: Sender<Message>) {
self.senders.insert(client_id, sender);
}
pub fn unregister_client(&self, client_id: ClientId) {
self.senders.remove(&client_id);
}
pub fn has_sender(&self) -> bool {
!self.senders.is_empty()
}
fn generate_token(progress_id: Ident, counter: u64) -> String {
format!("laburnum/progress/{}/{}", progress_id.as_u64(), counter)
}
fn send_notification(&self, token: &str, progress: WorkDoneProgress) {
if self.senders.is_empty() {
return;
}
let params = ProgressParams {
token: NumberOrString::String(token.to_string()),
value: ProgressParamsValue::WorkDone(progress),
};
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
let notification =
crate::protocol::jsonrpc::Notification::build("$/progress")
.params(params_value)
.finish();
let message = Message::Notification(notification);
for sender in self.senders.iter() {
let _ = sender.value().send_blocking(message.clone());
}
}
fn send_create_request(&self, token: &str) {
use {
crate::protocol::lsp::WorkDoneProgressCreateParams,
std::sync::atomic::{AtomicI64, Ordering},
};
static REQUEST_ID_COUNTER: AtomicI64 = AtomicI64::new(1_000_000);
if self.senders.is_empty() {
return;
}
let params = WorkDoneProgressCreateParams {
token: NumberOrString::String(token.to_string()),
};
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
for sender in self.senders.iter() {
let id = crate::protocol::jsonrpc::Id::Number(
REQUEST_ID_COUNTER.fetch_add(1, Ordering::Relaxed),
);
let request = crate::protocol::jsonrpc::Request::build(
"window/workDoneProgress/create",
id,
)
.params(params_value.clone())
.finish();
let message = Message::Request(request);
let _ = sender.value().send_blocking(message);
}
}
pub fn has_progress(&self, progress_id: Ident) -> bool {
self.entries.contains_key(&progress_id)
}
pub fn create_progress(
self: &Arc<Self>,
progress_id: Ident,
title: &str,
stages: Vec<Ident>,
task_id: Ident,
) -> Progress {
use dashmap::Entry;
let token = match self.entries.entry(progress_id) {
| Entry::Occupied(mut entry) => {
let e = entry.get_mut();
if e.state == ProgressState::Open {
e.token.clone()
} else {
let new_counter = e.counter + 1;
let token = Self::generate_token(progress_id, new_counter);
e.state = ProgressState::Open;
e.token = token.clone();
e.counter = new_counter;
e.title = title.to_string();
e.declared_stages = stages.iter().copied().collect();
e.active_stages.clear();
e.completed_stages.clear();
token
}
},
| Entry::Vacant(entry) => {
let token = Self::generate_token(progress_id, 0);
entry.insert(ProgressEntry {
state: ProgressState::Open,
token: token.clone(),
counter: 0,
title: title.to_string(),
declared_stages: stages.iter().copied().collect(),
active_stages: HashMap::new(),
completed_stages: HashSet::new(),
});
token
},
};
self.send_create_request(&token);
self.send_notification(
&token,
WorkDoneProgress::Begin(WorkDoneProgressBegin {
title: title.to_string(),
cancellable: Some(false),
message: None,
percentage: Some(0),
}),
);
Progress {
tracker: self.clone(),
progress_id,
task_id,
}
}
pub fn add_dynamic_stage(&self, progress_id: Ident, stage: Ident) {
if let Some(mut entry) = self.entries.get_mut(&progress_id) {
entry.declared_stages.insert(stage);
}
}
pub fn begin_stage_and_notify(
&self,
progress_id: Ident,
stage: Ident,
task_id: Ident,
message: &str,
) {
let token = {
let Some(mut entry) = self.entries.get_mut(&progress_id) else {
return;
};
entry.active_stages.insert(stage, StageInfo { task_id });
entry.token.clone()
};
self.send_notification(
&token,
WorkDoneProgress::Report(WorkDoneProgressReport {
cancellable: Some(false),
message: Some(message.to_string()),
percentage: None,
}),
);
}
pub fn complete_stage(&self, progress_id: Ident, stage: Ident) {
let (is_complete, token) = {
let Some(mut entry) = self.entries.get_mut(&progress_id) else {
return;
};
entry.active_stages.remove(&stage);
entry.completed_stages.insert(stage);
let is_complete = entry.active_stages.is_empty()
&& entry.declared_stages.is_subset(&entry.completed_stages);
if is_complete {
entry.state = ProgressState::Closed;
}
(is_complete, entry.token.clone())
};
if is_complete {
self.send_notification(
&token,
WorkDoneProgress::End(WorkDoneProgressEnd {
message: Some("Complete".to_string()),
}),
);
}
}
pub fn on_task_complete(&self, task_id: Ident) {
let mut entries_to_check: Vec<(Ident, String)> = Vec::new();
for mut entry in self.entries.iter_mut() {
let stages_to_complete: Vec<Ident> = entry
.active_stages
.iter()
.filter(|(_, info)| info.task_id == task_id)
.map(|(stage, _)| *stage)
.collect();
for stage in stages_to_complete {
entry.active_stages.remove(&stage);
entry.completed_stages.insert(stage);
}
let is_complete = entry.active_stages.is_empty()
&& entry.declared_stages.is_subset(&entry.completed_stages);
if is_complete && entry.state == ProgressState::Open {
entry.state = ProgressState::Closed;
entries_to_check.push((*entry.key(), entry.token.clone()));
}
}
for (_, token) in entries_to_check {
self.send_notification(
&token,
WorkDoneProgress::End(WorkDoneProgressEnd {
message: Some("Complete".to_string()),
}),
);
}
}
pub fn is_open(&self, progress_id: Ident) -> bool {
self
.entries
.get(&progress_id)
.map(|e| e.state == ProgressState::Open)
.unwrap_or(false)
}
pub fn try_open(&self, progress_id: Ident) -> Option<String> {
use dashmap::Entry;
match self.entries.entry(progress_id) {
| Entry::Occupied(mut entry) => {
let e = entry.get_mut();
if e.state == ProgressState::Open {
None
} else {
let new_counter = e.counter + 1;
let token = Self::generate_token(progress_id, new_counter);
e.state = ProgressState::Open;
e.token = token.clone();
e.counter = new_counter;
Some(token)
}
},
| Entry::Vacant(entry) => {
let token = Self::generate_token(progress_id, 0);
entry.insert(ProgressEntry {
state: ProgressState::Open,
token: token.clone(),
counter: 0,
title: String::new(),
declared_stages: HashSet::new(),
active_stages: HashMap::new(),
completed_stages: HashSet::new(),
});
Some(token)
},
}
}
pub fn get_token(&self, progress_id: Ident) -> Option<String> {
self
.entries
.get(&progress_id)
.filter(|e| e.state == ProgressState::Open)
.map(|e| e.token.clone())
}
pub fn close(&self, progress_id: Ident) -> Option<String> {
use dashmap::Entry;
match self.entries.entry(progress_id) {
| Entry::Occupied(mut entry) => {
let e = entry.get_mut();
if e.state == ProgressState::Open {
e.state = ProgressState::Closed;
Some(e.token.clone())
} else {
None
}
},
| Entry::Vacant(_) => None,
}
}
}
struct ProgressAction {
key_ident: Ident,
stage: ProgressStage,
lsp_progress: WorkDoneProgress,
}
pub fn work_done_progress_watcher<'a, P, T>(
ctx: &'a mut TaskContext<P, T>,
_writer: &'a mut PartitionWriteContextRef<'a, P>,
) -> Pin<Box<dyn Future<Output = WatcherResult<P, T>> + Send + 'a>>
where
P: Partitions,
T: LanguageServer<P>,
{
Box::pin(async move {
let updated_keys = ctx.matched_keys_updated().to_vec();
let progress_tracker = ctx.progress_tracker();
let mut actions = Vec::new();
for key in &updated_keys {
let key_ident = key.key_ident();
let results = key.get_record(ctx.query_client()).await;
for record_meta in results.records.iter() {
let Some(record) = results.get(record_meta) else {
continue;
};
let Some(progress) = record.as_work_done_progress() else {
continue;
};
actions.push(ProgressAction {
key_ident,
stage: progress.stage(),
lsp_progress: progress.to_lsp_progress(),
});
}
}
for action in actions {
let progress_id = action.key_ident;
match action.stage {
| ProgressStage::Begin => {
let Some(token) = progress_tracker.try_open(progress_id) else {
continue;
};
let lsp_token = NumberOrString::String(token);
ctx
.send_request_fire_and_forget::<WorkDoneProgressCreate>(
WorkDoneProgressCreateParams {
token: lsp_token.clone(),
},
)
.await;
ctx
.send_notification::<crate::connect::lsp::notification::Progress>(
ProgressParams {
token: lsp_token,
value: ProgressParamsValue::WorkDone(action.lsp_progress),
},
None,
)
.await;
},
| ProgressStage::Report => {
let Some(token) = progress_tracker.get_token(progress_id) else {
continue;
};
let lsp_token = NumberOrString::String(token);
ctx
.send_notification::<crate::connect::lsp::notification::Progress>(
ProgressParams {
token: lsp_token,
value: ProgressParamsValue::WorkDone(action.lsp_progress),
},
None,
)
.await;
},
| ProgressStage::End => {
let Some(token) = progress_tracker.close(progress_id) else {
continue;
};
let lsp_token = NumberOrString::String(token);
ctx
.send_notification::<crate::connect::lsp::notification::Progress>(
ProgressParams {
token: lsp_token,
value: ProgressParamsValue::WorkDone(action.lsp_progress),
},
None,
)
.await;
},
}
}
WatcherResult::empty()
})
}
#[cfg(test)]
mod tests {
use {super::*, async_channel::Receiver};
fn test_channel() -> (Sender<Message>, Receiver<Message>) {
async_channel::unbounded()
}
fn collect_messages(receiver: &Receiver<Message>) -> Vec<Message> {
let mut messages = Vec::new();
while let Ok(msg) = receiver.try_recv() {
messages.push(msg);
}
messages
}
fn extract_progress_kind(msg: &Message) -> Option<&'static str> {
match msg {
| Message::Notification(n) if n.method() == "$/progress" => {
let params = n.params()?;
let value = params.get("value")?;
let kind = value.get("kind")?.as_str()?;
Some(match kind {
| "begin" => "begin",
| "report" => "report",
| "end" => "end",
| _ => return None,
})
},
| Message::Request(r) if r.method() == "window/workDoneProgress/create" => {
Some("create")
},
| _ => None,
}
}
fn extract_report_message(msg: &Message) -> Option<String> {
match msg {
| Message::Notification(n) if n.method() == "$/progress" => {
let params = n.params()?;
let value = params.get("value")?;
if value.get("kind")?.as_str()? == "report" {
return value.get("message")?.as_str().map(|s| s.to_string());
}
None
},
| _ => None,
}
}
const PROJECT: Ident = Ident::new("project");
const PROGRESS_A: Ident = Ident::new("progress_a");
const PROGRESS_B: Ident = Ident::new("progress_b");
const PROGRESS_C: Ident = Ident::new("progress_c");
const TEST_PROGRESS: Ident = Ident::new("test_progress");
fn test_client_id() -> ClientId {
ClientId::INTERNAL
}
#[test]
fn test_try_open_returns_token_first_time() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let token = tracker.try_open(PROJECT);
assert!(token.is_some());
assert!(token.unwrap().starts_with("laburnum/progress/"));
}
#[test]
fn test_try_open_returns_none_when_already_open() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let first = tracker.try_open(PROJECT);
assert!(first.is_some());
let second = tracker.try_open(PROJECT);
assert!(second.is_none());
}
#[test]
fn test_get_token_returns_token_when_open() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let expected = tracker.try_open(PROJECT).unwrap();
let actual = tracker.get_token(PROJECT);
assert_eq!(actual, Some(expected));
}
#[test]
fn test_get_token_returns_none_when_closed() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
tracker.try_open(PROJECT);
tracker.close(PROJECT);
let token = tracker.get_token(PROJECT);
assert!(token.is_none());
}
#[test]
fn test_close_returns_token_and_marks_closed() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let expected = tracker.try_open(PROJECT).unwrap();
let closed = tracker.close(PROJECT);
assert_eq!(closed, Some(expected));
assert!(!tracker.is_open(PROJECT));
}
#[test]
fn test_can_reopen_after_close() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let first_token = tracker.try_open(PROJECT).unwrap();
tracker.close(PROJECT);
let second_token = tracker.try_open(PROJECT).unwrap();
assert_ne!(second_token, first_token);
}
#[test]
fn test_tokens_are_unique() {
let (sender, _) = test_channel();
let tracker = ProgressTracker::new(test_client_id(), sender);
let t1 = tracker.try_open(PROGRESS_A).unwrap();
let t2 = tracker.try_open(PROGRESS_B).unwrap();
let t3 = tracker.try_open(PROGRESS_C).unwrap();
assert_ne!(t1, t2);
assert_ne!(t2, t3);
assert_ne!(t1, t3);
}
#[test]
fn test_staged_progress_completes_when_all_stages_done() {
let (sender, _) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
let task_id = Ident::new("test_task");
let progress = tracker.create_progress(
TEST_PROGRESS,
"Testing",
vec![Ident::new("stage1"), Ident::new("stage2")],
task_id,
);
assert!(tracker.is_open(TEST_PROGRESS));
{
let _s1 = progress.stage(Ident::new("stage1"), "Stage 1").build();
}
assert!(tracker.is_open(TEST_PROGRESS));
{
let _s2 = progress.stage(Ident::new("stage2"), "Stage 2").build();
}
assert!(!tracker.is_open(TEST_PROGRESS));
}
#[test]
fn test_on_task_complete_cleans_up_stages() {
let (sender, _) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
let task_id = Ident::new("test_task");
let progress = tracker.create_progress(
TEST_PROGRESS,
"Testing",
vec![Ident::new("stage1")],
task_id,
);
let _s1 = progress.stage(Ident::new("stage1"), "Stage 1").build();
assert!(tracker.is_open(TEST_PROGRESS));
std::mem::forget(_s1);
tracker.on_task_complete(task_id);
assert!(!tracker.is_open(TEST_PROGRESS));
}
#[test]
fn test_create_progress_sends_create_and_begin_notifications() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
let task_id = Ident::new("test_task");
let _progress = tracker.create_progress(
TEST_PROGRESS,
"Testing",
vec![Ident::new("stage1")],
task_id,
);
let messages = collect_messages(&receiver);
let kinds: Vec<_> =
messages.iter().filter_map(extract_progress_kind).collect();
assert_eq!(
kinds,
vec!["create", "begin"],
"create_progress should send Create request then Begin notification"
);
}
#[test]
fn test_stage_start_sends_report_notification() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
let task_id = Ident::new("test_task");
let progress = tracker.create_progress(
TEST_PROGRESS,
"Testing",
vec![Ident::new("parsing")],
task_id,
);
let _ = collect_messages(&receiver);
let _stage = progress.stage(Ident::new("parsing"), "Parsing").build();
let messages = collect_messages(&receiver);
let kinds: Vec<_> =
messages.iter().filter_map(extract_progress_kind).collect();
assert_eq!(
kinds,
vec!["report"],
"stage().build() should send Report notification"
);
let report_msg = extract_report_message(&messages[0]);
assert_eq!(
report_msg,
Some("Parsing".to_string()),
"Report should contain stage name"
);
}
#[test]
fn test_stage_completion_sends_end_notification() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
let task_id = Ident::new("test_task");
let progress = tracker.create_progress(
TEST_PROGRESS,
"Testing",
vec![Ident::new("stage1")],
task_id,
);
let _ = collect_messages(&receiver);
{
let _stage = progress.stage(Ident::new("stage1"), "Stage 1").build();
}
let messages = collect_messages(&receiver);
let kinds: Vec<_> =
messages.iter().filter_map(extract_progress_kind).collect();
assert_eq!(
kinds,
vec!["report", "end"],
"stage drop should send End notification when all stages complete"
);
}
#[test]
fn test_full_progress_lifecycle_notifications() {
let (sender, receiver) = test_channel();
let tracker = Arc::new(ProgressTracker::new(test_client_id(), sender));
let task_id = Ident::new("test_task");
let progress = tracker.create_progress(
TEST_PROGRESS,
"Processing",
vec![Ident::new("parsing"), Ident::new("analyzing")],
task_id,
);
{
let _parse = progress.stage(Ident::new("parsing"), "Parsing").build();
}
{
let _analyze =
progress.stage(Ident::new("analyzing"), "Analyzing").build();
}
let messages = collect_messages(&receiver);
let kinds: Vec<_> =
messages.iter().filter_map(extract_progress_kind).collect();
assert_eq!(
kinds,
vec!["create", "begin", "report", "report", "end"],
"Full lifecycle: create, begin, report (parsing), report (analyzing), end"
);
}
}