async_nats_wrpc/
connection.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! This module provides a connection implementation for communicating with a NATS server.
15
16use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::task::{Context, Poll};
23
24use bytes::{Buf, Bytes, BytesMut};
25use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
26
27use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
28use crate::status::StatusCode;
29use crate::subject::Subject;
30use crate::{ClientOp, ServerError, ServerOp};
31
32/// Soft limit for the amount of bytes in [`Connection::write_buf`]
33/// and [`Connection::flattened_writes`].
34const SOFT_WRITE_BUF_LIMIT: usize = 65535;
35/// How big a single buffer must be before it's written separately
36/// instead of being flattened.
37const WRITE_FLATTEN_THRESHOLD: usize = 4096;
38/// How many buffers to write in a single vectored write call.
39const WRITE_VECTORED_CHUNKS: usize = 64;
40
41/// Supertrait enabling trait object for containing both TLS and non TLS `TcpStream` connection.
42pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
43
44/// Blanked implementation that applies to both TLS and non-TLS `TcpStream`.
45impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
46
47/// An enum representing the state of the connection.
48#[derive(Debug, Eq, PartialEq, Clone)]
49pub enum State {
50    Pending,
51    Connected,
52    Disconnected,
53}
54
55#[derive(Debug, Eq, PartialEq, Clone)]
56pub enum ShouldFlush {
57    /// Write buffers are empty, but the connection hasn't been flushed yet
58    Yes,
59    /// The connection hasn't been flushed yet, but write buffers aren't empty
60    May,
61    /// Flushing would just be a no-op
62    No,
63}
64
65impl Display for State {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            State::Pending => write!(f, "pending"),
69            State::Connected => write!(f, "connected"),
70            State::Disconnected => write!(f, "disconnected"),
71        }
72    }
73}
74
75/// A framed connection
76pub(crate) struct Connection {
77    pub(crate) stream: Box<dyn AsyncReadWrite>,
78    read_buf: BytesMut,
79    write_buf: VecDeque<Bytes>,
80    write_buf_len: usize,
81    flattened_writes: BytesMut,
82    can_flush: bool,
83}
84
85/// Internal representation of the connection.
86/// Holds connection with NATS Server and communicates with `Client` via channels.
87impl Connection {
88    pub(crate) fn new(stream: Box<dyn AsyncReadWrite>, read_buffer_capacity: usize) -> Self {
89        Self {
90            stream,
91            read_buf: BytesMut::with_capacity(read_buffer_capacity),
92            write_buf: VecDeque::new(),
93            write_buf_len: 0,
94            flattened_writes: BytesMut::new(),
95            can_flush: false,
96        }
97    }
98
99    /// Returns `true` if no more calls to [`Self::enqueue_write_op`] _should_ be made.
100    pub(crate) fn is_write_buf_full(&self) -> bool {
101        self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
102    }
103
104    /// Returns `true` if [`Self::poll_flush`] should be polled.
105    pub(crate) fn should_flush(&self) -> ShouldFlush {
106        match (
107            self.can_flush,
108            self.write_buf.is_empty() && self.flattened_writes.is_empty(),
109        ) {
110            (true, true) => ShouldFlush::Yes,
111            (true, false) => ShouldFlush::May,
112            (false, _) => ShouldFlush::No,
113        }
114    }
115
116    /// Attempts to read a server operation from the read buffer.
117    /// Returns `None` if there is not enough data to parse an entire operation.
118    pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
119        let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
120            Some(len) => len,
121            None => return Ok(None),
122        };
123
124        if self.read_buf.starts_with(b"+OK") {
125            self.read_buf.advance(len + 2);
126            return Ok(Some(ServerOp::Ok));
127        }
128
129        if self.read_buf.starts_with(b"PING") {
130            self.read_buf.advance(len + 2);
131            return Ok(Some(ServerOp::Ping));
132        }
133
134        if self.read_buf.starts_with(b"PONG") {
135            self.read_buf.advance(len + 2);
136            return Ok(Some(ServerOp::Pong));
137        }
138
139        if self.read_buf.starts_with(b"-ERR") {
140            let description = str::from_utf8(&self.read_buf[5..len])
141                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
142                .trim_matches('\'')
143                .to_owned();
144
145            self.read_buf.advance(len + 2);
146
147            return Ok(Some(ServerOp::Error(ServerError::new(description))));
148        }
149
150        if self.read_buf.starts_with(b"INFO ") {
151            let info = serde_json::from_slice(&self.read_buf[4..len])
152                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
153
154            self.read_buf.advance(len + 2);
155
156            return Ok(Some(ServerOp::Info(Box::new(info))));
157        }
158
159        if self.read_buf.starts_with(b"MSG ") {
160            let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
161            let mut args = line.split(' ').filter(|s| !s.is_empty());
162
163            // Parse the operation syntax: MSG <subject> <sid> [reply-to] <#bytes>
164            let (subject, sid, reply_to, payload_len) = match (
165                args.next(),
166                args.next(),
167                args.next(),
168                args.next(),
169                args.next(),
170            ) {
171                (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
172                    (subject, sid, Some(reply_to), payload_len)
173                }
174                (Some(subject), Some(sid), Some(payload_len), None, None) => {
175                    (subject, sid, None, payload_len)
176                }
177                _ => {
178                    return Err(io::Error::new(
179                        io::ErrorKind::InvalidInput,
180                        "invalid number of arguments after MSG",
181                    ))
182                }
183            };
184
185            let sid = sid
186                .parse::<u64>()
187                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
188
189            // Parse the number of payload bytes.
190            let payload_len = payload_len
191                .parse::<usize>()
192                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
193
194            // Return early without advancing if there is not enough data read the entire
195            // message
196            if len + payload_len + 4 > self.read_buf.remaining() {
197                return Ok(None);
198            }
199
200            let length = payload_len
201                + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
202                + subject.len();
203
204            let subject = Subject::from(subject);
205            let reply = reply_to.map(Subject::from);
206
207            self.read_buf.advance(len + 2);
208            let payload = self.read_buf.split_to(payload_len).freeze();
209            self.read_buf.advance(2);
210
211            return Ok(Some(ServerOp::Message {
212                sid,
213                length,
214                reply,
215                headers: None,
216                subject,
217                payload,
218                status: None,
219                description: None,
220            }));
221        }
222
223        if self.read_buf.starts_with(b"HMSG ") {
224            // Extract whitespace-delimited arguments that come after "HMSG".
225            let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
226            let mut args = line.split_whitespace().filter(|s| !s.is_empty());
227
228            // <subject> <sid> [reply-to] <# header bytes><# total bytes>
229            let (subject, sid, reply_to, header_len, total_len) = match (
230                args.next(),
231                args.next(),
232                args.next(),
233                args.next(),
234                args.next(),
235                args.next(),
236            ) {
237                (
238                    Some(subject),
239                    Some(sid),
240                    Some(reply_to),
241                    Some(header_len),
242                    Some(total_len),
243                    None,
244                ) => (subject, sid, Some(reply_to), header_len, total_len),
245                (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
246                    (subject, sid, None, header_len, total_len)
247                }
248                _ => {
249                    return Err(io::Error::new(
250                        io::ErrorKind::InvalidInput,
251                        "invalid number of arguments after HMSG",
252                    ))
253                }
254            };
255
256            // Convert the slice into a subject
257            let subject = Subject::from(subject);
258
259            // Parse the subject ID.
260            let sid = sid.parse::<u64>().map_err(|_| {
261                io::Error::new(
262                    io::ErrorKind::InvalidInput,
263                    "cannot parse sid argument after HMSG",
264                )
265            })?;
266
267            // Convert the slice into a subject.
268            let reply = reply_to.map(Subject::from);
269
270            // Parse the number of payload bytes.
271            let header_len = header_len.parse::<usize>().map_err(|_| {
272                io::Error::new(
273                    io::ErrorKind::InvalidInput,
274                    "cannot parse the number of header bytes argument after \
275                     HMSG",
276                )
277            })?;
278
279            // Parse the number of payload bytes.
280            let total_len = total_len.parse::<usize>().map_err(|_| {
281                io::Error::new(
282                    io::ErrorKind::InvalidInput,
283                    "cannot parse the number of bytes argument after HMSG",
284                )
285            })?;
286
287            if total_len < header_len {
288                return Err(io::Error::new(
289                    io::ErrorKind::InvalidInput,
290                    "number of header bytes was greater than or equal to the \
291                 total number of bytes after HMSG",
292                ));
293            }
294
295            if len + total_len + 4 > self.read_buf.remaining() {
296                return Ok(None);
297            }
298
299            self.read_buf.advance(len + 2);
300            let header = self.read_buf.split_to(header_len);
301            let payload = self.read_buf.split_to(total_len - header_len).freeze();
302            self.read_buf.advance(2);
303
304            let mut lines = std::str::from_utf8(&header)
305                .map_err(|_| {
306                    io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
307                })?
308                .lines()
309                .peekable();
310            let version_line = lines.next().ok_or_else(|| {
311                io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
312            })?;
313
314            let version_line_suffix = version_line
315                .strip_prefix("NATS/1.0")
316                .map(str::trim)
317                .ok_or_else(|| {
318                    io::Error::new(
319                        io::ErrorKind::InvalidInput,
320                        "header version line does not begin with `NATS/1.0`",
321                    )
322                })?;
323
324            let (status, description) = version_line_suffix
325                .split_once(' ')
326                .map(|(status, description)| (status.trim(), description.trim()))
327                .unwrap_or((version_line_suffix, ""));
328            let status = if !status.is_empty() {
329                Some(status.parse::<StatusCode>().map_err(|_| {
330                    std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
331                })?)
332            } else {
333                None
334            };
335            let description = if !description.is_empty() {
336                Some(description.to_owned())
337            } else {
338                None
339            };
340
341            let mut headers = HeaderMap::new();
342            while let Some(line) = lines.next() {
343                if line.is_empty() {
344                    continue;
345                }
346
347                let (name, value) = line.split_once(':').ok_or_else(|| {
348                    io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
349                })?;
350
351                let name = HeaderName::from_str(name)
352                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
353
354                // Read the header value, which might have been split into multiple lines
355                // `trim_start` and `trim_end` do the same job as doing `value.trim().to_owned()` at the end, but without a reallocation
356                let mut value = value.trim_start().to_owned();
357                while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
358                    value.push_str(v);
359                }
360                value.truncate(value.trim_end().len());
361
362                headers.append(name, value.into_header_value());
363            }
364
365            return Ok(Some(ServerOp::Message {
366                length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
367                sid,
368                reply,
369                subject,
370                headers: Some(headers),
371                payload,
372                status,
373                description,
374            }));
375        }
376
377        let buffer = self.read_buf.split_to(len + 2);
378        let line = str::from_utf8(&buffer).map_err(|_| {
379            io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
380        })?;
381
382        Err(io::Error::new(
383            io::ErrorKind::InvalidInput,
384            format!("invalid server operation: '{line}'"),
385        ))
386    }
387
388    pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
389        future::poll_fn(|cx| self.poll_read_op(cx))
390    }
391
392    // TODO: do we want an custom error here?
393    /// Read a server operation from read buffer.
394    /// Blocks until an operation ca be parsed.
395    pub(crate) fn poll_read_op(
396        &mut self,
397        cx: &mut Context<'_>,
398    ) -> Poll<io::Result<Option<ServerOp>>> {
399        loop {
400            if let Some(op) = self.try_read_op()? {
401                return Poll::Ready(Ok(Some(op)));
402            }
403
404            let read_buf = self.stream.read_buf(&mut self.read_buf);
405            tokio::pin!(read_buf);
406            return match read_buf.poll(cx) {
407                Poll::Pending => Poll::Pending,
408                Poll::Ready(Ok(0)) if self.read_buf.is_empty() => Poll::Ready(Ok(None)),
409                Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::ConnectionReset.into())),
410                Poll::Ready(Ok(_n)) => continue,
411                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
412            };
413        }
414    }
415
416    pub(crate) async fn easy_write_and_flush<'a>(
417        &mut self,
418        items: impl Iterator<Item = &'a ClientOp>,
419    ) -> io::Result<()> {
420        for item in items {
421            self.enqueue_write_op(item);
422        }
423
424        future::poll_fn(|cx| self.poll_write(cx)).await?;
425        future::poll_fn(|cx| self.poll_flush(cx)).await?;
426        Ok(())
427    }
428
429    /// Writes a client operation to the write buffer.
430    pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
431        macro_rules! small_write {
432            ($dst:expr) => {
433                write!(self.small_write(), $dst).expect("do small write to Connection");
434            };
435        }
436
437        match item {
438            ClientOp::Connect(connect_info) => {
439                let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
440
441                self.write("CONNECT ");
442                self.write(json);
443                self.write("\r\n");
444            }
445            ClientOp::Publish {
446                subject,
447                payload,
448                respond,
449                headers,
450            } => {
451                let verb = match headers.as_ref() {
452                    Some(headers) if !headers.is_empty() => "HPUB",
453                    _ => "PUB",
454                };
455
456                small_write!("{verb} {subject} ");
457
458                if let Some(respond) = respond {
459                    small_write!("{respond} ");
460                }
461
462                match headers {
463                    Some(headers) if !headers.is_empty() => {
464                        let headers = headers.to_bytes();
465
466                        let headers_len = headers.len();
467                        let total_len = headers_len + payload.len();
468                        small_write!("{headers_len} {total_len}\r\n");
469                        self.write(headers);
470                    }
471                    _ => {
472                        let payload_len = payload.len();
473                        small_write!("{payload_len}\r\n");
474                    }
475                }
476
477                self.write(Bytes::clone(payload));
478                self.write("\r\n");
479            }
480
481            ClientOp::Subscribe {
482                sid,
483                subject,
484                queue_group,
485            } => match queue_group {
486                Some(queue_group) => {
487                    small_write!("SUB {subject} {queue_group} {sid}\r\n");
488                }
489                None => {
490                    small_write!("SUB {subject} {sid}\r\n");
491                }
492            },
493
494            ClientOp::Unsubscribe { sid, max } => match max {
495                Some(max) => {
496                    small_write!("UNSUB {sid} {max}\r\n");
497                }
498                None => {
499                    small_write!("UNSUB {sid}\r\n");
500                }
501            },
502            ClientOp::Ping => {
503                self.write("PING\r\n");
504            }
505            ClientOp::Pong => {
506                self.write("PONG\r\n");
507            }
508        }
509    }
510
511    /// Write the internal buffers into the write stream
512    ///
513    /// Returns one of the following:
514    ///
515    /// * `Poll::Pending` means that we weren't able to fully empty
516    ///   the internal buffers. Compared to [`AsyncWrite::poll_write`],
517    ///   this implementation may do a partial write before yielding.
518    /// * `Poll::Ready(Ok())` means that the internal write buffers have
519    ///   been emptied or were already empty.
520    /// * `Poll::Ready(Err(err))` means that writing to the stream failed.
521    ///   Compared to [`AsyncWrite::poll_write`], this implementation
522    ///   may do a partial write before failing.
523    pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
524        if !self.stream.is_write_vectored() {
525            self.poll_write_sequential(cx)
526        } else {
527            self.poll_write_vectored(cx)
528        }
529    }
530
531    /// Write the internal buffers into the write stream using sequential write operations
532    ///
533    /// Writes one chunk at a time. Less efficient.
534    fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
535        loop {
536            let buf = match self.write_buf.front() {
537                Some(buf) => &**buf,
538                None if !self.flattened_writes.is_empty() => &self.flattened_writes,
539                None => return Poll::Ready(Ok(())),
540            };
541
542            debug_assert!(!buf.is_empty());
543
544            match Pin::new(&mut self.stream).poll_write(cx, buf) {
545                Poll::Pending => return Poll::Pending,
546                Poll::Ready(Ok(n)) => {
547                    self.write_buf_len -= n;
548                    self.can_flush = true;
549
550                    match self.write_buf.front_mut() {
551                        Some(buf) if n < buf.len() => {
552                            buf.advance(n);
553                        }
554                        Some(_buf) => {
555                            self.write_buf.pop_front();
556                        }
557                        None => {
558                            self.flattened_writes.advance(n);
559                        }
560                    }
561                    continue;
562                }
563                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
564            }
565        }
566    }
567
568    /// Write the internal buffers into the write stream using vectored write operations
569    ///
570    /// Writes [`WRITE_VECTORED_CHUNKS`] at a time. More efficient _if_
571    /// the underlying writer supports it.
572    fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
573        'outer: loop {
574            let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
575            let mut writes_len = 0;
576
577            self.write_buf
578                .iter()
579                .take(WRITE_VECTORED_CHUNKS)
580                .enumerate()
581                .for_each(|(i, buf)| {
582                    writes[i] = IoSlice::new(buf);
583                    writes_len += 1;
584                });
585
586            if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
587                writes[writes_len] = IoSlice::new(&self.flattened_writes);
588                writes_len += 1;
589            }
590
591            if writes_len == 0 {
592                return Poll::Ready(Ok(()));
593            }
594
595            match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
596                Poll::Pending => return Poll::Pending,
597                Poll::Ready(Ok(mut n)) => {
598                    self.write_buf_len -= n;
599                    self.can_flush = true;
600
601                    while let Some(buf) = self.write_buf.front_mut() {
602                        if n < buf.len() {
603                            buf.advance(n);
604                            continue 'outer;
605                        }
606
607                        n -= buf.len();
608                        self.write_buf.pop_front();
609                    }
610
611                    self.flattened_writes.advance(n);
612                }
613                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
614            }
615        }
616    }
617
618    /// Write `buf` into the writes buffer
619    ///
620    /// If `buf` is smaller than [`WRITE_FLATTEN_THRESHOLD`]
621    /// flattens it, otherwise appends it to the chunks queue.
622    ///
623    /// Empty `buf`s are a no-op.
624    fn write(&mut self, buf: impl Into<Bytes>) {
625        let buf = buf.into();
626        if buf.is_empty() {
627            return;
628        }
629
630        self.write_buf_len += buf.len();
631        if buf.len() < WRITE_FLATTEN_THRESHOLD {
632            self.flattened_writes.extend_from_slice(&buf);
633        } else {
634            if !self.flattened_writes.is_empty() {
635                let buf = self.flattened_writes.split().freeze();
636                self.write_buf.push_back(buf);
637            }
638
639            self.write_buf.push_back(buf);
640        }
641    }
642
643    /// Obtain an [`fmt::Write`]r for the small writes buffer.
644    fn small_write(&mut self) -> impl fmt::Write + '_ {
645        struct Writer<'a> {
646            this: &'a mut Connection,
647        }
648
649        impl<'a> fmt::Write for Writer<'a> {
650            fn write_str(&mut self, s: &str) -> fmt::Result {
651                self.this.write_buf_len += s.len();
652                self.this.flattened_writes.write_str(s)
653            }
654        }
655
656        Writer { this: self }
657    }
658
659    /// Flush the write buffer, sending all pending data down the current write stream.
660    ///
661    /// no-op if the write stream didn't need to be flushed.
662    pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
663        match Pin::new(&mut self.stream).poll_flush(cx) {
664            Poll::Pending => Poll::Pending,
665            Poll::Ready(Ok(())) => {
666                self.can_flush = false;
667                Poll::Ready(Ok(()))
668            }
669            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
670        }
671    }
672}
673
674#[cfg(test)]
675mod read_op {
676    use super::Connection;
677    use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, StatusCode};
678    use tokio::io::{self, AsyncWriteExt};
679
680    #[tokio::test]
681    async fn ok() {
682        let (stream, mut server) = io::duplex(128);
683        let mut connection = Connection::new(Box::new(stream), 0);
684
685        server.write_all(b"+OK\r\n").await.unwrap();
686        let result = connection.read_op().await.unwrap();
687        assert_eq!(result, Some(ServerOp::Ok));
688    }
689
690    #[tokio::test]
691    async fn ping() {
692        let (stream, mut server) = io::duplex(128);
693        let mut connection = Connection::new(Box::new(stream), 0);
694
695        server.write_all(b"PING\r\n").await.unwrap();
696        let result = connection.read_op().await.unwrap();
697        assert_eq!(result, Some(ServerOp::Ping));
698    }
699
700    #[tokio::test]
701    async fn pong() {
702        let (stream, mut server) = io::duplex(128);
703        let mut connection = Connection::new(Box::new(stream), 0);
704
705        server.write_all(b"PONG\r\n").await.unwrap();
706        let result = connection.read_op().await.unwrap();
707        assert_eq!(result, Some(ServerOp::Pong));
708    }
709
710    #[tokio::test]
711    async fn info() {
712        let (stream, mut server) = io::duplex(128);
713        let mut connection = Connection::new(Box::new(stream), 0);
714
715        server.write_all(b"INFO {}\r\n").await.unwrap();
716        server.flush().await.unwrap();
717
718        let result = connection.read_op().await.unwrap();
719        assert_eq!(result, Some(ServerOp::Info(Box::default())));
720
721        server
722            .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
723            .await
724            .unwrap();
725        server.flush().await.unwrap();
726
727        let result = connection.read_op().await.unwrap();
728        assert_eq!(
729            result,
730            Some(ServerOp::Info(Box::new(ServerInfo {
731                version: "1.0.0".into(),
732                ..Default::default()
733            })))
734        );
735    }
736
737    #[tokio::test]
738    async fn error() {
739        let (stream, mut server) = io::duplex(128);
740        let mut connection = Connection::new(Box::new(stream), 0);
741
742        server.write_all(b"INFO {}\r\n").await.unwrap();
743        let result = connection.read_op().await.unwrap();
744        assert_eq!(result, Some(ServerOp::Info(Box::default())));
745
746        server
747            .write_all(b"-ERR something went wrong\r\n")
748            .await
749            .unwrap();
750        let result = connection.read_op().await.unwrap();
751        assert_eq!(
752            result,
753            Some(ServerOp::Error(ServerError::Other(
754                "something went wrong".into()
755            )))
756        );
757    }
758
759    #[tokio::test]
760    async fn message() {
761        let (stream, mut server) = io::duplex(128);
762        let mut connection = Connection::new(Box::new(stream), 0);
763
764        server
765            .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
766            .await
767            .unwrap();
768
769        let result = connection.read_op().await.unwrap();
770        assert_eq!(
771            result,
772            Some(ServerOp::Message {
773                sid: 9,
774                subject: "FOO.BAR".into(),
775                reply: None,
776                headers: None,
777                payload: "Hello World".into(),
778                status: None,
779                description: None,
780                length: 7 + 11,
781            })
782        );
783
784        server
785            .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
786            .await
787            .unwrap();
788
789        let result = connection.read_op().await.unwrap();
790        assert_eq!(
791            result,
792            Some(ServerOp::Message {
793                sid: 9,
794                subject: "FOO.BAR".into(),
795                reply: Some("INBOX.34".into()),
796                headers: None,
797                payload: "Hello World".into(),
798                status: None,
799                description: None,
800                length: 7 + 8 + 11,
801            })
802        );
803
804        server
805            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
806            .await
807            .unwrap();
808        server.write_all(b"NATS/1.0\r\n").await.unwrap();
809        server.write_all(b"Header: X\r\n").await.unwrap();
810        server.write_all(b"\r\n").await.unwrap();
811        server.write_all(b"Hello World\r\n").await.unwrap();
812
813        let result = connection.read_op().await.unwrap();
814
815        assert_eq!(
816            result,
817            Some(ServerOp::Message {
818                sid: 10,
819                subject: "FOO.BAR".into(),
820                reply: Some("INBOX.35".into()),
821                headers: Some(HeaderMap::from_iter([(
822                    "Header".parse().unwrap(),
823                    "X".parse().unwrap()
824                )])),
825                payload: "Hello World".into(),
826                status: None,
827                description: None,
828                length: 7 + 8 + 34
829            })
830        );
831
832        server
833            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
834            .await
835            .unwrap();
836        server.write_all(b"NATS/1.0\r\n").await.unwrap();
837        server.write_all(b"Header: Y\r\n").await.unwrap();
838        server.write_all(b"\r\n").await.unwrap();
839        server.write_all(b"Hello World\r\n").await.unwrap();
840
841        let result = connection.read_op().await.unwrap();
842        assert_eq!(
843            result,
844            Some(ServerOp::Message {
845                sid: 10,
846                subject: "FOO.BAR".into(),
847                reply: Some("INBOX.35".into()),
848                headers: Some(HeaderMap::from_iter([(
849                    "Header".parse().unwrap(),
850                    "Y".parse().unwrap()
851                )])),
852                payload: "Hello World".into(),
853                status: None,
854                description: None,
855                length: 7 + 8 + 34,
856            })
857        );
858
859        server
860            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
861            .await
862            .unwrap();
863        server
864            .write_all(b"NATS/1.0 404 No Messages\r\n")
865            .await
866            .unwrap();
867        server.write_all(b"\r\n").await.unwrap();
868        server.write_all(b"\r\n").await.unwrap();
869
870        let result = connection.read_op().await.unwrap();
871        assert_eq!(
872            result,
873            Some(ServerOp::Message {
874                sid: 10,
875                subject: "FOO.BAR".into(),
876                reply: Some("INBOX.35".into()),
877                headers: Some(HeaderMap::default()),
878                payload: "".into(),
879                status: Some(StatusCode::NOT_FOUND),
880                description: Some("No Messages".to_string()),
881                length: 7 + 8 + 28,
882            })
883        );
884
885        server
886            .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
887            .await
888            .unwrap();
889
890        let result = connection.read_op().await.unwrap();
891        assert_eq!(
892            result,
893            Some(ServerOp::Message {
894                sid: 9,
895                subject: "FOO.BAR".into(),
896                reply: None,
897                headers: None,
898                payload: "Hello Again".into(),
899                status: None,
900                description: None,
901                length: 7 + 11,
902            })
903        );
904    }
905
906    #[tokio::test]
907    async fn unknown() {
908        let (stream, mut server) = io::duplex(128);
909        let mut connection = Connection::new(Box::new(stream), 0);
910
911        server.write_all(b"ONE\r\n").await.unwrap();
912        connection.read_op().await.unwrap_err();
913
914        server.write_all(b"TWO\r\n").await.unwrap();
915        connection.read_op().await.unwrap_err();
916
917        server.write_all(b"PING\r\n").await.unwrap();
918        connection.read_op().await.unwrap();
919
920        server.write_all(b"THREE\r\n").await.unwrap();
921        connection.read_op().await.unwrap_err();
922
923        server
924            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
925            .await
926            .unwrap();
927        server
928            .write_all(b"NATS/1.0 404 No Messages\r\n")
929            .await
930            .unwrap();
931        server.write_all(b"\r\n").await.unwrap();
932        server.write_all(b"\r\n").await.unwrap();
933
934        let result = connection.read_op().await.unwrap();
935        assert_eq!(
936            result,
937            Some(ServerOp::Message {
938                sid: 10,
939                subject: "FOO.BAR".into(),
940                reply: Some("INBOX.35".into()),
941                headers: Some(HeaderMap::default()),
942                payload: "".into(),
943                status: Some(StatusCode::NOT_FOUND),
944                description: Some("No Messages".to_string()),
945                length: 7 + 8 + 28,
946            })
947        );
948
949        server.write_all(b"FOUR\r\n").await.unwrap();
950        connection.read_op().await.unwrap_err();
951
952        server.write_all(b"PONG\r\n").await.unwrap();
953        connection.read_op().await.unwrap();
954    }
955}
956
957#[cfg(test)]
958mod write_op {
959    use super::Connection;
960    use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol};
961    use tokio::io::{self, AsyncBufReadExt, BufReader};
962
963    #[tokio::test]
964    async fn publish() {
965        let (stream, server) = io::duplex(128);
966        let mut connection = Connection::new(Box::new(stream), 0);
967
968        connection
969            .easy_write_and_flush(
970                [ClientOp::Publish {
971                    subject: "FOO.BAR".into(),
972                    payload: "Hello World".into(),
973                    respond: None,
974                    headers: None,
975                }]
976                .iter(),
977            )
978            .await
979            .unwrap();
980
981        let mut buffer = String::new();
982        let mut reader = BufReader::new(server);
983        reader.read_line(&mut buffer).await.unwrap();
984        reader.read_line(&mut buffer).await.unwrap();
985        assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
986
987        connection
988            .easy_write_and_flush(
989                [ClientOp::Publish {
990                    subject: "FOO.BAR".into(),
991                    payload: "Hello World".into(),
992                    respond: Some("INBOX.67".into()),
993                    headers: None,
994                }]
995                .iter(),
996            )
997            .await
998            .unwrap();
999
1000        buffer.clear();
1001        reader.read_line(&mut buffer).await.unwrap();
1002        reader.read_line(&mut buffer).await.unwrap();
1003        assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1004
1005        connection
1006            .easy_write_and_flush(
1007                [ClientOp::Publish {
1008                    subject: "FOO.BAR".into(),
1009                    payload: "Hello World".into(),
1010                    respond: Some("INBOX.67".into()),
1011                    headers: Some(HeaderMap::from_iter([(
1012                        "Header".parse().unwrap(),
1013                        "X".parse().unwrap(),
1014                    )])),
1015                }]
1016                .iter(),
1017            )
1018            .await
1019            .unwrap();
1020
1021        buffer.clear();
1022        reader.read_line(&mut buffer).await.unwrap();
1023        reader.read_line(&mut buffer).await.unwrap();
1024        reader.read_line(&mut buffer).await.unwrap();
1025        reader.read_line(&mut buffer).await.unwrap();
1026        assert_eq!(
1027            buffer,
1028            "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1029        );
1030    }
1031
1032    #[tokio::test]
1033    async fn subscribe() {
1034        let (stream, server) = io::duplex(128);
1035        let mut connection = Connection::new(Box::new(stream), 0);
1036
1037        connection
1038            .easy_write_and_flush(
1039                [ClientOp::Subscribe {
1040                    sid: 11,
1041                    subject: "FOO.BAR".into(),
1042                    queue_group: None,
1043                }]
1044                .iter(),
1045            )
1046            .await
1047            .unwrap();
1048
1049        let mut buffer = String::new();
1050        let mut reader = BufReader::new(server);
1051        reader.read_line(&mut buffer).await.unwrap();
1052        assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1053
1054        connection
1055            .easy_write_and_flush(
1056                [ClientOp::Subscribe {
1057                    sid: 11,
1058                    subject: "FOO.BAR".into(),
1059                    queue_group: Some("QUEUE.GROUP".into()),
1060                }]
1061                .iter(),
1062            )
1063            .await
1064            .unwrap();
1065
1066        buffer.clear();
1067        reader.read_line(&mut buffer).await.unwrap();
1068        assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1069    }
1070
1071    #[tokio::test]
1072    async fn unsubscribe() {
1073        let (stream, server) = io::duplex(128);
1074        let mut connection = Connection::new(Box::new(stream), 0);
1075
1076        connection
1077            .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1078            .await
1079            .unwrap();
1080
1081        let mut buffer = String::new();
1082        let mut reader = BufReader::new(server);
1083        reader.read_line(&mut buffer).await.unwrap();
1084        assert_eq!(buffer, "UNSUB 11\r\n");
1085
1086        connection
1087            .easy_write_and_flush(
1088                [ClientOp::Unsubscribe {
1089                    sid: 11,
1090                    max: Some(2),
1091                }]
1092                .iter(),
1093            )
1094            .await
1095            .unwrap();
1096
1097        buffer.clear();
1098        reader.read_line(&mut buffer).await.unwrap();
1099        assert_eq!(buffer, "UNSUB 11 2\r\n");
1100    }
1101
1102    #[tokio::test]
1103    async fn ping() {
1104        let (stream, server) = io::duplex(128);
1105        let mut connection = Connection::new(Box::new(stream), 0);
1106
1107        let mut reader = BufReader::new(server);
1108        let mut buffer = String::new();
1109
1110        connection
1111            .easy_write_and_flush([ClientOp::Ping].iter())
1112            .await
1113            .unwrap();
1114
1115        reader.read_line(&mut buffer).await.unwrap();
1116
1117        assert_eq!(buffer, "PING\r\n");
1118    }
1119
1120    #[tokio::test]
1121    async fn pong() {
1122        let (stream, server) = io::duplex(128);
1123        let mut connection = Connection::new(Box::new(stream), 0);
1124
1125        let mut reader = BufReader::new(server);
1126        let mut buffer = String::new();
1127
1128        connection
1129            .easy_write_and_flush([ClientOp::Pong].iter())
1130            .await
1131            .unwrap();
1132
1133        reader.read_line(&mut buffer).await.unwrap();
1134
1135        assert_eq!(buffer, "PONG\r\n");
1136    }
1137
1138    #[tokio::test]
1139    async fn connect() {
1140        let (stream, server) = io::duplex(1024);
1141        let mut connection = Connection::new(Box::new(stream), 0);
1142
1143        let mut reader = BufReader::new(server);
1144        let mut buffer = String::new();
1145
1146        connection
1147            .easy_write_and_flush(
1148                [ClientOp::Connect(ConnectInfo {
1149                    verbose: false,
1150                    pedantic: false,
1151                    user_jwt: None,
1152                    nkey: None,
1153                    signature: None,
1154                    name: None,
1155                    echo: false,
1156                    lang: "Rust".into(),
1157                    version: "1.0.0".into(),
1158                    protocol: Protocol::Dynamic,
1159                    tls_required: false,
1160                    user: None,
1161                    pass: None,
1162                    auth_token: None,
1163                    headers: false,
1164                    no_responders: false,
1165                })]
1166                .iter(),
1167            )
1168            .await
1169            .unwrap();
1170
1171        reader.read_line(&mut buffer).await.unwrap();
1172        assert_eq!(
1173            buffer,
1174            "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1175        );
1176    }
1177}