async_nats/
connection.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! This module provides a connection implementation for communicating with a NATS server.
15
16use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28    futures::{SinkExt, StreamExt},
29    pin_project::pin_project,
30    tokio::io::ReadBuf,
31    tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36use tracing::trace;
37
38use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
39use crate::status::StatusCode;
40use crate::subject::Subject;
41use crate::{ClientOp, ServerError, ServerOp, Statistics};
42
43/// Soft limit for the amount of bytes in [`Connection::write_buf`]
44/// and [`Connection::flattened_writes`].
45const SOFT_WRITE_BUF_LIMIT: usize = 65535;
46/// How big a single buffer must be before it's written separately
47/// instead of being flattened.
48const WRITE_FLATTEN_THRESHOLD: usize = 4096;
49/// How many buffers to write in a single vectored write call.
50const WRITE_VECTORED_CHUNKS: usize = 64;
51
52/// Supertrait enabling trait object for containing both TLS and non TLS `TcpStream` connection.
53pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
54
55/// Blanked implementation that applies to both TLS and non-TLS `TcpStream`.
56impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
57
58/// An enum representing the state of the connection.
59#[derive(Debug, Eq, PartialEq, Clone)]
60pub enum State {
61    Pending,
62    Connected,
63    Disconnected,
64}
65
66#[derive(Debug, Eq, PartialEq, Clone)]
67pub enum ShouldFlush {
68    /// Write buffers are empty, but the connection hasn't been flushed yet
69    Yes,
70    /// The connection hasn't been flushed yet, but write buffers aren't empty
71    May,
72    /// Flushing would just be a no-op
73    No,
74}
75
76impl Display for State {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            State::Pending => write!(f, "pending"),
80            State::Connected => write!(f, "connected"),
81            State::Disconnected => write!(f, "disconnected"),
82        }
83    }
84}
85
86/// A framed connection
87pub(crate) struct Connection {
88    pub(crate) stream: Box<dyn AsyncReadWrite>,
89    read_buf: BytesMut,
90    write_buf: VecDeque<Bytes>,
91    write_buf_len: usize,
92    flattened_writes: BytesMut,
93    can_flush: bool,
94    statistics: Arc<Statistics>,
95}
96
97/// Internal representation of the connection.
98/// Holds connection with NATS Server and communicates with `Client` via channels.
99impl Connection {
100    pub(crate) fn new(
101        stream: Box<dyn AsyncReadWrite>,
102        read_buffer_capacity: usize,
103        statistics: Arc<Statistics>,
104    ) -> Self {
105        Self {
106            stream,
107            read_buf: BytesMut::with_capacity(read_buffer_capacity),
108            write_buf: VecDeque::new(),
109            write_buf_len: 0,
110            flattened_writes: BytesMut::new(),
111            can_flush: false,
112            statistics,
113        }
114    }
115
116    /// Returns `true` if no more calls to [`Self::enqueue_write_op`] _should_ be made.
117    pub(crate) fn is_write_buf_full(&self) -> bool {
118        self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
119    }
120
121    /// Returns `true` if [`Self::poll_flush`] should be polled.
122    pub(crate) fn should_flush(&self) -> ShouldFlush {
123        match (
124            self.can_flush,
125            self.write_buf.is_empty() && self.flattened_writes.is_empty(),
126        ) {
127            (true, true) => ShouldFlush::Yes,
128            (true, false) => ShouldFlush::May,
129            (false, _) => ShouldFlush::No,
130        }
131    }
132
133    /// Attempts to read a server operation from the read buffer.
134    /// Returns `None` if there is not enough data to parse an entire operation.
135    pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
136        let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
137            Some(len) => len,
138            None => return Ok(None),
139        };
140
141        if self.read_buf.starts_with(b"+OK") {
142            self.read_buf.advance(len + 2);
143            trace!("read operation: OK");
144            return Ok(Some(ServerOp::Ok));
145        }
146
147        if self.read_buf.starts_with(b"PING") {
148            self.read_buf.advance(len + 2);
149            trace!("read operation: PING");
150            return Ok(Some(ServerOp::Ping));
151        }
152
153        if self.read_buf.starts_with(b"PONG") {
154            self.read_buf.advance(len + 2);
155            trace!("read operation: PONG");
156            return Ok(Some(ServerOp::Pong));
157        }
158
159        if self.read_buf.starts_with(b"-ERR") {
160            let description = str::from_utf8(&self.read_buf[5..len])
161                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
162                .trim_matches('\'')
163                .to_owned();
164
165            self.read_buf.advance(len + 2);
166            trace!(error = %description, "read operation: ERR");
167            return Ok(Some(ServerOp::Error(ServerError::new(description))));
168        }
169
170        if self.read_buf.starts_with(b"INFO ") {
171            let info = serde_json::from_slice(&self.read_buf[4..len])
172                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
173
174            self.read_buf.advance(len + 2);
175            trace!(?info, "read operation: INFO");
176            return Ok(Some(ServerOp::Info(Box::new(info))));
177        }
178
179        if self.read_buf.starts_with(b"MSG ") {
180            let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
181            let mut args = line.split(' ').filter(|s| !s.is_empty());
182
183            // Parse the operation syntax: MSG <subject> <sid> [reply-to] <#bytes>
184            let (subject, sid, reply_to, payload_len) = match (
185                args.next(),
186                args.next(),
187                args.next(),
188                args.next(),
189                args.next(),
190            ) {
191                (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
192                    (subject, sid, Some(reply_to), payload_len)
193                }
194                (Some(subject), Some(sid), Some(payload_len), None, None) => {
195                    (subject, sid, None, payload_len)
196                }
197                _ => {
198                    return Err(io::Error::new(
199                        io::ErrorKind::InvalidInput,
200                        "invalid number of arguments after MSG",
201                    ))
202                }
203            };
204
205            let sid = sid
206                .parse::<u64>()
207                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
208
209            // Parse the number of payload bytes.
210            let payload_len = payload_len
211                .parse::<usize>()
212                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
213
214            // Return early without advancing if there is not enough data read the entire
215            // message
216            if len + payload_len + 4 > self.read_buf.remaining() {
217                return Ok(None);
218            }
219
220            let length = payload_len
221                + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
222                + subject.len();
223
224            let subject = Subject::from(subject);
225            let reply = reply_to.map(Subject::from);
226
227            self.read_buf.advance(len + 2);
228            let payload = self.read_buf.split_to(payload_len).freeze();
229            self.read_buf.advance(2);
230
231            trace!(
232                subject = %subject,
233                sid = %sid,
234                reply = ?reply,
235                payload_len = %payload_len,
236                "read operation: MSG"
237            );
238
239            return Ok(Some(ServerOp::Message {
240                sid,
241                length,
242                reply,
243                headers: None,
244                subject,
245                payload,
246                status: None,
247                description: None,
248            }));
249        }
250
251        if self.read_buf.starts_with(b"HMSG ") {
252            // Extract whitespace-delimited arguments that come after "HMSG".
253            let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
254            let mut args = line.split_whitespace().filter(|s| !s.is_empty());
255
256            // <subject> <sid> [reply-to] <# header bytes><# total bytes>
257            let (subject, sid, reply_to, header_len, total_len) = match (
258                args.next(),
259                args.next(),
260                args.next(),
261                args.next(),
262                args.next(),
263                args.next(),
264            ) {
265                (
266                    Some(subject),
267                    Some(sid),
268                    Some(reply_to),
269                    Some(header_len),
270                    Some(total_len),
271                    None,
272                ) => (subject, sid, Some(reply_to), header_len, total_len),
273                (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
274                    (subject, sid, None, header_len, total_len)
275                }
276                _ => {
277                    return Err(io::Error::new(
278                        io::ErrorKind::InvalidInput,
279                        "invalid number of arguments after HMSG",
280                    ))
281                }
282            };
283
284            // Convert the slice into a subject
285            let subject = Subject::from(subject);
286
287            // Parse the subject ID.
288            let sid = sid.parse::<u64>().map_err(|_| {
289                io::Error::new(
290                    io::ErrorKind::InvalidInput,
291                    "cannot parse sid argument after HMSG",
292                )
293            })?;
294
295            // Convert the slice into a subject.
296            let reply = reply_to.map(Subject::from);
297
298            // Parse the number of payload bytes.
299            let header_len = header_len.parse::<usize>().map_err(|_| {
300                io::Error::new(
301                    io::ErrorKind::InvalidInput,
302                    "cannot parse the number of header bytes argument after \
303                     HMSG",
304                )
305            })?;
306
307            // Parse the number of payload bytes.
308            let total_len = total_len.parse::<usize>().map_err(|_| {
309                io::Error::new(
310                    io::ErrorKind::InvalidInput,
311                    "cannot parse the number of bytes argument after HMSG",
312                )
313            })?;
314
315            if total_len < header_len {
316                return Err(io::Error::new(
317                    io::ErrorKind::InvalidInput,
318                    "number of header bytes was greater than or equal to the \
319                 total number of bytes after HMSG",
320                ));
321            }
322
323            if len + total_len + 4 > self.read_buf.remaining() {
324                return Ok(None);
325            }
326
327            self.read_buf.advance(len + 2);
328            let header = self.read_buf.split_to(header_len);
329            let payload = self.read_buf.split_to(total_len - header_len).freeze();
330            self.read_buf.advance(2);
331
332            let mut lines = std::str::from_utf8(&header)
333                .map_err(|_| {
334                    io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
335                })?
336                .lines()
337                .peekable();
338            let version_line = lines.next().ok_or_else(|| {
339                io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
340            })?;
341
342            let version_line_suffix = version_line
343                .strip_prefix("NATS/1.0")
344                .map(str::trim)
345                .ok_or_else(|| {
346                    io::Error::new(
347                        io::ErrorKind::InvalidInput,
348                        "header version line does not begin with `NATS/1.0`",
349                    )
350                })?;
351
352            let (status, description) = version_line_suffix
353                .split_once(' ')
354                .map(|(status, description)| (status.trim(), description.trim()))
355                .unwrap_or((version_line_suffix, ""));
356            let status = if !status.is_empty() {
357                Some(status.parse::<StatusCode>().map_err(|_| {
358                    std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
359                })?)
360            } else {
361                None
362            };
363            let description = if !description.is_empty() {
364                Some(description.to_owned())
365            } else {
366                None
367            };
368
369            let mut headers = HeaderMap::new();
370            while let Some(line) = lines.next() {
371                if line.is_empty() {
372                    continue;
373                }
374
375                let (name, value) = line.split_once(':').ok_or_else(|| {
376                    io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
377                })?;
378
379                let name = HeaderName::from_str(name)
380                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
381
382                // Read the header value, which might have been split into multiple lines
383                // `trim_start` and `trim_end` do the same job as doing `value.trim().to_owned()` at the end, but without a reallocation
384                let mut value = value.trim_start().to_owned();
385                while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
386                    value.push_str(v);
387                }
388                value.truncate(value.trim_end().len());
389
390                headers.append(name, value.into_header_value());
391            }
392
393            trace!(
394                subject = %subject,
395                sid = %sid,
396                reply = ?reply,
397                header_len = %header_len,
398                total_len = %total_len,
399                status = ?status,
400                description = ?description,
401                "read operation: HMSG"
402            );
403
404            return Ok(Some(ServerOp::Message {
405                length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
406                sid,
407                reply,
408                subject,
409                headers: Some(headers),
410                payload,
411                status,
412                description,
413            }));
414        }
415
416        let buffer = self.read_buf.split_to(len + 2);
417        let line = str::from_utf8(&buffer).map_err(|_| {
418            io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
419        })?;
420
421        trace!(line = %line, "read operation: unknown");
422        Err(io::Error::new(
423            io::ErrorKind::InvalidInput,
424            format!("invalid server operation: '{line}'"),
425        ))
426    }
427
428    pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
429        future::poll_fn(|cx| self.poll_read_op(cx))
430    }
431
432    // TODO: do we want an custom error here?
433    /// Read a server operation from read buffer.
434    /// Blocks until an operation ca be parsed.
435    pub(crate) fn poll_read_op(
436        &mut self,
437        cx: &mut Context<'_>,
438    ) -> Poll<io::Result<Option<ServerOp>>> {
439        loop {
440            if let Some(op) = self.try_read_op()? {
441                trace!(?op, "read operation completed");
442                return Poll::Ready(Ok(Some(op)));
443            }
444
445            let read_buf = self.stream.read_buf(&mut self.read_buf);
446            tokio::pin!(read_buf);
447            return match read_buf.poll(cx) {
448                Poll::Pending => {
449                    trace!("read operation pending");
450                    Poll::Pending
451                }
452                Poll::Ready(Ok(0)) if self.read_buf.is_empty() => {
453                    trace!("read operation: empty buffer");
454                    Poll::Ready(Ok(None))
455                }
456                Poll::Ready(Ok(0)) => {
457                    trace!("read operation: connection reset");
458                    Poll::Ready(Err(io::ErrorKind::ConnectionReset.into()))
459                }
460                Poll::Ready(Ok(n)) => {
461                    self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
462                    trace!(bytes = %n, "read operation: received bytes");
463                    continue;
464                }
465                Poll::Ready(Err(err)) => {
466                    trace!(error = %err, "read operation: error");
467                    Poll::Ready(Err(err))
468                }
469            };
470        }
471    }
472
473    pub(crate) async fn easy_write_and_flush<'a>(
474        &mut self,
475        items: impl Iterator<Item = &'a ClientOp>,
476    ) -> io::Result<()> {
477        for item in items {
478            self.enqueue_write_op(item);
479        }
480
481        future::poll_fn(|cx| self.poll_write(cx)).await?;
482        future::poll_fn(|cx| self.poll_flush(cx)).await?;
483        Ok(())
484    }
485
486    /// Writes a client operation to the write buffer.
487    pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
488        macro_rules! small_write {
489            ($dst:expr) => {
490                write!(self.small_write(), $dst).expect("do small write to Connection");
491            };
492        }
493
494        match item {
495            ClientOp::Connect(connect_info) => {
496                let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
497
498                self.write("CONNECT ");
499                self.write(json);
500                self.write("\r\n");
501                trace!(?connect_info, "write operation: CONNECT");
502            }
503            ClientOp::Publish {
504                subject,
505                payload,
506                respond,
507                headers,
508            } => {
509                let verb = match headers.as_ref() {
510                    Some(headers) if !headers.is_empty() => "HPUB",
511                    _ => "PUB",
512                };
513
514                small_write!("{verb} {subject} ");
515
516                if let Some(respond) = respond {
517                    small_write!("{respond} ");
518                }
519
520                match headers {
521                    Some(headers) if !headers.is_empty() => {
522                        let headers = headers.to_bytes();
523
524                        let headers_len = headers.len();
525                        let total_len = headers_len + payload.len();
526                        small_write!("{headers_len} {total_len}\r\n");
527                        self.write(headers);
528                    }
529                    _ => {
530                        let payload_len = payload.len();
531                        small_write!("{payload_len}\r\n");
532                    }
533                }
534
535                self.write(Bytes::clone(payload));
536                self.write("\r\n");
537
538                trace!(
539                    verb = %verb,
540                    subject = %subject,
541                    reply = ?respond,
542                    headers = ?headers,
543                    payload_len = %payload.len(),
544                    "write operation: PUB"
545                );
546            }
547
548            ClientOp::Subscribe {
549                sid,
550                subject,
551                queue_group,
552            } => {
553                match queue_group {
554                    Some(queue_group) => {
555                        small_write!("SUB {subject} {queue_group} {sid}\r\n");
556                    }
557                    None => {
558                        small_write!("SUB {subject} {sid}\r\n");
559                    }
560                }
561
562                trace!(
563                    subject = %subject,
564                    sid = %sid,
565                    queue_group = ?queue_group,
566                    "write operation: SUB"
567                );
568            }
569
570            ClientOp::Unsubscribe { sid, max } => {
571                match max {
572                    Some(max) => {
573                        small_write!("UNSUB {sid} {max}\r\n");
574                    }
575                    None => {
576                        small_write!("UNSUB {sid}\r\n");
577                    }
578                }
579
580                trace!(
581                    sid = %sid,
582                    max = ?max,
583                    "write operation: UNSUB"
584                );
585            }
586            ClientOp::Ping => {
587                self.write("PING\r\n");
588                trace!("write operation: PING");
589            }
590            ClientOp::Pong => {
591                self.write("PONG\r\n");
592                trace!("write operation: PONG");
593            }
594        }
595    }
596
597    /// Write the internal buffers into the write stream
598    ///
599    /// Returns one of the following:
600    ///
601    /// * `Poll::Pending` means that we weren't able to fully empty
602    ///   the internal buffers. Compared to [`AsyncWrite::poll_write`],
603    ///   this implementation may do a partial write before yielding.
604    /// * `Poll::Ready(Ok())` means that the internal write buffers have
605    ///   been emptied or were already empty.
606    /// * `Poll::Ready(Err(err))` means that writing to the stream failed.
607    ///   Compared to [`AsyncWrite::poll_write`], this implementation
608    ///   may do a partial write before failing.
609    pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
610        if !self.stream.is_write_vectored() {
611            self.poll_write_sequential(cx)
612        } else {
613            self.poll_write_vectored(cx)
614        }
615    }
616
617    /// Write the internal buffers into the write stream using sequential write operations
618    ///
619    /// Writes one chunk at a time. Less efficient.
620    fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
621        loop {
622            let buf = match self.write_buf.front() {
623                Some(buf) => &**buf,
624                None if !self.flattened_writes.is_empty() => &self.flattened_writes,
625                None => return Poll::Ready(Ok(())),
626            };
627
628            debug_assert!(!buf.is_empty());
629
630            match Pin::new(&mut self.stream).poll_write(cx, buf) {
631                Poll::Pending => return Poll::Pending,
632                Poll::Ready(Ok(n)) => {
633                    self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
634                    self.write_buf_len -= n;
635                    self.can_flush = true;
636
637                    match self.write_buf.front_mut() {
638                        Some(buf) if n < buf.len() => {
639                            buf.advance(n);
640                        }
641                        Some(_buf) => {
642                            self.write_buf.pop_front();
643                        }
644                        None => {
645                            self.flattened_writes.advance(n);
646                        }
647                    }
648                    continue;
649                }
650                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
651            }
652        }
653    }
654    /// Write the internal buffers into the write stream using vectored write operations
655    ///
656    /// Writes [`WRITE_VECTORED_CHUNKS`] at a time. More efficient _if_
657    /// the underlying writer supports it.
658    fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
659        'outer: loop {
660            let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
661            let mut writes_len = 0;
662
663            self.write_buf
664                .iter()
665                .take(WRITE_VECTORED_CHUNKS)
666                .enumerate()
667                .for_each(|(i, buf)| {
668                    writes[i] = IoSlice::new(buf);
669                    writes_len += 1;
670                });
671
672            if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
673                writes[writes_len] = IoSlice::new(&self.flattened_writes);
674                writes_len += 1;
675            }
676
677            if writes_len == 0 {
678                return Poll::Ready(Ok(()));
679            }
680
681            match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
682                Poll::Pending => return Poll::Pending,
683                Poll::Ready(Ok(mut n)) => {
684                    self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
685                    self.write_buf_len -= n;
686                    self.can_flush = true;
687
688                    while let Some(buf) = self.write_buf.front_mut() {
689                        if n < buf.len() {
690                            buf.advance(n);
691                            continue 'outer;
692                        }
693
694                        n -= buf.len();
695                        self.write_buf.pop_front();
696                    }
697
698                    self.flattened_writes.advance(n);
699                }
700                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
701            }
702        }
703    }
704
705    /// Write `buf` into the writes buffer
706    ///
707    /// If `buf` is smaller than [`WRITE_FLATTEN_THRESHOLD`]
708    /// flattens it, otherwise appends it to the chunks queue.
709    ///
710    /// Empty `buf`s are a no-op.
711    fn write(&mut self, buf: impl Into<Bytes>) {
712        let buf = buf.into();
713        if buf.is_empty() {
714            return;
715        }
716
717        self.write_buf_len += buf.len();
718        if buf.len() < WRITE_FLATTEN_THRESHOLD {
719            self.flattened_writes.extend_from_slice(&buf);
720        } else {
721            if !self.flattened_writes.is_empty() {
722                let buf = self.flattened_writes.split().freeze();
723                self.write_buf.push_back(buf);
724            }
725
726            self.write_buf.push_back(buf);
727        }
728    }
729
730    /// Obtain an [`fmt::Write`]r for the small writes buffer.
731    fn small_write(&mut self) -> impl fmt::Write + '_ {
732        struct Writer<'a> {
733            this: &'a mut Connection,
734        }
735
736        impl fmt::Write for Writer<'_> {
737            fn write_str(&mut self, s: &str) -> fmt::Result {
738                self.this.write_buf_len += s.len();
739                self.this.flattened_writes.write_str(s)
740            }
741        }
742
743        Writer { this: self }
744    }
745
746    /// Flush the write buffer, sending all pending data down the current write stream.
747    ///
748    /// no-op if the write stream didn't need to be flushed.
749    pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
750        match Pin::new(&mut self.stream).poll_flush(cx) {
751            Poll::Pending => Poll::Pending,
752            Poll::Ready(Ok(())) => {
753                self.can_flush = false;
754                Poll::Ready(Ok(()))
755            }
756            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
757        }
758    }
759}
760
761#[cfg(feature = "websockets")]
762#[pin_project]
763pub(crate) struct WebSocketAdapter<T> {
764    #[pin]
765    pub(crate) inner: WebSocketStream<T>,
766    pub(crate) read_buf: BytesMut,
767}
768
769#[cfg(feature = "websockets")]
770impl<T> WebSocketAdapter<T> {
771    pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
772        Self {
773            inner,
774            read_buf: BytesMut::new(),
775        }
776    }
777}
778
779#[cfg(feature = "websockets")]
780impl<T> AsyncRead for WebSocketAdapter<T>
781where
782    T: AsyncRead + AsyncWrite + Unpin,
783{
784    fn poll_read(
785        self: Pin<&mut Self>,
786        cx: &mut Context<'_>,
787        buf: &mut ReadBuf<'_>,
788    ) -> Poll<std::io::Result<()>> {
789        let mut this = self.project();
790
791        loop {
792            // If we have data in the read buffer, let's move it to the output buffer.
793            if !this.read_buf.is_empty() {
794                let len = std::cmp::min(buf.remaining(), this.read_buf.len());
795                buf.put_slice(&this.read_buf.split_to(len));
796                return Poll::Ready(Ok(()));
797            }
798
799            match this.inner.poll_next_unpin(cx) {
800                Poll::Ready(Some(Ok(message))) => {
801                    this.read_buf.extend_from_slice(message.as_payload());
802                }
803                Poll::Ready(Some(Err(e))) => {
804                    return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
805                }
806                Poll::Ready(None) => {
807                    return Poll::Ready(Err(std::io::Error::new(
808                        std::io::ErrorKind::UnexpectedEof,
809                        "WebSocket closed",
810                    )));
811                }
812                Poll::Pending => {
813                    return Poll::Pending;
814                }
815            }
816        }
817    }
818}
819
820#[cfg(feature = "websockets")]
821impl<T> AsyncWrite for WebSocketAdapter<T>
822where
823    T: AsyncRead + AsyncWrite + Unpin,
824{
825    fn poll_write(
826        self: Pin<&mut Self>,
827        cx: &mut Context<'_>,
828        buf: &[u8],
829    ) -> Poll<std::io::Result<usize>> {
830        let mut this = self.project();
831
832        let data = buf.to_vec();
833        match this.inner.poll_ready_unpin(cx) {
834            Poll::Ready(Ok(())) => match this
835                .inner
836                .start_send_unpin(tokio_websockets::Message::binary(data))
837            {
838                Ok(()) => Poll::Ready(Ok(buf.len())),
839                Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))),
840            },
841            Poll::Ready(Err(e)) => {
842                Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
843            }
844            Poll::Pending => Poll::Pending,
845        }
846    }
847
848    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
849        self.project()
850            .inner
851            .poll_flush_unpin(cx)
852            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
853    }
854
855    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
856        self.project()
857            .inner
858            .poll_close_unpin(cx)
859            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
860    }
861}
862
863#[cfg(test)]
864mod read_op {
865    use std::sync::Arc;
866
867    use super::Connection;
868    use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
869    use tokio::io::{self, AsyncWriteExt};
870
871    #[tokio::test]
872    async fn ok() {
873        let (stream, mut server) = io::duplex(128);
874        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
875
876        server.write_all(b"+OK\r\n").await.unwrap();
877        let result = connection.read_op().await.unwrap();
878        assert_eq!(result, Some(ServerOp::Ok));
879    }
880
881    #[tokio::test]
882    async fn ping() {
883        let (stream, mut server) = io::duplex(128);
884        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
885
886        server.write_all(b"PING\r\n").await.unwrap();
887        let result = connection.read_op().await.unwrap();
888        assert_eq!(result, Some(ServerOp::Ping));
889    }
890
891    #[tokio::test]
892    async fn pong() {
893        let (stream, mut server) = io::duplex(128);
894        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
895
896        server.write_all(b"PONG\r\n").await.unwrap();
897        let result = connection.read_op().await.unwrap();
898        assert_eq!(result, Some(ServerOp::Pong));
899    }
900
901    #[tokio::test]
902    async fn info() {
903        let (stream, mut server) = io::duplex(128);
904        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
905
906        server.write_all(b"INFO {}\r\n").await.unwrap();
907        server.flush().await.unwrap();
908
909        let result = connection.read_op().await.unwrap();
910        assert_eq!(result, Some(ServerOp::Info(Box::default())));
911
912        server
913            .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
914            .await
915            .unwrap();
916        server.flush().await.unwrap();
917
918        let result = connection.read_op().await.unwrap();
919        assert_eq!(
920            result,
921            Some(ServerOp::Info(Box::new(ServerInfo {
922                version: "1.0.0".into(),
923                ..Default::default()
924            })))
925        );
926    }
927
928    #[tokio::test]
929    async fn error() {
930        let (stream, mut server) = io::duplex(128);
931        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
932
933        server.write_all(b"INFO {}\r\n").await.unwrap();
934        let result = connection.read_op().await.unwrap();
935        assert_eq!(result, Some(ServerOp::Info(Box::default())));
936
937        server
938            .write_all(b"-ERR something went wrong\r\n")
939            .await
940            .unwrap();
941        let result = connection.read_op().await.unwrap();
942        assert_eq!(
943            result,
944            Some(ServerOp::Error(ServerError::Other(
945                "something went wrong".into()
946            )))
947        );
948    }
949
950    #[tokio::test]
951    async fn message() {
952        let (stream, mut server) = io::duplex(128);
953        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
954
955        server
956            .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
957            .await
958            .unwrap();
959
960        let result = connection.read_op().await.unwrap();
961        assert_eq!(
962            result,
963            Some(ServerOp::Message {
964                sid: 9,
965                subject: "FOO.BAR".into(),
966                reply: None,
967                headers: None,
968                payload: "Hello World".into(),
969                status: None,
970                description: None,
971                length: 7 + 11,
972            })
973        );
974
975        server
976            .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
977            .await
978            .unwrap();
979
980        let result = connection.read_op().await.unwrap();
981        assert_eq!(
982            result,
983            Some(ServerOp::Message {
984                sid: 9,
985                subject: "FOO.BAR".into(),
986                reply: Some("INBOX.34".into()),
987                headers: None,
988                payload: "Hello World".into(),
989                status: None,
990                description: None,
991                length: 7 + 8 + 11,
992            })
993        );
994
995        server
996            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
997            .await
998            .unwrap();
999        server.write_all(b"NATS/1.0\r\n").await.unwrap();
1000        server.write_all(b"Header: X\r\n").await.unwrap();
1001        server.write_all(b"\r\n").await.unwrap();
1002        server.write_all(b"Hello World\r\n").await.unwrap();
1003
1004        let result = connection.read_op().await.unwrap();
1005
1006        assert_eq!(
1007            result,
1008            Some(ServerOp::Message {
1009                sid: 10,
1010                subject: "FOO.BAR".into(),
1011                reply: Some("INBOX.35".into()),
1012                headers: Some(HeaderMap::from_iter([(
1013                    "Header".parse().unwrap(),
1014                    "X".parse().unwrap()
1015                )])),
1016                payload: "Hello World".into(),
1017                status: None,
1018                description: None,
1019                length: 7 + 8 + 34
1020            })
1021        );
1022
1023        server
1024            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1025            .await
1026            .unwrap();
1027        server.write_all(b"NATS/1.0\r\n").await.unwrap();
1028        server.write_all(b"Header: Y\r\n").await.unwrap();
1029        server.write_all(b"\r\n").await.unwrap();
1030        server.write_all(b"Hello World\r\n").await.unwrap();
1031
1032        let result = connection.read_op().await.unwrap();
1033        assert_eq!(
1034            result,
1035            Some(ServerOp::Message {
1036                sid: 10,
1037                subject: "FOO.BAR".into(),
1038                reply: Some("INBOX.35".into()),
1039                headers: Some(HeaderMap::from_iter([(
1040                    "Header".parse().unwrap(),
1041                    "Y".parse().unwrap()
1042                )])),
1043                payload: "Hello World".into(),
1044                status: None,
1045                description: None,
1046                length: 7 + 8 + 34,
1047            })
1048        );
1049
1050        server
1051            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1052            .await
1053            .unwrap();
1054        server
1055            .write_all(b"NATS/1.0 404 No Messages\r\n")
1056            .await
1057            .unwrap();
1058        server.write_all(b"\r\n").await.unwrap();
1059        server.write_all(b"\r\n").await.unwrap();
1060
1061        let result = connection.read_op().await.unwrap();
1062        assert_eq!(
1063            result,
1064            Some(ServerOp::Message {
1065                sid: 10,
1066                subject: "FOO.BAR".into(),
1067                reply: Some("INBOX.35".into()),
1068                headers: Some(HeaderMap::default()),
1069                payload: "".into(),
1070                status: Some(StatusCode::NOT_FOUND),
1071                description: Some("No Messages".to_string()),
1072                length: 7 + 8 + 28,
1073            })
1074        );
1075
1076        server
1077            .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1078            .await
1079            .unwrap();
1080
1081        let result = connection.read_op().await.unwrap();
1082        assert_eq!(
1083            result,
1084            Some(ServerOp::Message {
1085                sid: 9,
1086                subject: "FOO.BAR".into(),
1087                reply: None,
1088                headers: None,
1089                payload: "Hello Again".into(),
1090                status: None,
1091                description: None,
1092                length: 7 + 11,
1093            })
1094        );
1095    }
1096
1097    #[tokio::test]
1098    async fn unknown() {
1099        let (stream, mut server) = io::duplex(128);
1100        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1101
1102        server.write_all(b"ONE\r\n").await.unwrap();
1103        connection.read_op().await.unwrap_err();
1104
1105        server.write_all(b"TWO\r\n").await.unwrap();
1106        connection.read_op().await.unwrap_err();
1107
1108        server.write_all(b"PING\r\n").await.unwrap();
1109        connection.read_op().await.unwrap();
1110
1111        server.write_all(b"THREE\r\n").await.unwrap();
1112        connection.read_op().await.unwrap_err();
1113
1114        server
1115            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1116            .await
1117            .unwrap();
1118        server
1119            .write_all(b"NATS/1.0 404 No Messages\r\n")
1120            .await
1121            .unwrap();
1122        server.write_all(b"\r\n").await.unwrap();
1123        server.write_all(b"\r\n").await.unwrap();
1124
1125        let result = connection.read_op().await.unwrap();
1126        assert_eq!(
1127            result,
1128            Some(ServerOp::Message {
1129                sid: 10,
1130                subject: "FOO.BAR".into(),
1131                reply: Some("INBOX.35".into()),
1132                headers: Some(HeaderMap::default()),
1133                payload: "".into(),
1134                status: Some(StatusCode::NOT_FOUND),
1135                description: Some("No Messages".to_string()),
1136                length: 7 + 8 + 28,
1137            })
1138        );
1139
1140        server.write_all(b"FOUR\r\n").await.unwrap();
1141        connection.read_op().await.unwrap_err();
1142
1143        server.write_all(b"PONG\r\n").await.unwrap();
1144        connection.read_op().await.unwrap();
1145    }
1146}
1147
1148#[cfg(test)]
1149mod write_op {
1150    use std::sync::Arc;
1151
1152    use super::Connection;
1153    use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1154    use tokio::io::{self, AsyncBufReadExt, BufReader};
1155
1156    #[tokio::test]
1157    async fn publish() {
1158        let (stream, server) = io::duplex(128);
1159        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1160
1161        connection
1162            .easy_write_and_flush(
1163                [ClientOp::Publish {
1164                    subject: "FOO.BAR".into(),
1165                    payload: "Hello World".into(),
1166                    respond: None,
1167                    headers: None,
1168                }]
1169                .iter(),
1170            )
1171            .await
1172            .unwrap();
1173
1174        let mut buffer = String::new();
1175        let mut reader = BufReader::new(server);
1176        reader.read_line(&mut buffer).await.unwrap();
1177        reader.read_line(&mut buffer).await.unwrap();
1178        assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1179
1180        connection
1181            .easy_write_and_flush(
1182                [ClientOp::Publish {
1183                    subject: "FOO.BAR".into(),
1184                    payload: "Hello World".into(),
1185                    respond: Some("INBOX.67".into()),
1186                    headers: None,
1187                }]
1188                .iter(),
1189            )
1190            .await
1191            .unwrap();
1192
1193        buffer.clear();
1194        reader.read_line(&mut buffer).await.unwrap();
1195        reader.read_line(&mut buffer).await.unwrap();
1196        assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1197
1198        connection
1199            .easy_write_and_flush(
1200                [ClientOp::Publish {
1201                    subject: "FOO.BAR".into(),
1202                    payload: "Hello World".into(),
1203                    respond: Some("INBOX.67".into()),
1204                    headers: Some(HeaderMap::from_iter([(
1205                        "Header".parse().unwrap(),
1206                        "X".parse().unwrap(),
1207                    )])),
1208                }]
1209                .iter(),
1210            )
1211            .await
1212            .unwrap();
1213
1214        buffer.clear();
1215        reader.read_line(&mut buffer).await.unwrap();
1216        reader.read_line(&mut buffer).await.unwrap();
1217        reader.read_line(&mut buffer).await.unwrap();
1218        reader.read_line(&mut buffer).await.unwrap();
1219        assert_eq!(
1220            buffer,
1221            "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1222        );
1223    }
1224
1225    #[tokio::test]
1226    async fn subscribe() {
1227        let (stream, server) = io::duplex(128);
1228        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1229
1230        connection
1231            .easy_write_and_flush(
1232                [ClientOp::Subscribe {
1233                    sid: 11,
1234                    subject: "FOO.BAR".into(),
1235                    queue_group: None,
1236                }]
1237                .iter(),
1238            )
1239            .await
1240            .unwrap();
1241
1242        let mut buffer = String::new();
1243        let mut reader = BufReader::new(server);
1244        reader.read_line(&mut buffer).await.unwrap();
1245        assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1246
1247        connection
1248            .easy_write_and_flush(
1249                [ClientOp::Subscribe {
1250                    sid: 11,
1251                    subject: "FOO.BAR".into(),
1252                    queue_group: Some("QUEUE.GROUP".into()),
1253                }]
1254                .iter(),
1255            )
1256            .await
1257            .unwrap();
1258
1259        buffer.clear();
1260        reader.read_line(&mut buffer).await.unwrap();
1261        assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1262    }
1263
1264    #[tokio::test]
1265    async fn unsubscribe() {
1266        let (stream, server) = io::duplex(128);
1267        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1268
1269        connection
1270            .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1271            .await
1272            .unwrap();
1273
1274        let mut buffer = String::new();
1275        let mut reader = BufReader::new(server);
1276        reader.read_line(&mut buffer).await.unwrap();
1277        assert_eq!(buffer, "UNSUB 11\r\n");
1278
1279        connection
1280            .easy_write_and_flush(
1281                [ClientOp::Unsubscribe {
1282                    sid: 11,
1283                    max: Some(2),
1284                }]
1285                .iter(),
1286            )
1287            .await
1288            .unwrap();
1289
1290        buffer.clear();
1291        reader.read_line(&mut buffer).await.unwrap();
1292        assert_eq!(buffer, "UNSUB 11 2\r\n");
1293    }
1294
1295    #[tokio::test]
1296    async fn ping() {
1297        let (stream, server) = io::duplex(128);
1298        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1299
1300        let mut reader = BufReader::new(server);
1301        let mut buffer = String::new();
1302
1303        connection
1304            .easy_write_and_flush([ClientOp::Ping].iter())
1305            .await
1306            .unwrap();
1307
1308        reader.read_line(&mut buffer).await.unwrap();
1309
1310        assert_eq!(buffer, "PING\r\n");
1311    }
1312
1313    #[tokio::test]
1314    async fn pong() {
1315        let (stream, server) = io::duplex(128);
1316        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1317
1318        let mut reader = BufReader::new(server);
1319        let mut buffer = String::new();
1320
1321        connection
1322            .easy_write_and_flush([ClientOp::Pong].iter())
1323            .await
1324            .unwrap();
1325
1326        reader.read_line(&mut buffer).await.unwrap();
1327
1328        assert_eq!(buffer, "PONG\r\n");
1329    }
1330
1331    #[tokio::test]
1332    async fn connect() {
1333        let (stream, server) = io::duplex(1024);
1334        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1335
1336        let mut reader = BufReader::new(server);
1337        let mut buffer = String::new();
1338
1339        connection
1340            .easy_write_and_flush(
1341                [ClientOp::Connect(ConnectInfo {
1342                    verbose: false,
1343                    pedantic: false,
1344                    user_jwt: None,
1345                    nkey: None,
1346                    signature: None,
1347                    name: None,
1348                    echo: false,
1349                    lang: "Rust".into(),
1350                    version: "1.0.0".into(),
1351                    protocol: Protocol::Dynamic,
1352                    tls_required: false,
1353                    user: None,
1354                    pass: None,
1355                    auth_token: None,
1356                    headers: false,
1357                    no_responders: false,
1358                })]
1359                .iter(),
1360            )
1361            .await
1362            .unwrap();
1363
1364        reader.read_line(&mut buffer).await.unwrap();
1365        assert_eq!(
1366            buffer,
1367            "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1368        );
1369    }
1370}