1use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::task::{Context, Poll};
23
24use bytes::{Buf, Bytes, BytesMut};
25use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
26
27use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
28use crate::status::StatusCode;
29use crate::subject::Subject;
30use crate::{ClientOp, ServerError, ServerOp};
31
32const SOFT_WRITE_BUF_LIMIT: usize = 65535;
35const WRITE_FLATTEN_THRESHOLD: usize = 4096;
38const WRITE_VECTORED_CHUNKS: usize = 64;
40
41pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
43
44impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
46
47#[derive(Debug, Eq, PartialEq, Clone)]
49pub enum State {
50 Pending,
51 Connected,
52 Disconnected,
53}
54
55#[derive(Debug, Eq, PartialEq, Clone)]
56pub enum ShouldFlush {
57 Yes,
59 May,
61 No,
63}
64
65impl Display for State {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 State::Pending => write!(f, "pending"),
69 State::Connected => write!(f, "connected"),
70 State::Disconnected => write!(f, "disconnected"),
71 }
72 }
73}
74
75pub(crate) struct Connection {
77 pub(crate) stream: Box<dyn AsyncReadWrite>,
78 read_buf: BytesMut,
79 write_buf: VecDeque<Bytes>,
80 write_buf_len: usize,
81 flattened_writes: BytesMut,
82 can_flush: bool,
83}
84
85impl Connection {
88 pub(crate) fn new(stream: Box<dyn AsyncReadWrite>, read_buffer_capacity: usize) -> Self {
89 Self {
90 stream,
91 read_buf: BytesMut::with_capacity(read_buffer_capacity),
92 write_buf: VecDeque::new(),
93 write_buf_len: 0,
94 flattened_writes: BytesMut::new(),
95 can_flush: false,
96 }
97 }
98
99 pub(crate) fn is_write_buf_full(&self) -> bool {
101 self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
102 }
103
104 pub(crate) fn should_flush(&self) -> ShouldFlush {
106 match (
107 self.can_flush,
108 self.write_buf.is_empty() && self.flattened_writes.is_empty(),
109 ) {
110 (true, true) => ShouldFlush::Yes,
111 (true, false) => ShouldFlush::May,
112 (false, _) => ShouldFlush::No,
113 }
114 }
115
116 pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
119 let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
120 Some(len) => len,
121 None => return Ok(None),
122 };
123
124 if self.read_buf.starts_with(b"+OK") {
125 self.read_buf.advance(len + 2);
126 return Ok(Some(ServerOp::Ok));
127 }
128
129 if self.read_buf.starts_with(b"PING") {
130 self.read_buf.advance(len + 2);
131 return Ok(Some(ServerOp::Ping));
132 }
133
134 if self.read_buf.starts_with(b"PONG") {
135 self.read_buf.advance(len + 2);
136 return Ok(Some(ServerOp::Pong));
137 }
138
139 if self.read_buf.starts_with(b"-ERR") {
140 let description = str::from_utf8(&self.read_buf[5..len])
141 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
142 .trim_matches('\'')
143 .to_owned();
144
145 self.read_buf.advance(len + 2);
146
147 return Ok(Some(ServerOp::Error(ServerError::new(description))));
148 }
149
150 if self.read_buf.starts_with(b"INFO ") {
151 let info = serde_json::from_slice(&self.read_buf[4..len])
152 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
153
154 self.read_buf.advance(len + 2);
155
156 return Ok(Some(ServerOp::Info(Box::new(info))));
157 }
158
159 if self.read_buf.starts_with(b"MSG ") {
160 let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
161 let mut args = line.split(' ').filter(|s| !s.is_empty());
162
163 let (subject, sid, reply_to, payload_len) = match (
165 args.next(),
166 args.next(),
167 args.next(),
168 args.next(),
169 args.next(),
170 ) {
171 (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
172 (subject, sid, Some(reply_to), payload_len)
173 }
174 (Some(subject), Some(sid), Some(payload_len), None, None) => {
175 (subject, sid, None, payload_len)
176 }
177 _ => {
178 return Err(io::Error::new(
179 io::ErrorKind::InvalidInput,
180 "invalid number of arguments after MSG",
181 ))
182 }
183 };
184
185 let sid = sid
186 .parse::<u64>()
187 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
188
189 let payload_len = payload_len
191 .parse::<usize>()
192 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
193
194 if len + payload_len + 4 > self.read_buf.remaining() {
197 return Ok(None);
198 }
199
200 let length = payload_len
201 + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
202 + subject.len();
203
204 let subject = Subject::from(subject);
205 let reply = reply_to.map(Subject::from);
206
207 self.read_buf.advance(len + 2);
208 let payload = self.read_buf.split_to(payload_len).freeze();
209 self.read_buf.advance(2);
210
211 return Ok(Some(ServerOp::Message {
212 sid,
213 length,
214 reply,
215 headers: None,
216 subject,
217 payload,
218 status: None,
219 description: None,
220 }));
221 }
222
223 if self.read_buf.starts_with(b"HMSG ") {
224 let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
226 let mut args = line.split_whitespace().filter(|s| !s.is_empty());
227
228 let (subject, sid, reply_to, header_len, total_len) = match (
230 args.next(),
231 args.next(),
232 args.next(),
233 args.next(),
234 args.next(),
235 args.next(),
236 ) {
237 (
238 Some(subject),
239 Some(sid),
240 Some(reply_to),
241 Some(header_len),
242 Some(total_len),
243 None,
244 ) => (subject, sid, Some(reply_to), header_len, total_len),
245 (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
246 (subject, sid, None, header_len, total_len)
247 }
248 _ => {
249 return Err(io::Error::new(
250 io::ErrorKind::InvalidInput,
251 "invalid number of arguments after HMSG",
252 ))
253 }
254 };
255
256 let subject = Subject::from(subject);
258
259 let sid = sid.parse::<u64>().map_err(|_| {
261 io::Error::new(
262 io::ErrorKind::InvalidInput,
263 "cannot parse sid argument after HMSG",
264 )
265 })?;
266
267 let reply = reply_to.map(Subject::from);
269
270 let header_len = header_len.parse::<usize>().map_err(|_| {
272 io::Error::new(
273 io::ErrorKind::InvalidInput,
274 "cannot parse the number of header bytes argument after \
275 HMSG",
276 )
277 })?;
278
279 let total_len = total_len.parse::<usize>().map_err(|_| {
281 io::Error::new(
282 io::ErrorKind::InvalidInput,
283 "cannot parse the number of bytes argument after HMSG",
284 )
285 })?;
286
287 if total_len < header_len {
288 return Err(io::Error::new(
289 io::ErrorKind::InvalidInput,
290 "number of header bytes was greater than or equal to the \
291 total number of bytes after HMSG",
292 ));
293 }
294
295 if len + total_len + 4 > self.read_buf.remaining() {
296 return Ok(None);
297 }
298
299 self.read_buf.advance(len + 2);
300 let header = self.read_buf.split_to(header_len);
301 let payload = self.read_buf.split_to(total_len - header_len).freeze();
302 self.read_buf.advance(2);
303
304 let mut lines = std::str::from_utf8(&header)
305 .map_err(|_| {
306 io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
307 })?
308 .lines()
309 .peekable();
310 let version_line = lines.next().ok_or_else(|| {
311 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
312 })?;
313
314 let version_line_suffix = version_line
315 .strip_prefix("NATS/1.0")
316 .map(str::trim)
317 .ok_or_else(|| {
318 io::Error::new(
319 io::ErrorKind::InvalidInput,
320 "header version line does not begin with `NATS/1.0`",
321 )
322 })?;
323
324 let (status, description) = version_line_suffix
325 .split_once(' ')
326 .map(|(status, description)| (status.trim(), description.trim()))
327 .unwrap_or((version_line_suffix, ""));
328 let status = if !status.is_empty() {
329 Some(status.parse::<StatusCode>().map_err(|_| {
330 std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
331 })?)
332 } else {
333 None
334 };
335 let description = if !description.is_empty() {
336 Some(description.to_owned())
337 } else {
338 None
339 };
340
341 let mut headers = HeaderMap::new();
342 while let Some(line) = lines.next() {
343 if line.is_empty() {
344 continue;
345 }
346
347 let (name, value) = line.split_once(':').ok_or_else(|| {
348 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
349 })?;
350
351 let name = HeaderName::from_str(name)
352 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
353
354 let mut value = value.trim_start().to_owned();
357 while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
358 value.push_str(v);
359 }
360 value.truncate(value.trim_end().len());
361
362 headers.append(name, value.into_header_value());
363 }
364
365 return Ok(Some(ServerOp::Message {
366 length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
367 sid,
368 reply,
369 subject,
370 headers: Some(headers),
371 payload,
372 status,
373 description,
374 }));
375 }
376
377 let buffer = self.read_buf.split_to(len + 2);
378 let line = str::from_utf8(&buffer).map_err(|_| {
379 io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
380 })?;
381
382 Err(io::Error::new(
383 io::ErrorKind::InvalidInput,
384 format!("invalid server operation: '{line}'"),
385 ))
386 }
387
388 pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
389 future::poll_fn(|cx| self.poll_read_op(cx))
390 }
391
392 pub(crate) fn poll_read_op(
396 &mut self,
397 cx: &mut Context<'_>,
398 ) -> Poll<io::Result<Option<ServerOp>>> {
399 loop {
400 if let Some(op) = self.try_read_op()? {
401 return Poll::Ready(Ok(Some(op)));
402 }
403
404 let read_buf = self.stream.read_buf(&mut self.read_buf);
405 tokio::pin!(read_buf);
406 return match read_buf.poll(cx) {
407 Poll::Pending => Poll::Pending,
408 Poll::Ready(Ok(0)) if self.read_buf.is_empty() => Poll::Ready(Ok(None)),
409 Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::ConnectionReset.into())),
410 Poll::Ready(Ok(_n)) => continue,
411 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
412 };
413 }
414 }
415
416 pub(crate) async fn easy_write_and_flush<'a>(
417 &mut self,
418 items: impl Iterator<Item = &'a ClientOp>,
419 ) -> io::Result<()> {
420 for item in items {
421 self.enqueue_write_op(item);
422 }
423
424 future::poll_fn(|cx| self.poll_write(cx)).await?;
425 future::poll_fn(|cx| self.poll_flush(cx)).await?;
426 Ok(())
427 }
428
429 pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
431 macro_rules! small_write {
432 ($dst:expr) => {
433 write!(self.small_write(), $dst).expect("do small write to Connection");
434 };
435 }
436
437 match item {
438 ClientOp::Connect(connect_info) => {
439 let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
440
441 self.write("CONNECT ");
442 self.write(json);
443 self.write("\r\n");
444 }
445 ClientOp::Publish {
446 subject,
447 payload,
448 respond,
449 headers,
450 } => {
451 let verb = match headers.as_ref() {
452 Some(headers) if !headers.is_empty() => "HPUB",
453 _ => "PUB",
454 };
455
456 small_write!("{verb} {subject} ");
457
458 if let Some(respond) = respond {
459 small_write!("{respond} ");
460 }
461
462 match headers {
463 Some(headers) if !headers.is_empty() => {
464 let headers = headers.to_bytes();
465
466 let headers_len = headers.len();
467 let total_len = headers_len + payload.len();
468 small_write!("{headers_len} {total_len}\r\n");
469 self.write(headers);
470 }
471 _ => {
472 let payload_len = payload.len();
473 small_write!("{payload_len}\r\n");
474 }
475 }
476
477 self.write(Bytes::clone(payload));
478 self.write("\r\n");
479 }
480
481 ClientOp::Subscribe {
482 sid,
483 subject,
484 queue_group,
485 } => match queue_group {
486 Some(queue_group) => {
487 small_write!("SUB {subject} {queue_group} {sid}\r\n");
488 }
489 None => {
490 small_write!("SUB {subject} {sid}\r\n");
491 }
492 },
493
494 ClientOp::Unsubscribe { sid, max } => match max {
495 Some(max) => {
496 small_write!("UNSUB {sid} {max}\r\n");
497 }
498 None => {
499 small_write!("UNSUB {sid}\r\n");
500 }
501 },
502 ClientOp::Ping => {
503 self.write("PING\r\n");
504 }
505 ClientOp::Pong => {
506 self.write("PONG\r\n");
507 }
508 }
509 }
510
511 pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
524 if !self.stream.is_write_vectored() {
525 self.poll_write_sequential(cx)
526 } else {
527 self.poll_write_vectored(cx)
528 }
529 }
530
531 fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
535 loop {
536 let buf = match self.write_buf.front() {
537 Some(buf) => &**buf,
538 None if !self.flattened_writes.is_empty() => &self.flattened_writes,
539 None => return Poll::Ready(Ok(())),
540 };
541
542 debug_assert!(!buf.is_empty());
543
544 match Pin::new(&mut self.stream).poll_write(cx, buf) {
545 Poll::Pending => return Poll::Pending,
546 Poll::Ready(Ok(n)) => {
547 self.write_buf_len -= n;
548 self.can_flush = true;
549
550 match self.write_buf.front_mut() {
551 Some(buf) if n < buf.len() => {
552 buf.advance(n);
553 }
554 Some(_buf) => {
555 self.write_buf.pop_front();
556 }
557 None => {
558 self.flattened_writes.advance(n);
559 }
560 }
561 continue;
562 }
563 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
564 }
565 }
566 }
567
568 fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
573 'outer: loop {
574 let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
575 let mut writes_len = 0;
576
577 self.write_buf
578 .iter()
579 .take(WRITE_VECTORED_CHUNKS)
580 .enumerate()
581 .for_each(|(i, buf)| {
582 writes[i] = IoSlice::new(buf);
583 writes_len += 1;
584 });
585
586 if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
587 writes[writes_len] = IoSlice::new(&self.flattened_writes);
588 writes_len += 1;
589 }
590
591 if writes_len == 0 {
592 return Poll::Ready(Ok(()));
593 }
594
595 match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
596 Poll::Pending => return Poll::Pending,
597 Poll::Ready(Ok(mut n)) => {
598 self.write_buf_len -= n;
599 self.can_flush = true;
600
601 while let Some(buf) = self.write_buf.front_mut() {
602 if n < buf.len() {
603 buf.advance(n);
604 continue 'outer;
605 }
606
607 n -= buf.len();
608 self.write_buf.pop_front();
609 }
610
611 self.flattened_writes.advance(n);
612 }
613 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
614 }
615 }
616 }
617
618 fn write(&mut self, buf: impl Into<Bytes>) {
625 let buf = buf.into();
626 if buf.is_empty() {
627 return;
628 }
629
630 self.write_buf_len += buf.len();
631 if buf.len() < WRITE_FLATTEN_THRESHOLD {
632 self.flattened_writes.extend_from_slice(&buf);
633 } else {
634 if !self.flattened_writes.is_empty() {
635 let buf = self.flattened_writes.split().freeze();
636 self.write_buf.push_back(buf);
637 }
638
639 self.write_buf.push_back(buf);
640 }
641 }
642
643 fn small_write(&mut self) -> impl fmt::Write + '_ {
645 struct Writer<'a> {
646 this: &'a mut Connection,
647 }
648
649 impl<'a> fmt::Write for Writer<'a> {
650 fn write_str(&mut self, s: &str) -> fmt::Result {
651 self.this.write_buf_len += s.len();
652 self.this.flattened_writes.write_str(s)
653 }
654 }
655
656 Writer { this: self }
657 }
658
659 pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
663 match Pin::new(&mut self.stream).poll_flush(cx) {
664 Poll::Pending => Poll::Pending,
665 Poll::Ready(Ok(())) => {
666 self.can_flush = false;
667 Poll::Ready(Ok(()))
668 }
669 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
670 }
671 }
672}
673
674#[cfg(test)]
675mod read_op {
676 use super::Connection;
677 use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, StatusCode};
678 use tokio::io::{self, AsyncWriteExt};
679
680 #[tokio::test]
681 async fn ok() {
682 let (stream, mut server) = io::duplex(128);
683 let mut connection = Connection::new(Box::new(stream), 0);
684
685 server.write_all(b"+OK\r\n").await.unwrap();
686 let result = connection.read_op().await.unwrap();
687 assert_eq!(result, Some(ServerOp::Ok));
688 }
689
690 #[tokio::test]
691 async fn ping() {
692 let (stream, mut server) = io::duplex(128);
693 let mut connection = Connection::new(Box::new(stream), 0);
694
695 server.write_all(b"PING\r\n").await.unwrap();
696 let result = connection.read_op().await.unwrap();
697 assert_eq!(result, Some(ServerOp::Ping));
698 }
699
700 #[tokio::test]
701 async fn pong() {
702 let (stream, mut server) = io::duplex(128);
703 let mut connection = Connection::new(Box::new(stream), 0);
704
705 server.write_all(b"PONG\r\n").await.unwrap();
706 let result = connection.read_op().await.unwrap();
707 assert_eq!(result, Some(ServerOp::Pong));
708 }
709
710 #[tokio::test]
711 async fn info() {
712 let (stream, mut server) = io::duplex(128);
713 let mut connection = Connection::new(Box::new(stream), 0);
714
715 server.write_all(b"INFO {}\r\n").await.unwrap();
716 server.flush().await.unwrap();
717
718 let result = connection.read_op().await.unwrap();
719 assert_eq!(result, Some(ServerOp::Info(Box::default())));
720
721 server
722 .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
723 .await
724 .unwrap();
725 server.flush().await.unwrap();
726
727 let result = connection.read_op().await.unwrap();
728 assert_eq!(
729 result,
730 Some(ServerOp::Info(Box::new(ServerInfo {
731 version: "1.0.0".into(),
732 ..Default::default()
733 })))
734 );
735 }
736
737 #[tokio::test]
738 async fn error() {
739 let (stream, mut server) = io::duplex(128);
740 let mut connection = Connection::new(Box::new(stream), 0);
741
742 server.write_all(b"INFO {}\r\n").await.unwrap();
743 let result = connection.read_op().await.unwrap();
744 assert_eq!(result, Some(ServerOp::Info(Box::default())));
745
746 server
747 .write_all(b"-ERR something went wrong\r\n")
748 .await
749 .unwrap();
750 let result = connection.read_op().await.unwrap();
751 assert_eq!(
752 result,
753 Some(ServerOp::Error(ServerError::Other(
754 "something went wrong".into()
755 )))
756 );
757 }
758
759 #[tokio::test]
760 async fn message() {
761 let (stream, mut server) = io::duplex(128);
762 let mut connection = Connection::new(Box::new(stream), 0);
763
764 server
765 .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
766 .await
767 .unwrap();
768
769 let result = connection.read_op().await.unwrap();
770 assert_eq!(
771 result,
772 Some(ServerOp::Message {
773 sid: 9,
774 subject: "FOO.BAR".into(),
775 reply: None,
776 headers: None,
777 payload: "Hello World".into(),
778 status: None,
779 description: None,
780 length: 7 + 11,
781 })
782 );
783
784 server
785 .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
786 .await
787 .unwrap();
788
789 let result = connection.read_op().await.unwrap();
790 assert_eq!(
791 result,
792 Some(ServerOp::Message {
793 sid: 9,
794 subject: "FOO.BAR".into(),
795 reply: Some("INBOX.34".into()),
796 headers: None,
797 payload: "Hello World".into(),
798 status: None,
799 description: None,
800 length: 7 + 8 + 11,
801 })
802 );
803
804 server
805 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
806 .await
807 .unwrap();
808 server.write_all(b"NATS/1.0\r\n").await.unwrap();
809 server.write_all(b"Header: X\r\n").await.unwrap();
810 server.write_all(b"\r\n").await.unwrap();
811 server.write_all(b"Hello World\r\n").await.unwrap();
812
813 let result = connection.read_op().await.unwrap();
814
815 assert_eq!(
816 result,
817 Some(ServerOp::Message {
818 sid: 10,
819 subject: "FOO.BAR".into(),
820 reply: Some("INBOX.35".into()),
821 headers: Some(HeaderMap::from_iter([(
822 "Header".parse().unwrap(),
823 "X".parse().unwrap()
824 )])),
825 payload: "Hello World".into(),
826 status: None,
827 description: None,
828 length: 7 + 8 + 34
829 })
830 );
831
832 server
833 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
834 .await
835 .unwrap();
836 server.write_all(b"NATS/1.0\r\n").await.unwrap();
837 server.write_all(b"Header: Y\r\n").await.unwrap();
838 server.write_all(b"\r\n").await.unwrap();
839 server.write_all(b"Hello World\r\n").await.unwrap();
840
841 let result = connection.read_op().await.unwrap();
842 assert_eq!(
843 result,
844 Some(ServerOp::Message {
845 sid: 10,
846 subject: "FOO.BAR".into(),
847 reply: Some("INBOX.35".into()),
848 headers: Some(HeaderMap::from_iter([(
849 "Header".parse().unwrap(),
850 "Y".parse().unwrap()
851 )])),
852 payload: "Hello World".into(),
853 status: None,
854 description: None,
855 length: 7 + 8 + 34,
856 })
857 );
858
859 server
860 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
861 .await
862 .unwrap();
863 server
864 .write_all(b"NATS/1.0 404 No Messages\r\n")
865 .await
866 .unwrap();
867 server.write_all(b"\r\n").await.unwrap();
868 server.write_all(b"\r\n").await.unwrap();
869
870 let result = connection.read_op().await.unwrap();
871 assert_eq!(
872 result,
873 Some(ServerOp::Message {
874 sid: 10,
875 subject: "FOO.BAR".into(),
876 reply: Some("INBOX.35".into()),
877 headers: Some(HeaderMap::default()),
878 payload: "".into(),
879 status: Some(StatusCode::NOT_FOUND),
880 description: Some("No Messages".to_string()),
881 length: 7 + 8 + 28,
882 })
883 );
884
885 server
886 .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
887 .await
888 .unwrap();
889
890 let result = connection.read_op().await.unwrap();
891 assert_eq!(
892 result,
893 Some(ServerOp::Message {
894 sid: 9,
895 subject: "FOO.BAR".into(),
896 reply: None,
897 headers: None,
898 payload: "Hello Again".into(),
899 status: None,
900 description: None,
901 length: 7 + 11,
902 })
903 );
904 }
905
906 #[tokio::test]
907 async fn unknown() {
908 let (stream, mut server) = io::duplex(128);
909 let mut connection = Connection::new(Box::new(stream), 0);
910
911 server.write_all(b"ONE\r\n").await.unwrap();
912 connection.read_op().await.unwrap_err();
913
914 server.write_all(b"TWO\r\n").await.unwrap();
915 connection.read_op().await.unwrap_err();
916
917 server.write_all(b"PING\r\n").await.unwrap();
918 connection.read_op().await.unwrap();
919
920 server.write_all(b"THREE\r\n").await.unwrap();
921 connection.read_op().await.unwrap_err();
922
923 server
924 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
925 .await
926 .unwrap();
927 server
928 .write_all(b"NATS/1.0 404 No Messages\r\n")
929 .await
930 .unwrap();
931 server.write_all(b"\r\n").await.unwrap();
932 server.write_all(b"\r\n").await.unwrap();
933
934 let result = connection.read_op().await.unwrap();
935 assert_eq!(
936 result,
937 Some(ServerOp::Message {
938 sid: 10,
939 subject: "FOO.BAR".into(),
940 reply: Some("INBOX.35".into()),
941 headers: Some(HeaderMap::default()),
942 payload: "".into(),
943 status: Some(StatusCode::NOT_FOUND),
944 description: Some("No Messages".to_string()),
945 length: 7 + 8 + 28,
946 })
947 );
948
949 server.write_all(b"FOUR\r\n").await.unwrap();
950 connection.read_op().await.unwrap_err();
951
952 server.write_all(b"PONG\r\n").await.unwrap();
953 connection.read_op().await.unwrap();
954 }
955}
956
957#[cfg(test)]
958mod write_op {
959 use super::Connection;
960 use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol};
961 use tokio::io::{self, AsyncBufReadExt, BufReader};
962
963 #[tokio::test]
964 async fn publish() {
965 let (stream, server) = io::duplex(128);
966 let mut connection = Connection::new(Box::new(stream), 0);
967
968 connection
969 .easy_write_and_flush(
970 [ClientOp::Publish {
971 subject: "FOO.BAR".into(),
972 payload: "Hello World".into(),
973 respond: None,
974 headers: None,
975 }]
976 .iter(),
977 )
978 .await
979 .unwrap();
980
981 let mut buffer = String::new();
982 let mut reader = BufReader::new(server);
983 reader.read_line(&mut buffer).await.unwrap();
984 reader.read_line(&mut buffer).await.unwrap();
985 assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
986
987 connection
988 .easy_write_and_flush(
989 [ClientOp::Publish {
990 subject: "FOO.BAR".into(),
991 payload: "Hello World".into(),
992 respond: Some("INBOX.67".into()),
993 headers: None,
994 }]
995 .iter(),
996 )
997 .await
998 .unwrap();
999
1000 buffer.clear();
1001 reader.read_line(&mut buffer).await.unwrap();
1002 reader.read_line(&mut buffer).await.unwrap();
1003 assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1004
1005 connection
1006 .easy_write_and_flush(
1007 [ClientOp::Publish {
1008 subject: "FOO.BAR".into(),
1009 payload: "Hello World".into(),
1010 respond: Some("INBOX.67".into()),
1011 headers: Some(HeaderMap::from_iter([(
1012 "Header".parse().unwrap(),
1013 "X".parse().unwrap(),
1014 )])),
1015 }]
1016 .iter(),
1017 )
1018 .await
1019 .unwrap();
1020
1021 buffer.clear();
1022 reader.read_line(&mut buffer).await.unwrap();
1023 reader.read_line(&mut buffer).await.unwrap();
1024 reader.read_line(&mut buffer).await.unwrap();
1025 reader.read_line(&mut buffer).await.unwrap();
1026 assert_eq!(
1027 buffer,
1028 "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1029 );
1030 }
1031
1032 #[tokio::test]
1033 async fn subscribe() {
1034 let (stream, server) = io::duplex(128);
1035 let mut connection = Connection::new(Box::new(stream), 0);
1036
1037 connection
1038 .easy_write_and_flush(
1039 [ClientOp::Subscribe {
1040 sid: 11,
1041 subject: "FOO.BAR".into(),
1042 queue_group: None,
1043 }]
1044 .iter(),
1045 )
1046 .await
1047 .unwrap();
1048
1049 let mut buffer = String::new();
1050 let mut reader = BufReader::new(server);
1051 reader.read_line(&mut buffer).await.unwrap();
1052 assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1053
1054 connection
1055 .easy_write_and_flush(
1056 [ClientOp::Subscribe {
1057 sid: 11,
1058 subject: "FOO.BAR".into(),
1059 queue_group: Some("QUEUE.GROUP".into()),
1060 }]
1061 .iter(),
1062 )
1063 .await
1064 .unwrap();
1065
1066 buffer.clear();
1067 reader.read_line(&mut buffer).await.unwrap();
1068 assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1069 }
1070
1071 #[tokio::test]
1072 async fn unsubscribe() {
1073 let (stream, server) = io::duplex(128);
1074 let mut connection = Connection::new(Box::new(stream), 0);
1075
1076 connection
1077 .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1078 .await
1079 .unwrap();
1080
1081 let mut buffer = String::new();
1082 let mut reader = BufReader::new(server);
1083 reader.read_line(&mut buffer).await.unwrap();
1084 assert_eq!(buffer, "UNSUB 11\r\n");
1085
1086 connection
1087 .easy_write_and_flush(
1088 [ClientOp::Unsubscribe {
1089 sid: 11,
1090 max: Some(2),
1091 }]
1092 .iter(),
1093 )
1094 .await
1095 .unwrap();
1096
1097 buffer.clear();
1098 reader.read_line(&mut buffer).await.unwrap();
1099 assert_eq!(buffer, "UNSUB 11 2\r\n");
1100 }
1101
1102 #[tokio::test]
1103 async fn ping() {
1104 let (stream, server) = io::duplex(128);
1105 let mut connection = Connection::new(Box::new(stream), 0);
1106
1107 let mut reader = BufReader::new(server);
1108 let mut buffer = String::new();
1109
1110 connection
1111 .easy_write_and_flush([ClientOp::Ping].iter())
1112 .await
1113 .unwrap();
1114
1115 reader.read_line(&mut buffer).await.unwrap();
1116
1117 assert_eq!(buffer, "PING\r\n");
1118 }
1119
1120 #[tokio::test]
1121 async fn pong() {
1122 let (stream, server) = io::duplex(128);
1123 let mut connection = Connection::new(Box::new(stream), 0);
1124
1125 let mut reader = BufReader::new(server);
1126 let mut buffer = String::new();
1127
1128 connection
1129 .easy_write_and_flush([ClientOp::Pong].iter())
1130 .await
1131 .unwrap();
1132
1133 reader.read_line(&mut buffer).await.unwrap();
1134
1135 assert_eq!(buffer, "PONG\r\n");
1136 }
1137
1138 #[tokio::test]
1139 async fn connect() {
1140 let (stream, server) = io::duplex(1024);
1141 let mut connection = Connection::new(Box::new(stream), 0);
1142
1143 let mut reader = BufReader::new(server);
1144 let mut buffer = String::new();
1145
1146 connection
1147 .easy_write_and_flush(
1148 [ClientOp::Connect(ConnectInfo {
1149 verbose: false,
1150 pedantic: false,
1151 user_jwt: None,
1152 nkey: None,
1153 signature: None,
1154 name: None,
1155 echo: false,
1156 lang: "Rust".into(),
1157 version: "1.0.0".into(),
1158 protocol: Protocol::Dynamic,
1159 tls_required: false,
1160 user: None,
1161 pass: None,
1162 auth_token: None,
1163 headers: false,
1164 no_responders: false,
1165 })]
1166 .iter(),
1167 )
1168 .await
1169 .unwrap();
1170
1171 reader.read_line(&mut buffer).await.unwrap();
1172 assert_eq!(
1173 buffer,
1174 "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1175 );
1176 }
1177}