nvim_rs/
neovim.rs

1//! An active neovim session.
2use std::{
3  future::Future,
4  sync::{
5    atomic::{AtomicU64, Ordering},
6    Arc,
7  },
8};
9
10use futures::{
11  channel::{
12    mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
13    oneshot,
14  },
15  future,
16  io::{AsyncRead, AsyncReadExt, AsyncWrite},
17  lock::Mutex,
18  sink::SinkExt,
19  stream::StreamExt,
20  TryFutureExt,
21};
22
23use crate::{
24  create::Spawner,
25  error::{CallError, DecodeError, EncodeError, HandshakeError, LoopError},
26  rpc::{
27    handler::Handler,
28    model,
29    model::{IntoVal, RpcMessage},
30  },
31  uioptions::UiAttachOptions,
32};
33use rmpv::Value;
34
35/// Pack the given arguments into a `Vec<Value>`, suitable for using it for a
36/// [`call`](crate::neovim::Neovim::call) to neovim.
37#[macro_export]
38macro_rules! call_args {
39    () => (Vec::new());
40    ($($e:expr), +,) => (call_args![$($e),*]);
41    ($($e:expr), +) => {{
42        let vec = vec![
43          $($e.into_val(),)*
44        ];
45        vec
46    }};
47}
48
49type ResponseResult = Result<Result<Value, Value>, Arc<DecodeError>>;
50
51type Queue = Arc<Mutex<Vec<(u64, oneshot::Sender<ResponseResult>)>>>;
52
53/// An active Neovim session.
54pub struct Neovim<W>
55where
56  W: AsyncWrite + Send + Unpin + 'static,
57{
58  pub(crate) writer: Arc<Mutex<W>>,
59  pub(crate) queue: Queue,
60  pub(crate) msgid_counter: Arc<AtomicU64>,
61}
62
63impl<W> Clone for Neovim<W>
64where
65  W: AsyncWrite + Send + Unpin + 'static,
66{
67  fn clone(&self) -> Self {
68    Neovim {
69      writer: self.writer.clone(),
70      queue: self.queue.clone(),
71      msgid_counter: self.msgid_counter.clone(),
72    }
73  }
74}
75
76impl<W> PartialEq for Neovim<W>
77where
78  W: AsyncWrite + Send + Unpin + 'static,
79{
80  fn eq(&self, other: &Self) -> bool {
81    Arc::ptr_eq(&self.writer, &other.writer)
82  }
83}
84impl<W> Eq for Neovim<W> where W: AsyncWrite + Send + Unpin + 'static {}
85
86impl<W> Neovim<W>
87where
88  W: AsyncWrite + Send + Unpin + 'static,
89{
90  #[allow(clippy::new_ret_no_self)]
91  pub fn new<H, R>(
92    reader: R,
93    writer: W,
94    handler: H,
95  ) -> (
96    Neovim<<H as Handler>::Writer>,
97    impl Future<Output = Result<(), Box<LoopError>>>,
98  )
99  where
100    R: AsyncRead + Send + Unpin + 'static,
101    H: Handler<Writer = W> + Spawner,
102  {
103    let req = Neovim {
104      writer: Arc::new(Mutex::new(writer)),
105      msgid_counter: Arc::new(AtomicU64::new(0)),
106      queue: Arc::new(Mutex::new(Vec::new())),
107    };
108
109    let (sender, receiver) = unbounded();
110    let fut = future::try_join(
111      req.clone().io_loop(reader, sender),
112      req.clone().handler_loop(handler, receiver),
113    )
114    .map_ok(|_| ());
115
116    (req, fut)
117  }
118
119  /// Create a new instance, immediately send a handshake message and
120  /// wait for the response. Unlike `new`, this function is tolerant to extra
121  /// data in the reader before the handshake response is received.
122  ///
123  /// `message` should be a unique string that is normally not found in the
124  /// stdout. Due to the way Neovim packs strings, the length has to be either
125  /// less than 20 characters or more than 31 characters long.
126  /// See https://github.com/neovim/neovim/issues/32784 for more information.
127  pub async fn handshake<H, R>(
128    mut reader: R,
129    writer: W,
130    handler: H,
131    message: &str,
132  ) -> Result<
133    (
134      Neovim<<H as Handler>::Writer>,
135      impl Future<Output = Result<(), Box<LoopError>>>,
136    ),
137    Box<HandshakeError>,
138  >
139  where
140    R: AsyncRead + Send + Unpin + 'static,
141    H: Handler<Writer = W> + Spawner,
142  {
143    let instance = Neovim {
144      writer: Arc::new(Mutex::new(writer)),
145      msgid_counter: Arc::new(AtomicU64::new(0)),
146      queue: Arc::new(Mutex::new(Vec::new())),
147    };
148
149    let msgid = instance.msgid_counter.fetch_add(1, Ordering::SeqCst);
150    // Nvim encodes fixed size strings with a length of 20-31 bytes wrong, so
151    // avoid that
152    let msg_len = message.len();
153    assert!(
154      !(20..=31).contains(&msg_len),
155      "The message should be less than 20 characters or more than 31 characters
156      long, but the length is {msg_len}."
157    );
158
159    let req = RpcMessage::RpcRequest {
160      msgid,
161      method: "nvim_exec_lua".to_owned(),
162      params: call_args![format!("return '{message}'"), Vec::<Value>::new()],
163    };
164    model::encode(instance.writer.clone(), req).await?;
165
166    let expected_resp = RpcMessage::RpcResponse {
167      msgid,
168      error: rmpv::Value::Nil,
169      result: rmpv::Value::String(message.into()),
170    };
171    let mut expected_data = Vec::new();
172    model::encode_sync(&mut expected_data, expected_resp)
173      .expect("Encoding static data can't fail");
174    let mut actual_data = Vec::new();
175    let mut start = 0;
176    let mut end = 0;
177    while end - start != expected_data.len() {
178      actual_data.resize(start + expected_data.len(), 0);
179
180      let bytes_read =
181        reader
182          .read(&mut actual_data[start..])
183          .await
184          .map_err(|err| {
185            (
186              err,
187              String::from_utf8_lossy(&actual_data[..end]).to_string(),
188            )
189          })?;
190      if bytes_read == 0 {
191        // The end of the stream has been reached when the reader returns Ok(0).
192        // Since we haven't detected a suitable response yet, return an error.
193        return Err(Box::new(HandshakeError::UnexpectedResponse(
194          String::from_utf8_lossy(&actual_data[..end]).to_string(),
195        )));
196      }
197      end += bytes_read;
198      while end - start > 0 {
199        if actual_data[start..end] == expected_data[..end - start] {
200          break;
201        }
202        start += 1;
203      }
204    }
205
206    let (sender, receiver) = unbounded();
207    let fut = future::try_join(
208      instance.clone().io_loop(reader, sender),
209      instance.clone().handler_loop(handler, receiver),
210    )
211    .map_ok(|_| ());
212
213    Ok((instance, fut))
214  }
215
216  async fn send_msg(
217    &self,
218    method: &str,
219    args: Vec<Value>,
220  ) -> Result<oneshot::Receiver<ResponseResult>, Box<EncodeError>> {
221    let msgid = self.msgid_counter.fetch_add(1, Ordering::SeqCst);
222
223    let req = RpcMessage::RpcRequest {
224      msgid,
225      method: method.to_owned(),
226      params: args,
227    };
228
229    let (sender, receiver) = oneshot::channel();
230
231    self.queue.lock().await.push((msgid, sender));
232
233    let writer = self.writer.clone();
234    model::encode(writer, req).await?;
235
236    Ok(receiver)
237  }
238
239  pub async fn call(
240    &self,
241    method: &str,
242    args: Vec<Value>,
243  ) -> Result<Result<Value, Value>, Box<CallError>> {
244    let receiver = self
245      .send_msg(method, args)
246      .await
247      .map_err(|e| CallError::SendError(*e, method.to_string()))?;
248
249    match receiver.await {
250      // Result<Result<Result<Value, Value>, Arc<DecodeError>>, RecvError>
251      Ok(Ok(r)) => Ok(r), // r is Result<Value, Value>, i.e. we got an answer
252      Ok(Err(err)) => {
253        // err is a Decode Error, i.e. the answer wasn't decodable
254        Err(Box::new(CallError::DecodeError(err, method.to_string())))
255      }
256      Err(err) => {
257        // err is RecvError
258        Err(Box::new(CallError::InternalReceiveError(
259          err,
260          method.to_string(),
261        )))
262      }
263    }
264  }
265
266  async fn send_error_to_callers(
267    &self,
268    queue: &Queue,
269    err: DecodeError,
270  ) -> Result<Arc<DecodeError>, Box<LoopError>> {
271    let err = Arc::new(err);
272    let mut v: Vec<u64> = vec![];
273
274    let mut queue = queue.lock().await;
275    queue.drain(0..).for_each(|sender| {
276      let msgid = sender.0;
277      sender
278        .1
279        .send(Err(err.clone()))
280        .unwrap_or_else(|_| v.push(msgid));
281    });
282
283    if v.is_empty() {
284      Ok(err)
285    } else {
286      Err((err, v).into())
287    }
288  }
289
290  async fn handler_loop<H>(
291    self,
292    handler: H,
293    mut receiver: UnboundedReceiver<RpcMessage>,
294  ) -> Result<(), Box<LoopError>>
295  where
296    H: Handler<Writer = W> + Spawner,
297  {
298    loop {
299      let msg = match receiver.next().await {
300        Some(msg) => msg,
301        /* If our receiver closes, that just means that io_handler started
302         * shutting down. This is normal, so shut down along with it and don't
303         * report an error
304         */
305        None => break Ok(()),
306      };
307
308      match msg {
309        RpcMessage::RpcRequest {
310          msgid,
311          method,
312          params,
313        } => {
314          let handler_c = handler.clone();
315          let neovim = self.clone();
316          let writer = self.writer.clone();
317
318          handler.spawn(async move {
319            let response = match handler_c
320              .handle_request(method, params, neovim)
321              .await
322              {
323                Ok(result) => RpcMessage::RpcResponse {
324                  msgid,
325                  result,
326                  error: Value::Nil,
327                },
328                Err(error) => RpcMessage::RpcResponse {
329                  msgid,
330                  result: Value::Nil,
331                  error,
332                },
333              };
334
335            model::encode(writer, response)
336              .await
337              .unwrap_or_else(|e| {
338                error!("Error sending response to request {}: '{}'", msgid, e);
339              });
340          });
341        },
342        RpcMessage::RpcNotification {
343          method,
344          params
345        } => handler.handle_notify(method, params, self.clone()).await,
346        RpcMessage::RpcResponse { .. } => unreachable!(),
347      }
348    }
349  }
350
351  async fn io_loop<R>(
352    self,
353    mut reader: R,
354    mut sender: UnboundedSender<RpcMessage>,
355  ) -> Result<(), Box<LoopError>>
356  where
357    R: AsyncRead + Send + Unpin + 'static,
358  {
359    let mut rest: Vec<u8> = vec![];
360
361    loop {
362      let msg = match model::decode(&mut reader, &mut rest).await {
363        Ok(msg) => msg,
364        Err(err) => {
365          let e = self.send_error_to_callers(&self.queue, *err).await?;
366          return Err(Box::new(LoopError::DecodeError(e, None)));
367        }
368      };
369
370      debug!("Get message {:?}", msg);
371      if let RpcMessage::RpcResponse { msgid, result, error, } = msg {
372        let sender = find_sender(&self.queue, msgid).await?;
373        if error == Value::Nil {
374          sender
375            .send(Ok(Ok(result)))
376            .map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
377        } else {
378          sender
379            .send(Ok(Err(error)))
380            .map_err(|r| (msgid, r.expect("This was an OK(_)")))?;
381        }
382      } else {
383        // Send message to handler_loop()
384        sender.send(msg).await.unwrap();
385      }
386    }
387  }
388
389  /// Register as a remote UI.
390  ///
391  /// After this method is called, the client will receive redraw notifications.
392  pub async fn ui_attach(
393    &self,
394    width: i64,
395    height: i64,
396    opts: &UiAttachOptions,
397  ) -> Result<(), Box<CallError>> {
398    self
399      .call(
400        "nvim_ui_attach",
401        call_args!(width, height, opts.to_value_map()),
402      )
403      .await?
404      .map(|_| Ok(()))?
405  }
406
407  /// Send a quit command to Nvim.
408  /// The quit command is 'qa!' which will make Nvim quit without
409  /// saving anything.
410  pub async fn quit_no_save(&self) -> Result<(), Box<CallError>> {
411    self.command("qa!").await
412  }
413}
414
415/* The idea to use Vec here instead of HashMap
416 * is that Vec is faster on small queue sizes
417 * in most cases Vec.len = 1 so we just take first item in iteration.
418 */
419async fn find_sender(
420  queue: &Queue,
421  msgid: u64,
422) -> Result<oneshot::Sender<ResponseResult>, Box<LoopError>> {
423  let mut queue = queue.lock().await;
424
425  let pos = match queue.iter().position(|req| req.0 == msgid) {
426    Some(p) => p,
427    None => return Err(msgid.into()),
428  };
429  Ok(queue.remove(pos).1)
430}
431
432#[cfg(all(test, feature = "use_tokio"))]
433mod tests {
434  use super::*;
435
436  #[tokio::test]
437  async fn test_find_sender() {
438    let queue = Arc::new(Mutex::new(Vec::new()));
439
440    {
441      let (sender, _receiver) = oneshot::channel();
442      queue.lock().await.push((1, sender));
443    }
444    {
445      let (sender, _receiver) = oneshot::channel();
446      queue.lock().await.push((2, sender));
447    }
448    {
449      let (sender, _receiver) = oneshot::channel();
450      queue.lock().await.push((3, sender));
451    }
452
453    find_sender(&queue, 1).await.unwrap();
454    assert_eq!(2, queue.lock().await.len());
455    find_sender(&queue, 2).await.unwrap();
456    assert_eq!(1, queue.lock().await.len());
457    find_sender(&queue, 3).await.unwrap();
458    assert!(queue.lock().await.is_empty());
459
460    if let LoopError::MsgidNotFound(17) =
461      *find_sender(&queue, 17).await.unwrap_err()
462    {
463    } else {
464      panic!()
465    }
466  }
467}