Skip to main content

async_nats/
connection.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! This module provides a connection implementation for communicating with a NATS server.
15
16use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28    futures_util::{SinkExt, StreamExt},
29    pin_project::pin_project,
30    tokio::io::ReadBuf,
31    tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36use tracing::trace;
37
38use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
39use crate::status::StatusCode;
40use crate::subject::Subject;
41use crate::{ClientOp, ServerError, ServerOp, Statistics};
42
43/// Soft limit for the amount of bytes in [`Connection::write_buf`]
44/// and [`Connection::flattened_writes`].
45const SOFT_WRITE_BUF_LIMIT: usize = 65535;
46/// How big a single buffer must be before it's written separately
47/// instead of being flattened.
48const WRITE_FLATTEN_THRESHOLD: usize = 4096;
49/// How many buffers to write in a single vectored write call.
50const WRITE_VECTORED_CHUNKS: usize = 64;
51
52/// Supertrait enabling trait object for containing both TLS and non TLS `TcpStream` connection.
53pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
54
55/// Blanked implementation that applies to both TLS and non-TLS `TcpStream`.
56impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
57
58/// An enum representing the state of the connection.
59#[derive(Debug, Eq, PartialEq, Clone)]
60pub enum State {
61    Pending,
62    Connected,
63    Disconnected,
64}
65
66#[derive(Debug, Eq, PartialEq, Clone)]
67pub enum ShouldFlush {
68    /// Write buffers are empty, but the connection hasn't been flushed yet
69    Yes,
70    /// The connection hasn't been flushed yet, but write buffers aren't empty
71    May,
72    /// Flushing would just be a no-op
73    No,
74}
75
76impl Display for State {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            State::Pending => write!(f, "pending"),
80            State::Connected => write!(f, "connected"),
81            State::Disconnected => write!(f, "disconnected"),
82        }
83    }
84}
85
86/// A framed connection
87pub(crate) struct Connection {
88    pub(crate) stream: Box<dyn AsyncReadWrite>,
89    read_buf: BytesMut,
90    write_buf: VecDeque<Bytes>,
91    write_buf_len: usize,
92    flattened_writes: BytesMut,
93    can_flush: bool,
94    statistics: Arc<Statistics>,
95}
96
97/// Internal representation of the connection.
98/// Holds connection with NATS Server and communicates with `Client` via channels.
99impl Connection {
100    pub(crate) fn new(
101        stream: Box<dyn AsyncReadWrite>,
102        read_buffer_capacity: usize,
103        statistics: Arc<Statistics>,
104    ) -> Self {
105        Self {
106            stream,
107            read_buf: BytesMut::with_capacity(read_buffer_capacity),
108            write_buf: VecDeque::new(),
109            write_buf_len: 0,
110            flattened_writes: BytesMut::new(),
111            can_flush: false,
112            statistics,
113        }
114    }
115
116    /// Returns `true` if no more calls to [`Self::enqueue_write_op`] _should_ be made.
117    pub(crate) fn is_write_buf_full(&self) -> bool {
118        self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
119    }
120
121    /// Returns `true` if [`Self::poll_flush`] should be polled.
122    pub(crate) fn should_flush(&self) -> ShouldFlush {
123        match (
124            self.can_flush,
125            self.write_buf.is_empty() && self.flattened_writes.is_empty(),
126        ) {
127            (true, true) => ShouldFlush::Yes,
128            (true, false) => ShouldFlush::May,
129            (false, _) => ShouldFlush::No,
130        }
131    }
132
133    /// Attempts to read a server operation from the read buffer.
134    /// Returns `None` if there is not enough data to parse an entire operation.
135    pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
136        let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
137            Some(len) => len,
138            None => return Ok(None),
139        };
140
141        if self.read_buf.starts_with(b"+OK") {
142            self.read_buf.advance(len + 2);
143            trace!("read operation: OK");
144            return Ok(Some(ServerOp::Ok));
145        }
146
147        if self.read_buf.starts_with(b"PING") {
148            self.read_buf.advance(len + 2);
149            trace!("read operation: PING");
150            return Ok(Some(ServerOp::Ping));
151        }
152
153        if self.read_buf.starts_with(b"PONG") {
154            self.read_buf.advance(len + 2);
155            trace!("read operation: PONG");
156            return Ok(Some(ServerOp::Pong));
157        }
158
159        if self.read_buf.starts_with(b"-ERR") {
160            let description = str::from_utf8(&self.read_buf[5..len])
161                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
162                .trim_matches('\'')
163                .to_owned();
164
165            self.read_buf.advance(len + 2);
166            trace!(error = %description, "read operation: ERR");
167            return Ok(Some(ServerOp::Error(ServerError::new(description))));
168        }
169
170        if self.read_buf.starts_with(b"INFO ") {
171            let info = serde_json::from_slice(&self.read_buf[4..len])
172                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
173
174            self.read_buf.advance(len + 2);
175            trace!(?info, "read operation: INFO");
176            return Ok(Some(ServerOp::Info(Box::new(info))));
177        }
178
179        if self.read_buf.starts_with(b"MSG ") {
180            let line = str::from_utf8(&self.read_buf[4..len])
181                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
182            let mut args = line.split(' ').filter(|s| !s.is_empty());
183
184            // Parse the operation syntax: MSG <subject> <sid> [reply-to] <#bytes>
185            let (subject, sid, reply_to, payload_len) = match (
186                args.next(),
187                args.next(),
188                args.next(),
189                args.next(),
190                args.next(),
191            ) {
192                (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
193                    (subject, sid, Some(reply_to), payload_len)
194                }
195                (Some(subject), Some(sid), Some(payload_len), None, None) => {
196                    (subject, sid, None, payload_len)
197                }
198                _ => {
199                    return Err(io::Error::new(
200                        io::ErrorKind::InvalidInput,
201                        "invalid number of arguments after MSG",
202                    ))
203                }
204            };
205
206            let sid = sid
207                .parse::<u64>()
208                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
209
210            // Parse the number of payload bytes.
211            let payload_len = payload_len
212                .parse::<usize>()
213                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
214
215            // Return early without advancing if there is not enough data read the entire
216            // message
217            if len + payload_len + 4 > self.read_buf.remaining() {
218                return Ok(None);
219            }
220
221            let length = payload_len
222                + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
223                + subject.len();
224
225            let subject = Subject::from(subject);
226            let reply = reply_to.map(Subject::from);
227
228            self.read_buf.advance(len + 2);
229            let payload = self.read_buf.split_to(payload_len).freeze();
230            self.read_buf.advance(2);
231
232            trace!(
233                subject = %subject,
234                sid = %sid,
235                reply = ?reply,
236                payload_len = %payload_len,
237                "read operation: MSG"
238            );
239
240            return Ok(Some(ServerOp::Message {
241                sid,
242                length,
243                reply,
244                headers: None,
245                subject,
246                payload,
247                status: None,
248                description: None,
249            }));
250        }
251
252        if self.read_buf.starts_with(b"HMSG ") {
253            // Extract whitespace-delimited arguments that come after "HMSG".
254            let line = std::str::from_utf8(&self.read_buf[5..len])
255                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
256            let mut args = line.split_whitespace().filter(|s| !s.is_empty());
257
258            // <subject> <sid> [reply-to] <# header bytes><# total bytes>
259            let (subject, sid, reply_to, header_len, total_len) = match (
260                args.next(),
261                args.next(),
262                args.next(),
263                args.next(),
264                args.next(),
265                args.next(),
266            ) {
267                (
268                    Some(subject),
269                    Some(sid),
270                    Some(reply_to),
271                    Some(header_len),
272                    Some(total_len),
273                    None,
274                ) => (subject, sid, Some(reply_to), header_len, total_len),
275                (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
276                    (subject, sid, None, header_len, total_len)
277                }
278                _ => {
279                    return Err(io::Error::new(
280                        io::ErrorKind::InvalidInput,
281                        "invalid number of arguments after HMSG",
282                    ))
283                }
284            };
285
286            // Convert the slice into a subject
287            let subject = Subject::from(subject);
288
289            // Parse the subject ID.
290            let sid = sid.parse::<u64>().map_err(|_| {
291                io::Error::new(
292                    io::ErrorKind::InvalidInput,
293                    "cannot parse sid argument after HMSG",
294                )
295            })?;
296
297            // Convert the slice into a subject.
298            let reply = reply_to.map(Subject::from);
299
300            // Parse the number of payload bytes.
301            let header_len = header_len.parse::<usize>().map_err(|_| {
302                io::Error::new(
303                    io::ErrorKind::InvalidInput,
304                    "cannot parse the number of header bytes argument after \
305                     HMSG",
306                )
307            })?;
308
309            // Parse the number of payload bytes.
310            let total_len = total_len.parse::<usize>().map_err(|_| {
311                io::Error::new(
312                    io::ErrorKind::InvalidInput,
313                    "cannot parse the number of bytes argument after HMSG",
314                )
315            })?;
316
317            if total_len < header_len {
318                return Err(io::Error::new(
319                    io::ErrorKind::InvalidInput,
320                    "number of header bytes was greater than or equal to the \
321                 total number of bytes after HMSG",
322                ));
323            }
324
325            if len + total_len + 4 > self.read_buf.remaining() {
326                return Ok(None);
327            }
328
329            self.read_buf.advance(len + 2);
330            let header = self.read_buf.split_to(header_len);
331            let payload = self.read_buf.split_to(total_len - header_len).freeze();
332            self.read_buf.advance(2);
333
334            let mut lines = std::str::from_utf8(&header)
335                .map_err(|_| {
336                    io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
337                })?
338                .lines()
339                .peekable();
340            let version_line = lines.next().ok_or_else(|| {
341                io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
342            })?;
343
344            let version_line_suffix = version_line
345                .strip_prefix("NATS/1.0")
346                .map(str::trim)
347                .ok_or_else(|| {
348                    io::Error::new(
349                        io::ErrorKind::InvalidInput,
350                        "header version line does not begin with `NATS/1.0`",
351                    )
352                })?;
353
354            let (status, description) = version_line_suffix
355                .split_once(' ')
356                .map(|(status, description)| (status.trim(), description.trim()))
357                .unwrap_or((version_line_suffix, ""));
358            let status = if !status.is_empty() {
359                Some(
360                    status
361                        .parse::<StatusCode>()
362                        .map_err(|_| std::io::Error::other("could not parse status parameter"))?,
363                )
364            } else {
365                None
366            };
367            let description = if !description.is_empty() {
368                Some(description.to_owned())
369            } else {
370                None
371            };
372
373            let mut headers = HeaderMap::new();
374            while let Some(line) = lines.next() {
375                if line.is_empty() {
376                    continue;
377                }
378
379                let (name, value) = line.split_once(':').ok_or_else(|| {
380                    io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
381                })?;
382
383                let name = HeaderName::from_str(name)
384                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
385
386                // Read the header value, which might have been split into multiple lines
387                // `trim_start` and `trim_end` do the same job as doing `value.trim().to_owned()` at the end, but without a reallocation
388                let mut value = value.trim_start().to_owned();
389                while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
390                    value.push_str(v);
391                }
392                value.truncate(value.trim_end().len());
393
394                headers.append(name, value.into_header_value());
395            }
396
397            trace!(
398                subject = %subject,
399                sid = %sid,
400                reply = ?reply,
401                header_len = %header_len,
402                total_len = %total_len,
403                status = ?status,
404                description = ?description,
405                "read operation: HMSG"
406            );
407
408            return Ok(Some(ServerOp::Message {
409                length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
410                sid,
411                reply,
412                subject,
413                headers: Some(headers),
414                payload,
415                status,
416                description,
417            }));
418        }
419
420        let buffer = self.read_buf.split_to(len + 2);
421        let line = str::from_utf8(&buffer).map_err(|_| {
422            io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
423        })?;
424
425        trace!(line = %line, "read operation: unknown");
426        Err(io::Error::new(
427            io::ErrorKind::InvalidInput,
428            format!("invalid server operation: '{line}'"),
429        ))
430    }
431
432    pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
433        future::poll_fn(|cx| self.poll_read_op(cx))
434    }
435
436    // TODO: do we want an custom error here?
437    /// Read a server operation from read buffer.
438    /// Blocks until an operation ca be parsed.
439    pub(crate) fn poll_read_op(
440        &mut self,
441        cx: &mut Context<'_>,
442    ) -> Poll<io::Result<Option<ServerOp>>> {
443        loop {
444            if let Some(op) = self.try_read_op()? {
445                trace!(?op, "read operation completed");
446                return Poll::Ready(Ok(Some(op)));
447            }
448
449            let read_buf = self.stream.read_buf(&mut self.read_buf);
450            tokio::pin!(read_buf);
451            return match read_buf.poll(cx) {
452                Poll::Pending => {
453                    trace!("read operation pending");
454                    Poll::Pending
455                }
456                Poll::Ready(Ok(0)) if self.read_buf.is_empty() => {
457                    trace!("read operation: empty buffer");
458                    Poll::Ready(Ok(None))
459                }
460                Poll::Ready(Ok(0)) => {
461                    trace!("read operation: connection reset");
462                    Poll::Ready(Err(io::ErrorKind::ConnectionReset.into()))
463                }
464                Poll::Ready(Ok(n)) => {
465                    self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
466                    trace!(bytes = %n, "read operation: received bytes");
467                    continue;
468                }
469                Poll::Ready(Err(err)) => {
470                    trace!(error = %err, "read operation: error");
471                    Poll::Ready(Err(err))
472                }
473            };
474        }
475    }
476
477    pub(crate) async fn easy_write_and_flush<'a>(
478        &mut self,
479        items: impl Iterator<Item = &'a ClientOp>,
480    ) -> io::Result<()> {
481        for item in items {
482            self.enqueue_write_op(item);
483        }
484
485        future::poll_fn(|cx| self.poll_write(cx)).await?;
486        future::poll_fn(|cx| self.poll_flush(cx)).await?;
487        Ok(())
488    }
489
490    /// Writes a client operation to the write buffer.
491    pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
492        macro_rules! small_write {
493            ($dst:expr) => {
494                write!(self.small_write(), $dst).expect("do small write to Connection");
495            };
496        }
497
498        match item {
499            ClientOp::Connect(connect_info) => {
500                let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
501
502                self.write("CONNECT ");
503                self.write(json);
504                self.write("\r\n");
505                trace!(?connect_info, "write operation: CONNECT");
506            }
507            ClientOp::Publish {
508                subject,
509                payload,
510                respond,
511                headers,
512            } => {
513                let verb = match headers.as_ref() {
514                    Some(headers) if !headers.is_empty() => "HPUB",
515                    _ => "PUB",
516                };
517
518                small_write!("{verb} {subject} ");
519
520                if let Some(respond) = respond {
521                    small_write!("{respond} ");
522                }
523
524                match headers {
525                    Some(headers) if !headers.is_empty() => {
526                        let headers = headers.to_bytes();
527
528                        let headers_len = headers.len();
529                        let total_len = headers_len + payload.len();
530                        small_write!("{headers_len} {total_len}\r\n");
531                        self.write(headers);
532                    }
533                    _ => {
534                        let payload_len = payload.len();
535                        small_write!("{payload_len}\r\n");
536                    }
537                }
538
539                self.write(Bytes::clone(payload));
540                self.write("\r\n");
541
542                trace!(
543                    verb = %verb,
544                    subject = %subject,
545                    reply = ?respond,
546                    headers = ?headers,
547                    payload_len = %payload.len(),
548                    "write operation: PUB"
549                );
550            }
551
552            ClientOp::Subscribe {
553                sid,
554                subject,
555                queue_group,
556            } => {
557                match queue_group {
558                    Some(queue_group) => {
559                        small_write!("SUB {subject} {queue_group} {sid}\r\n");
560                    }
561                    None => {
562                        small_write!("SUB {subject} {sid}\r\n");
563                    }
564                }
565
566                trace!(
567                    subject = %subject,
568                    sid = %sid,
569                    queue_group = ?queue_group,
570                    "write operation: SUB"
571                );
572            }
573
574            ClientOp::Unsubscribe { sid, max } => {
575                match max {
576                    Some(max) => {
577                        small_write!("UNSUB {sid} {max}\r\n");
578                    }
579                    None => {
580                        small_write!("UNSUB {sid}\r\n");
581                    }
582                }
583
584                trace!(
585                    sid = %sid,
586                    max = ?max,
587                    "write operation: UNSUB"
588                );
589            }
590            ClientOp::Ping => {
591                self.write("PING\r\n");
592                trace!("write operation: PING");
593            }
594            ClientOp::Pong => {
595                self.write("PONG\r\n");
596                trace!("write operation: PONG");
597            }
598        }
599    }
600
601    /// Write the internal buffers into the write stream
602    ///
603    /// Returns one of the following:
604    ///
605    /// * `Poll::Pending` means that we weren't able to fully empty
606    ///   the internal buffers. Compared to [`AsyncWrite::poll_write`],
607    ///   this implementation may do a partial write before yielding.
608    /// * `Poll::Ready(Ok())` means that the internal write buffers have
609    ///   been emptied or were already empty.
610    /// * `Poll::Ready(Err(err))` means that writing to the stream failed.
611    ///   Compared to [`AsyncWrite::poll_write`], this implementation
612    ///   may do a partial write before failing.
613    pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
614        if !self.stream.is_write_vectored() {
615            self.poll_write_sequential(cx)
616        } else {
617            self.poll_write_vectored(cx)
618        }
619    }
620
621    /// Write the internal buffers into the write stream using sequential write operations
622    ///
623    /// Writes one chunk at a time. Less efficient.
624    fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
625        loop {
626            let buf = match self.write_buf.front() {
627                Some(buf) => &**buf,
628                None if !self.flattened_writes.is_empty() => &self.flattened_writes,
629                None => return Poll::Ready(Ok(())),
630            };
631
632            debug_assert!(!buf.is_empty());
633
634            match Pin::new(&mut self.stream).poll_write(cx, buf) {
635                Poll::Pending => return Poll::Pending,
636                Poll::Ready(Ok(n)) => {
637                    self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
638                    self.write_buf_len -= n;
639                    self.can_flush = true;
640
641                    match self.write_buf.front_mut() {
642                        Some(buf) if n < buf.len() => {
643                            buf.advance(n);
644                        }
645                        Some(_buf) => {
646                            self.write_buf.pop_front();
647                        }
648                        None => {
649                            self.flattened_writes.advance(n);
650                        }
651                    }
652                    continue;
653                }
654                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
655            }
656        }
657    }
658    /// Write the internal buffers into the write stream using vectored write operations
659    ///
660    /// Writes [`WRITE_VECTORED_CHUNKS`] at a time. More efficient _if_
661    /// the underlying writer supports it.
662    fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
663        'outer: loop {
664            let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
665            let mut writes_len = 0;
666
667            self.write_buf
668                .iter()
669                .take(WRITE_VECTORED_CHUNKS)
670                .enumerate()
671                .for_each(|(i, buf)| {
672                    writes[i] = IoSlice::new(buf);
673                    writes_len += 1;
674                });
675
676            if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
677                writes[writes_len] = IoSlice::new(&self.flattened_writes);
678                writes_len += 1;
679            }
680
681            if writes_len == 0 {
682                return Poll::Ready(Ok(()));
683            }
684
685            match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
686                Poll::Pending => return Poll::Pending,
687                Poll::Ready(Ok(mut n)) => {
688                    self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
689                    self.write_buf_len -= n;
690                    self.can_flush = true;
691
692                    while let Some(buf) = self.write_buf.front_mut() {
693                        if n < buf.len() {
694                            buf.advance(n);
695                            continue 'outer;
696                        }
697
698                        n -= buf.len();
699                        self.write_buf.pop_front();
700                    }
701
702                    self.flattened_writes.advance(n);
703                }
704                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
705            }
706        }
707    }
708
709    /// Write `buf` into the writes buffer
710    ///
711    /// If `buf` is smaller than [`WRITE_FLATTEN_THRESHOLD`]
712    /// flattens it, otherwise appends it to the chunks queue.
713    ///
714    /// Empty `buf`s are a no-op.
715    fn write(&mut self, buf: impl Into<Bytes>) {
716        let buf = buf.into();
717        if buf.is_empty() {
718            return;
719        }
720
721        self.write_buf_len += buf.len();
722        if buf.len() < WRITE_FLATTEN_THRESHOLD {
723            self.flattened_writes.extend_from_slice(&buf);
724        } else {
725            if !self.flattened_writes.is_empty() {
726                let buf = self.flattened_writes.split().freeze();
727                self.write_buf.push_back(buf);
728            }
729
730            self.write_buf.push_back(buf);
731        }
732    }
733
734    /// Obtain an [`fmt::Write`]r for the small writes buffer.
735    fn small_write(&mut self) -> impl fmt::Write + '_ {
736        struct Writer<'a> {
737            this: &'a mut Connection,
738        }
739
740        impl fmt::Write for Writer<'_> {
741            fn write_str(&mut self, s: &str) -> fmt::Result {
742                self.this.write_buf_len += s.len();
743                self.this.flattened_writes.write_str(s)
744            }
745        }
746
747        Writer { this: self }
748    }
749
750    /// Flush the write buffer, sending all pending data down the current write stream.
751    ///
752    /// no-op if the write stream didn't need to be flushed.
753    pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
754        match Pin::new(&mut self.stream).poll_flush(cx) {
755            Poll::Pending => Poll::Pending,
756            Poll::Ready(Ok(())) => {
757                self.can_flush = false;
758                Poll::Ready(Ok(()))
759            }
760            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
761        }
762    }
763}
764
765#[cfg(feature = "websockets")]
766#[pin_project]
767pub(crate) struct WebSocketAdapter<T> {
768    #[pin]
769    pub(crate) inner: WebSocketStream<T>,
770    pub(crate) read_buf: BytesMut,
771}
772
773#[cfg(feature = "websockets")]
774impl<T> WebSocketAdapter<T> {
775    pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
776        Self {
777            inner,
778            read_buf: BytesMut::new(),
779        }
780    }
781}
782
783#[cfg(feature = "websockets")]
784impl<T> AsyncRead for WebSocketAdapter<T>
785where
786    T: AsyncRead + AsyncWrite + Unpin,
787{
788    fn poll_read(
789        self: Pin<&mut Self>,
790        cx: &mut Context<'_>,
791        buf: &mut ReadBuf<'_>,
792    ) -> Poll<std::io::Result<()>> {
793        let mut this = self.project();
794
795        loop {
796            // If we have data in the read buffer, let's move it to the output buffer.
797            if !this.read_buf.is_empty() {
798                let len = std::cmp::min(buf.remaining(), this.read_buf.len());
799                buf.put_slice(&this.read_buf.split_to(len));
800                return Poll::Ready(Ok(()));
801            }
802
803            match this.inner.poll_next_unpin(cx) {
804                Poll::Ready(Some(Ok(message))) => {
805                    this.read_buf.extend_from_slice(message.as_payload());
806                }
807                Poll::Ready(Some(Err(e))) => {
808                    return Poll::Ready(Err(std::io::Error::other(e)));
809                }
810                Poll::Ready(None) => {
811                    return Poll::Ready(Err(std::io::Error::new(
812                        std::io::ErrorKind::UnexpectedEof,
813                        "WebSocket closed",
814                    )));
815                }
816                Poll::Pending => {
817                    return Poll::Pending;
818                }
819            }
820        }
821    }
822}
823
824#[cfg(feature = "websockets")]
825impl<T> AsyncWrite for WebSocketAdapter<T>
826where
827    T: AsyncRead + AsyncWrite + Unpin,
828{
829    fn poll_write(
830        self: Pin<&mut Self>,
831        cx: &mut Context<'_>,
832        buf: &[u8],
833    ) -> Poll<std::io::Result<usize>> {
834        let mut this = self.project();
835
836        let data = buf.to_vec();
837        match this.inner.poll_ready_unpin(cx) {
838            Poll::Ready(Ok(())) => match this
839                .inner
840                .start_send_unpin(tokio_websockets::Message::binary(data))
841            {
842                Ok(()) => Poll::Ready(Ok(buf.len())),
843                Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
844            },
845            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
846            Poll::Pending => Poll::Pending,
847        }
848    }
849
850    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
851        self.project()
852            .inner
853            .poll_flush_unpin(cx)
854            .map_err(std::io::Error::other)
855    }
856
857    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
858        self.project()
859            .inner
860            .poll_close_unpin(cx)
861            .map_err(std::io::Error::other)
862    }
863}
864
865#[cfg(test)]
866mod read_op {
867    use std::sync::Arc;
868
869    use super::Connection;
870    use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
871    use tokio::io::{self, AsyncWriteExt};
872
873    #[tokio::test]
874    async fn ok() {
875        let (stream, mut server) = io::duplex(128);
876        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
877
878        server.write_all(b"+OK\r\n").await.unwrap();
879        let result = connection.read_op().await.unwrap();
880        assert_eq!(result, Some(ServerOp::Ok));
881    }
882
883    #[tokio::test]
884    async fn ping() {
885        let (stream, mut server) = io::duplex(128);
886        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
887
888        server.write_all(b"PING\r\n").await.unwrap();
889        let result = connection.read_op().await.unwrap();
890        assert_eq!(result, Some(ServerOp::Ping));
891    }
892
893    #[tokio::test]
894    async fn pong() {
895        let (stream, mut server) = io::duplex(128);
896        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
897
898        server.write_all(b"PONG\r\n").await.unwrap();
899        let result = connection.read_op().await.unwrap();
900        assert_eq!(result, Some(ServerOp::Pong));
901    }
902
903    #[tokio::test]
904    async fn info() {
905        let (stream, mut server) = io::duplex(128);
906        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
907
908        server.write_all(b"INFO {}\r\n").await.unwrap();
909        server.flush().await.unwrap();
910
911        let result = connection.read_op().await.unwrap();
912        assert_eq!(result, Some(ServerOp::Info(Box::default())));
913
914        server
915            .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
916            .await
917            .unwrap();
918        server.flush().await.unwrap();
919
920        let result = connection.read_op().await.unwrap();
921        assert_eq!(
922            result,
923            Some(ServerOp::Info(Box::new(ServerInfo {
924                version: "1.0.0".into(),
925                ..Default::default()
926            })))
927        );
928
929        server
930            .write_all(b"INFO { \"version\": \"1.0.0\", \"cluster\": \"test-cluster\" }\r\n")
931            .await
932            .unwrap();
933        server.flush().await.unwrap();
934
935        let result = connection.read_op().await.unwrap();
936        assert_eq!(
937            result,
938            Some(ServerOp::Info(Box::new(ServerInfo {
939                version: "1.0.0".into(),
940                cluster: Some("test-cluster".into()),
941                ..Default::default()
942            })))
943        );
944    }
945
946    #[tokio::test]
947    async fn error() {
948        let (stream, mut server) = io::duplex(128);
949        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
950
951        server.write_all(b"INFO {}\r\n").await.unwrap();
952        let result = connection.read_op().await.unwrap();
953        assert_eq!(result, Some(ServerOp::Info(Box::default())));
954
955        server
956            .write_all(b"-ERR something went wrong\r\n")
957            .await
958            .unwrap();
959        let result = connection.read_op().await.unwrap();
960        assert_eq!(
961            result,
962            Some(ServerOp::Error(ServerError::Other(
963                "something went wrong".into()
964            )))
965        );
966    }
967
968    #[tokio::test]
969    async fn message() {
970        let (stream, mut server) = io::duplex(128);
971        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
972
973        server
974            .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
975            .await
976            .unwrap();
977
978        let result = connection.read_op().await.unwrap();
979        assert_eq!(
980            result,
981            Some(ServerOp::Message {
982                sid: 9,
983                subject: "FOO.BAR".into(),
984                reply: None,
985                headers: None,
986                payload: "Hello World".into(),
987                status: None,
988                description: None,
989                length: 7 + 11,
990            })
991        );
992
993        server
994            .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
995            .await
996            .unwrap();
997
998        let result = connection.read_op().await.unwrap();
999        assert_eq!(
1000            result,
1001            Some(ServerOp::Message {
1002                sid: 9,
1003                subject: "FOO.BAR".into(),
1004                reply: Some("INBOX.34".into()),
1005                headers: None,
1006                payload: "Hello World".into(),
1007                status: None,
1008                description: None,
1009                length: 7 + 8 + 11,
1010            })
1011        );
1012
1013        server
1014            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1015            .await
1016            .unwrap();
1017        server.write_all(b"NATS/1.0\r\n").await.unwrap();
1018        server.write_all(b"Header: X\r\n").await.unwrap();
1019        server.write_all(b"\r\n").await.unwrap();
1020        server.write_all(b"Hello World\r\n").await.unwrap();
1021
1022        let result = connection.read_op().await.unwrap();
1023
1024        assert_eq!(
1025            result,
1026            Some(ServerOp::Message {
1027                sid: 10,
1028                subject: "FOO.BAR".into(),
1029                reply: Some("INBOX.35".into()),
1030                headers: Some(HeaderMap::from_iter([(
1031                    "Header".parse().unwrap(),
1032                    "X".parse().unwrap()
1033                )])),
1034                payload: "Hello World".into(),
1035                status: None,
1036                description: None,
1037                length: 7 + 8 + 34
1038            })
1039        );
1040
1041        server
1042            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1043            .await
1044            .unwrap();
1045        server.write_all(b"NATS/1.0\r\n").await.unwrap();
1046        server.write_all(b"Header: Y\r\n").await.unwrap();
1047        server.write_all(b"\r\n").await.unwrap();
1048        server.write_all(b"Hello World\r\n").await.unwrap();
1049
1050        let result = connection.read_op().await.unwrap();
1051        assert_eq!(
1052            result,
1053            Some(ServerOp::Message {
1054                sid: 10,
1055                subject: "FOO.BAR".into(),
1056                reply: Some("INBOX.35".into()),
1057                headers: Some(HeaderMap::from_iter([(
1058                    "Header".parse().unwrap(),
1059                    "Y".parse().unwrap()
1060                )])),
1061                payload: "Hello World".into(),
1062                status: None,
1063                description: None,
1064                length: 7 + 8 + 34,
1065            })
1066        );
1067
1068        server
1069            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1070            .await
1071            .unwrap();
1072        server
1073            .write_all(b"NATS/1.0 404 No Messages\r\n")
1074            .await
1075            .unwrap();
1076        server.write_all(b"\r\n").await.unwrap();
1077        server.write_all(b"\r\n").await.unwrap();
1078
1079        let result = connection.read_op().await.unwrap();
1080        assert_eq!(
1081            result,
1082            Some(ServerOp::Message {
1083                sid: 10,
1084                subject: "FOO.BAR".into(),
1085                reply: Some("INBOX.35".into()),
1086                headers: Some(HeaderMap::default()),
1087                payload: "".into(),
1088                status: Some(StatusCode::NOT_FOUND),
1089                description: Some("No Messages".to_string()),
1090                length: 7 + 8 + 28,
1091            })
1092        );
1093
1094        server
1095            .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1096            .await
1097            .unwrap();
1098
1099        let result = connection.read_op().await.unwrap();
1100        assert_eq!(
1101            result,
1102            Some(ServerOp::Message {
1103                sid: 9,
1104                subject: "FOO.BAR".into(),
1105                reply: None,
1106                headers: None,
1107                payload: "Hello Again".into(),
1108                status: None,
1109                description: None,
1110                length: 7 + 11,
1111            })
1112        );
1113    }
1114
1115    #[tokio::test]
1116    async fn unknown() {
1117        let (stream, mut server) = io::duplex(128);
1118        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1119
1120        server.write_all(b"ONE\r\n").await.unwrap();
1121        connection.read_op().await.unwrap_err();
1122
1123        server.write_all(b"TWO\r\n").await.unwrap();
1124        connection.read_op().await.unwrap_err();
1125
1126        server.write_all(b"PING\r\n").await.unwrap();
1127        connection.read_op().await.unwrap();
1128
1129        server.write_all(b"THREE\r\n").await.unwrap();
1130        connection.read_op().await.unwrap_err();
1131
1132        server
1133            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1134            .await
1135            .unwrap();
1136        server
1137            .write_all(b"NATS/1.0 404 No Messages\r\n")
1138            .await
1139            .unwrap();
1140        server.write_all(b"\r\n").await.unwrap();
1141        server.write_all(b"\r\n").await.unwrap();
1142
1143        let result = connection.read_op().await.unwrap();
1144        assert_eq!(
1145            result,
1146            Some(ServerOp::Message {
1147                sid: 10,
1148                subject: "FOO.BAR".into(),
1149                reply: Some("INBOX.35".into()),
1150                headers: Some(HeaderMap::default()),
1151                payload: "".into(),
1152                status: Some(StatusCode::NOT_FOUND),
1153                description: Some("No Messages".to_string()),
1154                length: 7 + 8 + 28,
1155            })
1156        );
1157
1158        server.write_all(b"FOUR\r\n").await.unwrap();
1159        connection.read_op().await.unwrap_err();
1160
1161        server.write_all(b"PONG\r\n").await.unwrap();
1162        connection.read_op().await.unwrap();
1163    }
1164
1165    // Regression for https://github.com/nats-io/nats.rs/issues/1572:
1166    // nats-server forwards arbitrary bytes in subjects (no UTF-8 enforcement),
1167    // so a publisher sending `balance.\xFF` reaches a `balance.*` subscriber.
1168    // The parser must surface an io::Error rather than panicking, otherwise the
1169    // ConnectionHandler task dies and the client never reconnects.
1170    #[tokio::test]
1171    async fn msg_with_non_utf8_subject_returns_error() {
1172        let (stream, mut server) = io::duplex(128);
1173        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1174
1175        server.write_all(b"MSG balance.\xFF 1 2\r\n").await.unwrap();
1176        server.write_all(b"hi\r\n").await.unwrap();
1177
1178        let err = connection.read_op().await.unwrap_err();
1179        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
1180    }
1181
1182    #[tokio::test]
1183    async fn hmsg_with_non_utf8_subject_returns_error() {
1184        let (stream, mut server) = io::duplex(128);
1185        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1186
1187        server
1188            .write_all(b"HMSG balance.\xFF 1 12 14\r\n")
1189            .await
1190            .unwrap();
1191        server.write_all(b"NATS/1.0\r\n\r\nhi\r\n").await.unwrap();
1192
1193        let err = connection.read_op().await.unwrap_err();
1194        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
1195    }
1196}
1197
1198#[cfg(test)]
1199mod write_op {
1200    use std::sync::Arc;
1201
1202    use super::Connection;
1203    use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1204    use tokio::io::{self, AsyncBufReadExt, BufReader};
1205
1206    #[tokio::test]
1207    async fn publish() {
1208        let (stream, server) = io::duplex(128);
1209        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1210
1211        connection
1212            .easy_write_and_flush(
1213                [ClientOp::Publish {
1214                    subject: "FOO.BAR".into(),
1215                    payload: "Hello World".into(),
1216                    respond: None,
1217                    headers: None,
1218                }]
1219                .iter(),
1220            )
1221            .await
1222            .unwrap();
1223
1224        let mut buffer = String::new();
1225        let mut reader = BufReader::new(server);
1226        reader.read_line(&mut buffer).await.unwrap();
1227        reader.read_line(&mut buffer).await.unwrap();
1228        assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1229
1230        connection
1231            .easy_write_and_flush(
1232                [ClientOp::Publish {
1233                    subject: "FOO.BAR".into(),
1234                    payload: "Hello World".into(),
1235                    respond: Some("INBOX.67".into()),
1236                    headers: None,
1237                }]
1238                .iter(),
1239            )
1240            .await
1241            .unwrap();
1242
1243        buffer.clear();
1244        reader.read_line(&mut buffer).await.unwrap();
1245        reader.read_line(&mut buffer).await.unwrap();
1246        assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1247
1248        connection
1249            .easy_write_and_flush(
1250                [ClientOp::Publish {
1251                    subject: "FOO.BAR".into(),
1252                    payload: "Hello World".into(),
1253                    respond: Some("INBOX.67".into()),
1254                    headers: Some(HeaderMap::from_iter([(
1255                        "Header".parse().unwrap(),
1256                        "X".parse().unwrap(),
1257                    )])),
1258                }]
1259                .iter(),
1260            )
1261            .await
1262            .unwrap();
1263
1264        buffer.clear();
1265        reader.read_line(&mut buffer).await.unwrap();
1266        reader.read_line(&mut buffer).await.unwrap();
1267        reader.read_line(&mut buffer).await.unwrap();
1268        reader.read_line(&mut buffer).await.unwrap();
1269        assert_eq!(
1270            buffer,
1271            "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1272        );
1273    }
1274
1275    #[tokio::test]
1276    async fn subscribe() {
1277        let (stream, server) = io::duplex(128);
1278        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1279
1280        connection
1281            .easy_write_and_flush(
1282                [ClientOp::Subscribe {
1283                    sid: 11,
1284                    subject: "FOO.BAR".into(),
1285                    queue_group: None,
1286                }]
1287                .iter(),
1288            )
1289            .await
1290            .unwrap();
1291
1292        let mut buffer = String::new();
1293        let mut reader = BufReader::new(server);
1294        reader.read_line(&mut buffer).await.unwrap();
1295        assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1296
1297        connection
1298            .easy_write_and_flush(
1299                [ClientOp::Subscribe {
1300                    sid: 11,
1301                    subject: "FOO.BAR".into(),
1302                    queue_group: Some("QUEUE.GROUP".into()),
1303                }]
1304                .iter(),
1305            )
1306            .await
1307            .unwrap();
1308
1309        buffer.clear();
1310        reader.read_line(&mut buffer).await.unwrap();
1311        assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1312    }
1313
1314    #[tokio::test]
1315    async fn unsubscribe() {
1316        let (stream, server) = io::duplex(128);
1317        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1318
1319        connection
1320            .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1321            .await
1322            .unwrap();
1323
1324        let mut buffer = String::new();
1325        let mut reader = BufReader::new(server);
1326        reader.read_line(&mut buffer).await.unwrap();
1327        assert_eq!(buffer, "UNSUB 11\r\n");
1328
1329        connection
1330            .easy_write_and_flush(
1331                [ClientOp::Unsubscribe {
1332                    sid: 11,
1333                    max: Some(2),
1334                }]
1335                .iter(),
1336            )
1337            .await
1338            .unwrap();
1339
1340        buffer.clear();
1341        reader.read_line(&mut buffer).await.unwrap();
1342        assert_eq!(buffer, "UNSUB 11 2\r\n");
1343    }
1344
1345    #[tokio::test]
1346    async fn ping() {
1347        let (stream, server) = io::duplex(128);
1348        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1349
1350        let mut reader = BufReader::new(server);
1351        let mut buffer = String::new();
1352
1353        connection
1354            .easy_write_and_flush([ClientOp::Ping].iter())
1355            .await
1356            .unwrap();
1357
1358        reader.read_line(&mut buffer).await.unwrap();
1359
1360        assert_eq!(buffer, "PING\r\n");
1361    }
1362
1363    #[tokio::test]
1364    async fn pong() {
1365        let (stream, server) = io::duplex(128);
1366        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1367
1368        let mut reader = BufReader::new(server);
1369        let mut buffer = String::new();
1370
1371        connection
1372            .easy_write_and_flush([ClientOp::Pong].iter())
1373            .await
1374            .unwrap();
1375
1376        reader.read_line(&mut buffer).await.unwrap();
1377
1378        assert_eq!(buffer, "PONG\r\n");
1379    }
1380
1381    #[tokio::test]
1382    async fn connect() {
1383        let (stream, server) = io::duplex(1024);
1384        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1385
1386        let mut reader = BufReader::new(server);
1387        let mut buffer = String::new();
1388
1389        connection
1390            .easy_write_and_flush(
1391                [ClientOp::Connect(ConnectInfo {
1392                    verbose: false,
1393                    pedantic: false,
1394                    user_jwt: None,
1395                    nkey: None,
1396                    signature: None,
1397                    name: None,
1398                    echo: false,
1399                    lang: "Rust".into(),
1400                    version: "1.0.0".into(),
1401                    protocol: Protocol::Dynamic,
1402                    tls_required: false,
1403                    user: None,
1404                    pass: None,
1405                    auth_token: None,
1406                    headers: false,
1407                    no_responders: false,
1408                })]
1409                .iter(),
1410            )
1411            .await
1412            .unwrap();
1413
1414        reader.read_line(&mut buffer).await.unwrap();
1415        assert_eq!(
1416            buffer,
1417            "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1418        );
1419    }
1420}