mod error;
mod io_thread;
mod message;
mod request;
mod response;
mod transporter;
mod notification;
#[cfg(test)]
mod tests;
pub use bsp_types as types;
pub use error::{ErrorCode, ExtractError, ProtocolError};
pub use io_thread::IoThreads;
pub use message::Message;
pub use notification::Notification;
pub use request::{Request, RequestId};
pub use response::{Response, ResponseError};
pub(crate) use transporter::Transporter;
use bsp_types::InitializeBuild;
use crossbeam_channel::{unbounded, Receiver, SendError, SendTimeoutError, Sender, TrySendError};
use serde::Serialize;
use std::io;
use std::net::{TcpListener, TcpStream, ToSocketAddrs};
use std::time::{Duration, Instant};
pub struct Connection {
pub sender: Sender<Message>,
pub receiver: Receiver<Message>,
}
impl Connection {
pub fn stdio() -> (Connection, IoThreads) {
let Transporter(sender, receiver, io_threads) = Transporter::stdio();
(Connection { sender, receiver }, io_threads)
}
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<(Connection, IoThreads)> {
let stream = TcpStream::connect(addr)?;
let Transporter(sender, receiver, io_threads) = Transporter::socket(stream);
Ok((Connection { sender, receiver }, io_threads))
}
pub fn listen<A: ToSocketAddrs>(addr: A) -> io::Result<(Connection, IoThreads)> {
let listener = TcpListener::bind(addr)?;
let (stream, _) = listener.accept()?;
let Transporter(sender, receiver, io_threads) = Transporter::socket(stream);
Ok((Connection { sender, receiver }, io_threads))
}
pub fn memory() -> (Connection, Connection) {
let ((s1, r1), (s2, r2)) = (unbounded(), unbounded());
(
Connection {
sender: s1,
receiver: r2,
},
Connection {
sender: s2,
receiver: r1,
},
)
}
#[tracing::instrument(skip_all)]
pub fn initialize<V: Serialize>(
&self,
process: impl FnOnce(&InitializeBuild) -> V,
) -> Result<InitializeBuild, ProtocolError> {
let (id, params) = self.initialize_start()?;
self.initialize_finish(id, process(¶ms))?;
Ok(params)
}
#[tracing::instrument(skip(self))]
fn initialize_start(&self) -> Result<(RequestId, InitializeBuild), ProtocolError> {
loop {
match self.receiver.recv() {
Ok(Message::Request(Request::InitializeBuild(id, params))) => {
return Ok((id, params));
}
Ok(Message::Request(req)) => {
let msg = format!("expected initialize request, got {:?}", req);
tracing::error!("{}", msg);
self.sender
.send(Response::server_not_initialized(req.id().clone(), msg).into())
.unwrap();
}
Ok(msg) => {
let msg = format!("expected initialize request, got {:?}", msg);
tracing::error!("{}", msg);
return Err(ProtocolError(msg));
}
Err(e) => {
let msg = format!("expected initialize request, got error: {}", e);
tracing::error!("{}", msg);
return Err(ProtocolError(msg));
}
};
}
}
#[tracing::instrument(skip_all)]
fn initialize_finish<V: Serialize>(
&self,
initialize_id: RequestId,
initialize_result: V,
) -> Result<(), ProtocolError> {
let resp = Response::ok(initialize_id, initialize_result);
self.sender.send(resp.into()).unwrap();
match &self.receiver.recv() {
Ok(Message::Notification(Notification::Initialized)) => (),
Ok(msg) => {
let msg = format!("expected Message::Notification, got: {:?}", msg,);
tracing::error!("{}", msg);
return Err(ProtocolError(msg));
}
Err(e) => {
let msg = format!("expected initialized notification, got error: {}", e,);
tracing::error!("{}", msg);
return Err(ProtocolError(msg));
}
}
Ok(())
}
pub fn handle_shutdown(&self, req: &Request) -> Result<bool, ProtocolError> {
if let Request::Shutdown(id) = req {
tracing::info!("processing shutdown server ...");
let resp = Response::ok(id.clone(), ());
let _ = self.sender.send(resp.into());
match &self.receiver.recv_timeout(Duration::from_secs(30)) {
Ok(Message::Notification(Notification::Exit)) => (),
Ok(msg) => {
let msg = format!("unexpected message during shutdown: {:?}", msg);
tracing::error!("{}", msg);
return Err(ProtocolError(msg));
}
Err(e) => {
let msg = format!("unexpected error during shutdown: {}", e);
return Err(ProtocolError(msg));
}
}
Ok(true)
} else {
Ok(false)
}
}
pub fn send<T: Into<Message>>(&self, msg: T) -> Result<(), SendError<Message>> {
self.sender.send(msg.into())
}
pub fn try_send<T: Into<Message>>(&self, msg: T) -> Result<(), TrySendError<Message>> {
self.sender.try_send(msg.into())
}
pub fn send_timeout<T: Into<Message>>(
&self,
msg: T,
timeout: Duration,
) -> Result<(), SendTimeoutError<Message>> {
self.sender.send_timeout(msg.into(), timeout)
}
pub fn send_deadline<T: Into<Message>>(
&self,
msg: T,
deadline: Instant,
) -> Result<(), SendTimeoutError<Message>> {
self.sender.send_deadline(msg.into(), deadline)
}
}