use super::LanguageServerState;
use crate::cancelation::is_canceled;
use crate::from_json;
use crate::state::{LanguageServerSnapshot, Task};
use lsp_server::ExtractError;
use serde::de::DeserializeOwned;
use serde::Serialize;
pub(crate) struct RequestDispatcher<'a> {
state: &'a mut LanguageServerState,
request: Option<lsp_server::Request>,
}
impl<'a> RequestDispatcher<'a> {
pub fn new(state: &'a mut LanguageServerState, request: lsp_server::Request) -> Self {
RequestDispatcher {
state,
request: Some(request),
}
}
pub fn on_sync<R>(
&mut self,
compute_response_fn: fn(&mut LanguageServerState, R::Params) -> anyhow::Result<R::Result>,
) -> anyhow::Result<&mut Self>
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + 'static,
R::Result: Serialize + 'static,
{
let (id, params) = match self.parse::<R>() {
Some(it) => it,
None => return Ok(self),
};
let result = compute_response_fn(self.state, params);
let response = result_to_response::<R>(id, result);
self.state.respond(response);
Ok(self)
}
pub fn on<R>(
&mut self,
compute_response_fn: fn(LanguageServerSnapshot, R::Params) -> anyhow::Result<R::Result>,
) -> anyhow::Result<&mut Self>
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + 'static + Send,
R::Result: Serialize + 'static,
{
let (id, params) = match self.parse::<R>() {
Some(it) => it,
None => return Ok(self),
};
self.state.thread_pool.execute({
let snapshot = self.state.snapshot();
let sender = self.state.task_sender.clone();
move || {
let result = compute_response_fn(snapshot, params);
sender
.send(Task::Response(result_to_response::<R>(id, result)))
.unwrap();
}
});
Ok(self)
}
fn parse<R>(&mut self) -> Option<(lsp_server::RequestId, R::Params)>
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + 'static,
{
let req = match &self.request {
Some(req) if req.method == R::METHOD => self.request.take().unwrap(),
_ => return None,
};
match from_json(R::METHOD, req.params) {
Ok(params) => Some((req.id, params)),
Err(err) => {
let response = lsp_server::Response::new_err(
req.id,
lsp_server::ErrorCode::InvalidParams as i32,
err.to_string(),
);
self.state.respond(response);
None
}
}
}
pub fn finish(&mut self) {
if let Some(req) = self.request.take() {
log::error!("unknown request: {:?}", req);
let response = lsp_server::Response::new_err(
req.id,
lsp_server::ErrorCode::MethodNotFound as i32,
"unknown request".to_string(),
);
self.state.respond(response);
}
}
}
pub(crate) struct NotificationDispatcher<'a> {
state: &'a mut LanguageServerState,
notification: Option<lsp_server::Notification>,
}
impl<'a> NotificationDispatcher<'a> {
pub fn new(state: &'a mut LanguageServerState, notification: lsp_server::Notification) -> Self {
NotificationDispatcher {
state,
notification: Some(notification),
}
}
pub fn on<N>(
&mut self,
handle_notification_fn: fn(&mut LanguageServerState, N::Params) -> anyhow::Result<()>,
) -> anyhow::Result<&mut Self>
where
N: lsp_types::notification::Notification + 'static,
N::Params: DeserializeOwned + Send + 'static,
{
let notification = match self.notification.take() {
Some(it) => it,
None => return Ok(self),
};
let params = match notification.extract::<N::Params>(N::METHOD) {
Ok(it) => it,
Err(ExtractError::JsonError { method, error }) => {
panic!("Invalid request\nMethod: {method}\n error: {error}",)
}
Err(ExtractError::MethodMismatch(notification)) => {
self.notification = Some(notification);
return Ok(self);
}
};
handle_notification_fn(self.state, params)?;
Ok(self)
}
pub fn finish(&mut self) {
if let Some(notification) = &self.notification {
if !notification.method.starts_with("$/") {
log::error!("unhandled notification: {:?}", notification);
}
}
}
}
fn result_to_response<R>(
id: lsp_server::RequestId,
result: anyhow::Result<R::Result>,
) -> lsp_server::Response
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + 'static,
R::Result: Serialize + 'static,
{
match result {
Ok(resp) => lsp_server::Response::new_ok(id, &resp),
Err(e) => {
if is_canceled(&*e) {
lsp_server::Response::new_err(
id,
lsp_server::ErrorCode::ContentModified as i32,
"content modified".to_string(),
)
} else {
lsp_server::Response::new_err(
id,
lsp_server::ErrorCode::InternalError as i32,
e.to_string(),
)
}
}
}
}