1#![cfg_attr(feature = "fail-on-warnings", deny(warnings))]
21#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
22#![allow(clippy::multiple_crate_versions)]
23
24use std::{collections::BTreeMap, task::Poll, time::SystemTime};
25
26use bytes::Bytes;
27use futures_util::{Future, Stream};
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30use strum_macros::EnumString;
31use switchy_async::sync::mpsc::Receiver;
32use switchy_async::util::CancellationToken;
33use switchy_http::models::Method;
34use thiserror::Error;
35
36#[cfg(feature = "base64")]
38static BASE64_TUNNEL_RESPONSE_PREFIX: &str = "TUNNEL_RESPONSE:";
39
40#[derive(Debug, Serialize, Deserialize, EnumString, PartialEq, Eq, Clone, Copy)]
42#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
43#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
44pub enum TunnelEncoding {
45 Binary,
47 #[cfg(feature = "base64")]
49 Base64,
50}
51
52#[derive(Debug, Serialize, Deserialize)]
54pub struct TunnelWsResponse {
55 pub request_id: u64,
57 pub body: Value,
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub exclude_connection_ids: Option<Vec<u64>>,
62 #[serde(skip_serializing_if = "Option::is_none")]
64 pub to_connection_ids: Option<Vec<u64>>,
65}
66
67#[derive(Debug)]
69pub struct TunnelResponse {
70 pub request_id: u64,
72 pub packet_id: u32,
74 pub last: bool,
76 pub bytes: Bytes,
78 pub status: Option<u16>,
80 pub headers: Option<BTreeMap<String, String>>,
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
86#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
87#[serde(tag = "type")]
88pub enum TunnelRequest {
89 Http(TunnelHttpRequest),
91 Ws(TunnelWsRequest),
93 Abort(TunnelAbortRequest),
95}
96
97#[derive(Debug, Serialize, Deserialize, Clone)]
99pub struct TunnelHttpRequest {
100 pub request_id: u64,
102 pub method: Method,
104 pub path: String,
106 pub query: Value,
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub payload: Option<Value>,
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub headers: Option<Value>,
114 pub encoding: TunnelEncoding,
116 pub profile: Option<String>,
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
122pub struct TunnelWsRequest {
123 pub conn_id: u64,
125 pub request_id: u64,
127 pub body: Value,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub connection_id: Option<Value>,
132 pub profile: Option<String>,
134}
135
136#[derive(Debug, Serialize, Deserialize, Clone)]
138pub struct TunnelAbortRequest {
139 pub request_id: u64,
141}
142
143#[derive(Debug, Error)]
145pub enum TryFromBytesError {
146 #[error(transparent)]
148 TryFromSlice(#[from] std::array::TryFromSliceError),
149 #[error(transparent)]
151 Serde(#[from] serde_json::Error),
152}
153
154impl TryFrom<Bytes> for TunnelResponse {
155 type Error = TryFromBytesError;
156
157 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
188 let mut data = bytes.slice(13..);
189 let request_id = u64::from_be_bytes(bytes[..8].try_into()?);
190 let packet_id = u32::from_be_bytes(bytes[8..12].try_into()?);
191 let last = u8::from_be_bytes(bytes[12..13].try_into()?) == 1;
192 let (status, headers) = if packet_id == 1 {
193 let status = u16::from_be_bytes(data[..2].try_into()?);
194 data = data.slice(2..);
195 let len = u32::from_be_bytes(data[..4].try_into()?) as usize;
196 let headers_bytes = &data.slice(4..(4 + len));
197 data = data.slice((4 + len)..);
198 (Some(status), Some(serde_json::from_slice(headers_bytes)?))
199 } else {
200 (None, None)
201 };
202
203 Ok(Self {
204 request_id,
205 packet_id,
206 last,
207 bytes: data,
208 status,
209 headers,
210 })
211 }
212}
213
214#[cfg(feature = "base64")]
216#[derive(Debug, Error)]
217pub enum Base64DecodeError {
218 #[error("Invalid Content: {0:?}")]
220 InvalidContent(String),
221 #[error(transparent)]
223 Decode(#[from] base64::DecodeError),
224}
225
226#[cfg(feature = "base64")]
227impl TryFrom<&str> for TunnelResponse {
228 type Error = Base64DecodeError;
229
230 fn try_from(base64: &str) -> Result<Self, Self::Error> {
250 use base64::{Engine, engine::general_purpose};
251
252 let payload = base64
253 .strip_prefix(BASE64_TUNNEL_RESPONSE_PREFIX)
254 .ok_or_else(|| {
255 Base64DecodeError::InvalidContent("Invalid TunnelRequest base64 data string".into())
256 })?;
257
258 let (request_id_part, payload) = payload.split_once('|').ok_or_else(|| {
259 Base64DecodeError::InvalidContent("Missing request_id. Expected '|' delimiter".into())
260 })?;
261 let request_id = request_id_part.parse::<u64>().map_err(|error| {
262 Base64DecodeError::InvalidContent(format!(
263 "Invalid request_id '{request_id_part}': {error}"
264 ))
265 })?;
266
267 let (packet_id_part, payload) = payload.split_once('|').ok_or_else(|| {
268 Base64DecodeError::InvalidContent("Missing packet_id. Expected '|' delimiter".into())
269 })?;
270 let packet_id = packet_id_part.parse::<u32>().map_err(|error| {
271 Base64DecodeError::InvalidContent(format!(
272 "Invalid packet_id '{packet_id_part}': {error}"
273 ))
274 })?;
275
276 let mut chars = payload.chars();
277 let Some(last_char) = chars.next() else {
278 return Err(Base64DecodeError::InvalidContent(
279 "Missing last flag (expected 0 or 1)".into(),
280 ));
281 };
282 let last = match last_char {
283 '0' => false,
284 '1' => true,
285 _ => {
286 return Err(Base64DecodeError::InvalidContent(format!(
287 "Invalid last flag '{last_char}' (expected 0 or 1)"
288 )));
289 }
290 };
291
292 let mut payload = chars.as_str();
293
294 let (status, headers) = if packet_id == 1 {
295 if payload.len() < 3 {
296 return Err(Base64DecodeError::InvalidContent(
297 "Missing status code for packet_id=1".into(),
298 ));
299 }
300
301 let status = payload[..3].parse::<u16>().map_err(|error| {
302 Base64DecodeError::InvalidContent(format!(
303 "Invalid status code '{}' for packet_id=1: {error}",
304 &payload[..3]
305 ))
306 })?;
307 payload = &payload[3..];
308 payload = payload.strip_prefix('|').unwrap_or(payload);
309
310 let (headers_json, remaining) = parse_json_object_prefix(payload)?;
311 payload = remaining.strip_prefix('|').unwrap_or(remaining);
312
313 let headers = serde_json::from_str(headers_json).map_err(|error| {
314 Base64DecodeError::InvalidContent(format!(
315 "Invalid headers JSON for packet_id=1: {error}"
316 ))
317 })?;
318 (Some(status), Some(headers))
319 } else {
320 (None, None)
321 };
322
323 let bytes = Bytes::from(general_purpose::STANDARD.decode(payload)?);
324
325 Ok(Self {
326 request_id,
327 packet_id,
328 last,
329 bytes,
330 status,
331 headers,
332 })
333 }
334}
335
336#[cfg(feature = "base64")]
337fn parse_json_object_prefix(payload: &str) -> Result<(&str, &str), Base64DecodeError> {
338 let payload = payload.trim_start();
339 if !payload.starts_with('{') {
340 return Err(Base64DecodeError::InvalidContent(
341 "Missing headers JSON object for packet_id=1".into(),
342 ));
343 }
344
345 let mut depth = 0usize;
346 let mut in_string = false;
347 let mut escaped = false;
348
349 for (index, ch) in payload.char_indices() {
350 if escaped {
351 escaped = false;
352 continue;
353 }
354
355 if ch == '\\' && in_string {
356 escaped = true;
357 continue;
358 }
359
360 if ch == '"' {
361 in_string = !in_string;
362 continue;
363 }
364
365 if in_string {
366 continue;
367 }
368
369 if ch == '{' {
370 depth += 1;
371 } else if ch == '}' {
372 depth = depth.saturating_sub(1);
373 if depth == 0 {
374 let end = index + ch.len_utf8();
375 return Ok((&payload[..end], &payload[end..]));
376 }
377 }
378 }
379
380 Err(Base64DecodeError::InvalidContent(
381 "Missing closing brace for headers JSON object".into(),
382 ))
383}
384
385#[cfg(feature = "base64")]
386impl TryFrom<String> for TunnelResponse {
387 type Error = Base64DecodeError;
388
389 fn try_from(base64: String) -> Result<Self, Self::Error> {
409 base64.as_str().try_into()
410 }
411}
412
413#[derive(Debug, Error)]
415pub enum TunnelStreamError {
416 #[error("TunnelStream aborted")]
418 Aborted,
419 #[error("TunnelStream end of stream")]
421 EndOfStream,
422}
423
424pub struct TunnelStream<'a, F: Future<Output = Result<(), Box<dyn std::error::Error>>>> {
429 start: SystemTime,
430 request_id: u64,
431 time_to_first_byte: Option<SystemTime>,
432 packet_count: u32,
433 byte_count: usize,
434 done: bool,
435 end_of_stream: bool,
436 rx: Receiver<TunnelResponse>,
437 on_end: &'a dyn Fn(u64) -> F,
438 packet_queue: Vec<TunnelResponse>,
439 abort_token: CancellationToken,
440}
441
442impl<'a, F: Future<Output = Result<(), Box<dyn std::error::Error>>>> TunnelStream<'a, F> {
443 #[must_use]
468 pub fn new(
469 request_id: u64,
470 rx: Receiver<TunnelResponse>,
471 abort_token: CancellationToken,
472 on_end: &'a impl Fn(u64) -> F,
473 ) -> Self {
474 Self {
475 start: switchy_time::now(),
476 request_id,
477 time_to_first_byte: None,
478 packet_count: 0,
479 byte_count: 0,
480 done: false,
481 end_of_stream: false,
482 rx,
483 on_end,
484 packet_queue: vec![],
485 abort_token,
486 }
487 }
488
489 fn process_queued_packet(
493 &mut self,
494 ) -> Option<std::task::Poll<Option<Result<Bytes, TunnelStreamError>>>> {
495 if self
496 .packet_queue
497 .first()
498 .is_some_and(|x| x.packet_id == self.packet_count + 1)
499 {
500 let response = self.packet_queue.remove(0);
501 log::debug!(
502 "poll_next: Sending queued packet_id={} for request_id={}",
503 response.packet_id,
504 self.request_id,
505 );
506 Some(return_polled_bytes(self, response))
507 } else {
508 None
509 }
510 }
511}
512
513fn return_polled_bytes<F: Future<Output = Result<(), Box<dyn std::error::Error>>>>(
518 stream: &mut TunnelStream<F>,
519 response: TunnelResponse,
520) -> std::task::Poll<Option<Result<Bytes, TunnelStreamError>>> {
521 if stream.time_to_first_byte.is_none() {
522 stream.time_to_first_byte = Some(switchy_time::now());
523 }
524
525 stream.packet_count += 1;
526
527 log::debug!(
528 "return_polled_bytes: Received packet for request_id={} packet_count={} {} bytes last={}",
529 stream.request_id,
530 stream.packet_count,
531 response.bytes.len(),
532 response.last,
533 );
534
535 if response.last {
536 stream.done = true;
537 }
538
539 stream.byte_count += response.bytes.len();
540
541 Poll::Ready(Some(Ok(response.bytes)))
542}
543
544impl<F: Future<Output = Result<(), Box<dyn std::error::Error>>>> Stream for TunnelStream<'_, F> {
545 type Item = Result<Bytes, TunnelStreamError>;
546
547 #[allow(clippy::too_many_lines)]
548 fn poll_next(
549 mut self: std::pin::Pin<&mut Self>,
550 cx: &mut std::task::Context<'_>,
551 ) -> std::task::Poll<Option<Self::Item>> {
552 let request_id = {
553 let mut stream = self.as_mut();
554 let request_id = stream.request_id;
555
556 log::trace!(
557 "poll_next: TunnelStream poll for request_id={request_id} packet_count={}",
558 stream.packet_count,
559 );
560
561 if stream.end_of_stream {
562 log::trace!(
563 "poll_next: End of stream for request_id={request_id} packet_count={}",
564 stream.packet_count,
565 );
566 return stream
567 .process_queued_packet()
568 .unwrap_or(Poll::Ready(Some(Err(TunnelStreamError::EndOfStream))));
569 }
570
571 if stream.abort_token.is_cancelled() {
572 log::debug!("poll_next: Stream is cancelled for request_id={request_id}");
573 return Poll::Ready(Some(Err(TunnelStreamError::Aborted)));
574 }
575
576 if stream.done {
577 let end = switchy_time::now();
578
579 log::debug!(
580 "poll_next: Byte count: {} for request_id={request_id} (received {} packet{}, took {}ms total, {}ms to first byte)",
581 stream.byte_count,
582 stream.packet_count,
583 if stream.packet_count == 1 { "" } else { "s" },
584 end.duration_since(stream.start).unwrap().as_millis(),
585 stream
586 .time_to_first_byte
587 .map(|t| t.duration_since(stream.start).unwrap().as_millis())
588 .map_or_else(|| "N/A".into(), |t| t.to_string())
589 );
590
591 (stream.on_end)(stream.request_id);
592
593 return Poll::Ready(None);
594 }
595
596 log::debug!(
597 "poll_next: Waiting for next packet for request_id={request_id} packet_count={}",
598 stream.packet_count,
599 );
600 let response = match stream.rx.poll_recv(cx) {
601 Poll::Ready(Some(response)) => response,
602 Poll::Pending => {
603 log::debug!("poll_next: Pending for request_id={request_id}...");
604 return stream.process_queued_packet().unwrap_or(Poll::Pending);
605 }
606 Poll::Ready(None) => {
607 log::debug!("poll_next: Finished");
608 moosicbox_assert::assert!(
609 !stream.done,
610 "Stream is not finished for request_id={request_id}"
611 );
612 stream.end_of_stream = true;
613 return stream.process_queued_packet().unwrap_or(Poll::Ready(None));
614 }
615 };
616 log::debug!(
617 "poll_next: Received next packet for request_id={request_id} packet_count={}: packet_id={} status={:?} last={}",
618 stream.packet_count,
619 response.packet_id,
620 response.status,
621 response.last,
622 );
623
624 if response.packet_id == 1 && response.last {
625 log::debug!(
626 "poll_next: Received first and final packet for request_id={request_id}"
627 );
628 return return_polled_bytes(&mut stream, response);
629 }
630
631 if response.packet_id == stream.packet_count + 1 {
632 return return_polled_bytes(&mut stream, response);
633 }
634
635 log::debug!(
636 "poll_next: Received future packet_id={} for request_id={request_id}. Waiting for packet {} before continuing",
637 response.packet_id,
638 stream.packet_count + 1,
639 );
640
641 let queued_response = if stream
642 .packet_queue
643 .first()
644 .is_some_and(|x| x.packet_id == stream.packet_count + 1)
645 {
646 let response = stream.packet_queue.remove(0);
647 log::debug!(
648 "poll_next: Sending queued packet_id={} for request_id={request_id}",
649 response.packet_id,
650 );
651 Some(return_polled_bytes(&mut stream, response))
652 } else {
653 None
654 };
655
656 if let Some(pos) = stream
657 .packet_queue
658 .iter()
659 .position(|r| r.packet_id > response.packet_id)
660 {
661 stream.packet_queue.insert(pos, response);
662 } else {
663 stream.packet_queue.push(response);
664 }
665
666 if let Some(response) = queued_response {
667 log::debug!("poll_next: Sending queued response for request_id={request_id}");
668 return response;
669 }
670
671 request_id
672 };
673
674 log::debug!("poll_next: Re-polling for response for request_id={request_id}");
675 self.poll_next(cx)
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682 use futures_util::StreamExt as _;
683 use std::collections::BTreeMap;
684
685 fn create_binary_response(
687 request_id: u64,
688 packet_id: u32,
689 last: bool,
690 status: Option<u16>,
691 headers: Option<BTreeMap<String, String>>,
692 body: &[u8],
693 ) -> Bytes {
694 let mut data = Vec::new();
695
696 data.extend_from_slice(&request_id.to_be_bytes());
698
699 data.extend_from_slice(&packet_id.to_be_bytes());
701
702 data.push(u8::from(last));
704
705 if packet_id == 1 {
707 let status = status.expect("First packet must have status");
708 data.extend_from_slice(&status.to_be_bytes());
709
710 let headers = headers.expect("First packet must have headers");
711 let headers_json = serde_json::to_vec(&headers).unwrap();
712 let headers_len = u32::try_from(headers_json.len()).unwrap();
713 data.extend_from_slice(&headers_len.to_be_bytes());
714 data.extend_from_slice(&headers_json);
715 }
716
717 data.extend_from_slice(body);
719
720 Bytes::from(data)
721 }
722
723 #[test_log::test]
724 fn test_tunnel_response_from_bytes_first_packet() {
725 let mut headers = BTreeMap::new();
726 headers.insert("content-type".to_string(), "application/json".to_string());
727 headers.insert("x-custom".to_string(), "test-value".to_string());
728
729 let body = b"test response body";
730 let bytes = create_binary_response(12345, 1, false, Some(200), Some(headers.clone()), body);
731
732 let response = TunnelResponse::try_from(bytes).unwrap();
733
734 assert_eq!(response.request_id, 12345);
735 assert_eq!(response.packet_id, 1);
736 assert!(!response.last);
737 assert_eq!(response.status, Some(200));
738 assert_eq!(response.headers, Some(headers));
739 assert_eq!(response.bytes.as_ref(), body);
740 }
741
742 #[test_log::test]
743 fn test_tunnel_response_from_bytes_subsequent_packet() {
744 let body = b"more data";
745 let bytes = create_binary_response(12345, 2, false, None, None, body);
746
747 let response = TunnelResponse::try_from(bytes).unwrap();
748
749 assert_eq!(response.request_id, 12345);
750 assert_eq!(response.packet_id, 2);
751 assert!(!response.last);
752 assert_eq!(response.status, None);
753 assert_eq!(response.headers, None);
754 assert_eq!(response.bytes.as_ref(), body);
755 }
756
757 #[test_log::test]
758 fn test_tunnel_response_from_bytes_final_packet() {
759 let body = b"final chunk";
760 let bytes = create_binary_response(12345, 3, true, None, None, body);
761
762 let response = TunnelResponse::try_from(bytes).unwrap();
763
764 assert_eq!(response.request_id, 12345);
765 assert_eq!(response.packet_id, 3);
766 assert!(response.last);
767 assert_eq!(response.status, None);
768 assert_eq!(response.headers, None);
769 assert_eq!(response.bytes.as_ref(), body);
770 }
771
772 #[test_log::test]
773 fn test_tunnel_response_from_bytes_empty_body() {
774 let headers = BTreeMap::new();
775 let bytes = create_binary_response(999, 1, true, Some(204), Some(headers.clone()), &[]);
776
777 let response = TunnelResponse::try_from(bytes).unwrap();
778
779 assert_eq!(response.request_id, 999);
780 assert_eq!(response.packet_id, 1);
781 assert!(response.last);
782 assert_eq!(response.status, Some(204));
783 assert_eq!(response.headers, Some(headers));
784 assert!(response.bytes.is_empty());
785 }
786
787 #[test_log::test]
788 fn test_tunnel_response_from_bytes_large_headers() {
789 let mut headers = BTreeMap::new();
790 for i in 0..50 {
791 headers.insert(format!("header-{i}"), format!("value-{i}"));
792 }
793
794 let body = b"body";
795 let bytes = create_binary_response(7777, 1, false, Some(200), Some(headers.clone()), body);
796
797 let response = TunnelResponse::try_from(bytes).unwrap();
798
799 assert_eq!(response.request_id, 7777);
800 assert_eq!(response.headers, Some(headers));
801 assert_eq!(response.bytes.as_ref(), body);
802 }
803
804 #[test_log::test]
805 #[should_panic(expected = "range start must not be greater than end")]
806 fn test_tunnel_response_from_bytes_too_short() {
807 let bytes = Bytes::from(vec![1, 2, 3, 4, 5]);
809 let _response = TunnelResponse::try_from(bytes).unwrap();
810 }
811
812 #[test_log::test]
813 fn test_tunnel_response_from_bytes_error_invalid_json_headers() {
814 let mut data = Vec::new();
815 data.extend_from_slice(&123_u64.to_be_bytes()); data.extend_from_slice(&1_u32.to_be_bytes()); data.push(0); data.extend_from_slice(&200_u16.to_be_bytes()); data.extend_from_slice(&5_u32.to_be_bytes()); data.extend_from_slice(b"{bad}"); let bytes = Bytes::from(data);
823 let result = TunnelResponse::try_from(bytes);
824
825 assert!(result.is_err());
826 assert!(matches!(result.unwrap_err(), TryFromBytesError::Serde(_)));
827 }
828
829 #[cfg(feature = "base64")]
830 #[test_log::test]
831 fn test_tunnel_response_from_base64_missing_prefix() {
832 let result = TunnelResponse::try_from("12345|1|0200{}|dGVzdA==");
833 assert!(result.is_err());
834 assert!(matches!(
835 result.unwrap_err(),
836 Base64DecodeError::InvalidContent(_)
837 ));
838 }
839
840 #[cfg(feature = "base64")]
841 #[test_log::test]
842 fn test_tunnel_response_from_base64_missing_request_id_delimiter() {
843 let invalid = format!("{BASE64_TUNNEL_RESPONSE_PREFIX}12345");
844 let result = TunnelResponse::try_from(invalid.as_str());
845 assert!(result.is_err());
846 assert!(matches!(
847 result.unwrap_err(),
848 Base64DecodeError::InvalidContent(_)
849 ));
850 }
851
852 #[cfg(feature = "base64")]
853 #[test_log::test]
854 fn test_tunnel_response_from_base64_missing_packet_id_delimiter() {
855 let invalid = format!("{BASE64_TUNNEL_RESPONSE_PREFIX}12345|1");
856 let result = TunnelResponse::try_from(invalid.as_str());
857 assert!(result.is_err());
858 assert!(matches!(
859 result.unwrap_err(),
860 Base64DecodeError::InvalidContent(_)
861 ));
862 }
863
864 #[test_log::test]
865 fn test_tunnel_request_http_serialization() {
866 let request = TunnelRequest::Http(TunnelHttpRequest {
867 request_id: 123,
868 method: Method::Get,
869 path: "/api/test".to_string(),
870 query: serde_json::json!({"foo": "bar"}),
871 payload: Some(serde_json::json!({"data": "value"})),
872 headers: Some(serde_json::json!({"Authorization": "Bearer token"})),
873 encoding: TunnelEncoding::Binary,
874 profile: Some("test-profile".to_string()),
875 });
876
877 let json = serde_json::to_string(&request).unwrap();
878 let deserialized: TunnelRequest = serde_json::from_str(&json).unwrap();
879
880 match deserialized {
881 TunnelRequest::Http(req) => {
882 assert_eq!(req.request_id, 123);
883 assert_eq!(req.method, Method::Get);
884 assert_eq!(req.path, "/api/test");
885 assert_eq!(req.encoding, TunnelEncoding::Binary);
886 }
887 _ => panic!("Expected HTTP request"),
888 }
889 }
890
891 #[test_log::test]
892 fn test_tunnel_request_ws_serialization() {
893 let request = TunnelRequest::Ws(TunnelWsRequest {
894 conn_id: 456,
895 request_id: 789,
896 body: serde_json::json!({"message": "hello"}),
897 connection_id: Some(serde_json::json!(42)),
898 profile: None,
899 });
900
901 let json = serde_json::to_string(&request).unwrap();
902 let deserialized: TunnelRequest = serde_json::from_str(&json).unwrap();
903
904 match deserialized {
905 TunnelRequest::Ws(req) => {
906 assert_eq!(req.conn_id, 456);
907 assert_eq!(req.request_id, 789);
908 assert_eq!(req.body, serde_json::json!({"message": "hello"}));
909 }
910 _ => panic!("Expected WS request"),
911 }
912 }
913
914 #[test_log::test]
915 fn test_tunnel_request_abort_serialization() {
916 let request = TunnelRequest::Abort(TunnelAbortRequest { request_id: 999 });
917
918 let json = serde_json::to_string(&request).unwrap();
919 let deserialized: TunnelRequest = serde_json::from_str(&json).unwrap();
920
921 match deserialized {
922 TunnelRequest::Abort(req) => {
923 assert_eq!(req.request_id, 999);
924 }
925 _ => panic!("Expected Abort request"),
926 }
927 }
928
929 #[test_log::test]
930 fn test_tunnel_ws_response_serialization() {
931 let response = TunnelWsResponse {
932 request_id: 123,
933 body: serde_json::json!({"status": "ok"}),
934 exclude_connection_ids: Some(vec![1, 2, 3]),
935 to_connection_ids: Some(vec![4, 5, 6]),
936 };
937
938 let json = serde_json::to_string(&response).unwrap();
939 let deserialized: TunnelWsResponse = serde_json::from_str(&json).unwrap();
940
941 assert_eq!(deserialized.request_id, 123);
942 assert_eq!(deserialized.exclude_connection_ids, Some(vec![1, 2, 3]));
943 assert_eq!(deserialized.to_connection_ids, Some(vec![4, 5, 6]));
944 }
945
946 #[test_log::test]
947 fn test_tunnel_ws_response_optional_fields_omitted() {
948 let response = TunnelWsResponse {
949 request_id: 456,
950 body: serde_json::json!({"data": "test"}),
951 exclude_connection_ids: None,
952 to_connection_ids: None,
953 };
954
955 let json = serde_json::to_string(&response).unwrap();
956
957 assert!(!json.contains("exclude_connection_ids"));
959 assert!(!json.contains("to_connection_ids"));
960
961 let deserialized: TunnelWsResponse = serde_json::from_str(&json).unwrap();
962 assert_eq!(deserialized.request_id, 456);
963 assert_eq!(deserialized.exclude_connection_ids, None);
964 assert_eq!(deserialized.to_connection_ids, None);
965 }
966
967 #[test_log::test]
968 fn test_tunnel_encoding_serialization() {
969 let binary = TunnelEncoding::Binary;
970 let json = serde_json::to_string(&binary).unwrap();
971 assert_eq!(json, "\"BINARY\"");
972
973 let deserialized: TunnelEncoding = serde_json::from_str(&json).unwrap();
974 assert_eq!(deserialized, TunnelEncoding::Binary);
975 }
976
977 #[cfg(feature = "base64")]
978 #[test_log::test]
979 fn test_tunnel_encoding_base64_serialization() {
980 let base64 = TunnelEncoding::Base64;
981 let json = serde_json::to_string(&base64).unwrap();
982 assert_eq!(json, "\"BASE64\"");
983
984 let deserialized: TunnelEncoding = serde_json::from_str(&json).unwrap();
985 assert_eq!(deserialized, TunnelEncoding::Base64);
986 }
987
988 #[test_log::test]
989 fn test_tunnel_http_request_optional_fields() {
990 let request = TunnelHttpRequest {
991 request_id: 1,
992 method: Method::Post,
993 path: "/test".to_string(),
994 query: serde_json::json!({}),
995 payload: None,
996 headers: None,
997 encoding: TunnelEncoding::Binary,
998 profile: None,
999 };
1000
1001 let json = serde_json::to_string(&request).unwrap();
1002
1003 assert!(!json.contains("payload"));
1005 assert!(!json.contains("headers"));
1006
1007 let deserialized: TunnelHttpRequest = serde_json::from_str(&json).unwrap();
1008 assert_eq!(deserialized.payload, None);
1009 assert_eq!(deserialized.headers, None);
1010 assert_eq!(deserialized.profile, None);
1011 }
1012
1013 #[test_log::test]
1014 fn test_tunnel_request_tagged_enum_format() {
1015 let http_request = TunnelRequest::Http(TunnelHttpRequest {
1016 request_id: 1,
1017 method: Method::Get,
1018 path: "/".to_string(),
1019 query: serde_json::json!({}),
1020 payload: None,
1021 headers: None,
1022 encoding: TunnelEncoding::Binary,
1023 profile: None,
1024 });
1025
1026 let json = serde_json::to_string(&http_request).unwrap();
1027
1028 assert!(json.contains("\"type\":\"HTTP\""));
1030
1031 let ws_request = TunnelRequest::Ws(TunnelWsRequest {
1032 conn_id: 1,
1033 request_id: 2,
1034 body: serde_json::json!({}),
1035 connection_id: None,
1036 profile: None,
1037 });
1038
1039 let json = serde_json::to_string(&ws_request).unwrap();
1040 assert!(json.contains("\"type\":\"WS\""));
1041
1042 let abort_request = TunnelRequest::Abort(TunnelAbortRequest { request_id: 3 });
1043 let json = serde_json::to_string(&abort_request).unwrap();
1044 assert!(json.contains("\"type\":\"ABORT\""));
1045 }
1046
1047 fn create_tunnel_response(
1049 request_id: u64,
1050 packet_id: u32,
1051 last: bool,
1052 body: &[u8],
1053 ) -> TunnelResponse {
1054 TunnelResponse {
1055 request_id,
1056 packet_id,
1057 last,
1058 bytes: Bytes::from(body.to_vec()),
1059 status: if packet_id == 1 { Some(200) } else { None },
1060 headers: if packet_id == 1 {
1061 Some(BTreeMap::new())
1062 } else {
1063 None
1064 },
1065 }
1066 }
1067
1068 async fn noop_on_end(_request_id: u64) -> Result<(), Box<dyn std::error::Error>> {
1070 Ok(())
1071 }
1072
1073 #[test_log::test(switchy_async::test)]
1074 async fn test_tunnel_stream_single_packet_first_and_last() {
1075 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1076 let abort_token = CancellationToken::new();
1077
1078 let mut stream = TunnelStream::new(123, rx, abort_token, &noop_on_end);
1079
1080 tx.send(create_tunnel_response(123, 1, true, b"complete response"))
1082 .unwrap();
1083
1084 let result = stream.next().await;
1086 assert!(result.is_some());
1087 let bytes = result.unwrap().unwrap();
1088 assert_eq!(bytes.as_ref(), b"complete response");
1089
1090 let result = stream.next().await;
1092 assert!(result.is_none());
1093 }
1094
1095 #[test_log::test(switchy_async::test)]
1096 async fn test_tunnel_stream_in_order_packets() {
1097 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1098 let abort_token = CancellationToken::new();
1099
1100 let mut stream = TunnelStream::new(456, rx, abort_token, &noop_on_end);
1101
1102 tx.send(create_tunnel_response(456, 1, false, b"packet1"))
1104 .unwrap();
1105 tx.send(create_tunnel_response(456, 2, false, b"packet2"))
1106 .unwrap();
1107 tx.send(create_tunnel_response(456, 3, true, b"packet3"))
1108 .unwrap();
1109
1110 let bytes1 = stream.next().await.unwrap().unwrap();
1112 assert_eq!(bytes1.as_ref(), b"packet1");
1113
1114 let bytes2 = stream.next().await.unwrap().unwrap();
1115 assert_eq!(bytes2.as_ref(), b"packet2");
1116
1117 let bytes3 = stream.next().await.unwrap().unwrap();
1118 assert_eq!(bytes3.as_ref(), b"packet3");
1119
1120 let result = stream.next().await;
1122 assert!(result.is_none());
1123 }
1124
1125 #[test_log::test(switchy_async::test)]
1126 async fn test_tunnel_stream_out_of_order_packets() {
1127 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1128 let abort_token = CancellationToken::new();
1129
1130 let mut stream = TunnelStream::new(789, rx, abort_token, &noop_on_end);
1131
1132 tx.send(create_tunnel_response(789, 2, false, b"packet2"))
1134 .unwrap();
1135 tx.send(create_tunnel_response(789, 3, true, b"packet3"))
1136 .unwrap();
1137 tx.send(create_tunnel_response(789, 1, false, b"packet1"))
1138 .unwrap();
1139
1140 let bytes1 = stream.next().await.unwrap().unwrap();
1142 assert_eq!(bytes1.as_ref(), b"packet1");
1143
1144 let bytes2 = stream.next().await.unwrap().unwrap();
1145 assert_eq!(bytes2.as_ref(), b"packet2");
1146
1147 let bytes3 = stream.next().await.unwrap().unwrap();
1148 assert_eq!(bytes3.as_ref(), b"packet3");
1149
1150 let result = stream.next().await;
1152 assert!(result.is_none());
1153 }
1154
1155 #[test_log::test(switchy_async::test)]
1156 async fn test_tunnel_stream_abort_cancellation() {
1157 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1158 let abort_token = CancellationToken::new();
1159
1160 let mut stream = TunnelStream::new(111, rx, abort_token.clone(), &noop_on_end);
1161
1162 tx.send(create_tunnel_response(111, 1, false, b"packet1"))
1164 .unwrap();
1165
1166 let bytes1 = stream.next().await.unwrap().unwrap();
1168 assert_eq!(bytes1.as_ref(), b"packet1");
1169
1170 abort_token.cancel();
1172
1173 let result = stream.next().await;
1175 assert!(result.is_some());
1176 let err = result.unwrap().unwrap_err();
1177 assert!(matches!(err, TunnelStreamError::Aborted));
1178 }
1179
1180 #[test_log::test(switchy_async::test)]
1181 async fn test_tunnel_stream_end_of_stream_before_completion() {
1182 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1183 let abort_token = CancellationToken::new();
1184
1185 let mut stream = TunnelStream::new(222, rx, abort_token, &noop_on_end);
1186
1187 tx.send(create_tunnel_response(222, 1, false, b"packet1"))
1189 .unwrap();
1190
1191 let bytes1 = stream.next().await.unwrap().unwrap();
1193 assert_eq!(bytes1.as_ref(), b"packet1");
1194
1195 drop(tx);
1197
1198 let result = stream.next().await;
1200 assert!(result.is_none());
1201 }
1202
1203 #[test_log::test(switchy_async::test)]
1204 async fn test_tunnel_stream_queue_insertion_maintains_order() {
1205 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1206 let abort_token = CancellationToken::new();
1207
1208 let mut stream = TunnelStream::new(333, rx, abort_token, &noop_on_end);
1209
1210 tx.send(create_tunnel_response(333, 5, true, b"packet5"))
1213 .unwrap();
1214 tx.send(create_tunnel_response(333, 4, false, b"packet4"))
1215 .unwrap();
1216 tx.send(create_tunnel_response(333, 3, false, b"packet3"))
1217 .unwrap();
1218 tx.send(create_tunnel_response(333, 2, false, b"packet2"))
1219 .unwrap();
1220 tx.send(create_tunnel_response(333, 1, false, b"packet1"))
1221 .unwrap();
1222
1223 for i in 1..=5 {
1225 let bytes = stream.next().await.unwrap().unwrap();
1226 assert_eq!(bytes.as_ref(), format!("packet{i}").as_bytes());
1227 }
1228
1229 let result = stream.next().await;
1231 assert!(result.is_none());
1232 }
1233
1234 #[test_log::test(switchy_async::test)]
1235 async fn test_tunnel_stream_empty_body_packets() {
1236 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1237 let abort_token = CancellationToken::new();
1238
1239 let mut stream = TunnelStream::new(444, rx, abort_token, &noop_on_end);
1240
1241 tx.send(create_tunnel_response(444, 1, false, b"")).unwrap();
1243 tx.send(create_tunnel_response(444, 2, true, b"")).unwrap();
1244
1245 let bytes1 = stream.next().await.unwrap().unwrap();
1246 assert!(bytes1.is_empty());
1247
1248 let bytes2 = stream.next().await.unwrap().unwrap();
1249 assert!(bytes2.is_empty());
1250
1251 let result = stream.next().await;
1253 assert!(result.is_none());
1254 }
1255
1256 #[test_log::test(switchy_async::test)]
1257 async fn test_tunnel_stream_end_of_stream_with_queued_packets() {
1258 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1259 let abort_token = CancellationToken::new();
1260
1261 let mut stream = TunnelStream::new(555, rx, abort_token, &noop_on_end);
1262
1263 tx.send(create_tunnel_response(555, 2, false, b"packet2"))
1265 .unwrap();
1266 tx.send(create_tunnel_response(555, 3, true, b"packet3"))
1267 .unwrap();
1268
1269 drop(tx);
1271
1272 let result = stream.next().await;
1275 assert!(result.is_none());
1276
1277 let result = stream.next().await;
1280 assert!(result.is_some());
1281 let err = result.unwrap().unwrap_err();
1282 assert!(matches!(err, TunnelStreamError::EndOfStream));
1283 }
1284
1285 #[test_log::test(switchy_async::test)]
1286 async fn test_tunnel_stream_processes_queued_packet_when_pending() {
1287 let (tx, rx) = switchy_async::sync::mpsc::unbounded();
1288 let abort_token = CancellationToken::new();
1289
1290 let mut stream = TunnelStream::new(666, rx, abort_token, &noop_on_end);
1291
1292 tx.send(create_tunnel_response(666, 2, false, b"packet2"))
1294 .unwrap();
1295
1296 tx.send(create_tunnel_response(666, 1, false, b"packet1"))
1298 .unwrap();
1299
1300 tx.send(create_tunnel_response(666, 3, true, b"packet3"))
1302 .unwrap();
1303
1304 let bytes1 = stream.next().await.unwrap().unwrap();
1306 assert_eq!(bytes1.as_ref(), b"packet1");
1307
1308 let bytes2 = stream.next().await.unwrap().unwrap();
1309 assert_eq!(bytes2.as_ref(), b"packet2");
1310
1311 let bytes3 = stream.next().await.unwrap().unwrap();
1312 assert_eq!(bytes3.as_ref(), b"packet3");
1313
1314 let result = stream.next().await;
1316 assert!(result.is_none());
1317 }
1318}