1use std::{
3 future::Future,
4 sync::{
5 atomic::{AtomicU64, Ordering},
6 Arc,
7 },
8};
9
10use futures::{
11 channel::{
12 mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future,
16 io::{AsyncRead, AsyncReadExt, AsyncWrite},
17 lock::Mutex,
18 sink::SinkExt,
19 stream::StreamExt,
20 TryFutureExt,
21};
22
23use crate::{
24 create::Spawner,
25 error::{CallError, DecodeError, EncodeError, HandshakeError, LoopError},
26 rpc::{
27 handler::Handler,
28 model,
29 model::{IntoVal, RpcMessage},
30 },
31 uioptions::UiAttachOptions,
32};
33use rmpv::Value;
34
35#[macro_export]
38macro_rules! call_args {
39 () => (Vec::new());
40 ($($e:expr), +,) => (call_args![$($e),*]);
41 ($($e:expr), +) => {{
42 let vec = vec![
43 $($e.into_val(),)*
44 ];
45 vec
46 }};
47}
48
49type ResponseResult = Result<Result<Value, Value>, Arc<DecodeError>>;
50
51type Queue = Arc<Mutex<Vec<(u64, oneshot::Sender<ResponseResult>)>>>;
52
53pub struct Neovim<W>
55where
56 W: AsyncWrite + Send + Unpin + 'static,
57{
58 pub(crate) writer: Arc<Mutex<W>>,
59 pub(crate) queue: Queue,
60 pub(crate) msgid_counter: Arc<AtomicU64>,
61}
62
63impl<W> Clone for Neovim<W>
64where
65 W: AsyncWrite + Send + Unpin + 'static,
66{
67 fn clone(&self) -> Self {
68 Neovim {
69 writer: self.writer.clone(),
70 queue: self.queue.clone(),
71 msgid_counter: self.msgid_counter.clone(),
72 }
73 }
74}
75
76impl<W> PartialEq for Neovim<W>
77where
78 W: AsyncWrite + Send + Unpin + 'static,
79{
80 fn eq(&self, other: &Self) -> bool {
81 Arc::ptr_eq(&self.writer, &other.writer)
82 }
83}
84impl<W> Eq for Neovim<W> where W: AsyncWrite + Send + Unpin + 'static {}
85
86impl<W> Neovim<W>
87where
88 W: AsyncWrite + Send + Unpin + 'static,
89{
90 #[allow(clippy::new_ret_no_self)]
91 pub fn new<H, R>(
92 reader: R,
93 writer: W,
94 handler: H,
95 ) -> (
96 Neovim<<H as Handler>::Writer>,
97 impl Future<Output = Result<(), Box<LoopError>>>,
98 )
99 where
100 R: AsyncRead + Send + Unpin + 'static,
101 H: Handler<Writer = W> + Spawner,
102 {
103 let req = Neovim {
104 writer: Arc::new(Mutex::new(writer)),
105 msgid_counter: Arc::new(AtomicU64::new(0)),
106 queue: Arc::new(Mutex::new(Vec::new())),
107 };
108
109 let (sender, receiver) = unbounded();
110 let fut = future::try_join(
111 req.clone().io_loop(reader, sender),
112 req.clone().handler_loop(handler, receiver),
113 )
114 .map_ok(|_| ());
115
116 (req, fut)
117 }
118
119 pub async fn handshake<H, R>(
128 mut reader: R,
129 writer: W,
130 handler: H,
131 message: &str,
132 ) -> Result<
133 (
134 Neovim<<H as Handler>::Writer>,
135 impl Future<Output = Result<(), Box<LoopError>>>,
136 ),
137 Box<HandshakeError>,
138 >
139 where
140 R: AsyncRead + Send + Unpin + 'static,
141 H: Handler<Writer = W> + Spawner,
142 {
143 let instance = Neovim {
144 writer: Arc::new(Mutex::new(writer)),
145 msgid_counter: Arc::new(AtomicU64::new(0)),
146 queue: Arc::new(Mutex::new(Vec::new())),
147 };
148
149 let msgid = instance.msgid_counter.fetch_add(1, Ordering::SeqCst);
150 let msg_len = message.len();
153 assert!(
154 !(20..=31).contains(&msg_len),
155 "The message should be less than 20 characters or more than 31 characters
156 long, but the length is {msg_len}."
157 );
158
159 let req = RpcMessage::RpcRequest {
160 msgid,
161 method: "nvim_exec_lua".to_owned(),
162 params: call_args![format!("return '{message}'"), Vec::<Value>::new()],
163 };
164 model::encode(instance.writer.clone(), req).await?;
165
166 let expected_resp = RpcMessage::RpcResponse {
167 msgid,
168 error: rmpv::Value::Nil,
169 result: rmpv::Value::String(message.into()),
170 };
171 let mut expected_data = Vec::new();
172 model::encode_sync(&mut expected_data, expected_resp)
173 .expect("Encoding static data can't fail");
174 let mut actual_data = Vec::new();
175 let mut start = 0;
176 let mut end = 0;
177 while end - start != expected_data.len() {
178 actual_data.resize(start + expected_data.len(), 0);
179
180 let bytes_read =
181 reader
182 .read(&mut actual_data[start..])
183 .await
184 .map_err(|err| {
185 (
186 err,
187 String::from_utf8_lossy(&actual_data[..end]).to_string(),
188 )
189 })?;
190 if bytes_read == 0 {
191 return Err(Box::new(HandshakeError::UnexpectedResponse(
194 String::from_utf8_lossy(&actual_data[..end]).to_string(),
195 )));
196 }
197 end += bytes_read;
198 while end - start > 0 {
199 if actual_data[start..end] == expected_data[..end - start] {
200 break;
201 }
202 start += 1;
203 }
204 }
205
206 let (sender, receiver) = unbounded();
207 let fut = future::try_join(
208 instance.clone().io_loop(reader, sender),
209 instance.clone().handler_loop(handler, receiver),
210 )
211 .map_ok(|_| ());
212
213 Ok((instance, fut))
214 }
215
216 async fn send_msg(
217 &self,
218 method: &str,
219 args: Vec<Value>,
220 ) -> Result<oneshot::Receiver<ResponseResult>, Box<EncodeError>> {
221 let msgid = self.msgid_counter.fetch_add(1, Ordering::SeqCst);
222
223 let req = RpcMessage::RpcRequest {
224 msgid,
225 method: method.to_owned(),
226 params: args,
227 };
228
229 let (sender, receiver) = oneshot::channel();
230
231 self.queue.lock().await.push((msgid, sender));
232
233 let writer = self.writer.clone();
234 model::encode(writer, req).await?;
235
236 Ok(receiver)
237 }
238
239 pub async fn call(
240 &self,
241 method: &str,
242 args: Vec<Value>,
243 ) -> Result<Result<Value, Value>, Box<CallError>> {
244 let receiver = self
245 .send_msg(method, args)
246 .await
247 .map_err(|e| CallError::SendError(*e, method.to_string()))?;
248
249 match receiver.await {
250 Ok(Ok(r)) => Ok(r), Ok(Err(err)) => {
253 Err(Box::new(CallError::DecodeError(err, method.to_string())))
255 }
256 Err(err) => {
257 Err(Box::new(CallError::InternalReceiveError(
259 err,
260 method.to_string(),
261 )))
262 }
263 }
264 }
265
266 async fn send_error_to_callers(
267 &self,
268 queue: &Queue,
269 err: DecodeError,
270 ) -> Result<Arc<DecodeError>, Box<LoopError>> {
271 let err = Arc::new(err);
272 let mut v: Vec<u64> = vec![];
273
274 let mut queue = queue.lock().await;
275 queue.drain(0..).for_each(|sender| {
276 let msgid = sender.0;
277 sender
278 .1
279 .send(Err(err.clone()))
280 .unwrap_or_else(|_| v.push(msgid));
281 });
282
283 if v.is_empty() {
284 Ok(err)
285 } else {
286 Err((err, v).into())
287 }
288 }
289
290 async fn handler_loop<H>(
291 self,
292 handler: H,
293 mut receiver: UnboundedReceiver<RpcMessage>,
294 ) -> Result<(), Box<LoopError>>
295 where
296 H: Handler<Writer = W> + Spawner,
297 {
298 loop {
299 let msg = match receiver.next().await {
300 Some(msg) => msg,
301 None => break Ok(()),
306 };
307
308 match msg {
309 RpcMessage::RpcRequest {
310 msgid,
311 method,
312 params,
313 } => {
314 let handler_c = handler.clone();
315 let neovim = self.clone();
316 let writer = self.writer.clone();
317
318 handler.spawn(async move {
319 let response = match handler_c
320 .handle_request(method, params, neovim)
321 .await
322 {
323 Ok(result) => RpcMessage::RpcResponse {
324 msgid,
325 result,
326 error: Value::Nil,
327 },
328 Err(error) => RpcMessage::RpcResponse {
329 msgid,
330 result: Value::Nil,
331 error,
332 },
333 };
334
335 model::encode(writer, response)
336 .await
337 .unwrap_or_else(|e| {
338 error!("Error sending response to request {}: '{}'", msgid, e);
339 });
340 });
341 },
342 RpcMessage::RpcNotification {
343 method,
344 params
345 } => handler.handle_notify(method, params, self.clone()).await,
346 RpcMessage::RpcResponse { .. } => unreachable!(),
347 }
348 }
349 }
350
351 async fn io_loop<R>(
352 self,
353 mut reader: R,
354 mut sender: UnboundedSender<RpcMessage>,
355 ) -> Result<(), Box<LoopError>>
356 where
357 R: AsyncRead + Send + Unpin + 'static,
358 {
359 let mut rest: Vec<u8> = vec![];
360
361 loop {
362 let msg = match model::decode(&mut reader, &mut rest).await {
363 Ok(msg) => msg,
364 Err(err) => {
365 let e = self.send_error_to_callers(&self.queue, *err).await?;
366 return Err(Box::new(LoopError::DecodeError(e, None)));
367 }
368 };
369
370 debug!("Get message {:?}", msg);
371 if let RpcMessage::RpcResponse { msgid, result, error, } = msg {
372 let sender = find_sender(&self.queue, msgid).await?;
373 if error == Value::Nil {
374 sender
375 .send(Ok(Ok(result)))
376 .map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
377 } else {
378 sender
379 .send(Ok(Err(error)))
380 .map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
381 }
382 } else {
383 sender.send(msg).await.unwrap();
385 }
386 }
387 }
388
389 pub async fn ui_attach(
393 &self,
394 width: i64,
395 height: i64,
396 opts: &UiAttachOptions,
397 ) -> Result<(), Box<CallError>> {
398 self
399 .call(
400 "nvim_ui_attach",
401 call_args!(width, height, opts.to_value_map()),
402 )
403 .await?
404 .map(|_| Ok(()))?
405 }
406
407 pub async fn quit_no_save(&self) -> Result<(), Box<CallError>> {
411 self.command("qa!").await
412 }
413}
414
415async fn find_sender(
420 queue: &Queue,
421 msgid: u64,
422) -> Result<oneshot::Sender<ResponseResult>, Box<LoopError>> {
423 let mut queue = queue.lock().await;
424
425 let pos = match queue.iter().position(|req| req.0 == msgid) {
426 Some(p) => p,
427 None => return Err(msgid.into()),
428 };
429 Ok(queue.remove(pos).1)
430}
431
432#[cfg(all(test, feature = "use_tokio"))]
433mod tests {
434 use super::*;
435
436 #[tokio::test]
437 async fn test_find_sender() {
438 let queue = Arc::new(Mutex::new(Vec::new()));
439
440 {
441 let (sender, _receiver) = oneshot::channel();
442 queue.lock().await.push((1, sender));
443 }
444 {
445 let (sender, _receiver) = oneshot::channel();
446 queue.lock().await.push((2, sender));
447 }
448 {
449 let (sender, _receiver) = oneshot::channel();
450 queue.lock().await.push((3, sender));
451 }
452
453 find_sender(&queue, 1).await.unwrap();
454 assert_eq!(2, queue.lock().await.len());
455 find_sender(&queue, 2).await.unwrap();
456 assert_eq!(1, queue.lock().await.len());
457 find_sender(&queue, 3).await.unwrap();
458 assert!(queue.lock().await.is_empty());
459
460 if let LoopError::MsgidNotFound(17) =
461 *find_sender(&queue, 17).await.unwrap_err()
462 {
463 } else {
464 panic!()
465 }
466 }
467}