1use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28 futures::{SinkExt, StreamExt},
29 pin_project::pin_project,
30 tokio::io::ReadBuf,
31 tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36use tracing::trace;
37
38use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
39use crate::status::StatusCode;
40use crate::subject::Subject;
41use crate::{ClientOp, ServerError, ServerOp, Statistics};
42
43const SOFT_WRITE_BUF_LIMIT: usize = 65535;
46const WRITE_FLATTEN_THRESHOLD: usize = 4096;
49const WRITE_VECTORED_CHUNKS: usize = 64;
51
52pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
54
55impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
57
58#[derive(Debug, Eq, PartialEq, Clone)]
60pub enum State {
61 Pending,
62 Connected,
63 Disconnected,
64}
65
66#[derive(Debug, Eq, PartialEq, Clone)]
67pub enum ShouldFlush {
68 Yes,
70 May,
72 No,
74}
75
76impl Display for State {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 State::Pending => write!(f, "pending"),
80 State::Connected => write!(f, "connected"),
81 State::Disconnected => write!(f, "disconnected"),
82 }
83 }
84}
85
86pub(crate) struct Connection {
88 pub(crate) stream: Box<dyn AsyncReadWrite>,
89 read_buf: BytesMut,
90 write_buf: VecDeque<Bytes>,
91 write_buf_len: usize,
92 flattened_writes: BytesMut,
93 can_flush: bool,
94 statistics: Arc<Statistics>,
95}
96
97impl Connection {
100 pub(crate) fn new(
101 stream: Box<dyn AsyncReadWrite>,
102 read_buffer_capacity: usize,
103 statistics: Arc<Statistics>,
104 ) -> Self {
105 Self {
106 stream,
107 read_buf: BytesMut::with_capacity(read_buffer_capacity),
108 write_buf: VecDeque::new(),
109 write_buf_len: 0,
110 flattened_writes: BytesMut::new(),
111 can_flush: false,
112 statistics,
113 }
114 }
115
116 pub(crate) fn is_write_buf_full(&self) -> bool {
118 self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
119 }
120
121 pub(crate) fn should_flush(&self) -> ShouldFlush {
123 match (
124 self.can_flush,
125 self.write_buf.is_empty() && self.flattened_writes.is_empty(),
126 ) {
127 (true, true) => ShouldFlush::Yes,
128 (true, false) => ShouldFlush::May,
129 (false, _) => ShouldFlush::No,
130 }
131 }
132
133 pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
136 let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
137 Some(len) => len,
138 None => return Ok(None),
139 };
140
141 if self.read_buf.starts_with(b"+OK") {
142 self.read_buf.advance(len + 2);
143 trace!("read operation: OK");
144 return Ok(Some(ServerOp::Ok));
145 }
146
147 if self.read_buf.starts_with(b"PING") {
148 self.read_buf.advance(len + 2);
149 trace!("read operation: PING");
150 return Ok(Some(ServerOp::Ping));
151 }
152
153 if self.read_buf.starts_with(b"PONG") {
154 self.read_buf.advance(len + 2);
155 trace!("read operation: PONG");
156 return Ok(Some(ServerOp::Pong));
157 }
158
159 if self.read_buf.starts_with(b"-ERR") {
160 let description = str::from_utf8(&self.read_buf[5..len])
161 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
162 .trim_matches('\'')
163 .to_owned();
164
165 self.read_buf.advance(len + 2);
166 trace!(error = %description, "read operation: ERR");
167 return Ok(Some(ServerOp::Error(ServerError::new(description))));
168 }
169
170 if self.read_buf.starts_with(b"INFO ") {
171 let info = serde_json::from_slice(&self.read_buf[4..len])
172 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
173
174 self.read_buf.advance(len + 2);
175 trace!(?info, "read operation: INFO");
176 return Ok(Some(ServerOp::Info(Box::new(info))));
177 }
178
179 if self.read_buf.starts_with(b"MSG ") {
180 let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
181 let mut args = line.split(' ').filter(|s| !s.is_empty());
182
183 let (subject, sid, reply_to, payload_len) = match (
185 args.next(),
186 args.next(),
187 args.next(),
188 args.next(),
189 args.next(),
190 ) {
191 (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
192 (subject, sid, Some(reply_to), payload_len)
193 }
194 (Some(subject), Some(sid), Some(payload_len), None, None) => {
195 (subject, sid, None, payload_len)
196 }
197 _ => {
198 return Err(io::Error::new(
199 io::ErrorKind::InvalidInput,
200 "invalid number of arguments after MSG",
201 ))
202 }
203 };
204
205 let sid = sid
206 .parse::<u64>()
207 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
208
209 let payload_len = payload_len
211 .parse::<usize>()
212 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
213
214 if len + payload_len + 4 > self.read_buf.remaining() {
217 return Ok(None);
218 }
219
220 let length = payload_len
221 + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
222 + subject.len();
223
224 let subject = Subject::from(subject);
225 let reply = reply_to.map(Subject::from);
226
227 self.read_buf.advance(len + 2);
228 let payload = self.read_buf.split_to(payload_len).freeze();
229 self.read_buf.advance(2);
230
231 trace!(
232 subject = %subject,
233 sid = %sid,
234 reply = ?reply,
235 payload_len = %payload_len,
236 "read operation: MSG"
237 );
238
239 return Ok(Some(ServerOp::Message {
240 sid,
241 length,
242 reply,
243 headers: None,
244 subject,
245 payload,
246 status: None,
247 description: None,
248 }));
249 }
250
251 if self.read_buf.starts_with(b"HMSG ") {
252 let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
254 let mut args = line.split_whitespace().filter(|s| !s.is_empty());
255
256 let (subject, sid, reply_to, header_len, total_len) = match (
258 args.next(),
259 args.next(),
260 args.next(),
261 args.next(),
262 args.next(),
263 args.next(),
264 ) {
265 (
266 Some(subject),
267 Some(sid),
268 Some(reply_to),
269 Some(header_len),
270 Some(total_len),
271 None,
272 ) => (subject, sid, Some(reply_to), header_len, total_len),
273 (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
274 (subject, sid, None, header_len, total_len)
275 }
276 _ => {
277 return Err(io::Error::new(
278 io::ErrorKind::InvalidInput,
279 "invalid number of arguments after HMSG",
280 ))
281 }
282 };
283
284 let subject = Subject::from(subject);
286
287 let sid = sid.parse::<u64>().map_err(|_| {
289 io::Error::new(
290 io::ErrorKind::InvalidInput,
291 "cannot parse sid argument after HMSG",
292 )
293 })?;
294
295 let reply = reply_to.map(Subject::from);
297
298 let header_len = header_len.parse::<usize>().map_err(|_| {
300 io::Error::new(
301 io::ErrorKind::InvalidInput,
302 "cannot parse the number of header bytes argument after \
303 HMSG",
304 )
305 })?;
306
307 let total_len = total_len.parse::<usize>().map_err(|_| {
309 io::Error::new(
310 io::ErrorKind::InvalidInput,
311 "cannot parse the number of bytes argument after HMSG",
312 )
313 })?;
314
315 if total_len < header_len {
316 return Err(io::Error::new(
317 io::ErrorKind::InvalidInput,
318 "number of header bytes was greater than or equal to the \
319 total number of bytes after HMSG",
320 ));
321 }
322
323 if len + total_len + 4 > self.read_buf.remaining() {
324 return Ok(None);
325 }
326
327 self.read_buf.advance(len + 2);
328 let header = self.read_buf.split_to(header_len);
329 let payload = self.read_buf.split_to(total_len - header_len).freeze();
330 self.read_buf.advance(2);
331
332 let mut lines = std::str::from_utf8(&header)
333 .map_err(|_| {
334 io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
335 })?
336 .lines()
337 .peekable();
338 let version_line = lines.next().ok_or_else(|| {
339 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
340 })?;
341
342 let version_line_suffix = version_line
343 .strip_prefix("NATS/1.0")
344 .map(str::trim)
345 .ok_or_else(|| {
346 io::Error::new(
347 io::ErrorKind::InvalidInput,
348 "header version line does not begin with `NATS/1.0`",
349 )
350 })?;
351
352 let (status, description) = version_line_suffix
353 .split_once(' ')
354 .map(|(status, description)| (status.trim(), description.trim()))
355 .unwrap_or((version_line_suffix, ""));
356 let status = if !status.is_empty() {
357 Some(status.parse::<StatusCode>().map_err(|_| {
358 std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
359 })?)
360 } else {
361 None
362 };
363 let description = if !description.is_empty() {
364 Some(description.to_owned())
365 } else {
366 None
367 };
368
369 let mut headers = HeaderMap::new();
370 while let Some(line) = lines.next() {
371 if line.is_empty() {
372 continue;
373 }
374
375 let (name, value) = line.split_once(':').ok_or_else(|| {
376 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
377 })?;
378
379 let name = HeaderName::from_str(name)
380 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
381
382 let mut value = value.trim_start().to_owned();
385 while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
386 value.push_str(v);
387 }
388 value.truncate(value.trim_end().len());
389
390 headers.append(name, value.into_header_value());
391 }
392
393 trace!(
394 subject = %subject,
395 sid = %sid,
396 reply = ?reply,
397 header_len = %header_len,
398 total_len = %total_len,
399 status = ?status,
400 description = ?description,
401 "read operation: HMSG"
402 );
403
404 return Ok(Some(ServerOp::Message {
405 length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
406 sid,
407 reply,
408 subject,
409 headers: Some(headers),
410 payload,
411 status,
412 description,
413 }));
414 }
415
416 let buffer = self.read_buf.split_to(len + 2);
417 let line = str::from_utf8(&buffer).map_err(|_| {
418 io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
419 })?;
420
421 trace!(line = %line, "read operation: unknown");
422 Err(io::Error::new(
423 io::ErrorKind::InvalidInput,
424 format!("invalid server operation: '{line}'"),
425 ))
426 }
427
428 pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
429 future::poll_fn(|cx| self.poll_read_op(cx))
430 }
431
432 pub(crate) fn poll_read_op(
436 &mut self,
437 cx: &mut Context<'_>,
438 ) -> Poll<io::Result<Option<ServerOp>>> {
439 loop {
440 if let Some(op) = self.try_read_op()? {
441 trace!(?op, "read operation completed");
442 return Poll::Ready(Ok(Some(op)));
443 }
444
445 let read_buf = self.stream.read_buf(&mut self.read_buf);
446 tokio::pin!(read_buf);
447 return match read_buf.poll(cx) {
448 Poll::Pending => {
449 trace!("read operation pending");
450 Poll::Pending
451 }
452 Poll::Ready(Ok(0)) if self.read_buf.is_empty() => {
453 trace!("read operation: empty buffer");
454 Poll::Ready(Ok(None))
455 }
456 Poll::Ready(Ok(0)) => {
457 trace!("read operation: connection reset");
458 Poll::Ready(Err(io::ErrorKind::ConnectionReset.into()))
459 }
460 Poll::Ready(Ok(n)) => {
461 self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
462 trace!(bytes = %n, "read operation: received bytes");
463 continue;
464 }
465 Poll::Ready(Err(err)) => {
466 trace!(error = %err, "read operation: error");
467 Poll::Ready(Err(err))
468 }
469 };
470 }
471 }
472
473 pub(crate) async fn easy_write_and_flush<'a>(
474 &mut self,
475 items: impl Iterator<Item = &'a ClientOp>,
476 ) -> io::Result<()> {
477 for item in items {
478 self.enqueue_write_op(item);
479 }
480
481 future::poll_fn(|cx| self.poll_write(cx)).await?;
482 future::poll_fn(|cx| self.poll_flush(cx)).await?;
483 Ok(())
484 }
485
486 pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
488 macro_rules! small_write {
489 ($dst:expr) => {
490 write!(self.small_write(), $dst).expect("do small write to Connection");
491 };
492 }
493
494 match item {
495 ClientOp::Connect(connect_info) => {
496 let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
497
498 self.write("CONNECT ");
499 self.write(json);
500 self.write("\r\n");
501 trace!(?connect_info, "write operation: CONNECT");
502 }
503 ClientOp::Publish {
504 subject,
505 payload,
506 respond,
507 headers,
508 } => {
509 let verb = match headers.as_ref() {
510 Some(headers) if !headers.is_empty() => "HPUB",
511 _ => "PUB",
512 };
513
514 small_write!("{verb} {subject} ");
515
516 if let Some(respond) = respond {
517 small_write!("{respond} ");
518 }
519
520 match headers {
521 Some(headers) if !headers.is_empty() => {
522 let headers = headers.to_bytes();
523
524 let headers_len = headers.len();
525 let total_len = headers_len + payload.len();
526 small_write!("{headers_len} {total_len}\r\n");
527 self.write(headers);
528 }
529 _ => {
530 let payload_len = payload.len();
531 small_write!("{payload_len}\r\n");
532 }
533 }
534
535 self.write(Bytes::clone(payload));
536 self.write("\r\n");
537
538 trace!(
539 verb = %verb,
540 subject = %subject,
541 reply = ?respond,
542 headers = ?headers,
543 payload_len = %payload.len(),
544 "write operation: PUB"
545 );
546 }
547
548 ClientOp::Subscribe {
549 sid,
550 subject,
551 queue_group,
552 } => {
553 match queue_group {
554 Some(queue_group) => {
555 small_write!("SUB {subject} {queue_group} {sid}\r\n");
556 }
557 None => {
558 small_write!("SUB {subject} {sid}\r\n");
559 }
560 }
561
562 trace!(
563 subject = %subject,
564 sid = %sid,
565 queue_group = ?queue_group,
566 "write operation: SUB"
567 );
568 }
569
570 ClientOp::Unsubscribe { sid, max } => {
571 match max {
572 Some(max) => {
573 small_write!("UNSUB {sid} {max}\r\n");
574 }
575 None => {
576 small_write!("UNSUB {sid}\r\n");
577 }
578 }
579
580 trace!(
581 sid = %sid,
582 max = ?max,
583 "write operation: UNSUB"
584 );
585 }
586 ClientOp::Ping => {
587 self.write("PING\r\n");
588 trace!("write operation: PING");
589 }
590 ClientOp::Pong => {
591 self.write("PONG\r\n");
592 trace!("write operation: PONG");
593 }
594 }
595 }
596
597 pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
610 if !self.stream.is_write_vectored() {
611 self.poll_write_sequential(cx)
612 } else {
613 self.poll_write_vectored(cx)
614 }
615 }
616
617 fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
621 loop {
622 let buf = match self.write_buf.front() {
623 Some(buf) => &**buf,
624 None if !self.flattened_writes.is_empty() => &self.flattened_writes,
625 None => return Poll::Ready(Ok(())),
626 };
627
628 debug_assert!(!buf.is_empty());
629
630 match Pin::new(&mut self.stream).poll_write(cx, buf) {
631 Poll::Pending => return Poll::Pending,
632 Poll::Ready(Ok(n)) => {
633 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
634 self.write_buf_len -= n;
635 self.can_flush = true;
636
637 match self.write_buf.front_mut() {
638 Some(buf) if n < buf.len() => {
639 buf.advance(n);
640 }
641 Some(_buf) => {
642 self.write_buf.pop_front();
643 }
644 None => {
645 self.flattened_writes.advance(n);
646 }
647 }
648 continue;
649 }
650 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
651 }
652 }
653 }
654 fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
659 'outer: loop {
660 let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
661 let mut writes_len = 0;
662
663 self.write_buf
664 .iter()
665 .take(WRITE_VECTORED_CHUNKS)
666 .enumerate()
667 .for_each(|(i, buf)| {
668 writes[i] = IoSlice::new(buf);
669 writes_len += 1;
670 });
671
672 if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
673 writes[writes_len] = IoSlice::new(&self.flattened_writes);
674 writes_len += 1;
675 }
676
677 if writes_len == 0 {
678 return Poll::Ready(Ok(()));
679 }
680
681 match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
682 Poll::Pending => return Poll::Pending,
683 Poll::Ready(Ok(mut n)) => {
684 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
685 self.write_buf_len -= n;
686 self.can_flush = true;
687
688 while let Some(buf) = self.write_buf.front_mut() {
689 if n < buf.len() {
690 buf.advance(n);
691 continue 'outer;
692 }
693
694 n -= buf.len();
695 self.write_buf.pop_front();
696 }
697
698 self.flattened_writes.advance(n);
699 }
700 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
701 }
702 }
703 }
704
705 fn write(&mut self, buf: impl Into<Bytes>) {
712 let buf = buf.into();
713 if buf.is_empty() {
714 return;
715 }
716
717 self.write_buf_len += buf.len();
718 if buf.len() < WRITE_FLATTEN_THRESHOLD {
719 self.flattened_writes.extend_from_slice(&buf);
720 } else {
721 if !self.flattened_writes.is_empty() {
722 let buf = self.flattened_writes.split().freeze();
723 self.write_buf.push_back(buf);
724 }
725
726 self.write_buf.push_back(buf);
727 }
728 }
729
730 fn small_write(&mut self) -> impl fmt::Write + '_ {
732 struct Writer<'a> {
733 this: &'a mut Connection,
734 }
735
736 impl fmt::Write for Writer<'_> {
737 fn write_str(&mut self, s: &str) -> fmt::Result {
738 self.this.write_buf_len += s.len();
739 self.this.flattened_writes.write_str(s)
740 }
741 }
742
743 Writer { this: self }
744 }
745
746 pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
750 match Pin::new(&mut self.stream).poll_flush(cx) {
751 Poll::Pending => Poll::Pending,
752 Poll::Ready(Ok(())) => {
753 self.can_flush = false;
754 Poll::Ready(Ok(()))
755 }
756 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
757 }
758 }
759}
760
761#[cfg(feature = "websockets")]
762#[pin_project]
763pub(crate) struct WebSocketAdapter<T> {
764 #[pin]
765 pub(crate) inner: WebSocketStream<T>,
766 pub(crate) read_buf: BytesMut,
767}
768
769#[cfg(feature = "websockets")]
770impl<T> WebSocketAdapter<T> {
771 pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
772 Self {
773 inner,
774 read_buf: BytesMut::new(),
775 }
776 }
777}
778
779#[cfg(feature = "websockets")]
780impl<T> AsyncRead for WebSocketAdapter<T>
781where
782 T: AsyncRead + AsyncWrite + Unpin,
783{
784 fn poll_read(
785 self: Pin<&mut Self>,
786 cx: &mut Context<'_>,
787 buf: &mut ReadBuf<'_>,
788 ) -> Poll<std::io::Result<()>> {
789 let mut this = self.project();
790
791 loop {
792 if !this.read_buf.is_empty() {
794 let len = std::cmp::min(buf.remaining(), this.read_buf.len());
795 buf.put_slice(&this.read_buf.split_to(len));
796 return Poll::Ready(Ok(()));
797 }
798
799 match this.inner.poll_next_unpin(cx) {
800 Poll::Ready(Some(Ok(message))) => {
801 this.read_buf.extend_from_slice(message.as_payload());
802 }
803 Poll::Ready(Some(Err(e))) => {
804 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
805 }
806 Poll::Ready(None) => {
807 return Poll::Ready(Err(std::io::Error::new(
808 std::io::ErrorKind::UnexpectedEof,
809 "WebSocket closed",
810 )));
811 }
812 Poll::Pending => {
813 return Poll::Pending;
814 }
815 }
816 }
817 }
818}
819
820#[cfg(feature = "websockets")]
821impl<T> AsyncWrite for WebSocketAdapter<T>
822where
823 T: AsyncRead + AsyncWrite + Unpin,
824{
825 fn poll_write(
826 self: Pin<&mut Self>,
827 cx: &mut Context<'_>,
828 buf: &[u8],
829 ) -> Poll<std::io::Result<usize>> {
830 let mut this = self.project();
831
832 let data = buf.to_vec();
833 match this.inner.poll_ready_unpin(cx) {
834 Poll::Ready(Ok(())) => match this
835 .inner
836 .start_send_unpin(tokio_websockets::Message::binary(data))
837 {
838 Ok(()) => Poll::Ready(Ok(buf.len())),
839 Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))),
840 },
841 Poll::Ready(Err(e)) => {
842 Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
843 }
844 Poll::Pending => Poll::Pending,
845 }
846 }
847
848 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
849 self.project()
850 .inner
851 .poll_flush_unpin(cx)
852 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
853 }
854
855 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
856 self.project()
857 .inner
858 .poll_close_unpin(cx)
859 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
860 }
861}
862
863#[cfg(test)]
864mod read_op {
865 use std::sync::Arc;
866
867 use super::Connection;
868 use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
869 use tokio::io::{self, AsyncWriteExt};
870
871 #[tokio::test]
872 async fn ok() {
873 let (stream, mut server) = io::duplex(128);
874 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
875
876 server.write_all(b"+OK\r\n").await.unwrap();
877 let result = connection.read_op().await.unwrap();
878 assert_eq!(result, Some(ServerOp::Ok));
879 }
880
881 #[tokio::test]
882 async fn ping() {
883 let (stream, mut server) = io::duplex(128);
884 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
885
886 server.write_all(b"PING\r\n").await.unwrap();
887 let result = connection.read_op().await.unwrap();
888 assert_eq!(result, Some(ServerOp::Ping));
889 }
890
891 #[tokio::test]
892 async fn pong() {
893 let (stream, mut server) = io::duplex(128);
894 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
895
896 server.write_all(b"PONG\r\n").await.unwrap();
897 let result = connection.read_op().await.unwrap();
898 assert_eq!(result, Some(ServerOp::Pong));
899 }
900
901 #[tokio::test]
902 async fn info() {
903 let (stream, mut server) = io::duplex(128);
904 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
905
906 server.write_all(b"INFO {}\r\n").await.unwrap();
907 server.flush().await.unwrap();
908
909 let result = connection.read_op().await.unwrap();
910 assert_eq!(result, Some(ServerOp::Info(Box::default())));
911
912 server
913 .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
914 .await
915 .unwrap();
916 server.flush().await.unwrap();
917
918 let result = connection.read_op().await.unwrap();
919 assert_eq!(
920 result,
921 Some(ServerOp::Info(Box::new(ServerInfo {
922 version: "1.0.0".into(),
923 ..Default::default()
924 })))
925 );
926 }
927
928 #[tokio::test]
929 async fn error() {
930 let (stream, mut server) = io::duplex(128);
931 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
932
933 server.write_all(b"INFO {}\r\n").await.unwrap();
934 let result = connection.read_op().await.unwrap();
935 assert_eq!(result, Some(ServerOp::Info(Box::default())));
936
937 server
938 .write_all(b"-ERR something went wrong\r\n")
939 .await
940 .unwrap();
941 let result = connection.read_op().await.unwrap();
942 assert_eq!(
943 result,
944 Some(ServerOp::Error(ServerError::Other(
945 "something went wrong".into()
946 )))
947 );
948 }
949
950 #[tokio::test]
951 async fn message() {
952 let (stream, mut server) = io::duplex(128);
953 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
954
955 server
956 .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
957 .await
958 .unwrap();
959
960 let result = connection.read_op().await.unwrap();
961 assert_eq!(
962 result,
963 Some(ServerOp::Message {
964 sid: 9,
965 subject: "FOO.BAR".into(),
966 reply: None,
967 headers: None,
968 payload: "Hello World".into(),
969 status: None,
970 description: None,
971 length: 7 + 11,
972 })
973 );
974
975 server
976 .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
977 .await
978 .unwrap();
979
980 let result = connection.read_op().await.unwrap();
981 assert_eq!(
982 result,
983 Some(ServerOp::Message {
984 sid: 9,
985 subject: "FOO.BAR".into(),
986 reply: Some("INBOX.34".into()),
987 headers: None,
988 payload: "Hello World".into(),
989 status: None,
990 description: None,
991 length: 7 + 8 + 11,
992 })
993 );
994
995 server
996 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
997 .await
998 .unwrap();
999 server.write_all(b"NATS/1.0\r\n").await.unwrap();
1000 server.write_all(b"Header: X\r\n").await.unwrap();
1001 server.write_all(b"\r\n").await.unwrap();
1002 server.write_all(b"Hello World\r\n").await.unwrap();
1003
1004 let result = connection.read_op().await.unwrap();
1005
1006 assert_eq!(
1007 result,
1008 Some(ServerOp::Message {
1009 sid: 10,
1010 subject: "FOO.BAR".into(),
1011 reply: Some("INBOX.35".into()),
1012 headers: Some(HeaderMap::from_iter([(
1013 "Header".parse().unwrap(),
1014 "X".parse().unwrap()
1015 )])),
1016 payload: "Hello World".into(),
1017 status: None,
1018 description: None,
1019 length: 7 + 8 + 34
1020 })
1021 );
1022
1023 server
1024 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
1025 .await
1026 .unwrap();
1027 server.write_all(b"NATS/1.0\r\n").await.unwrap();
1028 server.write_all(b"Header: Y\r\n").await.unwrap();
1029 server.write_all(b"\r\n").await.unwrap();
1030 server.write_all(b"Hello World\r\n").await.unwrap();
1031
1032 let result = connection.read_op().await.unwrap();
1033 assert_eq!(
1034 result,
1035 Some(ServerOp::Message {
1036 sid: 10,
1037 subject: "FOO.BAR".into(),
1038 reply: Some("INBOX.35".into()),
1039 headers: Some(HeaderMap::from_iter([(
1040 "Header".parse().unwrap(),
1041 "Y".parse().unwrap()
1042 )])),
1043 payload: "Hello World".into(),
1044 status: None,
1045 description: None,
1046 length: 7 + 8 + 34,
1047 })
1048 );
1049
1050 server
1051 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1052 .await
1053 .unwrap();
1054 server
1055 .write_all(b"NATS/1.0 404 No Messages\r\n")
1056 .await
1057 .unwrap();
1058 server.write_all(b"\r\n").await.unwrap();
1059 server.write_all(b"\r\n").await.unwrap();
1060
1061 let result = connection.read_op().await.unwrap();
1062 assert_eq!(
1063 result,
1064 Some(ServerOp::Message {
1065 sid: 10,
1066 subject: "FOO.BAR".into(),
1067 reply: Some("INBOX.35".into()),
1068 headers: Some(HeaderMap::default()),
1069 payload: "".into(),
1070 status: Some(StatusCode::NOT_FOUND),
1071 description: Some("No Messages".to_string()),
1072 length: 7 + 8 + 28,
1073 })
1074 );
1075
1076 server
1077 .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1078 .await
1079 .unwrap();
1080
1081 let result = connection.read_op().await.unwrap();
1082 assert_eq!(
1083 result,
1084 Some(ServerOp::Message {
1085 sid: 9,
1086 subject: "FOO.BAR".into(),
1087 reply: None,
1088 headers: None,
1089 payload: "Hello Again".into(),
1090 status: None,
1091 description: None,
1092 length: 7 + 11,
1093 })
1094 );
1095 }
1096
1097 #[tokio::test]
1098 async fn unknown() {
1099 let (stream, mut server) = io::duplex(128);
1100 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1101
1102 server.write_all(b"ONE\r\n").await.unwrap();
1103 connection.read_op().await.unwrap_err();
1104
1105 server.write_all(b"TWO\r\n").await.unwrap();
1106 connection.read_op().await.unwrap_err();
1107
1108 server.write_all(b"PING\r\n").await.unwrap();
1109 connection.read_op().await.unwrap();
1110
1111 server.write_all(b"THREE\r\n").await.unwrap();
1112 connection.read_op().await.unwrap_err();
1113
1114 server
1115 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1116 .await
1117 .unwrap();
1118 server
1119 .write_all(b"NATS/1.0 404 No Messages\r\n")
1120 .await
1121 .unwrap();
1122 server.write_all(b"\r\n").await.unwrap();
1123 server.write_all(b"\r\n").await.unwrap();
1124
1125 let result = connection.read_op().await.unwrap();
1126 assert_eq!(
1127 result,
1128 Some(ServerOp::Message {
1129 sid: 10,
1130 subject: "FOO.BAR".into(),
1131 reply: Some("INBOX.35".into()),
1132 headers: Some(HeaderMap::default()),
1133 payload: "".into(),
1134 status: Some(StatusCode::NOT_FOUND),
1135 description: Some("No Messages".to_string()),
1136 length: 7 + 8 + 28,
1137 })
1138 );
1139
1140 server.write_all(b"FOUR\r\n").await.unwrap();
1141 connection.read_op().await.unwrap_err();
1142
1143 server.write_all(b"PONG\r\n").await.unwrap();
1144 connection.read_op().await.unwrap();
1145 }
1146}
1147
1148#[cfg(test)]
1149mod write_op {
1150 use std::sync::Arc;
1151
1152 use super::Connection;
1153 use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1154 use tokio::io::{self, AsyncBufReadExt, BufReader};
1155
1156 #[tokio::test]
1157 async fn publish() {
1158 let (stream, server) = io::duplex(128);
1159 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1160
1161 connection
1162 .easy_write_and_flush(
1163 [ClientOp::Publish {
1164 subject: "FOO.BAR".into(),
1165 payload: "Hello World".into(),
1166 respond: None,
1167 headers: None,
1168 }]
1169 .iter(),
1170 )
1171 .await
1172 .unwrap();
1173
1174 let mut buffer = String::new();
1175 let mut reader = BufReader::new(server);
1176 reader.read_line(&mut buffer).await.unwrap();
1177 reader.read_line(&mut buffer).await.unwrap();
1178 assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1179
1180 connection
1181 .easy_write_and_flush(
1182 [ClientOp::Publish {
1183 subject: "FOO.BAR".into(),
1184 payload: "Hello World".into(),
1185 respond: Some("INBOX.67".into()),
1186 headers: None,
1187 }]
1188 .iter(),
1189 )
1190 .await
1191 .unwrap();
1192
1193 buffer.clear();
1194 reader.read_line(&mut buffer).await.unwrap();
1195 reader.read_line(&mut buffer).await.unwrap();
1196 assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1197
1198 connection
1199 .easy_write_and_flush(
1200 [ClientOp::Publish {
1201 subject: "FOO.BAR".into(),
1202 payload: "Hello World".into(),
1203 respond: Some("INBOX.67".into()),
1204 headers: Some(HeaderMap::from_iter([(
1205 "Header".parse().unwrap(),
1206 "X".parse().unwrap(),
1207 )])),
1208 }]
1209 .iter(),
1210 )
1211 .await
1212 .unwrap();
1213
1214 buffer.clear();
1215 reader.read_line(&mut buffer).await.unwrap();
1216 reader.read_line(&mut buffer).await.unwrap();
1217 reader.read_line(&mut buffer).await.unwrap();
1218 reader.read_line(&mut buffer).await.unwrap();
1219 assert_eq!(
1220 buffer,
1221 "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1222 );
1223 }
1224
1225 #[tokio::test]
1226 async fn subscribe() {
1227 let (stream, server) = io::duplex(128);
1228 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1229
1230 connection
1231 .easy_write_and_flush(
1232 [ClientOp::Subscribe {
1233 sid: 11,
1234 subject: "FOO.BAR".into(),
1235 queue_group: None,
1236 }]
1237 .iter(),
1238 )
1239 .await
1240 .unwrap();
1241
1242 let mut buffer = String::new();
1243 let mut reader = BufReader::new(server);
1244 reader.read_line(&mut buffer).await.unwrap();
1245 assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1246
1247 connection
1248 .easy_write_and_flush(
1249 [ClientOp::Subscribe {
1250 sid: 11,
1251 subject: "FOO.BAR".into(),
1252 queue_group: Some("QUEUE.GROUP".into()),
1253 }]
1254 .iter(),
1255 )
1256 .await
1257 .unwrap();
1258
1259 buffer.clear();
1260 reader.read_line(&mut buffer).await.unwrap();
1261 assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1262 }
1263
1264 #[tokio::test]
1265 async fn unsubscribe() {
1266 let (stream, server) = io::duplex(128);
1267 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1268
1269 connection
1270 .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1271 .await
1272 .unwrap();
1273
1274 let mut buffer = String::new();
1275 let mut reader = BufReader::new(server);
1276 reader.read_line(&mut buffer).await.unwrap();
1277 assert_eq!(buffer, "UNSUB 11\r\n");
1278
1279 connection
1280 .easy_write_and_flush(
1281 [ClientOp::Unsubscribe {
1282 sid: 11,
1283 max: Some(2),
1284 }]
1285 .iter(),
1286 )
1287 .await
1288 .unwrap();
1289
1290 buffer.clear();
1291 reader.read_line(&mut buffer).await.unwrap();
1292 assert_eq!(buffer, "UNSUB 11 2\r\n");
1293 }
1294
1295 #[tokio::test]
1296 async fn ping() {
1297 let (stream, server) = io::duplex(128);
1298 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1299
1300 let mut reader = BufReader::new(server);
1301 let mut buffer = String::new();
1302
1303 connection
1304 .easy_write_and_flush([ClientOp::Ping].iter())
1305 .await
1306 .unwrap();
1307
1308 reader.read_line(&mut buffer).await.unwrap();
1309
1310 assert_eq!(buffer, "PING\r\n");
1311 }
1312
1313 #[tokio::test]
1314 async fn pong() {
1315 let (stream, server) = io::duplex(128);
1316 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1317
1318 let mut reader = BufReader::new(server);
1319 let mut buffer = String::new();
1320
1321 connection
1322 .easy_write_and_flush([ClientOp::Pong].iter())
1323 .await
1324 .unwrap();
1325
1326 reader.read_line(&mut buffer).await.unwrap();
1327
1328 assert_eq!(buffer, "PONG\r\n");
1329 }
1330
1331 #[tokio::test]
1332 async fn connect() {
1333 let (stream, server) = io::duplex(1024);
1334 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1335
1336 let mut reader = BufReader::new(server);
1337 let mut buffer = String::new();
1338
1339 connection
1340 .easy_write_and_flush(
1341 [ClientOp::Connect(ConnectInfo {
1342 verbose: false,
1343 pedantic: false,
1344 user_jwt: None,
1345 nkey: None,
1346 signature: None,
1347 name: None,
1348 echo: false,
1349 lang: "Rust".into(),
1350 version: "1.0.0".into(),
1351 protocol: Protocol::Dynamic,
1352 tls_required: false,
1353 user: None,
1354 pass: None,
1355 auth_token: None,
1356 headers: false,
1357 no_responders: false,
1358 })]
1359 .iter(),
1360 )
1361 .await
1362 .unwrap();
1363
1364 reader.read_line(&mut buffer).await.unwrap();
1365 assert_eq!(
1366 buffer,
1367 "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1368 );
1369 }
1370}