#![allow(clippy::mutable_key_type)]
mod task_pool;
use std::collections::HashMap;
use std::panic::AssertUnwindSafe;
use std::path::{Path, PathBuf};
use std::thread::JoinHandle;
use crossbeam_channel::{Receiver, Sender, select, unbounded};
use lsp_server::{Connection, ErrorCode, Message, Notification, Request, RequestId, Response};
use lsp_types::notification::{
DidChangeConfiguration, DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument,
Notification as _, PublishDiagnostics,
};
use lsp_types::request::{Formatting, Request as _};
use lsp_types::{
Diagnostic, DiagnosticSeverity, DidChangeConfigurationParams, DidChangeTextDocumentParams,
DidCloseTextDocumentParams, DidOpenTextDocumentParams, DocumentFormattingParams,
FormattingOptions, OneOf, Position, PublishDiagnosticsParams, Range, ServerCapabilities,
TextDocumentContentChangeEvent, TextDocumentSyncCapability, TextDocumentSyncKind, TextEdit,
Uri,
};
use salsa::Database as _;
use serde::Deserialize;
use crate::formatter::{FormatStyle, format_node, format_with_style};
use crate::incremental::{Analysis, IncrementalDatabase};
use crate::text::LineIndex;
use task_pool::{Spawner, TaskPool, read_pool_size};
type DynError = Box<dyn std::error::Error + Sync + Send>;
pub fn run() -> Result<(), DynError> {
let (connection, io_threads) = Connection::stdio();
serve(connection)?;
io_threads.join()?;
Ok(())
}
pub fn serve(connection: Connection) -> Result<(), DynError> {
let capabilities = serde_json::to_value(server_capabilities())?;
let init_params = connection.initialize(capabilities)?;
main_loop(connection, init_params)
}
fn server_capabilities() -> ServerCapabilities {
ServerCapabilities {
text_document_sync: Some(TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::INCREMENTAL,
)),
document_formatting_provider: Some(OneOf::Left(true)),
..Default::default()
}
}
struct Document {
text: String,
version: i32,
}
struct GlobalState {
documents: HashMap<Uri, Document>,
editor_settings: EditorSettings,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase", default)]
struct EditorSettings {
line_width: Option<u32>,
indent_width: Option<u32>,
}
impl EditorSettings {
fn from_client_value(value: &serde_json::Value) -> Self {
let section = value
.get("badness")
.filter(|v| v.is_object())
.unwrap_or(value);
serde_json::from_value(section.clone()).unwrap_or_default()
}
fn to_format_style(&self) -> FormatStyle {
let mut style = FormatStyle::default();
if let Some(width) = self.line_width {
style.line_width = width as usize;
}
if let Some(width) = self.indent_width {
style.indent_width = width as usize;
}
style
}
}
fn resolve_style(settings: &EditorSettings, options: &FormattingOptions) -> FormatStyle {
let mut style = settings.to_format_style();
if options.tab_size > 0 {
style.indent_width = options.tab_size as usize;
}
style
}
enum WorkerJob {
Edit {
uri: Uri,
path: PathBuf,
text: String,
version: i32,
},
Close { path: PathBuf },
Format {
id: RequestId,
path: PathBuf,
text: String,
style: FormatStyle,
},
}
enum Outbound {
Diagnostics {
uri: Uri,
version: i32,
diags: Vec<Diagnostic>,
},
Response(Response),
}
fn uri_to_path(uri: &Uri) -> PathBuf {
PathBuf::from(uri.as_str())
}
fn main_loop(connection: Connection, init_params: serde_json::Value) -> Result<(), DynError> {
let editor_settings = init_params
.get("initializationOptions")
.map(EditorSettings::from_client_value)
.unwrap_or_default();
let mut state = GlobalState {
documents: HashMap::new(),
editor_settings,
};
let read_pool = TaskPool::new("badness-lsp-read", read_pool_size());
let (job_tx, job_rx) = unbounded::<WorkerJob>();
let (out_tx, out_rx) = unbounded::<Outbound>();
let worker = spawn_worker(job_rx, out_tx, read_pool.spawner());
loop {
select! {
recv(connection.receiver) -> msg => {
let Ok(msg) = msg else { break };
match msg {
Message::Request(req) => {
if connection.handle_shutdown(&req)? {
break;
}
match req.method.as_str() {
Formatting::METHOD => on_formatting(&connection, &state, &job_tx, req),
_ => respond_unhandled(&connection, req),
}
}
Message::Notification(not) => {
on_notification(&connection, &mut state, &job_tx, not);
}
Message::Response(_) => {}
}
}
recv(out_rx) -> outbound => {
let Ok(outbound) = outbound else { continue };
forward_outbound(&connection, &state, outbound);
}
}
}
drop(job_tx);
let _ = worker.join();
Ok(())
}
fn on_notification(
connection: &Connection,
state: &mut GlobalState,
job_tx: &Sender<WorkerJob>,
not: Notification,
) {
match not.method.as_str() {
DidOpenTextDocument::METHOD => {
let Ok(params) = not.extract::<DidOpenTextDocumentParams>(DidOpenTextDocument::METHOD)
else {
return;
};
let doc = params.text_document;
let uri = doc.uri;
state.documents.insert(
uri.clone(),
Document {
text: doc.text.clone(),
version: doc.version,
},
);
let _ = job_tx.send(WorkerJob::Edit {
path: uri_to_path(&uri),
uri,
text: doc.text,
version: doc.version,
});
}
DidChangeTextDocument::METHOD => {
let Ok(params) =
not.extract::<DidChangeTextDocumentParams>(DidChangeTextDocument::METHOD)
else {
return;
};
let uri = params.text_document.uri;
let version = params.text_document.version;
let Some(doc) = state.documents.get_mut(&uri) else {
return;
};
apply_content_changes(&mut doc.text, params.content_changes);
doc.version = version;
let text = doc.text.clone();
let _ = job_tx.send(WorkerJob::Edit {
path: uri_to_path(&uri),
uri,
text,
version,
});
}
DidCloseTextDocument::METHOD => {
let Ok(params) =
not.extract::<DidCloseTextDocumentParams>(DidCloseTextDocument::METHOD)
else {
return;
};
let uri = params.text_document.uri;
state.documents.remove(&uri);
let _ = job_tx.send(WorkerJob::Close {
path: uri_to_path(&uri),
});
send_diagnostics(connection, uri, Vec::new(), None);
}
DidChangeConfiguration::METHOD => {
if let Ok(params) =
not.extract::<DidChangeConfigurationParams>(DidChangeConfiguration::METHOD)
{
state.editor_settings = EditorSettings::from_client_value(¶ms.settings);
}
}
_ => {}
}
}
fn apply_content_changes(text: &mut String, changes: Vec<TextDocumentContentChangeEvent>) {
for change in changes {
match change.range {
None => *text = change.text,
Some(range) => {
let idx = LineIndex::new(text);
let start = idx.offset_at(text, range.start.line, range.start.character);
let end = idx.offset_at(text, range.end.line, range.end.character);
let (start, end) = (start.min(end), start.max(end));
text.replace_range(start..end, &change.text);
}
}
}
}
fn on_formatting(
connection: &Connection,
state: &GlobalState,
job_tx: &Sender<WorkerJob>,
req: Request,
) {
let id = req.id.clone();
let params = match req.extract::<DocumentFormattingParams>(Formatting::METHOD) {
Ok((_, params)) => params,
Err(_) => {
let resp = Response::new_err(
id,
ErrorCode::InvalidParams as i32,
"invalid formatting params".to_owned(),
);
let _ = connection.sender.send(Message::Response(resp));
return;
}
};
let uri = params.text_document.uri;
let Some(doc) = state.documents.get(&uri) else {
let _ = connection.sender.send(Message::Response(Response::new_ok(
id,
serde_json::Value::Null,
)));
return;
};
let style = resolve_style(&state.editor_settings, ¶ms.options);
let _ = job_tx.send(WorkerJob::Format {
id,
path: uri_to_path(&uri),
text: doc.text.clone(),
style,
});
}
fn forward_outbound(connection: &Connection, state: &GlobalState, outbound: Outbound) {
match outbound {
Outbound::Diagnostics {
uri,
version,
diags,
} => {
if state
.documents
.get(&uri)
.is_some_and(|doc| doc.version == version)
{
send_diagnostics(connection, uri, diags, Some(version));
}
}
Outbound::Response(resp) => {
let _ = connection.sender.send(Message::Response(resp));
}
}
}
struct AnalyzeDone {
uri: Uri,
version: i32,
}
struct InflightAnalyze {
uri: Uri,
version: i32,
}
struct AnalyzeRequest {
uri: Uri,
path: PathBuf,
version: i32,
}
#[derive(Debug, PartialEq, Eq)]
enum DispatchAction {
Wait,
Start(Uri),
SupersedeAndStart(Uri),
}
fn decide(inflight: Option<(&Uri, i32)>, pending: &HashMap<Uri, i32>) -> DispatchAction {
match inflight {
None => match pending.keys().next() {
Some(uri) => DispatchAction::Start(uri.clone()),
None => DispatchAction::Wait,
},
Some((uri, version)) => {
if pending.get(uri).is_some_and(|&v| v > version) {
DispatchAction::SupersedeAndStart(uri.clone())
} else {
DispatchAction::Wait
}
}
}
}
fn spawn_worker(
job_rx: Receiver<WorkerJob>,
out_tx: Sender<Outbound>,
read_spawner: Spawner,
) -> JoinHandle<()> {
let (done_tx, done_rx) = unbounded::<AnalyzeDone>();
std::thread::Builder::new()
.name("badness-lsp-worker".to_owned())
.spawn(move || {
let mut worker = Worker {
db: IncrementalDatabase::default(),
out_tx,
done_tx,
read_spawner,
inflight: None,
pending: HashMap::new(),
};
worker.run(&job_rx, &done_rx);
})
.expect("spawn LSP worker thread")
}
struct Worker {
db: IncrementalDatabase,
out_tx: Sender<Outbound>,
done_tx: Sender<AnalyzeDone>,
read_spawner: Spawner,
inflight: Option<InflightAnalyze>,
pending: HashMap<Uri, AnalyzeRequest>,
}
impl Worker {
fn run(&mut self, job_rx: &Receiver<WorkerJob>, done_rx: &Receiver<AnalyzeDone>) {
loop {
select! {
recv(job_rx) -> job => {
let Ok(job) = job else { break }; self.handle_job(job);
while let Ok(j) = job_rx.try_recv() {
self.handle_job(j);
}
self.try_dispatch();
}
recv(done_rx) -> done => {
let Ok(done) = done else { continue };
if matches!(&self.inflight, Some(f) if f.uri == done.uri && f.version == done.version)
{
self.inflight = None;
}
self.try_dispatch();
}
}
}
}
fn handle_job(&mut self, job: WorkerJob) {
match job {
WorkerJob::Edit {
uri,
path,
text,
version,
} => {
self.db.upsert_file(&path, text);
self.enqueue(AnalyzeRequest { uri, path, version });
}
WorkerJob::Close { path } => {
self.db.remove_file(&path);
}
WorkerJob::Format {
id,
path,
text,
style,
} => {
let snapshot = self.db.snapshot();
let out_tx = self.out_tx.clone();
self.read_spawner
.spawn(move || run_format(&snapshot, id, &path, &text, style, &out_tx));
}
}
}
fn enqueue(&mut self, req: AnalyzeRequest) {
match self.pending.get(&req.uri) {
Some(existing) if existing.version >= req.version => {}
_ => {
self.pending.insert(req.uri.clone(), req);
}
}
}
fn try_dispatch(&mut self) {
let versions: HashMap<Uri, i32> = self
.pending
.iter()
.map(|(uri, req)| (uri.clone(), req.version))
.collect();
let inflight = self.inflight.as_ref().map(|f| (&f.uri, f.version));
let uri = match decide(inflight, &versions) {
DispatchAction::Wait => return,
DispatchAction::Start(uri) => uri,
DispatchAction::SupersedeAndStart(uri) => {
self.db.trigger_cancellation();
self.inflight = None;
uri
}
};
let Some(req) = self.pending.remove(&uri) else {
return;
};
self.start_analyze(req);
}
fn start_analyze(&mut self, req: AnalyzeRequest) {
let snapshot = self.db.snapshot();
let out_tx = self.out_tx.clone();
let done_tx = self.done_tx.clone();
let AnalyzeRequest { uri, path, version } = req;
self.inflight = Some(InflightAnalyze {
uri: uri.clone(),
version,
});
self.read_spawner.spawn(move || {
let result = salsa::Cancelled::catch(AssertUnwindSafe(|| {
let file = snapshot.lookup_file(&path)?;
let text = snapshot.file_text(file).to_owned();
let idx = LineIndex::new(&text);
let diags: Vec<Diagnostic> = snapshot
.parse_diagnostics(file)
.iter()
.map(|d| Diagnostic {
range: byte_range_to_lsp(&idx, &text, d.start, d.end),
severity: Some(DiagnosticSeverity::ERROR),
source: Some("badness".to_owned()),
message: d.message.clone(),
..Default::default()
})
.collect();
Some(diags)
}));
if let Ok(Some(diags)) = result {
let _ = out_tx.send(Outbound::Diagnostics {
uri: uri.clone(),
version,
diags,
});
}
drop(snapshot);
let _ = done_tx.send(AnalyzeDone { uri, version });
});
}
}
fn run_format(
snapshot: &Analysis,
id: RequestId,
path: &Path,
text: &str,
style: FormatStyle,
out_tx: &Sender<Outbound>,
) {
let result = match compute_format(snapshot, path, text, style) {
Some(edit) => serde_json::to_value(vec![edit]).unwrap_or(serde_json::Value::Null),
None => serde_json::Value::Null,
};
let _ = out_tx.send(Outbound::Response(Response::new_ok(id, result)));
}
fn compute_format(
snapshot: &Analysis,
path: &Path,
text: &str,
style: FormatStyle,
) -> Option<TextEdit> {
let cached = salsa::Cancelled::catch(AssertUnwindSafe(|| {
let file = snapshot.lookup_file(path)?;
if snapshot.file_text(file) != text {
return None;
}
if !snapshot.parse_diagnostics(file).is_empty() {
return Some(None);
}
let root = snapshot.parsed_tree(file);
Some(format_node(&root, style).ok())
}));
let formatted = match cached {
Ok(Some(opt)) => opt,
Ok(None) | Err(_) => format_with_style(text, style).ok(),
}?;
if formatted == text {
return None;
}
let idx = LineIndex::new(text);
let (end_line, end_col) = idx.utf16_position(text, text.len());
Some(TextEdit {
range: Range {
start: Position::new(0, 0),
end: Position::new(end_line, end_col),
},
new_text: formatted,
})
}
fn send_diagnostics(
connection: &Connection,
uri: Uri,
diagnostics: Vec<Diagnostic>,
version: Option<i32>,
) {
let params = PublishDiagnosticsParams {
uri,
diagnostics,
version,
};
let not = Notification::new(PublishDiagnostics::METHOD.to_owned(), params);
let _ = connection.sender.send(Message::Notification(not));
}
fn respond_unhandled(connection: &Connection, req: Request) {
let resp = Response::new_err(
req.id,
ErrorCode::MethodNotFound as i32,
format!("unhandled request: {}", req.method),
);
let _ = connection.sender.send(Message::Response(resp));
}
fn byte_range_to_lsp(idx: &LineIndex, text: &str, start: usize, end: usize) -> Range {
let (sl, sc) = idx.utf16_position(text, start);
let (el, ec) = idx.utf16_position(text, end);
Range {
start: Position::new(sl, sc),
end: Position::new(el, ec),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn uri(s: &str) -> Uri {
s.parse().unwrap()
}
#[test]
fn decide_starts_when_idle() {
let mut pending = HashMap::new();
pending.insert(uri("file:///a.tex"), 1);
assert_eq!(
decide(None, &pending),
DispatchAction::Start(uri("file:///a.tex"))
);
}
#[test]
fn decide_waits_when_idle_and_empty() {
assert_eq!(decide(None, &HashMap::new()), DispatchAction::Wait);
}
#[test]
fn decide_supersedes_only_on_newer_same_uri() {
let a = uri("file:///a.tex");
let mut pending = HashMap::new();
pending.insert(a.clone(), 5);
assert_eq!(
decide(Some((&a, 3)), &pending),
DispatchAction::SupersedeAndStart(a.clone())
);
assert_eq!(decide(Some((&a, 5)), &pending), DispatchAction::Wait);
}
#[test]
fn decide_never_cancels_inflight_for_a_different_uri() {
let a = uri("file:///a.tex");
let b = uri("file:///b.tex");
let mut pending = HashMap::new();
pending.insert(b, 9);
assert_eq!(decide(Some((&a, 1)), &pending), DispatchAction::Wait);
}
#[test]
fn apply_content_changes_splices_ranged_edit() {
let mut text = "hello world\n".to_owned();
let change = TextDocumentContentChangeEvent {
range: Some(Range {
start: Position::new(0, 6),
end: Position::new(0, 11),
}),
range_length: None,
text: "there".to_owned(),
};
apply_content_changes(&mut text, vec![change]);
assert_eq!(text, "hello there\n");
}
#[test]
fn apply_content_changes_full_replace_on_no_range() {
let mut text = "old".to_owned();
let change = TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: "new".to_owned(),
};
apply_content_changes(&mut text, vec![change]);
assert_eq!(text, "new");
}
#[test]
fn editor_settings_namespaced_and_bare() {
let bare = serde_json::json!({ "lineWidth": 100, "indentWidth": 4 });
let s = EditorSettings::from_client_value(&bare);
assert_eq!(s.line_width, Some(100));
assert_eq!(s.indent_width, Some(4));
let style = s.to_format_style();
assert_eq!(style.line_width, 100);
assert_eq!(style.indent_width, 4);
let namespaced = serde_json::json!({ "badness": { "lineWidth": 72 } });
let s = EditorSettings::from_client_value(&namespaced);
assert_eq!(s.line_width, Some(72));
assert_eq!(s.indent_width, None);
}
}