1mod cancellation;
88mod message;
89
90pub use message::NotificationMessage;
91use tracing::warn;
92
93use std::{
94 collections::HashMap,
95 ffi::OsStr,
96 process::Stdio,
97 sync::{
98 atomic::{AtomicI64, Ordering},
99 Arc,
100 },
101};
102
103use cancellation::CancellationToken;
104use message::{send_message, Message};
105use serde_json::json;
106use tokio::{
107 process::{ChildStdin, ChildStdout, Command},
108 sync::{
109 mpsc::{self, Receiver, Sender},
110 Mutex, Notify,
111 },
112};
113use tower_lsp::{
114 jsonrpc::{self, Id, Request, Response},
115 lsp_types::{
116 self,
117 notification::{Exit, Initialized},
118 request::{Initialize, Shutdown},
119 InitializeParams, InitializeResult, InitializedParams,
120 },
121};
122
123pub struct LspServer {
124 count: AtomicI64,
125 state: Arc<Mutex<ClientState>>,
126 stdin: Arc<Mutex<ChildStdin>>,
127 channel_map: Arc<Mutex<HashMap<Id, Sender<Response>>>>,
128}
129
130impl Clone for LspServer {
131 fn clone(&self) -> Self {
132 Self {
133 count: AtomicI64::new(self.count.load(Ordering::Relaxed)),
134 state: self.state.clone(),
135 stdin: self.stdin.clone(),
136 channel_map: self.channel_map.clone(),
137 }
138 }
139}
140
141impl LspServer {
142 pub fn new<S, I>(program: S, args: I) -> (LspServer, Receiver<ServerMessage>)
143 where
144 S: AsRef<OsStr>,
145 I: IntoIterator<Item = S> + Clone,
146 {
147 let child = match Command::new(program)
148 .args(args.clone())
149 .stdin(Stdio::piped())
150 .stdout(Stdio::piped())
151 .spawn()
152 {
153 Err(err) => panic!(
154 "Couldn't spawn: {:?} in {:?}",
155 err,
156 args.into_iter()
157 .map(|v| v.as_ref().to_str().map(|v| v.to_string()))
158 .collect::<Vec<_>>()
159 ),
160 Ok(child) => child,
161 };
162 let stdin = child.stdin.unwrap();
163 let mut stdout = child.stdout.unwrap();
164
165 let channel_map = Arc::new(Mutex::new(HashMap::<Id, Sender<Response>>::new()));
166 let channel_map_ = Arc::clone(&channel_map);
167
168 let (tx, rx) = mpsc::channel(16);
169
170 tokio::spawn(async move { message_loop(&mut stdout, channel_map_, tx).await });
171 (
172 LspServer {
173 count: AtomicI64::new(0),
174 state: Arc::new(Mutex::new(ClientState::Uninitialized)),
175 stdin: Arc::new(Mutex::new(stdin)),
176 channel_map,
177 },
178 rx,
179 )
180 }
181
182 pub async fn initialize(
183 &self,
184 params: InitializeParams,
185 ) -> Result<InitializeResult, jsonrpc::Error> {
186 *self.state.lock().await = ClientState::Initializing;
187 let initialize_result = self.send_request::<Initialize>(params).await;
188 initialize_result
189 }
190
191 pub async fn initialized(&self) {
192 self.send_notification::<Initialized>(InitializedParams {})
193 .await;
194 *self.state.lock().await = ClientState::Initialized;
195 }
196
197 pub async fn send_request<R>(&self, params: R::Params) -> Result<R::Result, jsonrpc::Error>
198 where
199 R: lsp_types::request::Request,
200 {
201 let id = {
202 self.count.fetch_add(1, Ordering::Relaxed);
203 self.count.load(Ordering::Relaxed)
204 };
205 {
206 let mut stdin = self.stdin.lock().await;
207 send_message(
208 json!({
209 "jsonrpc": "2.0",
210 "id": id,
211 "method": R::METHOD,
212 "params": params
213 }),
214 &mut stdin,
215 )
216 .await;
217 }
218
219 let notify = Arc::new(Notify::new());
220 let mut token = CancellationToken::new(Arc::clone(¬ify));
221 let stdin = Arc::clone(&self.stdin);
222 let cancel = tokio::spawn(async move {
223 notify.notified().await;
224 let mut stdin = stdin.lock().await;
225 send_message(
226 json!({
227 "jsonrpc": "2.0",
228 "method": "$/cancelRequest",
229 "params": {
230 "id": id,
231 }
232 }),
233 &mut stdin,
234 )
235 .await;
236 });
237
238 let (tx, mut rx) = mpsc::channel::<Response>(1);
239
240 {
241 self.channel_map.lock().await.insert(Id::Number(id), tx);
242 }
243
244 let response = rx.recv().await.unwrap();
245
246 token.finish();
247 cancel.abort();
248
249 if response.is_ok() {
250 Ok(serde_json::from_value(response.result().unwrap().to_owned()).unwrap())
251 } else {
252 Err(response.error().unwrap().to_owned())
253 }
254 }
255
256 pub async fn send_response<R>(&self, id: Id, result: R::Result)
257 where
258 R: lsp_types::request::Request,
259 {
260 let mut stdin = self.stdin.lock().await;
261 send_message(
262 json!({
263 "jsonrpc": "2.0",
264 "id": id,
265 "result": result,
266 }),
267 &mut stdin,
268 )
269 .await;
270 }
271
272 pub async fn send_error_response(&self, id: Id, error: jsonrpc::Error) {
273 let mut stdin = self.stdin.lock().await;
274 send_message(
275 json!({
276 "jsonrpc": "2.0",
277 "id": id,
278 "error": error,
279 }),
280 &mut stdin,
281 )
282 .await;
283 }
284
285 pub async fn send_notification<N>(&self, params: N::Params)
286 where
287 N: lsp_types::notification::Notification,
288 {
289 let mut stdin = self.stdin.lock().await;
290 send_message(
291 json!({
292 "jsonrpc": "2.0",
293 "method": N::METHOD,
294 "params": params
295 }),
296 &mut stdin,
297 )
298 .await;
299 }
300
301 pub async fn shutdown(&self) -> Result<(), jsonrpc::Error> {
302 let result = self.send_request::<Shutdown>(()).await;
303 *self.state.lock().await = ClientState::ShutDown;
304 result
305 }
306
307 pub async fn exit(&self) {
308 self.send_notification::<Exit>(()).await;
309 *self.state.lock().await = ClientState::Exited;
310 }
311}
312
313async fn message_loop(
314 stdout: &mut ChildStdout,
315 channel_map: Arc<Mutex<HashMap<Id, Sender<Response>>>>,
316 tx: Sender<ServerMessage>,
317) {
318 loop {
319 let msg = message::get_message(stdout).await;
320 if let Some(msg) = msg {
321 match msg {
322 Message::Notification(msg) => {
323 tx.send(ServerMessage::Notification(msg)).await.unwrap();
324 }
325 Message::Request(req) => {
326 tx.send(ServerMessage::Request(req)).await.unwrap();
327 }
328 Message::Response(res) => {
329 let mut channel_map = channel_map.lock().await;
330 let id = res.id().clone();
331 if let Some(tx) = channel_map.get(&id) {
332 let result = tx.send(res).await;
333 if let Err(err) = result {
334 if cfg!(feature = "tracing") {
335 warn!("send error: {:?}", err);
336 }
337 }
338 channel_map.remove(&id);
339 }
340 }
341 }
342 } else {
343 break;
344 }
345 }
346}
347
348#[derive(Clone, Copy)]
349enum ClientState {
350 Uninitialized = 0,
352 Initializing = 1,
354 Initialized = 2,
356 ShutDown = 3,
358 Exited = 4,
360}
361
362pub enum ServerMessage {
363 Request(Request),
364 Notification(NotificationMessage),
365}