mod laburnum_commands;
mod router;
pub use router::RequestRouter;
use {
crate::{
Ident,
Partitions,
TRACER,
connect::ipc::Connection,
protocol::{
jsonrpc::{
self,
Id,
Message,
Response,
},
lsp::LanguageServer,
otel::TraceContext,
},
scheduler::{
Scheduler,
lanes::RPC_LANE_3,
task::{
LaburnumTask,
TaskContext,
},
},
connect::lsp::ClientId,
},
async_channel::Receiver,
std::{
collections::HashSet,
ops::ControlFlow,
sync::{
Arc,
atomic::{
AtomicBool,
Ordering,
},
},
},
};
struct ShutdownFlag(Arc<AtomicBool>);
impl ShutdownFlag {
fn set(&self) {
self.0.store(true, Ordering::SeqCst);
}
pub(crate) fn is_shutdown(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
pub struct RpcTask<P: Partitions, T: LanguageServer<P>> {
state: ServerState,
ctx: TaskContext<P, T>,
conn: Connection,
router: RequestRouter<P, T>,
pending_requests: Arc<parking_lot::Mutex<HashSet<jsonrpc::Id>>>,
response_tx: async_channel::Sender<Response>,
response_rx: Receiver<Response>,
init_request_id: Option<Id>,
shutdown_flag: ShutdownFlag,
}
impl<P: Partitions, T: LanguageServer<P>> RpcTask<P, T> {
pub(crate) fn create(
scheduler: Arc<Scheduler<P, T>>,
conn: Connection,
client_id: ClientId,
server: Arc<T>,
shutdown_flag: Arc<AtomicBool>,
response_capacity: usize,
) -> Arc<LaburnumTask<P, T>> {
let shutdown_flag = ShutdownFlag(shutdown_flag);
LaburnumTask::new_with_parent(
scheduler,
move |ctx| {
Box::pin(async move {
let (response_tx, response_rx) =
async_channel::bounded(response_capacity);
let mut task = RpcTask {
state: ServerState::new(),
ctx,
conn,
router: RequestRouter::new(server),
pending_requests: Arc::new(parking_lot::Mutex::new(HashSet::new())),
response_tx,
response_rx,
init_request_id: None,
shutdown_flag,
};
loop {
if task.shutdown_flag.is_shutdown() {
break;
}
enum Event {
Message(Message),
Response(Response),
Timeout,
Closed,
}
let event = futures_lite::future::or(
futures_lite::future::or(
async {
match task.conn.receiver.recv().await {
| Ok(msg) => Event::Message(msg),
| Err(_) => Event::Closed,
}
},
async {
match task.response_rx.recv().await {
| Ok(resp) => Event::Response(resp),
| Err(_) => Event::Closed,
}
},
),
async {
smol::Timer::after(std::time::Duration::from_millis(100)).await;
Event::Timeout
},
)
.await;
match event {
| Event::Message(message) => {
if task.handle_message(message).is_break() {
break;
}
},
| Event::Response(response) => {
task.handle_response(response);
},
| Event::Timeout => {},
| Event::Closed => break,
}
}
None
})
},
RPC_LANE_3,
None,
client_id,
)
}
}
impl<P: Partitions, T: LanguageServer<P>> RpcTask<P, T> {
fn is_initialize_request(method: &str) -> bool {
method == "initialize"
}
fn is_shutdown_request(method: &str) -> bool {
method == "shutdown"
}
fn is_initialized_notification(method: &str) -> bool {
method == "initialized"
}
fn is_exit_notification(method: &str) -> bool {
method == "exit"
}
fn is_execute_command_request(method: &str) -> bool {
method == "workspace/executeCommand"
}
fn send_error_response(&self, id: jsonrpc::Id, error: jsonrpc::Error) {
use opentelemetry::trace::SpanKind;
otel::span!("rpc.send_error_response", kind = SpanKind::Producer);
let response = Response::from_error(id.clone(), error);
if let Err(e) = self.conn.sender.try_send(Message::Response(response)) {
otel::error!(
"error_response_send_failed",
format!("Failed to send error response for request {}: {:?}", id, e)
);
}
}
fn handle_message(&mut self, message: Message) -> ControlFlow<()> {
match message {
| Message::Request(request) => {
self.handle_request(request);
ControlFlow::Continue(())
},
| Message::Notification(notification) => {
self.handle_notification(notification)
},
| Message::Response(_) => ControlFlow::Continue(()),
}
}
fn handle_request(&mut self, request: jsonrpc::Request) {
let trace_ctx = request.trace_context();
let _guard = trace_ctx.attach();
use opentelemetry::trace::SpanKind;
otel::span!("rpc.receive_request", kind = SpanKind::Consumer);
let msg_client_id = request.client_id();
if let Some(id) = msg_client_id
&& id != self.ctx.client_id()
{
let request_id = request.id().clone();
self.send_error_response(request_id, jsonrpc::Error {
code: jsonrpc::ErrorCode::InvalidRequest,
message: "client_id mismatch".into(),
data: None,
});
return;
}
let task_client_id = msg_client_id.unwrap_or_else(|| self.ctx.client_id());
let (method, request_id, params) = request.into_parts();
eprintln!("Received request: method={}, id={}", method, request_id);
let current_state = self.state.get();
if Self::is_initialize_request(&method) {
if current_state != State::Uninitialized {
self.send_error_response(request_id, jsonrpc::Error {
code: jsonrpc::ErrorCode::InvalidRequest,
message: "Server already initialized".into(),
data: None,
});
return;
}
self.state.set(State::Initializing);
self.init_request_id = Some(request_id.clone());
}
if Self::is_shutdown_request(&method) {
self.handle_shutdown(request_id);
return;
}
if current_state != State::Initialized
&& !Self::is_initialize_request(&method)
&& !Self::is_execute_command_request(&method)
{
self.send_error_response(
request_id,
crate::protocol::jsonrpc::error::not_initialized_error(),
);
return;
}
self.pending_requests.lock().insert(request_id.clone());
#[cfg(feature = "testing-commands")]
if let Some(params_value) = ¶ms
&& self.try_handle_laburnum_command(
&method,
params_value,
request_id.clone(),
)
{
return;
}
let (handler, method_lane) = self.router.get_request_handler(&method);
let tx = self.response_tx.clone();
let request_trace_context = TraceContext::from_current_span();
self.ctx.spawn_task_for_client(
move |ctx| {
use opentelemetry::trace::FutureExt as _;
let parent_cx = {
use opentelemetry::global;
let mut map = std::collections::HashMap::new();
if let Some(tp) = &request_trace_context.traceparent {
map.insert("traceparent".to_string(), tp.clone());
}
if let Some(ts) = &request_trace_context.tracestate {
map.insert("tracestate".to_string(), ts.clone());
}
global::get_text_map_propagator(|propagator| {
propagator
.extract(&crate::protocol::otel::JsonRpcExtractor::new(&map))
})
};
async move {
let writer = ctx
.new_record_writer(Ident::new(&format!("request({})", request_id)));
let (response, writer) = handler(
method.to_string(),
request_id.clone(),
params,
ctx,
writer,
)
.await;
if let Some(response) = response
&& let Err(e) = tx.send(response).await
{
otel::error!(
"response_channel_send_failed",
format!(
"Failed to send response to RPC task for request {}: {:?}",
request_id, e
)
);
}
Some(writer)
}
.with_context(parent_cx)
},
method_lane,
task_client_id,
);
}
fn handle_response(&mut self, response: jsonrpc::Response) {
let request_id = response.id().clone();
if let Some(init_id) = &self.init_request_id
&& *init_id == request_id
{
self.init_request_id = None;
if response.error().is_some() {
self.state.set(State::Uninitialized);
} else {
self.state.set(State::Initialized);
}
}
self.pending_requests.lock().remove(&request_id);
use opentelemetry::trace::SpanKind;
otel::span!("rpc.send_response", kind = SpanKind::Producer);
if let Err(e) = self.conn.sender.try_send(Message::Response(response)) {
otel::error!(
"response_send_failed",
format!(
"Failed to send response for request {}: {:?}",
request_id, e
)
);
}
}
fn handle_notification(
&mut self,
notification: jsonrpc::Notification,
) -> ControlFlow<()> {
let trace_ctx = notification.trace_context();
let _guard = trace_ctx.attach();
use opentelemetry::trace::SpanKind;
otel::span!("rpc.receive_notification", kind = SpanKind::Consumer);
let msg_client_id = notification.client_id();
if let Some(id) = msg_client_id
&& id != self.ctx.client_id()
{
otel::error!(
"notification_client_id_mismatch",
format!(
"Notification has client_id {:?} but connection has {:?}",
id,
self.ctx.client_id()
)
);
return ControlFlow::Continue(());
}
let task_client_id = msg_client_id.unwrap_or_else(|| self.ctx.client_id());
let (method, params) = notification.into_parts();
if Self::is_initialized_notification(&method) {
return ControlFlow::Continue(());
}
if Self::is_exit_notification(&method) {
self.state.set(State::Exited);
self.shutdown_flag.set();
return ControlFlow::Break(());
}
let (handler, method_lane) = self.router.get_notification_handler(&method);
let notification_trace_context = TraceContext::from_current_span();
self.ctx.spawn_task_for_client(
move |ctx| {
use opentelemetry::trace::FutureExt as _;
let parent_cx = {
use opentelemetry::global;
let mut map = std::collections::HashMap::new();
if let Some(tp) = ¬ification_trace_context.traceparent {
map.insert("traceparent".to_string(), tp.clone());
}
if let Some(ts) = ¬ification_trace_context.tracestate {
map.insert("tracestate".to_string(), ts.clone());
}
global::get_text_map_propagator(|propagator| {
propagator
.extract(&crate::protocol::otel::JsonRpcExtractor::new(&map))
})
};
async move {
let writer = ctx.new_record_writer(Ident::new(&format!(
"notification({})",
method
)));
let writer = handler(method.to_string(), params, ctx, writer).await;
Some(writer)
}
.with_context(parent_cx)
},
method_lane,
task_client_id,
);
ControlFlow::Continue(())
}
fn handle_shutdown(&mut self, request_id: jsonrpc::Id) {
let current_state = self.state.get();
if current_state != State::Initialized {
self.send_error_response(request_id, jsonrpc::Error {
code: jsonrpc::ErrorCode::InvalidRequest,
message: "Server is shutting down".into(),
data: None,
});
return;
}
self.state.set(State::ShutDown);
self.shutdown_flag.set();
let response =
Response::from_ok(request_id.clone(), serde_json::Value::Null);
if let Err(e) = self.conn.sender.try_send(Message::Response(response)) {
otel::error!(
"shutdown_response_send_failed",
format!("Failed to send shutdown response: {:?}", e)
);
}
}
}
use std::{
fmt::{
self,
Debug,
Formatter,
},
sync::atomic::AtomicU8,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum State {
Uninitialized = 0,
Initializing = 1,
Initialized = 2,
ShutDown = 3,
Exited = 4,
}
#[derive(Clone)]
pub struct ServerState(Arc<AtomicU8>);
impl ServerState {
pub fn new() -> Self {
ServerState(Arc::new(AtomicU8::new(State::Uninitialized as u8)))
}
pub fn set(&self, state: State) {
self.0.store(state as u8, Ordering::SeqCst);
}
pub fn get(&self) -> State {
match self.0.load(Ordering::SeqCst) {
| 0 => State::Uninitialized,
| 1 => State::Initializing,
| 2 => State::Initialized,
| 3 => State::ShutDown,
| 4 => State::Exited,
| _ => State::Uninitialized,
}
}
}
impl Debug for ServerState {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.get().fmt(f)
}
}