1use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9 io::AsyncWriteExt,
10 net::TcpStream,
11 time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use prometheus::IntCounter;
16
17use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
18use crate::engine::AsyncEngineContext;
19use crate::pipeline::network::{
20 ConnectionInfo, ResponseStreamPrologue, StreamSender,
21 codec::{TwoPartCodec, TwoPartMessage},
22 tcp::StreamType,
23};
24use anyhow::{Context, Result, anyhow as error}; #[allow(dead_code)]
27pub struct TcpClient {
28 worker_id: String,
29}
30
31impl Default for TcpClient {
32 fn default() -> Self {
33 TcpClient {
34 worker_id: uuid::Uuid::new_v4().to_string(),
35 }
36 }
37}
38
39impl TcpClient {
40 pub fn new(worker_id: String) -> Self {
41 TcpClient { worker_id }
42 }
43
44 async fn connect(address: &str) -> std::io::Result<TcpStream> {
45 let backoff = std::time::Duration::from_millis(200);
47 loop {
48 match TcpStream::connect(address).await {
49 Ok(socket) => {
50 socket.set_nodelay(true)?;
51 return Ok(socket);
52 }
53 Err(e) => {
54 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
55 tracing::warn!("retry warning: failed to connect: {:?}", e);
56 tokio::time::sleep(backoff).await;
57 } else {
58 return Err(e);
59 }
60 }
61 }
62 }
63 }
64
65 pub async fn create_response_stream(
66 context: Arc<dyn AsyncEngineContext>,
67 info: ConnectionInfo,
68 cancellation_counter: Option<IntCounter>,
69 ) -> Result<StreamSender> {
70 let info =
71 TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
72 tracing::trace!("Creating response stream for {:?}", info);
73
74 if info.stream_type != StreamType::Response {
75 return Err(error!(
76 "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
77 info.stream_type
78 ));
79 }
80
81 if info.context != context.id() {
82 return Err(error!(
83 "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
84 context.id(),
85 info.context
86 ));
87 }
88
89 let stream = TcpClient::connect(&info.address).await?;
90 let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
91 let (read_half, write_half) = tokio::io::split(stream);
92
93 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
94 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
95
96 let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
102
103 let reader_task = tokio::spawn(handle_reader(
104 framed_reader,
105 context.clone(),
106 alive_tx,
107 cancellation_counter,
108 ));
109
110 let handshake = CallHomeHandshake {
112 subject: info.subject.clone(),
113 stream_type: StreamType::Response,
114 };
115
116 let handshake_bytes = match serde_json::to_vec(&handshake) {
117 Ok(hb) => hb,
118 Err(err) => {
119 return Err(error!(
120 "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
121 ));
122 }
123 };
124 let msg = TwoPartMessage::from_header(handshake_bytes.into());
125
126 framed_writer
128 .send(msg)
129 .await
130 .map_err(|e| error!("failed to send handshake: {:?}", e))?;
131
132 let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
134
135 let writer_context = context.clone();
137 let writer_task = tokio::spawn(handle_writer(
138 framed_writer,
139 bytes_rx,
140 alive_rx,
141 writer_context,
142 ));
143
144 let subject = info.subject.clone();
145 let monitor_context = context;
146 tokio::spawn(async move {
149 let _ = wait_for_connection_tasks(
150 reader_task,
151 writer_task,
152 monitor_context,
153 peer_port,
154 subject,
155 )
156 .await;
157 });
158
159 let prologue = Some(ResponseStreamPrologue { error: None });
162
163 let stream_sender = StreamSender {
165 tx: bytes_tx,
166 prologue,
167 };
168
169 Ok(stream_sender)
170 }
171}
172
173async fn wait_for_connection_tasks(
174 reader_task: tokio::task::JoinHandle<FramedRead<ReadHalf<TcpStream>, TwoPartCodec>>,
175 writer_task: tokio::task::JoinHandle<Result<FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>>>,
176 context: Arc<dyn AsyncEngineContext>,
177 peer_port: Option<u16>,
178 subject: String,
179) -> Result<()> {
180 let reader = match reader_task.await {
183 Ok(reader) => reader,
184 Err(reader_err) => {
185 writer_task.abort();
186 let _ = writer_task.await;
187 tracing::error!(
188 subject = %subject,
189 peer_port = ?peer_port,
190 err = ?reader_err,
191 "reader task failed to join"
192 );
193 return Err(reader_err.into());
194 }
195 };
196
197 let writer = match writer_task.await {
198 Ok(writer) => writer,
199 Err(writer_err) => {
200 tracing::error!(
201 subject = %subject,
202 peer_port = ?peer_port,
203 err = ?writer_err,
204 "writer task failed to join"
205 );
206 return Err(writer_err.into());
207 }
208 };
209
210 let reader = reader.into_inner();
211 let writer = match writer {
212 Ok(writer) => writer.into_inner(),
213 Err(e) => {
214 tracing::error!(
215 subject = %subject,
216 peer_port = ?peer_port,
217 err = ?e,
218 "writer task returned error"
219 );
220 return Err(e);
221 }
222 };
223
224 let stream = reader.unsplit(writer);
225 wait_for_server_shutdown(stream, context).await
226}
227
228async fn wait_for_server_shutdown(
229 mut stream: TcpStream,
230 context: Arc<dyn AsyncEngineContext>,
231) -> Result<()> {
232 if context.is_killed() || context.is_stopped() {
236 tracing::debug!("stream context killed or stopped; skipping server FIN wait");
237 return Ok(());
238 }
239
240 let mut buf = [0u8; 1024];
243 let deadline = Instant::now() + Duration::from_secs(10);
244 loop {
245 let n = time::timeout_at(deadline, stream.read(&mut buf))
246 .await
247 .inspect_err(|_| {
248 tracing::debug!("server did not close socket within the deadline");
249 })?
250 .inspect_err(|e| {
251 tracing::debug!(err = ?e, "failed to read from stream");
252 })?;
253 if n == 0 {
254 break;
256 }
257 }
258
259 Ok(())
260}
261
262async fn handle_reader(
263 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
264 context: Arc<dyn AsyncEngineContext>,
265 alive_tx: tokio::sync::oneshot::Sender<()>,
266 cancellation_counter: Option<IntCounter>,
267) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
268 let mut framed_reader = framed_reader;
269 let mut alive_tx = alive_tx;
270 let mut cancellation_seen = false;
272 loop {
273 tokio::select! {
274 msg = framed_reader.next() => {
275 match msg {
276 Some(Ok(two_part_msg)) => {
277 match two_part_msg.optional_parts() {
278 (Some(bytes), None) => {
279 let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
280 Ok(msg) => msg,
281 Err(e) => {
282 tracing::warn!(
283 err = ?e,
284 "invalid control message, closing connection"
285 );
286 cancellation_seen = true;
287 context.kill();
288 break;
289 }
290 };
291
292 match msg {
299 ControlMessage::Stop => {
300 cancellation_seen = true;
301 context.stop();
302 }
303 ControlMessage::Kill => {
304 cancellation_seen = true;
305 context.kill();
306 }
307 ControlMessage::Sentinel => {
308 tracing::warn!(
309 "unexpected sentinel on client reader, closing connection"
310 );
311 cancellation_seen = true;
312 context.kill();
313 break;
314 }
315 }
316 }
317 _ => {
318 tracing::warn!(
319 "unexpected non-control message on client reader, closing connection"
320 );
321 cancellation_seen = true;
322 context.kill();
323 break;
324 }
325 }
326 }
327 Some(Err(e)) => {
328 tracing::warn!(err = ?e, "tcp stream read error, closing connection");
331 cancellation_seen = true;
332 context.kill();
333 break;
334 }
335 None => {
336 tracing::debug!("tcp stream closed by server");
337 cancellation_seen = true;
338 break;
339 }
340 }
341 }
342 _ = alive_tx.closed() => {
343 break;
344 }
345 }
346 }
347 if cancellation_seen && let Some(counter) = &cancellation_counter {
348 counter.inc();
349 }
350 framed_reader
351}
352
353async fn handle_writer(
354 mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
355 mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
356 alive_rx: tokio::sync::oneshot::Receiver<()>,
357 context: Arc<dyn AsyncEngineContext>,
358) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
359 let mut send_sentinel = true;
361
362 loop {
363 let msg = tokio::select! {
364 biased;
365
366 _ = context.killed() => {
367 tracing::trace!("context kill signal received; shutting down");
368 send_sentinel = false;
369 break;
370 }
371
372 _ = context.stopped() => {
373 tracing::trace!("context stop signal received; shutting down");
374 send_sentinel = false;
375 break;
376 }
377
378 msg = bytes_rx.recv() => {
379 match msg {
380 Some(msg) => msg,
381 None => {
382 tracing::trace!("response channel closed; shutting down");
383 break;
384 }
385 }
386 }
387 };
388
389 if let Err(e) = framed_writer.send(msg).await {
390 tracing::trace!(
391 "failed to send message to network; possible disconnect: {:?}",
392 e
393 );
394 send_sentinel = false;
395 break;
396 }
397 }
398
399 if send_sentinel {
401 let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
402 let msg = TwoPartMessage::from_header(message.into());
403 framed_writer.send(msg).await?;
404 }
405
406 drop(alive_rx);
407 Ok(framed_writer)
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::pipeline::context::Controller;
414 use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
415 use bytes::Bytes;
416 use futures::StreamExt;
417 use std::sync::Arc;
418 use tokio::io::{AsyncReadExt, AsyncWriteExt};
419 use tokio::net::TcpStream;
420 use tokio::sync::{mpsc, oneshot};
421 use tokio_util::codec::FramedRead;
422
423 struct WriterHarness {
424 server: tokio::net::TcpStream,
425 framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
426 bytes_tx: mpsc::Sender<TwoPartMessage>,
427 bytes_rx: mpsc::Receiver<TwoPartMessage>,
428 alive_tx: oneshot::Sender<()>,
429 alive_rx: oneshot::Receiver<()>,
430 controller: Arc<Controller>,
431 }
432
433 async fn writer_harness() -> WriterHarness {
435 let (client, server) = create_tcp_pair().await;
436 let (_, write_half) = tokio::io::split(client);
437 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
438
439 let (bytes_tx, bytes_rx) = mpsc::channel(64);
440 let (alive_tx, alive_rx) = oneshot::channel::<()>();
441 let controller = Arc::new(Controller::default());
442
443 WriterHarness {
444 server,
445 framed_writer,
446 bytes_tx,
447 bytes_rx,
448 alive_tx,
449 alive_rx,
450 controller,
451 }
452 }
453
454 async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
455 reader
456 .next()
457 .await
458 .expect("expected message")
459 .expect("failed to decode message")
460 }
461
462 fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
463 let (header, data) = msg.optional_parts();
464 assert!(header.is_none(), "data-only message should not have header");
465 assert_eq!(
466 data.expect("data payload missing").as_ref(),
467 expected,
468 "data payload should match"
469 );
470 }
471
472 fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
473 let (header, data) = msg.optional_parts();
474 assert!(data.is_none(), "header-only message should not carry data");
475 assert_eq!(
476 header.expect("header missing").as_ref(),
477 expected,
478 "header payload should match"
479 );
480 }
481
482 fn assert_header_and_data_message(
483 msg: TwoPartMessage,
484 expected_header: &[u8],
485 expected_data: &[u8],
486 ) {
487 let (header, data) = msg.optional_parts();
488 assert_eq!(
489 header.expect("header missing").as_ref(),
490 expected_header,
491 "header payload should match"
492 );
493 assert_eq!(
494 data.expect("data missing").as_ref(),
495 expected_data,
496 "data payload should match"
497 );
498 }
499
500 fn assert_sentinel_message(msg: TwoPartMessage) {
501 let (header, data) = msg.optional_parts();
502 assert!(data.is_none(), "sentinel should not include a data section");
503 let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
504 assert_eq!(
505 header.expect("sentinel header missing").as_ref(),
506 expected_sentinel.as_slice(),
507 "sentinel header should match serialized ControlMessage::Sentinel"
508 );
509 }
510
511 #[tokio::test]
513 async fn test_handle_writer_forwards_messages() {
514 let WriterHarness {
515 server,
516 framed_writer,
517 bytes_tx,
518 bytes_rx,
519 alive_rx,
520 controller,
521 ..
522 } = writer_harness().await;
523
524 let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
526 bytes_tx.send(test_msg).await.unwrap();
527
528 drop(bytes_tx);
530
531 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
532
533 assert!(result.is_ok());
534
535 let mut reader = FramedRead::new(server, TwoPartCodec::default());
537
538 let msg = recv_msg(&mut reader).await;
539 assert_data_only_message(msg, b"test data");
540
541 let sentinel = recv_msg(&mut reader).await;
542 assert_sentinel_message(sentinel);
543 }
544
545 #[tokio::test]
547 async fn test_handle_writer_sends_sentinel_on_normal_closure() {
548 let WriterHarness {
549 mut server,
550 framed_writer,
551 bytes_tx,
552 bytes_rx,
553 alive_rx,
554 controller,
555 ..
556 } = writer_harness().await;
557
558 drop(bytes_tx);
560
561 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
562
563 assert!(result.is_ok());
564
565 let mut buffer = vec![0u8; 1024];
567 let n = server.read(&mut buffer).await.unwrap();
568
569 assert!(n > 0, "Expected sentinel to be written to the TCP stream");
571
572 let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
574 assert!(
575 buffer[..n]
576 .windows(sentinel_json.len())
577 .any(|w| w == sentinel_json.as_slice()),
578 "Buffer should contain sentinel message. Buffer: {:?}",
579 String::from_utf8_lossy(&buffer[..n])
580 );
581 }
582
583 #[tokio::test]
585 async fn test_handle_writer_no_sentinel_on_context_killed() {
586 let WriterHarness {
587 mut server,
588 framed_writer,
589 bytes_rx,
590 alive_rx,
591 controller,
592 ..
593 } = writer_harness().await;
594
595 controller.kill();
597
598 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
599
600 assert!(result.is_ok());
601
602 drop(result);
605
606 let mut buffer = vec![0u8; 1024];
608 let n = server.read(&mut buffer).await.unwrap();
609
610 let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
612 assert!(
613 n == 0
614 || !buffer[..n]
615 .windows(sentinel_json.len())
616 .any(|w| w == sentinel_json.as_slice()),
617 "Buffer should NOT contain sentinel message when context is killed"
618 );
619 }
620
621 #[tokio::test]
623 async fn test_handle_writer_no_sentinel_on_context_stopped() {
624 let WriterHarness {
625 mut server,
626 framed_writer,
627 bytes_rx,
628 alive_rx,
629 controller,
630 ..
631 } = writer_harness().await;
632
633 controller.stop();
635
636 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
637
638 assert!(result.is_ok());
639
640 drop(result);
643
644 let mut buffer = vec![0u8; 1024];
646 let n = server.read(&mut buffer).await.unwrap();
647
648 let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
650 assert!(
651 n == 0
652 || !buffer[..n]
653 .windows(sentinel_json.len())
654 .any(|w| w == sentinel_json.as_slice()),
655 "Buffer should NOT contain sentinel message when context is stopped"
656 );
657 }
658
659 #[tokio::test]
661 async fn test_handle_writer_multiple_messages() {
662 let WriterHarness {
663 server,
664 framed_writer,
665 bytes_tx,
666 bytes_rx,
667 alive_rx,
668 controller,
669 ..
670 } = writer_harness().await;
671
672 for i in 0..5 {
674 let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
675 bytes_tx.send(test_msg).await.unwrap();
676 }
677
678 drop(bytes_tx);
680
681 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
682
683 assert!(result.is_ok());
684
685 let mut reader = FramedRead::new(server, TwoPartCodec::default());
687 for i in 0..5 {
688 let msg = recv_msg(&mut reader).await;
689 assert_data_only_message(msg, format!("message {}", i).as_bytes());
690 }
691
692 let sentinel = recv_msg(&mut reader).await;
693 assert_sentinel_message(sentinel);
694 }
695
696 #[tokio::test]
698 async fn test_handle_writer_drops_alive_rx() {
699 let WriterHarness {
700 framed_writer,
701 bytes_tx,
702 bytes_rx,
703 alive_tx,
704 alive_rx,
705 controller,
706 ..
707 } = writer_harness().await;
708
709 drop(bytes_tx);
711
712 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
713
714 assert!(result.is_ok());
715
716 assert!(alive_tx.is_closed());
718 }
719
720 #[tokio::test]
722 async fn test_handle_writer_header_only_messages() {
723 let WriterHarness {
724 server,
725 framed_writer,
726 bytes_tx,
727 bytes_rx,
728 alive_rx,
729 controller,
730 ..
731 } = writer_harness().await;
732
733 let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
735 bytes_tx.send(header_msg).await.unwrap();
736
737 drop(bytes_tx);
739
740 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
741
742 assert!(result.is_ok());
743
744 let mut reader = FramedRead::new(server, TwoPartCodec::default());
745
746 let header_msg = recv_msg(&mut reader).await;
747 assert_header_only_message(header_msg, b"header content");
748
749 let sentinel = recv_msg(&mut reader).await;
750 assert_sentinel_message(sentinel);
751 }
752
753 #[tokio::test]
755 async fn test_handle_writer_mixed_messages() {
756 let WriterHarness {
757 server,
758 framed_writer,
759 bytes_tx,
760 bytes_rx,
761 alive_rx,
762 controller,
763 ..
764 } = writer_harness().await;
765
766 bytes_tx
768 .send(TwoPartMessage::from_header(Bytes::from("header1")))
769 .await
770 .unwrap();
771 bytes_tx
772 .send(TwoPartMessage::from_data(Bytes::from("data1")))
773 .await
774 .unwrap();
775 bytes_tx
776 .send(TwoPartMessage::from_parts(
777 Bytes::from("header2"),
778 Bytes::from("data2"),
779 ))
780 .await
781 .unwrap();
782
783 drop(bytes_tx);
785
786 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
787
788 assert!(result.is_ok());
789
790 let mut reader = FramedRead::new(server, TwoPartCodec::default());
791
792 let first = recv_msg(&mut reader).await;
793 assert_header_only_message(first, b"header1");
794
795 let second = recv_msg(&mut reader).await;
796 assert_data_only_message(second, b"data1");
797
798 let third = recv_msg(&mut reader).await;
799 assert_header_and_data_message(third, b"header2", b"data2");
800
801 let sentinel = recv_msg(&mut reader).await;
802 assert_sentinel_message(sentinel);
803 }
804
805 #[tokio::test]
807 async fn test_wait_for_server_shutdown_skips_terminal_context() {
808 for action in [Controller::kill as fn(&Controller), Controller::stop] {
809 let (client, _server) = create_tcp_pair().await;
810 let controller = Arc::new(Controller::default());
811 action(&controller);
812
813 let context: Arc<dyn AsyncEngineContext> = controller;
814 let result = tokio::time::timeout(
815 std::time::Duration::from_millis(50),
816 wait_for_server_shutdown(client, context),
817 )
818 .await;
819
820 assert!(result.is_ok(), "terminal context should not wait for FIN");
821 assert!(
822 result.unwrap().is_ok(),
823 "terminal context shutdown should succeed"
824 );
825 }
826 }
827
828 #[tokio::test]
830 async fn test_connection_monitor_skips_fin_wait_after_read_error_kills_context() {
831 let (client, mut server) = create_tcp_pair().await;
832 let (read_half, write_half) = tokio::io::split(client);
833 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
834 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
835 let (_bytes_tx, bytes_rx) = mpsc::channel(64);
836 let (alive_tx, alive_rx) = oneshot::channel::<()>();
837 let controller = Arc::new(Controller::default());
838
839 let reader_context = controller.clone();
840 let reader_task = tokio::spawn(async move {
841 handle_reader(framed_reader, reader_context, alive_tx, None).await
842 });
843 let writer_context = controller.clone();
844 let writer_task = tokio::spawn(async move {
845 handle_writer(framed_writer, bytes_rx, alive_rx, writer_context).await
846 });
847
848 server.write_all(&[0xFF; 24]).await.unwrap();
852
853 let monitor_context: Arc<dyn AsyncEngineContext> = controller.clone();
854 let result = tokio::time::timeout(
855 std::time::Duration::from_millis(250),
856 wait_for_connection_tasks(
857 reader_task,
858 writer_task,
859 monitor_context,
860 None,
861 "test-subject".to_string(),
862 ),
863 )
864 .await;
865
866 assert!(
867 result.is_ok(),
868 "connection monitor should not wait for the FIN deadline after read error"
869 );
870 assert!(result.unwrap().is_ok(), "connection monitor should succeed");
871 assert!(
872 controller.is_killed(),
873 "read error should kill the stream context"
874 );
875 }
876
877 #[tokio::test]
887 async fn test_connection_monitor_aborts_writer_when_reader_panics() {
888 let reader_task: tokio::task::JoinHandle<
892 FramedRead<ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
893 > = tokio::spawn(async {
894 panic!("simulated reader panic to trigger JoinError");
895 });
896
897 let writer_task: tokio::task::JoinHandle<
902 Result<FramedWrite<WriteHalf<tokio::net::TcpStream>, TwoPartCodec>>,
903 > = tokio::spawn(async {
904 std::future::pending::<()>().await;
905 unreachable!()
906 });
907
908 let controller = Arc::new(Controller::default());
909 let context: Arc<dyn AsyncEngineContext> = controller.clone();
910
911 let result = tokio::time::timeout(
914 std::time::Duration::from_millis(250),
915 wait_for_connection_tasks(
916 reader_task,
917 writer_task,
918 context,
919 None,
920 "test-reader-panic".to_string(),
921 ),
922 )
923 .await;
924
925 assert!(
928 result.is_ok(),
929 "wait_for_connection_tasks must return after reader panic, \
930 not hang waiting on the writer"
931 );
932
933 assert!(
935 result.unwrap().is_err(),
936 "reader panic should propagate as Err from wait_for_connection_tasks"
937 );
938 }
939
940 struct ReaderHarness {
943 framed_server: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
944 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
945 alive_tx: oneshot::Sender<()>,
946 alive_rx: oneshot::Receiver<()>,
947 controller: Arc<Controller>,
948 }
949
950 async fn reader_harness() -> ReaderHarness {
952 let (client, server) = create_tcp_pair().await;
953 let (read_half, _write_half) = tokio::io::split(client);
954 let (_server_read, server_write) = tokio::io::split(server);
955
956 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
957 let framed_server = FramedWrite::new(server_write, TwoPartCodec::default());
958 let (alive_tx, alive_rx) = oneshot::channel::<()>();
959 let controller = Arc::new(Controller::default());
960
961 ReaderHarness {
962 framed_server,
963 framed_reader,
964 alive_tx,
965 alive_rx,
966 controller,
967 }
968 }
969
970 fn control_message(msg: &ControlMessage) -> TwoPartMessage {
971 let msg_bytes = serde_json::to_vec(msg).unwrap();
972 TwoPartMessage::from_header(Bytes::from(msg_bytes))
973 }
974
975 #[tokio::test]
977 async fn test_handle_reader_stop_control_message() {
978 let ReaderHarness {
979 mut framed_server,
980 framed_reader,
981 alive_tx,
982 alive_rx: _alive_rx,
983 controller,
984 } = reader_harness().await;
985
986 let controller_clone = controller.clone();
988 let reader_handle = tokio::spawn(async move {
989 handle_reader(framed_reader, controller_clone, alive_tx, None).await
990 });
991
992 framed_server
994 .send(control_message(&ControlMessage::Stop))
995 .await
996 .unwrap();
997
998 framed_server.close().await.unwrap();
1000
1001 let _ = reader_handle.await.unwrap();
1003
1004 assert!(
1006 controller.is_stopped(),
1007 "Controller should be stopped after receiving Stop message"
1008 );
1009 }
1010
1011 #[tokio::test]
1013 async fn test_handle_reader_kill_control_message() {
1014 let ReaderHarness {
1015 mut framed_server,
1016 framed_reader,
1017 alive_tx,
1018 alive_rx: _alive_rx,
1019 controller,
1020 } = reader_harness().await;
1021
1022 let controller_clone = controller.clone();
1024 let reader_handle = tokio::spawn(async move {
1025 handle_reader(framed_reader, controller_clone, alive_tx, None).await
1026 });
1027
1028 framed_server
1030 .send(control_message(&ControlMessage::Kill))
1031 .await
1032 .unwrap();
1033
1034 framed_server.close().await.unwrap();
1036
1037 let _ = reader_handle.await.unwrap();
1039
1040 assert!(
1042 controller.is_killed(),
1043 "Controller should be killed after receiving Kill message"
1044 );
1045 }
1046
1047 #[tokio::test]
1049 async fn test_handle_reader_exits_on_alive_channel_closed() {
1050 let ReaderHarness {
1051 framed_reader,
1052 alive_tx,
1053 alive_rx,
1054 controller,
1055 ..
1056 } = reader_harness().await;
1057
1058 let reader_handle =
1060 tokio::spawn(
1061 async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1062 );
1063
1064 drop(alive_rx);
1066
1067 let result = reader_handle.await;
1069
1070 assert!(
1071 result.is_ok(),
1072 "handle_reader should exit when alive channel is closed"
1073 );
1074 }
1075
1076 #[tokio::test]
1078 async fn test_handle_reader_exits_on_stream_closed() {
1079 let ReaderHarness {
1080 mut framed_server,
1081 framed_reader,
1082 alive_tx,
1083 alive_rx: _alive_rx,
1084 controller,
1085 } = reader_harness().await;
1086
1087 let reader_handle =
1089 tokio::spawn(
1090 async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1091 );
1092
1093 framed_server.close().await.unwrap();
1095
1096 let result = tokio::time::timeout(std::time::Duration::from_secs(1), reader_handle).await;
1098
1099 assert!(
1100 result.is_ok(),
1101 "handle_reader should exit when stream is closed"
1102 );
1103 }
1104
1105 #[tokio::test]
1107 async fn test_handle_reader_multiple_control_messages() {
1108 let ReaderHarness {
1109 mut framed_server,
1110 framed_reader,
1111 alive_tx,
1112 alive_rx: _alive_rx,
1113 controller,
1114 } = reader_harness().await;
1115
1116 let controller_clone = controller.clone();
1118 let reader_handle = tokio::spawn(async move {
1119 handle_reader(framed_reader, controller_clone, alive_tx, None).await
1120 });
1121
1122 framed_server
1124 .send(control_message(&ControlMessage::Stop))
1125 .await
1126 .unwrap();
1127 framed_server
1128 .send(control_message(&ControlMessage::Stop))
1129 .await
1130 .unwrap();
1131
1132 framed_server.close().await.unwrap();
1134
1135 let _ = reader_handle.await.unwrap();
1137
1138 assert!(
1140 controller.is_stopped(),
1141 "Controller should be stopped after receiving Stop messages"
1142 );
1143 }
1144
1145 #[tokio::test]
1147 async fn test_handle_reader_stop_then_kill() {
1148 let ReaderHarness {
1149 mut framed_server,
1150 framed_reader,
1151 alive_tx,
1152 alive_rx: _alive_rx,
1153 controller,
1154 } = reader_harness().await;
1155
1156 let controller_clone = controller.clone();
1158 let reader_handle = tokio::spawn(async move {
1159 handle_reader(framed_reader, controller_clone, alive_tx, None).await
1160 });
1161
1162 framed_server
1164 .send(control_message(&ControlMessage::Stop))
1165 .await
1166 .unwrap();
1167 framed_server
1168 .send(control_message(&ControlMessage::Kill))
1169 .await
1170 .unwrap();
1171
1172 framed_server.close().await.unwrap();
1174
1175 let _ = reader_handle.await.unwrap();
1177
1178 assert!(
1180 controller.is_killed(),
1181 "Controller should be killed after receiving Kill message"
1182 );
1183 }
1184
1185 #[tokio::test]
1187 async fn test_handle_reader_increments_cancellation_counter_on_read_error() {
1188 let ReaderHarness {
1189 framed_server,
1190 framed_reader,
1191 alive_tx,
1192 alive_rx: _alive_rx,
1193 controller,
1194 } = reader_harness().await;
1195 let cancellation_counter = IntCounter::new(
1196 "tcp_client_reader_read_error_cancellations_test",
1197 "test cancellation counter",
1198 )
1199 .unwrap();
1200
1201 let counter_clone = cancellation_counter.clone();
1202 let controller_clone = controller.clone();
1203 let reader_handle = tokio::spawn(async move {
1204 handle_reader(
1205 framed_reader,
1206 controller_clone,
1207 alive_tx,
1208 Some(counter_clone),
1209 )
1210 .await
1211 });
1212
1213 let mut raw_writer = framed_server.into_inner();
1214 raw_writer.write_all(&[0u8; 8]).await.unwrap();
1215 raw_writer.shutdown().await.unwrap();
1216
1217 let _ = reader_handle.await.unwrap();
1218
1219 assert!(
1220 controller.is_killed(),
1221 "Controller should be killed after TCP stream read error"
1222 );
1223 assert_eq!(
1224 cancellation_counter.get(),
1225 1,
1226 "read-error close should increment cancellation metric once"
1227 );
1228 }
1229
1230 async fn run_reader_with(
1233 msg: TwoPartMessage,
1234 counter_name: &str,
1235 ) -> (Arc<Controller>, IntCounter) {
1236 let ReaderHarness {
1237 mut framed_server,
1238 framed_reader,
1239 alive_tx,
1240 alive_rx: _alive_rx,
1241 controller,
1242 } = reader_harness().await;
1243 let counter = IntCounter::new(counter_name, "test counter").unwrap();
1244
1245 let counter_clone = counter.clone();
1246 let controller_clone = controller.clone();
1247 let reader_handle = tokio::spawn(async move {
1248 handle_reader(
1249 framed_reader,
1250 controller_clone,
1251 alive_tx,
1252 Some(counter_clone),
1253 )
1254 .await
1255 });
1256
1257 framed_server.send(msg).await.unwrap();
1258 let _ = reader_handle.await.unwrap();
1259
1260 (controller, counter)
1261 }
1262
1263 #[tokio::test]
1269 async fn test_handle_reader_kills_on_protocol_violations() {
1270 let cases: Vec<(&str, TwoPartMessage)> = vec![
1271 (
1272 "invalid control bytes",
1273 TwoPartMessage::from_header(Bytes::from_static(b"not a valid control message")),
1274 ),
1275 (
1276 "sentinel from server",
1277 control_message(&ControlMessage::Sentinel),
1278 ),
1279 (
1280 "non-control (data-only)",
1281 TwoPartMessage::from_data(Bytes::from_static(b"unexpected payload")),
1282 ),
1283 ];
1284
1285 for (i, (label, msg)) in cases.into_iter().enumerate() {
1286 let counter_name = format!("tcp_client_reader_protocol_violation_test_{i}");
1287 let (controller, counter) = run_reader_with(msg, &counter_name).await;
1288 assert!(
1289 controller.is_killed(),
1290 "{label}: should kill stream context"
1291 );
1292 assert_eq!(counter.get(), 1, "{label}: should be counted once");
1293 }
1294 }
1295}