1mod error;
9mod io_thread;
10mod message;
11mod request;
13mod response;
14mod transporter;
15
16mod notification;
17#[cfg(test)]
18mod tests;
19pub use bsp_types as types;
20pub use error::{ErrorCode, ExtractError, ProtocolError};
21pub use io_thread::IoThreads;
22pub use message::Message;
23pub use notification::Notification;
24pub use request::{Request, RequestId};
26pub use response::{Response, ResponseError};
27pub(crate) use transporter::Transporter;
28
29use bsp_types::InitializeBuild;
30use crossbeam_channel::{unbounded, Receiver, SendError, SendTimeoutError, Sender, TrySendError};
31use serde::Serialize;
32use std::io;
33use std::net::{TcpListener, TcpStream, ToSocketAddrs};
34use std::time::{Duration, Instant};
35
36pub struct Connection {
38 pub sender: Sender<Message>,
39 pub receiver: Receiver<Message>,
40}
41
42impl Connection {
43 pub fn stdio() -> (Connection, IoThreads) {
47 let Transporter(sender, receiver, io_threads) = Transporter::stdio();
48 (Connection { sender, receiver }, io_threads)
49 }
50
51 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<(Connection, IoThreads)> {
54 let stream = TcpStream::connect(addr)?;
55 let Transporter(sender, receiver, io_threads) = Transporter::socket(stream);
56 Ok((Connection { sender, receiver }, io_threads))
57 }
58
59 pub fn listen<A: ToSocketAddrs>(addr: A) -> io::Result<(Connection, IoThreads)> {
62 let listener = TcpListener::bind(addr)?;
63 let (stream, _) = listener.accept()?;
64 let Transporter(sender, receiver, io_threads) = Transporter::socket(stream);
65 Ok((Connection { sender, receiver }, io_threads))
66 }
67
68 pub fn memory() -> (Connection, Connection) {
70 let ((s1, r1), (s2, r2)) = (unbounded(), unbounded());
71 (
72 Connection {
73 sender: s1,
74 receiver: r2,
75 },
76 Connection {
77 sender: s2,
78 receiver: r1,
79 },
80 )
81 }
82
83 #[tracing::instrument(skip_all)]
114 pub fn initialize<V: Serialize>(
115 &self,
116 process: impl FnOnce(&InitializeBuild) -> V,
117 ) -> Result<InitializeBuild, ProtocolError> {
118 let (id, params) = self.initialize_start()?;
119 self.initialize_finish(id, process(¶ms))?;
120 Ok(params)
121 }
122
123 #[tracing::instrument(skip(self))]
124 fn initialize_start(&self) -> Result<(RequestId, InitializeBuild), ProtocolError> {
125 loop {
126 match self.receiver.recv() {
127 Ok(Message::Request(Request::InitializeBuild(id, params))) => {
128 return Ok((id, params));
129 }
130 Ok(Message::Request(req)) => {
132 let msg = format!("expected initialize request, got {:?}", req);
133 tracing::error!("{}", msg);
134 self.sender
135 .send(Response::server_not_initialized(req.id().clone(), msg).into())
136 .unwrap();
137 }
138 Ok(msg) => {
139 let msg = format!("expected initialize request, got {:?}", msg);
140 tracing::error!("{}", msg);
141 return Err(ProtocolError(msg));
142 }
143 Err(e) => {
144 let msg = format!("expected initialize request, got error: {}", e);
145 tracing::error!("{}", msg);
146 return Err(ProtocolError(msg));
147 }
148 };
149 }
150 }
151
152 #[tracing::instrument(skip_all)]
154 fn initialize_finish<V: Serialize>(
155 &self,
156 initialize_id: RequestId,
157 initialize_result: V,
158 ) -> Result<(), ProtocolError> {
159 let resp = Response::ok(initialize_id, initialize_result);
160 self.sender.send(resp.into()).unwrap();
161 match &self.receiver.recv() {
162 Ok(Message::Notification(Notification::Initialized)) => (),
163 Ok(msg) => {
164 let msg = format!("expected Message::Notification, got: {:?}", msg,);
165 tracing::error!("{}", msg);
166 return Err(ProtocolError(msg));
167 }
168 Err(e) => {
169 let msg = format!("expected initialized notification, got error: {}", e,);
170 tracing::error!("{}", msg);
171 return Err(ProtocolError(msg));
172 }
173 }
174 Ok(())
175 }
176
177 pub fn handle_shutdown(&self, req: &Request) -> Result<bool, ProtocolError> {
179 if let Request::Shutdown(id) = req {
180 tracing::info!("processing shutdown server ...");
181 let resp = Response::ok(id.clone(), ());
182 let _ = self.sender.send(resp.into());
183 match &self.receiver.recv_timeout(Duration::from_secs(30)) {
184 Ok(Message::Notification(Notification::Exit)) => (),
185 Ok(msg) => {
186 let msg = format!("unexpected message during shutdown: {:?}", msg);
187 tracing::error!("{}", msg);
188
189 return Err(ProtocolError(msg));
190 }
191 Err(e) => {
192 let msg = format!("unexpected error during shutdown: {}", e);
193 return Err(ProtocolError(msg));
194 }
195 }
196 Ok(true)
197 } else {
198 Ok(false)
199 }
200 }
201
202 pub fn send<T: Into<Message>>(&self, msg: T) -> Result<(), SendError<Message>> {
204 self.sender.send(msg.into())
205 }
206
207 pub fn try_send<T: Into<Message>>(&self, msg: T) -> Result<(), TrySendError<Message>> {
209 self.sender.try_send(msg.into())
210 }
211
212 pub fn send_timeout<T: Into<Message>>(
214 &self,
215 msg: T,
216 timeout: Duration,
217 ) -> Result<(), SendTimeoutError<Message>> {
218 self.sender.send_timeout(msg.into(), timeout)
219 }
220
221 pub fn send_deadline<T: Into<Message>>(
223 &self,
224 msg: T,
225 deadline: Instant,
226 ) -> Result<(), SendTimeoutError<Message>> {
227 self.sender.send_deadline(msg.into(), deadline)
228 }
229}