1use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28 futures_util::{SinkExt, StreamExt},
29 pin_project::pin_project,
30 tokio::io::ReadBuf,
31 tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36use tracing::trace;
37
38use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
39use crate::status::StatusCode;
40use crate::subject::Subject;
41use crate::{ClientOp, ServerError, ServerOp, Statistics};
42
43const SOFT_WRITE_BUF_LIMIT: usize = 65535;
46const WRITE_FLATTEN_THRESHOLD: usize = 4096;
49const WRITE_VECTORED_CHUNKS: usize = 64;
51
52pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
54
55impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
57
58#[derive(Debug, Eq, PartialEq, Clone)]
60pub enum State {
61 Pending,
62 Connected,
63 Disconnected,
64}
65
66#[derive(Debug, Eq, PartialEq, Clone)]
67pub enum ShouldFlush {
68 Yes,
70 May,
72 No,
74}
75
76impl Display for State {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 State::Pending => write!(f, "pending"),
80 State::Connected => write!(f, "connected"),
81 State::Disconnected => write!(f, "disconnected"),
82 }
83 }
84}
85
86pub(crate) struct Connection {
88 pub(crate) stream: Box<dyn AsyncReadWrite>,
89 read_buf: BytesMut,
90 write_buf: VecDeque<Bytes>,
91 write_buf_len: usize,
92 flattened_writes: BytesMut,
93 can_flush: bool,
94 statistics: Arc<Statistics>,
95}
96
97impl Connection {
100 pub(crate) fn new(
101 stream: Box<dyn AsyncReadWrite>,
102 read_buffer_capacity: usize,
103 statistics: Arc<Statistics>,
104 ) -> Self {
105 Self {
106 stream,
107 read_buf: BytesMut::with_capacity(read_buffer_capacity),
108 write_buf: VecDeque::new(),
109 write_buf_len: 0,
110 flattened_writes: BytesMut::new(),
111 can_flush: false,
112 statistics,
113 }
114 }
115
116 pub(crate) fn is_write_buf_full(&self) -> bool {
118 self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
119 }
120
121 pub(crate) fn should_flush(&self) -> ShouldFlush {
123 match (
124 self.can_flush,
125 self.write_buf.is_empty() && self.flattened_writes.is_empty(),
126 ) {
127 (true, true) => ShouldFlush::Yes,
128 (true, false) => ShouldFlush::May,
129 (false, _) => ShouldFlush::No,
130 }
131 }
132
133 pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
136 let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
137 Some(len) => len,
138 None => return Ok(None),
139 };
140
141 if self.read_buf.starts_with(b"+OK") {
142 self.read_buf.advance(len + 2);
143 trace!("read operation: OK");
144 return Ok(Some(ServerOp::Ok));
145 }
146
147 if self.read_buf.starts_with(b"PING") {
148 self.read_buf.advance(len + 2);
149 trace!("read operation: PING");
150 return Ok(Some(ServerOp::Ping));
151 }
152
153 if self.read_buf.starts_with(b"PONG") {
154 self.read_buf.advance(len + 2);
155 trace!("read operation: PONG");
156 return Ok(Some(ServerOp::Pong));
157 }
158
159 if self.read_buf.starts_with(b"-ERR") {
160 let description = str::from_utf8(&self.read_buf[5..len])
161 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
162 .trim_matches('\'')
163 .to_owned();
164
165 self.read_buf.advance(len + 2);
166 trace!(error = %description, "read operation: ERR");
167 return Ok(Some(ServerOp::Error(ServerError::new(description))));
168 }
169
170 if self.read_buf.starts_with(b"INFO ") {
171 let info = serde_json::from_slice(&self.read_buf[4..len])
172 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
173
174 self.read_buf.advance(len + 2);
175 trace!(?info, "read operation: INFO");
176 return Ok(Some(ServerOp::Info(Box::new(info))));
177 }
178
179 if self.read_buf.starts_with(b"MSG ") {
180 let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
181 let mut args = line.split(' ').filter(|s| !s.is_empty());
182
183 let (subject, sid, reply_to, payload_len) = match (
185 args.next(),
186 args.next(),
187 args.next(),
188 args.next(),
189 args.next(),
190 ) {
191 (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
192 (subject, sid, Some(reply_to), payload_len)
193 }
194 (Some(subject), Some(sid), Some(payload_len), None, None) => {
195 (subject, sid, None, payload_len)
196 }
197 _ => {
198 return Err(io::Error::new(
199 io::ErrorKind::InvalidInput,
200 "invalid number of arguments after MSG",
201 ))
202 }
203 };
204
205 let sid = sid
206 .parse::<u64>()
207 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
208
209 let payload_len = payload_len
211 .parse::<usize>()
212 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
213
214 if len + payload_len + 4 > self.read_buf.remaining() {
217 return Ok(None);
218 }
219
220 let length = payload_len
221 + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
222 + subject.len();
223
224 let subject = Subject::from(subject);
225 let reply = reply_to.map(Subject::from);
226
227 self.read_buf.advance(len + 2);
228 let payload = self.read_buf.split_to(payload_len).freeze();
229 self.read_buf.advance(2);
230
231 trace!(
232 subject = %subject,
233 sid = %sid,
234 reply = ?reply,
235 payload_len = %payload_len,
236 "read operation: MSG"
237 );
238
239 return Ok(Some(ServerOp::Message {
240 sid,
241 length,
242 reply,
243 headers: None,
244 subject,
245 payload,
246 status: None,
247 description: None,
248 }));
249 }
250
251 if self.read_buf.starts_with(b"HMSG ") {
252 let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
254 let mut args = line.split_whitespace().filter(|s| !s.is_empty());
255
256 let (subject, sid, reply_to, header_len, total_len) = match (
258 args.next(),
259 args.next(),
260 args.next(),
261 args.next(),
262 args.next(),
263 args.next(),
264 ) {
265 (
266 Some(subject),
267 Some(sid),
268 Some(reply_to),
269 Some(header_len),
270 Some(total_len),
271 None,
272 ) => (subject, sid, Some(reply_to), header_len, total_len),
273 (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
274 (subject, sid, None, header_len, total_len)
275 }
276 _ => {
277 return Err(io::Error::new(
278 io::ErrorKind::InvalidInput,
279 "invalid number of arguments after HMSG",
280 ))
281 }
282 };
283
284 let subject = Subject::from(subject);
286
287 let sid = sid.parse::<u64>().map_err(|_| {
289 io::Error::new(
290 io::ErrorKind::InvalidInput,
291 "cannot parse sid argument after HMSG",
292 )
293 })?;
294
295 let reply = reply_to.map(Subject::from);
297
298 let header_len = header_len.parse::<usize>().map_err(|_| {
300 io::Error::new(
301 io::ErrorKind::InvalidInput,
302 "cannot parse the number of header bytes argument after \
303 HMSG",
304 )
305 })?;
306
307 let total_len = total_len.parse::<usize>().map_err(|_| {
309 io::Error::new(
310 io::ErrorKind::InvalidInput,
311 "cannot parse the number of bytes argument after HMSG",
312 )
313 })?;
314
315 if total_len < header_len {
316 return Err(io::Error::new(
317 io::ErrorKind::InvalidInput,
318 "number of header bytes was greater than or equal to the \
319 total number of bytes after HMSG",
320 ));
321 }
322
323 if len + total_len + 4 > self.read_buf.remaining() {
324 return Ok(None);
325 }
326
327 self.read_buf.advance(len + 2);
328 let header = self.read_buf.split_to(header_len);
329 let payload = self.read_buf.split_to(total_len - header_len).freeze();
330 self.read_buf.advance(2);
331
332 let mut lines = std::str::from_utf8(&header)
333 .map_err(|_| {
334 io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
335 })?
336 .lines()
337 .peekable();
338 let version_line = lines.next().ok_or_else(|| {
339 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
340 })?;
341
342 let version_line_suffix = version_line
343 .strip_prefix("NATS/1.0")
344 .map(str::trim)
345 .ok_or_else(|| {
346 io::Error::new(
347 io::ErrorKind::InvalidInput,
348 "header version line does not begin with `NATS/1.0`",
349 )
350 })?;
351
352 let (status, description) = version_line_suffix
353 .split_once(' ')
354 .map(|(status, description)| (status.trim(), description.trim()))
355 .unwrap_or((version_line_suffix, ""));
356 let status = if !status.is_empty() {
357 Some(
358 status
359 .parse::<StatusCode>()
360 .map_err(|_| std::io::Error::other("could not parse status parameter"))?,
361 )
362 } else {
363 None
364 };
365 let description = if !description.is_empty() {
366 Some(description.to_owned())
367 } else {
368 None
369 };
370
371 let mut headers = HeaderMap::new();
372 while let Some(line) = lines.next() {
373 if line.is_empty() {
374 continue;
375 }
376
377 let (name, value) = line.split_once(':').ok_or_else(|| {
378 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
379 })?;
380
381 let name = HeaderName::from_str(name)
382 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
383
384 let mut value = value.trim_start().to_owned();
387 while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
388 value.push_str(v);
389 }
390 value.truncate(value.trim_end().len());
391
392 headers.append(name, value.into_header_value());
393 }
394
395 trace!(
396 subject = %subject,
397 sid = %sid,
398 reply = ?reply,
399 header_len = %header_len,
400 total_len = %total_len,
401 status = ?status,
402 description = ?description,
403 "read operation: HMSG"
404 );
405
406 return Ok(Some(ServerOp::Message {
407 length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
408 sid,
409 reply,
410 subject,
411 headers: Some(headers),
412 payload,
413 status,
414 description,
415 }));
416 }
417
418 let buffer = self.read_buf.split_to(len + 2);
419 let line = str::from_utf8(&buffer).map_err(|_| {
420 io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
421 })?;
422
423 trace!(line = %line, "read operation: unknown");
424 Err(io::Error::new(
425 io::ErrorKind::InvalidInput,
426 format!("invalid server operation: '{line}'"),
427 ))
428 }
429
430 pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
431 future::poll_fn(|cx| self.poll_read_op(cx))
432 }
433
434 pub(crate) fn poll_read_op(
438 &mut self,
439 cx: &mut Context<'_>,
440 ) -> Poll<io::Result<Option<ServerOp>>> {
441 loop {
442 if let Some(op) = self.try_read_op()? {
443 trace!(?op, "read operation completed");
444 return Poll::Ready(Ok(Some(op)));
445 }
446
447 let read_buf = self.stream.read_buf(&mut self.read_buf);
448 tokio::pin!(read_buf);
449 return match read_buf.poll(cx) {
450 Poll::Pending => {
451 trace!("read operation pending");
452 Poll::Pending
453 }
454 Poll::Ready(Ok(0)) if self.read_buf.is_empty() => {
455 trace!("read operation: empty buffer");
456 Poll::Ready(Ok(None))
457 }
458 Poll::Ready(Ok(0)) => {
459 trace!("read operation: connection reset");
460 Poll::Ready(Err(io::ErrorKind::ConnectionReset.into()))
461 }
462 Poll::Ready(Ok(n)) => {
463 self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
464 trace!(bytes = %n, "read operation: received bytes");
465 continue;
466 }
467 Poll::Ready(Err(err)) => {
468 trace!(error = %err, "read operation: error");
469 Poll::Ready(Err(err))
470 }
471 };
472 }
473 }
474
475 pub(crate) async fn easy_write_and_flush<'a>(
476 &mut self,
477 items: impl Iterator<Item = &'a ClientOp>,
478 ) -> io::Result<()> {
479 for item in items {
480 self.enqueue_write_op(item);
481 }
482
483 future::poll_fn(|cx| self.poll_write(cx)).await?;
484 future::poll_fn(|cx| self.poll_flush(cx)).await?;
485 Ok(())
486 }
487
488 pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
490 macro_rules! small_write {
491 ($dst:expr) => {
492 write!(self.small_write(), $dst).expect("do small write to Connection");
493 };
494 }
495
496 match item {
497 ClientOp::Connect(connect_info) => {
498 let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
499
500 self.write("CONNECT ");
501 self.write(json);
502 self.write("\r\n");
503 trace!(?connect_info, "write operation: CONNECT");
504 }
505 ClientOp::Publish {
506 subject,
507 payload,
508 respond,
509 headers,
510 } => {
511 let verb = match headers.as_ref() {
512 Some(headers) if !headers.is_empty() => "HPUB",
513 _ => "PUB",
514 };
515
516 small_write!("{verb} {subject} ");
517
518 if let Some(respond) = respond {
519 small_write!("{respond} ");
520 }
521
522 match headers {
523 Some(headers) if !headers.is_empty() => {
524 let headers = headers.to_bytes();
525
526 let headers_len = headers.len();
527 let total_len = headers_len + payload.len();
528 small_write!("{headers_len} {total_len}\r\n");
529 self.write(headers);
530 }
531 _ => {
532 let payload_len = payload.len();
533 small_write!("{payload_len}\r\n");
534 }
535 }
536
537 self.write(Bytes::clone(payload));
538 self.write("\r\n");
539
540 trace!(
541 verb = %verb,
542 subject = %subject,
543 reply = ?respond,
544 headers = ?headers,
545 payload_len = %payload.len(),
546 "write operation: PUB"
547 );
548 }
549
550 ClientOp::Subscribe {
551 sid,
552 subject,
553 queue_group,
554 } => {
555 match queue_group {
556 Some(queue_group) => {
557 small_write!("SUB {subject} {queue_group} {sid}\r\n");
558 }
559 None => {
560 small_write!("SUB {subject} {sid}\r\n");
561 }
562 }
563
564 trace!(
565 subject = %subject,
566 sid = %sid,
567 queue_group = ?queue_group,
568 "write operation: SUB"
569 );
570 }
571
572 ClientOp::Unsubscribe { sid, max } => {
573 match max {
574 Some(max) => {
575 small_write!("UNSUB {sid} {max}\r\n");
576 }
577 None => {
578 small_write!("UNSUB {sid}\r\n");
579 }
580 }
581
582 trace!(
583 sid = %sid,
584 max = ?max,
585 "write operation: UNSUB"
586 );
587 }
588 ClientOp::Ping => {
589 self.write("PING\r\n");
590 trace!("write operation: PING");
591 }
592 ClientOp::Pong => {
593 self.write("PONG\r\n");
594 trace!("write operation: PONG");
595 }
596 }
597 }
598
599 pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
612 if !self.stream.is_write_vectored() {
613 self.poll_write_sequential(cx)
614 } else {
615 self.poll_write_vectored(cx)
616 }
617 }
618
619 fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
623 loop {
624 let buf = match self.write_buf.front() {
625 Some(buf) => &**buf,
626 None if !self.flattened_writes.is_empty() => &self.flattened_writes,
627 None => return Poll::Ready(Ok(())),
628 };
629
630 debug_assert!(!buf.is_empty());
631
632 match Pin::new(&mut self.stream).poll_write(cx, buf) {
633 Poll::Pending => return Poll::Pending,
634 Poll::Ready(Ok(n)) => {
635 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
636 self.write_buf_len -= n;
637 self.can_flush = true;
638
639 match self.write_buf.front_mut() {
640 Some(buf) if n < buf.len() => {
641 buf.advance(n);
642 }
643 Some(_buf) => {
644 self.write_buf.pop_front();
645 }
646 None => {
647 self.flattened_writes.advance(n);
648 }
649 }
650 continue;
651 }
652 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
653 }
654 }
655 }
656 fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
661 'outer: loop {
662 let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
663 let mut writes_len = 0;
664
665 self.write_buf
666 .iter()
667 .take(WRITE_VECTORED_CHUNKS)
668 .enumerate()
669 .for_each(|(i, buf)| {
670 writes[i] = IoSlice::new(buf);
671 writes_len += 1;
672 });
673
674 if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
675 writes[writes_len] = IoSlice::new(&self.flattened_writes);
676 writes_len += 1;
677 }
678
679 if writes_len == 0 {
680 return Poll::Ready(Ok(()));
681 }
682
683 match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
684 Poll::Pending => return Poll::Pending,
685 Poll::Ready(Ok(mut n)) => {
686 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
687 self.write_buf_len -= n;
688 self.can_flush = true;
689
690 while let Some(buf) = self.write_buf.front_mut() {
691 if n < buf.len() {
692 buf.advance(n);
693 continue 'outer;
694 }
695
696 n -= buf.len();
697 self.write_buf.pop_front();
698 }
699
700 self.flattened_writes.advance(n);
701 }
702 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
703 }
704 }
705 }
706
707 fn write(&mut self, buf: impl Into<Bytes>) {
714 let buf = buf.into();
715 if buf.is_empty() {
716 return;
717 }
718
719 self.write_buf_len += buf.len();
720 if buf.len() < WRITE_FLATTEN_THRESHOLD {
721 self.flattened_writes.extend_from_slice(&buf);
722 } else {
723 if !self.flattened_writes.is_empty() {
724 let buf = self.flattened_writes.split().freeze();
725 self.write_buf.push_back(buf);
726 }
727
728 self.write_buf.push_back(buf);
729 }
730 }
731
732 fn small_write(&mut self) -> impl fmt::Write + '_ {
734 struct Writer<'a> {
735 this: &'a mut Connection,
736 }
737
738 impl fmt::Write for Writer<'_> {
739 fn write_str(&mut self, s: &str) -> fmt::Result {
740 self.this.write_buf_len += s.len();
741 self.this.flattened_writes.write_str(s)
742 }
743 }
744
745 Writer { this: self }
746 }
747
748 pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
752 match Pin::new(&mut self.stream).poll_flush(cx) {
753 Poll::Pending => Poll::Pending,
754 Poll::Ready(Ok(())) => {
755 self.can_flush = false;
756 Poll::Ready(Ok(()))
757 }
758 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
759 }
760 }
761}
762
763#[cfg(feature = "websockets")]
764#[pin_project]
765pub(crate) struct WebSocketAdapter<T> {
766 #[pin]
767 pub(crate) inner: WebSocketStream<T>,
768 pub(crate) read_buf: BytesMut,
769}
770
771#[cfg(feature = "websockets")]
772impl<T> WebSocketAdapter<T> {
773 pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
774 Self {
775 inner,
776 read_buf: BytesMut::new(),
777 }
778 }
779}
780
781#[cfg(feature = "websockets")]
782impl<T> AsyncRead for WebSocketAdapter<T>
783where
784 T: AsyncRead + AsyncWrite + Unpin,
785{
786 fn poll_read(
787 self: Pin<&mut Self>,
788 cx: &mut Context<'_>,
789 buf: &mut ReadBuf<'_>,
790 ) -> Poll<std::io::Result<()>> {
791 let mut this = self.project();
792
793 loop {
794 if !this.read_buf.is_empty() {
796 let len = std::cmp::min(buf.remaining(), this.read_buf.len());
797 buf.put_slice(&this.read_buf.split_to(len));
798 return Poll::Ready(Ok(()));
799 }
800
801 match this.inner.poll_next_unpin(cx) {
802 Poll::Ready(Some(Ok(message))) => {
803 this.read_buf.extend_from_slice(message.as_payload());
804 }
805 Poll::Ready(Some(Err(e))) => {
806 return Poll::Ready(Err(std::io::Error::other(e)));
807 }
808 Poll::Ready(None) => {
809 return Poll::Ready(Err(std::io::Error::new(
810 std::io::ErrorKind::UnexpectedEof,
811 "WebSocket closed",
812 )));
813 }
814 Poll::Pending => {
815 return Poll::Pending;
816 }
817 }
818 }
819 }
820}
821
822#[cfg(feature = "websockets")]
823impl<T> AsyncWrite for WebSocketAdapter<T>
824where
825 T: AsyncRead + AsyncWrite + Unpin,
826{
827 fn poll_write(
828 self: Pin<&mut Self>,
829 cx: &mut Context<'_>,
830 buf: &[u8],
831 ) -> Poll<std::io::Result<usize>> {
832 let mut this = self.project();
833
834 let data = buf.to_vec();
835 match this.inner.poll_ready_unpin(cx) {
836 Poll::Ready(Ok(())) => match this
837 .inner
838 .start_send_unpin(tokio_websockets::Message::binary(data))
839 {
840 Ok(()) => Poll::Ready(Ok(buf.len())),
841 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
842 },
843 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
844 Poll::Pending => Poll::Pending,
845 }
846 }
847
848 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
849 self.project()
850 .inner
851 .poll_flush_unpin(cx)
852 .map_err(std::io::Error::other)
853 }
854
855 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
856 self.project()
857 .inner
858 .poll_close_unpin(cx)
859 .map_err(std::io::Error::other)
860 }
861}
862
863#[cfg(test)]
864mod read_op {
865 use std::sync::Arc;
866
867 use super::Connection;
868 use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
869 use tokio::io::{self, AsyncWriteExt};
870
871 #[tokio::test]
872 async fn ok() {
873 let (stream, mut server) = io::duplex(128);
874 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
875
876 server.write_all(b"+OK\r\n").await.unwrap();
877 let result = connection.read_op().await.unwrap();
878 assert_eq!(result, Some(ServerOp::Ok));
879 }
880
881 #[tokio::test]
882 async fn ping() {
883 let (stream, mut server) = io::duplex(128);
884 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
885
886 server.write_all(b"PING\r\n").await.unwrap();
887 let result = connection.read_op().await.unwrap();
888 assert_eq!(result, Some(ServerOp::Ping));
889 }
890
891 #[tokio::test]
892 async fn pong() {
893 let (stream, mut server) = io::duplex(128);
894 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
895
896 server.write_all(b"PONG\r\n").await.unwrap();
897 let result = connection.read_op().await.unwrap();
898 assert_eq!(result, Some(ServerOp::Pong));
899 }
900
901 #[tokio::test]
902 async fn info() {
903 let (stream, mut server) = io::duplex(128);
904 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
905
906 server.write_all(b"INFO {}\r\n").await.unwrap();
907 server.flush().await.unwrap();
908
909 let result = connection.read_op().await.unwrap();
910 assert_eq!(result, Some(ServerOp::Info(Box::default())));
911
912 server
913 .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
914 .await
915 .unwrap();
916 server.flush().await.unwrap();
917
918 let result = connection.read_op().await.unwrap();
919 assert_eq!(
920 result,
921 Some(ServerOp::Info(Box::new(ServerInfo {
922 version: "1.0.0".into(),
923 ..Default::default()
924 })))
925 );
926
927 server
928 .write_all(b"INFO { \"version\": \"1.0.0\", \"cluster\": \"test-cluster\" }\r\n")
929 .await
930 .unwrap();
931 server.flush().await.unwrap();
932
933 let result = connection.read_op().await.unwrap();
934 assert_eq!(
935 result,
936 Some(ServerOp::Info(Box::new(ServerInfo {
937 version: "1.0.0".into(),
938 cluster: Some("test-cluster".into()),
939 ..Default::default()
940 })))
941 );
942 }
943
944 #[tokio::test]
945 async fn error() {
946 let (stream, mut server) = io::duplex(128);
947 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
948
949 server.write_all(b"INFO {}\r\n").await.unwrap();
950 let result = connection.read_op().await.unwrap();
951 assert_eq!(result, Some(ServerOp::Info(Box::default())));
952
953 server
954 .write_all(b"-ERR something went wrong\r\n")
955 .await
956 .unwrap();
957 let result = connection.read_op().await.unwrap();
958 assert_eq!(
959 result,
960 Some(ServerOp::Error(ServerError::Other(
961 "something went wrong".into()
962 )))
963 );
964 }
965
966 #[tokio::test]
967 async fn message() {
968 let (stream, mut server) = io::duplex(128);
969 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
970
971 server
972 .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
973 .await
974 .unwrap();
975
976 let result = connection.read_op().await.unwrap();
977 assert_eq!(
978 result,
979 Some(ServerOp::Message {
980 sid: 9,
981 subject: "FOO.BAR".into(),
982 reply: None,
983 headers: None,
984 payload: "Hello World".into(),
985 status: None,
986 description: None,
987 length: 7 + 11,
988 })
989 );
990
991 server
992 .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
993 .await
994 .unwrap();
995
996 let result = connection.read_op().await.unwrap();
997 assert_eq!(
998 result,
999 Some(ServerOp::Message {
1000 sid: 9,
1001 subject: "FOO.BAR".into(),
1002 reply: Some("INBOX.34".into()),
1003 headers: None,
1004 payload: "Hello World".into(),
1005 status: None,
1006 description: None,
1007 length: 7 + 8 + 11,
1008 })
1009 );
1010
1011 server
1012 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1013 .await
1014 .unwrap();
1015 server.write_all(b"NATS/1.0\r\n").await.unwrap();
1016 server.write_all(b"Header: X\r\n").await.unwrap();
1017 server.write_all(b"\r\n").await.unwrap();
1018 server.write_all(b"Hello World\r\n").await.unwrap();
1019
1020 let result = connection.read_op().await.unwrap();
1021
1022 assert_eq!(
1023 result,
1024 Some(ServerOp::Message {
1025 sid: 10,
1026 subject: "FOO.BAR".into(),
1027 reply: Some("INBOX.35".into()),
1028 headers: Some(HeaderMap::from_iter([(
1029 "Header".parse().unwrap(),
1030 "X".parse().unwrap()
1031 )])),
1032 payload: "Hello World".into(),
1033 status: None,
1034 description: None,
1035 length: 7 + 8 + 34
1036 })
1037 );
1038
1039 server
1040 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1041 .await
1042 .unwrap();
1043 server.write_all(b"NATS/1.0\r\n").await.unwrap();
1044 server.write_all(b"Header: Y\r\n").await.unwrap();
1045 server.write_all(b"\r\n").await.unwrap();
1046 server.write_all(b"Hello World\r\n").await.unwrap();
1047
1048 let result = connection.read_op().await.unwrap();
1049 assert_eq!(
1050 result,
1051 Some(ServerOp::Message {
1052 sid: 10,
1053 subject: "FOO.BAR".into(),
1054 reply: Some("INBOX.35".into()),
1055 headers: Some(HeaderMap::from_iter([(
1056 "Header".parse().unwrap(),
1057 "Y".parse().unwrap()
1058 )])),
1059 payload: "Hello World".into(),
1060 status: None,
1061 description: None,
1062 length: 7 + 8 + 34,
1063 })
1064 );
1065
1066 server
1067 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1068 .await
1069 .unwrap();
1070 server
1071 .write_all(b"NATS/1.0 404 No Messages\r\n")
1072 .await
1073 .unwrap();
1074 server.write_all(b"\r\n").await.unwrap();
1075 server.write_all(b"\r\n").await.unwrap();
1076
1077 let result = connection.read_op().await.unwrap();
1078 assert_eq!(
1079 result,
1080 Some(ServerOp::Message {
1081 sid: 10,
1082 subject: "FOO.BAR".into(),
1083 reply: Some("INBOX.35".into()),
1084 headers: Some(HeaderMap::default()),
1085 payload: "".into(),
1086 status: Some(StatusCode::NOT_FOUND),
1087 description: Some("No Messages".to_string()),
1088 length: 7 + 8 + 28,
1089 })
1090 );
1091
1092 server
1093 .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1094 .await
1095 .unwrap();
1096
1097 let result = connection.read_op().await.unwrap();
1098 assert_eq!(
1099 result,
1100 Some(ServerOp::Message {
1101 sid: 9,
1102 subject: "FOO.BAR".into(),
1103 reply: None,
1104 headers: None,
1105 payload: "Hello Again".into(),
1106 status: None,
1107 description: None,
1108 length: 7 + 11,
1109 })
1110 );
1111 }
1112
1113 #[tokio::test]
1114 async fn unknown() {
1115 let (stream, mut server) = io::duplex(128);
1116 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1117
1118 server.write_all(b"ONE\r\n").await.unwrap();
1119 connection.read_op().await.unwrap_err();
1120
1121 server.write_all(b"TWO\r\n").await.unwrap();
1122 connection.read_op().await.unwrap_err();
1123
1124 server.write_all(b"PING\r\n").await.unwrap();
1125 connection.read_op().await.unwrap();
1126
1127 server.write_all(b"THREE\r\n").await.unwrap();
1128 connection.read_op().await.unwrap_err();
1129
1130 server
1131 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1132 .await
1133 .unwrap();
1134 server
1135 .write_all(b"NATS/1.0 404 No Messages\r\n")
1136 .await
1137 .unwrap();
1138 server.write_all(b"\r\n").await.unwrap();
1139 server.write_all(b"\r\n").await.unwrap();
1140
1141 let result = connection.read_op().await.unwrap();
1142 assert_eq!(
1143 result,
1144 Some(ServerOp::Message {
1145 sid: 10,
1146 subject: "FOO.BAR".into(),
1147 reply: Some("INBOX.35".into()),
1148 headers: Some(HeaderMap::default()),
1149 payload: "".into(),
1150 status: Some(StatusCode::NOT_FOUND),
1151 description: Some("No Messages".to_string()),
1152 length: 7 + 8 + 28,
1153 })
1154 );
1155
1156 server.write_all(b"FOUR\r\n").await.unwrap();
1157 connection.read_op().await.unwrap_err();
1158
1159 server.write_all(b"PONG\r\n").await.unwrap();
1160 connection.read_op().await.unwrap();
1161 }
1162}
1163
1164#[cfg(test)]
1165mod write_op {
1166 use std::sync::Arc;
1167
1168 use super::Connection;
1169 use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1170 use tokio::io::{self, AsyncBufReadExt, BufReader};
1171
1172 #[tokio::test]
1173 async fn publish() {
1174 let (stream, server) = io::duplex(128);
1175 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1176
1177 connection
1178 .easy_write_and_flush(
1179 [ClientOp::Publish {
1180 subject: "FOO.BAR".into(),
1181 payload: "Hello World".into(),
1182 respond: None,
1183 headers: None,
1184 }]
1185 .iter(),
1186 )
1187 .await
1188 .unwrap();
1189
1190 let mut buffer = String::new();
1191 let mut reader = BufReader::new(server);
1192 reader.read_line(&mut buffer).await.unwrap();
1193 reader.read_line(&mut buffer).await.unwrap();
1194 assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1195
1196 connection
1197 .easy_write_and_flush(
1198 [ClientOp::Publish {
1199 subject: "FOO.BAR".into(),
1200 payload: "Hello World".into(),
1201 respond: Some("INBOX.67".into()),
1202 headers: None,
1203 }]
1204 .iter(),
1205 )
1206 .await
1207 .unwrap();
1208
1209 buffer.clear();
1210 reader.read_line(&mut buffer).await.unwrap();
1211 reader.read_line(&mut buffer).await.unwrap();
1212 assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1213
1214 connection
1215 .easy_write_and_flush(
1216 [ClientOp::Publish {
1217 subject: "FOO.BAR".into(),
1218 payload: "Hello World".into(),
1219 respond: Some("INBOX.67".into()),
1220 headers: Some(HeaderMap::from_iter([(
1221 "Header".parse().unwrap(),
1222 "X".parse().unwrap(),
1223 )])),
1224 }]
1225 .iter(),
1226 )
1227 .await
1228 .unwrap();
1229
1230 buffer.clear();
1231 reader.read_line(&mut buffer).await.unwrap();
1232 reader.read_line(&mut buffer).await.unwrap();
1233 reader.read_line(&mut buffer).await.unwrap();
1234 reader.read_line(&mut buffer).await.unwrap();
1235 assert_eq!(
1236 buffer,
1237 "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn subscribe() {
1243 let (stream, server) = io::duplex(128);
1244 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1245
1246 connection
1247 .easy_write_and_flush(
1248 [ClientOp::Subscribe {
1249 sid: 11,
1250 subject: "FOO.BAR".into(),
1251 queue_group: None,
1252 }]
1253 .iter(),
1254 )
1255 .await
1256 .unwrap();
1257
1258 let mut buffer = String::new();
1259 let mut reader = BufReader::new(server);
1260 reader.read_line(&mut buffer).await.unwrap();
1261 assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1262
1263 connection
1264 .easy_write_and_flush(
1265 [ClientOp::Subscribe {
1266 sid: 11,
1267 subject: "FOO.BAR".into(),
1268 queue_group: Some("QUEUE.GROUP".into()),
1269 }]
1270 .iter(),
1271 )
1272 .await
1273 .unwrap();
1274
1275 buffer.clear();
1276 reader.read_line(&mut buffer).await.unwrap();
1277 assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1278 }
1279
1280 #[tokio::test]
1281 async fn unsubscribe() {
1282 let (stream, server) = io::duplex(128);
1283 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1284
1285 connection
1286 .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1287 .await
1288 .unwrap();
1289
1290 let mut buffer = String::new();
1291 let mut reader = BufReader::new(server);
1292 reader.read_line(&mut buffer).await.unwrap();
1293 assert_eq!(buffer, "UNSUB 11\r\n");
1294
1295 connection
1296 .easy_write_and_flush(
1297 [ClientOp::Unsubscribe {
1298 sid: 11,
1299 max: Some(2),
1300 }]
1301 .iter(),
1302 )
1303 .await
1304 .unwrap();
1305
1306 buffer.clear();
1307 reader.read_line(&mut buffer).await.unwrap();
1308 assert_eq!(buffer, "UNSUB 11 2\r\n");
1309 }
1310
1311 #[tokio::test]
1312 async fn ping() {
1313 let (stream, server) = io::duplex(128);
1314 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1315
1316 let mut reader = BufReader::new(server);
1317 let mut buffer = String::new();
1318
1319 connection
1320 .easy_write_and_flush([ClientOp::Ping].iter())
1321 .await
1322 .unwrap();
1323
1324 reader.read_line(&mut buffer).await.unwrap();
1325
1326 assert_eq!(buffer, "PING\r\n");
1327 }
1328
1329 #[tokio::test]
1330 async fn pong() {
1331 let (stream, server) = io::duplex(128);
1332 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1333
1334 let mut reader = BufReader::new(server);
1335 let mut buffer = String::new();
1336
1337 connection
1338 .easy_write_and_flush([ClientOp::Pong].iter())
1339 .await
1340 .unwrap();
1341
1342 reader.read_line(&mut buffer).await.unwrap();
1343
1344 assert_eq!(buffer, "PONG\r\n");
1345 }
1346
1347 #[tokio::test]
1348 async fn connect() {
1349 let (stream, server) = io::duplex(1024);
1350 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1351
1352 let mut reader = BufReader::new(server);
1353 let mut buffer = String::new();
1354
1355 connection
1356 .easy_write_and_flush(
1357 [ClientOp::Connect(ConnectInfo {
1358 verbose: false,
1359 pedantic: false,
1360 user_jwt: None,
1361 nkey: None,
1362 signature: None,
1363 name: None,
1364 echo: false,
1365 lang: "Rust".into(),
1366 version: "1.0.0".into(),
1367 protocol: Protocol::Dynamic,
1368 tls_required: false,
1369 user: None,
1370 pass: None,
1371 auth_token: None,
1372 headers: false,
1373 no_responders: false,
1374 })]
1375 .iter(),
1376 )
1377 .await
1378 .unwrap();
1379
1380 reader.read_line(&mut buffer).await.unwrap();
1381 assert_eq!(
1382 buffer,
1383 "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1384 );
1385 }
1386}