1use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28 futures_util::{SinkExt, StreamExt},
29 pin_project::pin_project,
30 tokio::io::ReadBuf,
31 tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36use tracing::trace;
37
38use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
39use crate::status::StatusCode;
40use crate::subject::Subject;
41use crate::{ClientOp, ServerError, ServerOp, Statistics};
42
43const SOFT_WRITE_BUF_LIMIT: usize = 65535;
46const WRITE_FLATTEN_THRESHOLD: usize = 4096;
49const WRITE_VECTORED_CHUNKS: usize = 64;
51
52pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
54
55impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
57
58#[derive(Debug, Eq, PartialEq, Clone)]
60pub enum State {
61 Pending,
62 Connected,
63 Disconnected,
64}
65
66#[derive(Debug, Eq, PartialEq, Clone)]
67pub enum ShouldFlush {
68 Yes,
70 May,
72 No,
74}
75
76impl Display for State {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 State::Pending => write!(f, "pending"),
80 State::Connected => write!(f, "connected"),
81 State::Disconnected => write!(f, "disconnected"),
82 }
83 }
84}
85
86pub(crate) struct Connection {
88 pub(crate) stream: Box<dyn AsyncReadWrite>,
89 read_buf: BytesMut,
90 write_buf: VecDeque<Bytes>,
91 write_buf_len: usize,
92 flattened_writes: BytesMut,
93 can_flush: bool,
94 statistics: Arc<Statistics>,
95}
96
97impl Connection {
100 pub(crate) fn new(
101 stream: Box<dyn AsyncReadWrite>,
102 read_buffer_capacity: usize,
103 statistics: Arc<Statistics>,
104 ) -> Self {
105 Self {
106 stream,
107 read_buf: BytesMut::with_capacity(read_buffer_capacity),
108 write_buf: VecDeque::new(),
109 write_buf_len: 0,
110 flattened_writes: BytesMut::new(),
111 can_flush: false,
112 statistics,
113 }
114 }
115
116 pub(crate) fn is_write_buf_full(&self) -> bool {
118 self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
119 }
120
121 pub(crate) fn should_flush(&self) -> ShouldFlush {
123 match (
124 self.can_flush,
125 self.write_buf.is_empty() && self.flattened_writes.is_empty(),
126 ) {
127 (true, true) => ShouldFlush::Yes,
128 (true, false) => ShouldFlush::May,
129 (false, _) => ShouldFlush::No,
130 }
131 }
132
133 pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
136 let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
137 Some(len) => len,
138 None => return Ok(None),
139 };
140
141 if self.read_buf.starts_with(b"+OK") {
142 self.read_buf.advance(len + 2);
143 trace!("read operation: OK");
144 return Ok(Some(ServerOp::Ok));
145 }
146
147 if self.read_buf.starts_with(b"PING") {
148 self.read_buf.advance(len + 2);
149 trace!("read operation: PING");
150 return Ok(Some(ServerOp::Ping));
151 }
152
153 if self.read_buf.starts_with(b"PONG") {
154 self.read_buf.advance(len + 2);
155 trace!("read operation: PONG");
156 return Ok(Some(ServerOp::Pong));
157 }
158
159 if self.read_buf.starts_with(b"-ERR") {
160 let description = str::from_utf8(&self.read_buf[5..len])
161 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
162 .trim_matches('\'')
163 .to_owned();
164
165 self.read_buf.advance(len + 2);
166 trace!(error = %description, "read operation: ERR");
167 return Ok(Some(ServerOp::Error(ServerError::new(description))));
168 }
169
170 if self.read_buf.starts_with(b"INFO ") {
171 let info = serde_json::from_slice(&self.read_buf[4..len])
172 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
173
174 self.read_buf.advance(len + 2);
175 trace!(?info, "read operation: INFO");
176 return Ok(Some(ServerOp::Info(Box::new(info))));
177 }
178
179 if self.read_buf.starts_with(b"MSG ") {
180 let line = str::from_utf8(&self.read_buf[4..len])
181 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
182 let mut args = line.split(' ').filter(|s| !s.is_empty());
183
184 let (subject, sid, reply_to, payload_len) = match (
186 args.next(),
187 args.next(),
188 args.next(),
189 args.next(),
190 args.next(),
191 ) {
192 (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
193 (subject, sid, Some(reply_to), payload_len)
194 }
195 (Some(subject), Some(sid), Some(payload_len), None, None) => {
196 (subject, sid, None, payload_len)
197 }
198 _ => {
199 return Err(io::Error::new(
200 io::ErrorKind::InvalidInput,
201 "invalid number of arguments after MSG",
202 ))
203 }
204 };
205
206 let sid = sid
207 .parse::<u64>()
208 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
209
210 let payload_len = payload_len
212 .parse::<usize>()
213 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
214
215 if len + payload_len + 4 > self.read_buf.remaining() {
218 return Ok(None);
219 }
220
221 let length = payload_len
222 + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
223 + subject.len();
224
225 let subject = Subject::from(subject);
226 let reply = reply_to.map(Subject::from);
227
228 self.read_buf.advance(len + 2);
229 let payload = self.read_buf.split_to(payload_len).freeze();
230 self.read_buf.advance(2);
231
232 trace!(
233 subject = %subject,
234 sid = %sid,
235 reply = ?reply,
236 payload_len = %payload_len,
237 "read operation: MSG"
238 );
239
240 return Ok(Some(ServerOp::Message {
241 sid,
242 length,
243 reply,
244 headers: None,
245 subject,
246 payload,
247 status: None,
248 description: None,
249 }));
250 }
251
252 if self.read_buf.starts_with(b"HMSG ") {
253 let line = std::str::from_utf8(&self.read_buf[5..len])
255 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
256 let mut args = line.split_whitespace().filter(|s| !s.is_empty());
257
258 let (subject, sid, reply_to, header_len, total_len) = match (
260 args.next(),
261 args.next(),
262 args.next(),
263 args.next(),
264 args.next(),
265 args.next(),
266 ) {
267 (
268 Some(subject),
269 Some(sid),
270 Some(reply_to),
271 Some(header_len),
272 Some(total_len),
273 None,
274 ) => (subject, sid, Some(reply_to), header_len, total_len),
275 (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
276 (subject, sid, None, header_len, total_len)
277 }
278 _ => {
279 return Err(io::Error::new(
280 io::ErrorKind::InvalidInput,
281 "invalid number of arguments after HMSG",
282 ))
283 }
284 };
285
286 let subject = Subject::from(subject);
288
289 let sid = sid.parse::<u64>().map_err(|_| {
291 io::Error::new(
292 io::ErrorKind::InvalidInput,
293 "cannot parse sid argument after HMSG",
294 )
295 })?;
296
297 let reply = reply_to.map(Subject::from);
299
300 let header_len = header_len.parse::<usize>().map_err(|_| {
302 io::Error::new(
303 io::ErrorKind::InvalidInput,
304 "cannot parse the number of header bytes argument after \
305 HMSG",
306 )
307 })?;
308
309 let total_len = total_len.parse::<usize>().map_err(|_| {
311 io::Error::new(
312 io::ErrorKind::InvalidInput,
313 "cannot parse the number of bytes argument after HMSG",
314 )
315 })?;
316
317 if total_len < header_len {
318 return Err(io::Error::new(
319 io::ErrorKind::InvalidInput,
320 "number of header bytes was greater than or equal to the \
321 total number of bytes after HMSG",
322 ));
323 }
324
325 if len + total_len + 4 > self.read_buf.remaining() {
326 return Ok(None);
327 }
328
329 self.read_buf.advance(len + 2);
330 let header = self.read_buf.split_to(header_len);
331 let payload = self.read_buf.split_to(total_len - header_len).freeze();
332 self.read_buf.advance(2);
333
334 let mut lines = std::str::from_utf8(&header)
335 .map_err(|_| {
336 io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
337 })?
338 .lines()
339 .peekable();
340 let version_line = lines.next().ok_or_else(|| {
341 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
342 })?;
343
344 let version_line_suffix = version_line
345 .strip_prefix("NATS/1.0")
346 .map(str::trim)
347 .ok_or_else(|| {
348 io::Error::new(
349 io::ErrorKind::InvalidInput,
350 "header version line does not begin with `NATS/1.0`",
351 )
352 })?;
353
354 let (status, description) = version_line_suffix
355 .split_once(' ')
356 .map(|(status, description)| (status.trim(), description.trim()))
357 .unwrap_or((version_line_suffix, ""));
358 let status = if !status.is_empty() {
359 Some(
360 status
361 .parse::<StatusCode>()
362 .map_err(|_| std::io::Error::other("could not parse status parameter"))?,
363 )
364 } else {
365 None
366 };
367 let description = if !description.is_empty() {
368 Some(description.to_owned())
369 } else {
370 None
371 };
372
373 let mut headers = HeaderMap::new();
374 while let Some(line) = lines.next() {
375 if line.is_empty() {
376 continue;
377 }
378
379 let (name, value) = line.split_once(':').ok_or_else(|| {
380 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
381 })?;
382
383 let name = HeaderName::from_str(name)
384 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
385
386 let mut value = value.trim_start().to_owned();
389 while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
390 value.push_str(v);
391 }
392 value.truncate(value.trim_end().len());
393
394 headers.append(name, value.into_header_value());
395 }
396
397 trace!(
398 subject = %subject,
399 sid = %sid,
400 reply = ?reply,
401 header_len = %header_len,
402 total_len = %total_len,
403 status = ?status,
404 description = ?description,
405 "read operation: HMSG"
406 );
407
408 return Ok(Some(ServerOp::Message {
409 length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
410 sid,
411 reply,
412 subject,
413 headers: Some(headers),
414 payload,
415 status,
416 description,
417 }));
418 }
419
420 let buffer = self.read_buf.split_to(len + 2);
421 let line = str::from_utf8(&buffer).map_err(|_| {
422 io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
423 })?;
424
425 trace!(line = %line, "read operation: unknown");
426 Err(io::Error::new(
427 io::ErrorKind::InvalidInput,
428 format!("invalid server operation: '{line}'"),
429 ))
430 }
431
432 pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
433 future::poll_fn(|cx| self.poll_read_op(cx))
434 }
435
436 pub(crate) fn poll_read_op(
440 &mut self,
441 cx: &mut Context<'_>,
442 ) -> Poll<io::Result<Option<ServerOp>>> {
443 loop {
444 if let Some(op) = self.try_read_op()? {
445 trace!(?op, "read operation completed");
446 return Poll::Ready(Ok(Some(op)));
447 }
448
449 let read_buf = self.stream.read_buf(&mut self.read_buf);
450 tokio::pin!(read_buf);
451 return match read_buf.poll(cx) {
452 Poll::Pending => {
453 trace!("read operation pending");
454 Poll::Pending
455 }
456 Poll::Ready(Ok(0)) if self.read_buf.is_empty() => {
457 trace!("read operation: empty buffer");
458 Poll::Ready(Ok(None))
459 }
460 Poll::Ready(Ok(0)) => {
461 trace!("read operation: connection reset");
462 Poll::Ready(Err(io::ErrorKind::ConnectionReset.into()))
463 }
464 Poll::Ready(Ok(n)) => {
465 self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
466 trace!(bytes = %n, "read operation: received bytes");
467 continue;
468 }
469 Poll::Ready(Err(err)) => {
470 trace!(error = %err, "read operation: error");
471 Poll::Ready(Err(err))
472 }
473 };
474 }
475 }
476
477 pub(crate) async fn easy_write_and_flush<'a>(
478 &mut self,
479 items: impl Iterator<Item = &'a ClientOp>,
480 ) -> io::Result<()> {
481 for item in items {
482 self.enqueue_write_op(item);
483 }
484
485 future::poll_fn(|cx| self.poll_write(cx)).await?;
486 future::poll_fn(|cx| self.poll_flush(cx)).await?;
487 Ok(())
488 }
489
490 pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
492 macro_rules! small_write {
493 ($dst:expr) => {
494 write!(self.small_write(), $dst).expect("do small write to Connection");
495 };
496 }
497
498 match item {
499 ClientOp::Connect(connect_info) => {
500 let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
501
502 self.write("CONNECT ");
503 self.write(json);
504 self.write("\r\n");
505 trace!(?connect_info, "write operation: CONNECT");
506 }
507 ClientOp::Publish {
508 subject,
509 payload,
510 respond,
511 headers,
512 } => {
513 let verb = match headers.as_ref() {
514 Some(headers) if !headers.is_empty() => "HPUB",
515 _ => "PUB",
516 };
517
518 small_write!("{verb} {subject} ");
519
520 if let Some(respond) = respond {
521 small_write!("{respond} ");
522 }
523
524 match headers {
525 Some(headers) if !headers.is_empty() => {
526 let headers = headers.to_bytes();
527
528 let headers_len = headers.len();
529 let total_len = headers_len + payload.len();
530 small_write!("{headers_len} {total_len}\r\n");
531 self.write(headers);
532 }
533 _ => {
534 let payload_len = payload.len();
535 small_write!("{payload_len}\r\n");
536 }
537 }
538
539 self.write(Bytes::clone(payload));
540 self.write("\r\n");
541
542 trace!(
543 verb = %verb,
544 subject = %subject,
545 reply = ?respond,
546 headers = ?headers,
547 payload_len = %payload.len(),
548 "write operation: PUB"
549 );
550 }
551
552 ClientOp::Subscribe {
553 sid,
554 subject,
555 queue_group,
556 } => {
557 match queue_group {
558 Some(queue_group) => {
559 small_write!("SUB {subject} {queue_group} {sid}\r\n");
560 }
561 None => {
562 small_write!("SUB {subject} {sid}\r\n");
563 }
564 }
565
566 trace!(
567 subject = %subject,
568 sid = %sid,
569 queue_group = ?queue_group,
570 "write operation: SUB"
571 );
572 }
573
574 ClientOp::Unsubscribe { sid, max } => {
575 match max {
576 Some(max) => {
577 small_write!("UNSUB {sid} {max}\r\n");
578 }
579 None => {
580 small_write!("UNSUB {sid}\r\n");
581 }
582 }
583
584 trace!(
585 sid = %sid,
586 max = ?max,
587 "write operation: UNSUB"
588 );
589 }
590 ClientOp::Ping => {
591 self.write("PING\r\n");
592 trace!("write operation: PING");
593 }
594 ClientOp::Pong => {
595 self.write("PONG\r\n");
596 trace!("write operation: PONG");
597 }
598 }
599 }
600
601 pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
614 if !self.stream.is_write_vectored() {
615 self.poll_write_sequential(cx)
616 } else {
617 self.poll_write_vectored(cx)
618 }
619 }
620
621 fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
625 loop {
626 let buf = match self.write_buf.front() {
627 Some(buf) => &**buf,
628 None if !self.flattened_writes.is_empty() => &self.flattened_writes,
629 None => return Poll::Ready(Ok(())),
630 };
631
632 debug_assert!(!buf.is_empty());
633
634 match Pin::new(&mut self.stream).poll_write(cx, buf) {
635 Poll::Pending => return Poll::Pending,
636 Poll::Ready(Ok(n)) => {
637 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
638 self.write_buf_len -= n;
639 self.can_flush = true;
640
641 match self.write_buf.front_mut() {
642 Some(buf) if n < buf.len() => {
643 buf.advance(n);
644 }
645 Some(_buf) => {
646 self.write_buf.pop_front();
647 }
648 None => {
649 self.flattened_writes.advance(n);
650 }
651 }
652 continue;
653 }
654 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
655 }
656 }
657 }
658 fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
663 'outer: loop {
664 let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
665 let mut writes_len = 0;
666
667 self.write_buf
668 .iter()
669 .take(WRITE_VECTORED_CHUNKS)
670 .enumerate()
671 .for_each(|(i, buf)| {
672 writes[i] = IoSlice::new(buf);
673 writes_len += 1;
674 });
675
676 if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
677 writes[writes_len] = IoSlice::new(&self.flattened_writes);
678 writes_len += 1;
679 }
680
681 if writes_len == 0 {
682 return Poll::Ready(Ok(()));
683 }
684
685 match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
686 Poll::Pending => return Poll::Pending,
687 Poll::Ready(Ok(mut n)) => {
688 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
689 self.write_buf_len -= n;
690 self.can_flush = true;
691
692 while let Some(buf) = self.write_buf.front_mut() {
693 if n < buf.len() {
694 buf.advance(n);
695 continue 'outer;
696 }
697
698 n -= buf.len();
699 self.write_buf.pop_front();
700 }
701
702 self.flattened_writes.advance(n);
703 }
704 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
705 }
706 }
707 }
708
709 fn write(&mut self, buf: impl Into<Bytes>) {
716 let buf = buf.into();
717 if buf.is_empty() {
718 return;
719 }
720
721 self.write_buf_len += buf.len();
722 if buf.len() < WRITE_FLATTEN_THRESHOLD {
723 self.flattened_writes.extend_from_slice(&buf);
724 } else {
725 if !self.flattened_writes.is_empty() {
726 let buf = self.flattened_writes.split().freeze();
727 self.write_buf.push_back(buf);
728 }
729
730 self.write_buf.push_back(buf);
731 }
732 }
733
734 fn small_write(&mut self) -> impl fmt::Write + '_ {
736 struct Writer<'a> {
737 this: &'a mut Connection,
738 }
739
740 impl fmt::Write for Writer<'_> {
741 fn write_str(&mut self, s: &str) -> fmt::Result {
742 self.this.write_buf_len += s.len();
743 self.this.flattened_writes.write_str(s)
744 }
745 }
746
747 Writer { this: self }
748 }
749
750 pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
754 match Pin::new(&mut self.stream).poll_flush(cx) {
755 Poll::Pending => Poll::Pending,
756 Poll::Ready(Ok(())) => {
757 self.can_flush = false;
758 Poll::Ready(Ok(()))
759 }
760 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
761 }
762 }
763}
764
765#[cfg(feature = "websockets")]
766#[pin_project]
767pub(crate) struct WebSocketAdapter<T> {
768 #[pin]
769 pub(crate) inner: WebSocketStream<T>,
770 pub(crate) read_buf: BytesMut,
771}
772
773#[cfg(feature = "websockets")]
774impl<T> WebSocketAdapter<T> {
775 pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
776 Self {
777 inner,
778 read_buf: BytesMut::new(),
779 }
780 }
781}
782
783#[cfg(feature = "websockets")]
784impl<T> AsyncRead for WebSocketAdapter<T>
785where
786 T: AsyncRead + AsyncWrite + Unpin,
787{
788 fn poll_read(
789 self: Pin<&mut Self>,
790 cx: &mut Context<'_>,
791 buf: &mut ReadBuf<'_>,
792 ) -> Poll<std::io::Result<()>> {
793 let mut this = self.project();
794
795 loop {
796 if !this.read_buf.is_empty() {
798 let len = std::cmp::min(buf.remaining(), this.read_buf.len());
799 buf.put_slice(&this.read_buf.split_to(len));
800 return Poll::Ready(Ok(()));
801 }
802
803 match this.inner.poll_next_unpin(cx) {
804 Poll::Ready(Some(Ok(message))) => {
805 this.read_buf.extend_from_slice(message.as_payload());
806 }
807 Poll::Ready(Some(Err(e))) => {
808 return Poll::Ready(Err(std::io::Error::other(e)));
809 }
810 Poll::Ready(None) => {
811 return Poll::Ready(Err(std::io::Error::new(
812 std::io::ErrorKind::UnexpectedEof,
813 "WebSocket closed",
814 )));
815 }
816 Poll::Pending => {
817 return Poll::Pending;
818 }
819 }
820 }
821 }
822}
823
824#[cfg(feature = "websockets")]
825impl<T> AsyncWrite for WebSocketAdapter<T>
826where
827 T: AsyncRead + AsyncWrite + Unpin,
828{
829 fn poll_write(
830 self: Pin<&mut Self>,
831 cx: &mut Context<'_>,
832 buf: &[u8],
833 ) -> Poll<std::io::Result<usize>> {
834 let mut this = self.project();
835
836 let data = buf.to_vec();
837 match this.inner.poll_ready_unpin(cx) {
838 Poll::Ready(Ok(())) => match this
839 .inner
840 .start_send_unpin(tokio_websockets::Message::binary(data))
841 {
842 Ok(()) => Poll::Ready(Ok(buf.len())),
843 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
844 },
845 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
846 Poll::Pending => Poll::Pending,
847 }
848 }
849
850 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
851 self.project()
852 .inner
853 .poll_flush_unpin(cx)
854 .map_err(std::io::Error::other)
855 }
856
857 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
858 self.project()
859 .inner
860 .poll_close_unpin(cx)
861 .map_err(std::io::Error::other)
862 }
863}
864
865#[cfg(test)]
866mod read_op {
867 use std::sync::Arc;
868
869 use super::Connection;
870 use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
871 use tokio::io::{self, AsyncWriteExt};
872
873 #[tokio::test]
874 async fn ok() {
875 let (stream, mut server) = io::duplex(128);
876 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
877
878 server.write_all(b"+OK\r\n").await.unwrap();
879 let result = connection.read_op().await.unwrap();
880 assert_eq!(result, Some(ServerOp::Ok));
881 }
882
883 #[tokio::test]
884 async fn ping() {
885 let (stream, mut server) = io::duplex(128);
886 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
887
888 server.write_all(b"PING\r\n").await.unwrap();
889 let result = connection.read_op().await.unwrap();
890 assert_eq!(result, Some(ServerOp::Ping));
891 }
892
893 #[tokio::test]
894 async fn pong() {
895 let (stream, mut server) = io::duplex(128);
896 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
897
898 server.write_all(b"PONG\r\n").await.unwrap();
899 let result = connection.read_op().await.unwrap();
900 assert_eq!(result, Some(ServerOp::Pong));
901 }
902
903 #[tokio::test]
904 async fn info() {
905 let (stream, mut server) = io::duplex(128);
906 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
907
908 server.write_all(b"INFO {}\r\n").await.unwrap();
909 server.flush().await.unwrap();
910
911 let result = connection.read_op().await.unwrap();
912 assert_eq!(result, Some(ServerOp::Info(Box::default())));
913
914 server
915 .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
916 .await
917 .unwrap();
918 server.flush().await.unwrap();
919
920 let result = connection.read_op().await.unwrap();
921 assert_eq!(
922 result,
923 Some(ServerOp::Info(Box::new(ServerInfo {
924 version: "1.0.0".into(),
925 ..Default::default()
926 })))
927 );
928
929 server
930 .write_all(b"INFO { \"version\": \"1.0.0\", \"cluster\": \"test-cluster\" }\r\n")
931 .await
932 .unwrap();
933 server.flush().await.unwrap();
934
935 let result = connection.read_op().await.unwrap();
936 assert_eq!(
937 result,
938 Some(ServerOp::Info(Box::new(ServerInfo {
939 version: "1.0.0".into(),
940 cluster: Some("test-cluster".into()),
941 ..Default::default()
942 })))
943 );
944 }
945
946 #[tokio::test]
947 async fn error() {
948 let (stream, mut server) = io::duplex(128);
949 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
950
951 server.write_all(b"INFO {}\r\n").await.unwrap();
952 let result = connection.read_op().await.unwrap();
953 assert_eq!(result, Some(ServerOp::Info(Box::default())));
954
955 server
956 .write_all(b"-ERR something went wrong\r\n")
957 .await
958 .unwrap();
959 let result = connection.read_op().await.unwrap();
960 assert_eq!(
961 result,
962 Some(ServerOp::Error(ServerError::Other(
963 "something went wrong".into()
964 )))
965 );
966 }
967
968 #[tokio::test]
969 async fn message() {
970 let (stream, mut server) = io::duplex(128);
971 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
972
973 server
974 .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
975 .await
976 .unwrap();
977
978 let result = connection.read_op().await.unwrap();
979 assert_eq!(
980 result,
981 Some(ServerOp::Message {
982 sid: 9,
983 subject: "FOO.BAR".into(),
984 reply: None,
985 headers: None,
986 payload: "Hello World".into(),
987 status: None,
988 description: None,
989 length: 7 + 11,
990 })
991 );
992
993 server
994 .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
995 .await
996 .unwrap();
997
998 let result = connection.read_op().await.unwrap();
999 assert_eq!(
1000 result,
1001 Some(ServerOp::Message {
1002 sid: 9,
1003 subject: "FOO.BAR".into(),
1004 reply: Some("INBOX.34".into()),
1005 headers: None,
1006 payload: "Hello World".into(),
1007 status: None,
1008 description: None,
1009 length: 7 + 8 + 11,
1010 })
1011 );
1012
1013 server
1014 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1015 .await
1016 .unwrap();
1017 server.write_all(b"NATS/1.0\r\n").await.unwrap();
1018 server.write_all(b"Header: X\r\n").await.unwrap();
1019 server.write_all(b"\r\n").await.unwrap();
1020 server.write_all(b"Hello World\r\n").await.unwrap();
1021
1022 let result = connection.read_op().await.unwrap();
1023
1024 assert_eq!(
1025 result,
1026 Some(ServerOp::Message {
1027 sid: 10,
1028 subject: "FOO.BAR".into(),
1029 reply: Some("INBOX.35".into()),
1030 headers: Some(HeaderMap::from_iter([(
1031 "Header".parse().unwrap(),
1032 "X".parse().unwrap()
1033 )])),
1034 payload: "Hello World".into(),
1035 status: None,
1036 description: None,
1037 length: 7 + 8 + 34
1038 })
1039 );
1040
1041 server
1042 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1043 .await
1044 .unwrap();
1045 server.write_all(b"NATS/1.0\r\n").await.unwrap();
1046 server.write_all(b"Header: Y\r\n").await.unwrap();
1047 server.write_all(b"\r\n").await.unwrap();
1048 server.write_all(b"Hello World\r\n").await.unwrap();
1049
1050 let result = connection.read_op().await.unwrap();
1051 assert_eq!(
1052 result,
1053 Some(ServerOp::Message {
1054 sid: 10,
1055 subject: "FOO.BAR".into(),
1056 reply: Some("INBOX.35".into()),
1057 headers: Some(HeaderMap::from_iter([(
1058 "Header".parse().unwrap(),
1059 "Y".parse().unwrap()
1060 )])),
1061 payload: "Hello World".into(),
1062 status: None,
1063 description: None,
1064 length: 7 + 8 + 34,
1065 })
1066 );
1067
1068 server
1069 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1070 .await
1071 .unwrap();
1072 server
1073 .write_all(b"NATS/1.0 404 No Messages\r\n")
1074 .await
1075 .unwrap();
1076 server.write_all(b"\r\n").await.unwrap();
1077 server.write_all(b"\r\n").await.unwrap();
1078
1079 let result = connection.read_op().await.unwrap();
1080 assert_eq!(
1081 result,
1082 Some(ServerOp::Message {
1083 sid: 10,
1084 subject: "FOO.BAR".into(),
1085 reply: Some("INBOX.35".into()),
1086 headers: Some(HeaderMap::default()),
1087 payload: "".into(),
1088 status: Some(StatusCode::NOT_FOUND),
1089 description: Some("No Messages".to_string()),
1090 length: 7 + 8 + 28,
1091 })
1092 );
1093
1094 server
1095 .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1096 .await
1097 .unwrap();
1098
1099 let result = connection.read_op().await.unwrap();
1100 assert_eq!(
1101 result,
1102 Some(ServerOp::Message {
1103 sid: 9,
1104 subject: "FOO.BAR".into(),
1105 reply: None,
1106 headers: None,
1107 payload: "Hello Again".into(),
1108 status: None,
1109 description: None,
1110 length: 7 + 11,
1111 })
1112 );
1113 }
1114
1115 #[tokio::test]
1116 async fn unknown() {
1117 let (stream, mut server) = io::duplex(128);
1118 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1119
1120 server.write_all(b"ONE\r\n").await.unwrap();
1121 connection.read_op().await.unwrap_err();
1122
1123 server.write_all(b"TWO\r\n").await.unwrap();
1124 connection.read_op().await.unwrap_err();
1125
1126 server.write_all(b"PING\r\n").await.unwrap();
1127 connection.read_op().await.unwrap();
1128
1129 server.write_all(b"THREE\r\n").await.unwrap();
1130 connection.read_op().await.unwrap_err();
1131
1132 server
1133 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1134 .await
1135 .unwrap();
1136 server
1137 .write_all(b"NATS/1.0 404 No Messages\r\n")
1138 .await
1139 .unwrap();
1140 server.write_all(b"\r\n").await.unwrap();
1141 server.write_all(b"\r\n").await.unwrap();
1142
1143 let result = connection.read_op().await.unwrap();
1144 assert_eq!(
1145 result,
1146 Some(ServerOp::Message {
1147 sid: 10,
1148 subject: "FOO.BAR".into(),
1149 reply: Some("INBOX.35".into()),
1150 headers: Some(HeaderMap::default()),
1151 payload: "".into(),
1152 status: Some(StatusCode::NOT_FOUND),
1153 description: Some("No Messages".to_string()),
1154 length: 7 + 8 + 28,
1155 })
1156 );
1157
1158 server.write_all(b"FOUR\r\n").await.unwrap();
1159 connection.read_op().await.unwrap_err();
1160
1161 server.write_all(b"PONG\r\n").await.unwrap();
1162 connection.read_op().await.unwrap();
1163 }
1164
1165 #[tokio::test]
1171 async fn msg_with_non_utf8_subject_returns_error() {
1172 let (stream, mut server) = io::duplex(128);
1173 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1174
1175 server.write_all(b"MSG balance.\xFF 1 2\r\n").await.unwrap();
1176 server.write_all(b"hi\r\n").await.unwrap();
1177
1178 let err = connection.read_op().await.unwrap_err();
1179 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
1180 }
1181
1182 #[tokio::test]
1183 async fn hmsg_with_non_utf8_subject_returns_error() {
1184 let (stream, mut server) = io::duplex(128);
1185 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1186
1187 server
1188 .write_all(b"HMSG balance.\xFF 1 12 14\r\n")
1189 .await
1190 .unwrap();
1191 server.write_all(b"NATS/1.0\r\n\r\nhi\r\n").await.unwrap();
1192
1193 let err = connection.read_op().await.unwrap_err();
1194 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
1195 }
1196}
1197
1198#[cfg(test)]
1199mod write_op {
1200 use std::sync::Arc;
1201
1202 use super::Connection;
1203 use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1204 use tokio::io::{self, AsyncBufReadExt, BufReader};
1205
1206 #[tokio::test]
1207 async fn publish() {
1208 let (stream, server) = io::duplex(128);
1209 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1210
1211 connection
1212 .easy_write_and_flush(
1213 [ClientOp::Publish {
1214 subject: "FOO.BAR".into(),
1215 payload: "Hello World".into(),
1216 respond: None,
1217 headers: None,
1218 }]
1219 .iter(),
1220 )
1221 .await
1222 .unwrap();
1223
1224 let mut buffer = String::new();
1225 let mut reader = BufReader::new(server);
1226 reader.read_line(&mut buffer).await.unwrap();
1227 reader.read_line(&mut buffer).await.unwrap();
1228 assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1229
1230 connection
1231 .easy_write_and_flush(
1232 [ClientOp::Publish {
1233 subject: "FOO.BAR".into(),
1234 payload: "Hello World".into(),
1235 respond: Some("INBOX.67".into()),
1236 headers: None,
1237 }]
1238 .iter(),
1239 )
1240 .await
1241 .unwrap();
1242
1243 buffer.clear();
1244 reader.read_line(&mut buffer).await.unwrap();
1245 reader.read_line(&mut buffer).await.unwrap();
1246 assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1247
1248 connection
1249 .easy_write_and_flush(
1250 [ClientOp::Publish {
1251 subject: "FOO.BAR".into(),
1252 payload: "Hello World".into(),
1253 respond: Some("INBOX.67".into()),
1254 headers: Some(HeaderMap::from_iter([(
1255 "Header".parse().unwrap(),
1256 "X".parse().unwrap(),
1257 )])),
1258 }]
1259 .iter(),
1260 )
1261 .await
1262 .unwrap();
1263
1264 buffer.clear();
1265 reader.read_line(&mut buffer).await.unwrap();
1266 reader.read_line(&mut buffer).await.unwrap();
1267 reader.read_line(&mut buffer).await.unwrap();
1268 reader.read_line(&mut buffer).await.unwrap();
1269 assert_eq!(
1270 buffer,
1271 "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1272 );
1273 }
1274
1275 #[tokio::test]
1276 async fn subscribe() {
1277 let (stream, server) = io::duplex(128);
1278 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1279
1280 connection
1281 .easy_write_and_flush(
1282 [ClientOp::Subscribe {
1283 sid: 11,
1284 subject: "FOO.BAR".into(),
1285 queue_group: None,
1286 }]
1287 .iter(),
1288 )
1289 .await
1290 .unwrap();
1291
1292 let mut buffer = String::new();
1293 let mut reader = BufReader::new(server);
1294 reader.read_line(&mut buffer).await.unwrap();
1295 assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1296
1297 connection
1298 .easy_write_and_flush(
1299 [ClientOp::Subscribe {
1300 sid: 11,
1301 subject: "FOO.BAR".into(),
1302 queue_group: Some("QUEUE.GROUP".into()),
1303 }]
1304 .iter(),
1305 )
1306 .await
1307 .unwrap();
1308
1309 buffer.clear();
1310 reader.read_line(&mut buffer).await.unwrap();
1311 assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1312 }
1313
1314 #[tokio::test]
1315 async fn unsubscribe() {
1316 let (stream, server) = io::duplex(128);
1317 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1318
1319 connection
1320 .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1321 .await
1322 .unwrap();
1323
1324 let mut buffer = String::new();
1325 let mut reader = BufReader::new(server);
1326 reader.read_line(&mut buffer).await.unwrap();
1327 assert_eq!(buffer, "UNSUB 11\r\n");
1328
1329 connection
1330 .easy_write_and_flush(
1331 [ClientOp::Unsubscribe {
1332 sid: 11,
1333 max: Some(2),
1334 }]
1335 .iter(),
1336 )
1337 .await
1338 .unwrap();
1339
1340 buffer.clear();
1341 reader.read_line(&mut buffer).await.unwrap();
1342 assert_eq!(buffer, "UNSUB 11 2\r\n");
1343 }
1344
1345 #[tokio::test]
1346 async fn ping() {
1347 let (stream, server) = io::duplex(128);
1348 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1349
1350 let mut reader = BufReader::new(server);
1351 let mut buffer = String::new();
1352
1353 connection
1354 .easy_write_and_flush([ClientOp::Ping].iter())
1355 .await
1356 .unwrap();
1357
1358 reader.read_line(&mut buffer).await.unwrap();
1359
1360 assert_eq!(buffer, "PING\r\n");
1361 }
1362
1363 #[tokio::test]
1364 async fn pong() {
1365 let (stream, server) = io::duplex(128);
1366 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1367
1368 let mut reader = BufReader::new(server);
1369 let mut buffer = String::new();
1370
1371 connection
1372 .easy_write_and_flush([ClientOp::Pong].iter())
1373 .await
1374 .unwrap();
1375
1376 reader.read_line(&mut buffer).await.unwrap();
1377
1378 assert_eq!(buffer, "PONG\r\n");
1379 }
1380
1381 #[tokio::test]
1382 async fn connect() {
1383 let (stream, server) = io::duplex(1024);
1384 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1385
1386 let mut reader = BufReader::new(server);
1387 let mut buffer = String::new();
1388
1389 connection
1390 .easy_write_and_flush(
1391 [ClientOp::Connect(ConnectInfo {
1392 verbose: false,
1393 pedantic: false,
1394 user_jwt: None,
1395 nkey: None,
1396 signature: None,
1397 name: None,
1398 echo: false,
1399 lang: "Rust".into(),
1400 version: "1.0.0".into(),
1401 protocol: Protocol::Dynamic,
1402 tls_required: false,
1403 user: None,
1404 pass: None,
1405 auth_token: None,
1406 headers: false,
1407 no_responders: false,
1408 })]
1409 .iter(),
1410 )
1411 .await
1412 .unwrap();
1413
1414 reader.read_line(&mut buffer).await.unwrap();
1415 assert_eq!(
1416 buffer,
1417 "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1418 );
1419 }
1420}