1use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9 io::AsyncWriteExt,
10 net::TcpStream,
11 time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use prometheus::IntCounter;
16
17use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
18use crate::engine::AsyncEngineContext;
19use crate::pipeline::network::{
20 ConnectionInfo, ResponseStreamPrologue, StreamSender,
21 codec::{TwoPartCodec, TwoPartMessage},
22 tcp::StreamType,
23};
24use anyhow::{Context, Result, anyhow as error}; #[allow(dead_code)]
27pub struct TcpClient {
28 worker_id: String,
29}
30
31impl Default for TcpClient {
32 fn default() -> Self {
33 TcpClient {
34 worker_id: uuid::Uuid::new_v4().to_string(),
35 }
36 }
37}
38
39impl TcpClient {
40 pub fn new(worker_id: String) -> Self {
41 TcpClient { worker_id }
42 }
43
44 async fn connect(address: &str) -> std::io::Result<TcpStream> {
45 let backoff = std::time::Duration::from_millis(200);
47 loop {
48 match TcpStream::connect(address).await {
49 Ok(socket) => {
50 socket.set_nodelay(true)?;
51 return Ok(socket);
52 }
53 Err(e) => {
54 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
55 tracing::warn!("retry warning: failed to connect: {:?}", e);
56 tokio::time::sleep(backoff).await;
57 } else {
58 return Err(e);
59 }
60 }
61 }
62 }
63 }
64
65 pub async fn create_response_stream(
66 context: Arc<dyn AsyncEngineContext>,
67 info: ConnectionInfo,
68 cancellation_counter: Option<IntCounter>,
69 ) -> Result<StreamSender> {
70 let info =
71 TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
72 tracing::trace!("Creating response stream for {:?}", info);
73
74 if info.stream_type != StreamType::Response {
75 return Err(error!(
76 "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
77 info.stream_type
78 ));
79 }
80
81 if info.context != context.id() {
82 return Err(error!(
83 "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
84 context.id(),
85 info.context
86 ));
87 }
88
89 let stream = TcpClient::connect(&info.address).await?;
90 let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
91 let (read_half, write_half) = tokio::io::split(stream);
92
93 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
94 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
95
96 let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
102
103 let reader_task = tokio::spawn(handle_reader(
104 framed_reader,
105 context.clone(),
106 alive_tx,
107 cancellation_counter,
108 ));
109
110 let handshake = CallHomeHandshake {
112 subject: info.subject.clone(),
113 stream_type: StreamType::Response,
114 };
115
116 let handshake_bytes = match serde_json::to_vec(&handshake) {
117 Ok(hb) => hb,
118 Err(err) => {
119 return Err(error!(
120 "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
121 ));
122 }
123 };
124 let msg = TwoPartMessage::from_header(handshake_bytes.into());
125
126 framed_writer
128 .send(msg)
129 .await
130 .map_err(|e| error!("failed to send handshake: {:?}", e))?;
131
132 let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
134
135 let writer_context = context.clone();
137 let writer_task = tokio::spawn(handle_writer(
138 framed_writer,
139 bytes_rx,
140 alive_rx,
141 writer_context,
142 ));
143
144 let subject = info.subject.clone();
145 let monitor_context = context;
146 tokio::spawn(async move {
149 let _ = wait_for_connection_tasks(
150 reader_task,
151 writer_task,
152 monitor_context,
153 peer_port,
154 subject,
155 )
156 .await;
157 });
158
159 let prologue = Some(ResponseStreamPrologue { error: None });
162
163 let stream_sender = StreamSender {
165 tx: bytes_tx,
166 prologue,
167 };
168
169 Ok(stream_sender)
170 }
171}
172
173async fn wait_for_connection_tasks(
174 reader_task: tokio::task::JoinHandle<FramedRead<ReadHalf<TcpStream>, TwoPartCodec>>,
175 writer_task: tokio::task::JoinHandle<Result<FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>>>,
176 context: Arc<dyn AsyncEngineContext>,
177 peer_port: Option<u16>,
178 subject: String,
179) -> Result<()> {
180 let (reader, writer) = tokio::join!(reader_task, writer_task);
181
182 match (reader, writer) {
183 (Ok(reader), Ok(writer)) => {
184 let reader = reader.into_inner();
185
186 let writer = match writer {
187 Ok(writer) => writer.into_inner(),
188 Err(e) => {
189 tracing::error!(
190 subject = %subject,
191 peer_port = ?peer_port,
192 err = ?e,
193 "writer task returned error"
194 );
195 return Err(e);
196 }
197 };
198
199 let stream = reader.unsplit(writer);
200 wait_for_server_shutdown(stream, context).await
201 }
202 (Err(reader_err), Ok(_)) => {
203 tracing::error!(
204 subject = %subject,
205 peer_port = ?peer_port,
206 err = ?reader_err,
207 "reader task failed to join"
208 );
209 Err(reader_err.into())
210 }
211 (Ok(_), Err(writer_err)) => {
212 tracing::error!(
213 subject = %subject,
214 peer_port = ?peer_port,
215 err = ?writer_err,
216 "writer task failed to join"
217 );
218 Err(writer_err.into())
219 }
220 (Err(reader_err), Err(writer_err)) => {
221 tracing::error!(
222 subject = %subject,
223 peer_port = ?peer_port,
224 reader_err = ?reader_err,
225 writer_err = ?writer_err,
226 "both reader and writer tasks failed to join"
227 );
228 Err(reader_err.into())
230 }
231 }
232}
233
234async fn wait_for_server_shutdown(
235 mut stream: TcpStream,
236 context: Arc<dyn AsyncEngineContext>,
237) -> Result<()> {
238 if context.is_killed() || context.is_stopped() {
242 tracing::debug!("stream context killed or stopped; skipping server FIN wait");
243 return Ok(());
244 }
245
246 let mut buf = [0u8; 1024];
249 let deadline = Instant::now() + Duration::from_secs(10);
250 loop {
251 let n = time::timeout_at(deadline, stream.read(&mut buf))
252 .await
253 .inspect_err(|_| {
254 tracing::debug!("server did not close socket within the deadline");
255 })?
256 .inspect_err(|e| {
257 tracing::debug!(err = ?e, "failed to read from stream");
258 })?;
259 if n == 0 {
260 break;
262 }
263 }
264
265 Ok(())
266}
267
268async fn handle_reader(
269 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
270 context: Arc<dyn AsyncEngineContext>,
271 alive_tx: tokio::sync::oneshot::Sender<()>,
272 cancellation_counter: Option<IntCounter>,
273) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
274 let mut framed_reader = framed_reader;
275 let mut alive_tx = alive_tx;
276 let mut cancellation_seen = false;
278 loop {
279 tokio::select! {
280 msg = framed_reader.next() => {
281 match msg {
282 Some(Ok(two_part_msg)) => {
283 match two_part_msg.optional_parts() {
284 (Some(bytes), None) => {
285 let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
286 Ok(msg) => msg,
287 Err(e) => {
288 tracing::warn!(
289 err = ?e,
290 "invalid control message, closing connection"
291 );
292 cancellation_seen = true;
293 context.kill();
294 break;
295 }
296 };
297
298 match msg {
305 ControlMessage::Stop => {
306 cancellation_seen = true;
307 context.stop();
308 }
309 ControlMessage::Kill => {
310 cancellation_seen = true;
311 context.kill();
312 }
313 ControlMessage::Sentinel => {
314 tracing::warn!(
315 "unexpected sentinel on client reader, closing connection"
316 );
317 cancellation_seen = true;
318 context.kill();
319 break;
320 }
321 }
322 }
323 _ => {
324 tracing::warn!(
325 "unexpected non-control message on client reader, closing connection"
326 );
327 cancellation_seen = true;
328 context.kill();
329 break;
330 }
331 }
332 }
333 Some(Err(e)) => {
334 tracing::warn!(err = ?e, "tcp stream read error, closing connection");
337 cancellation_seen = true;
338 context.kill();
339 break;
340 }
341 None => {
342 tracing::debug!("tcp stream closed by server");
343 cancellation_seen = true;
344 break;
345 }
346 }
347 }
348 _ = alive_tx.closed() => {
349 break;
350 }
351 }
352 }
353 if cancellation_seen && let Some(counter) = &cancellation_counter {
354 counter.inc();
355 }
356 framed_reader
357}
358
359async fn handle_writer(
360 mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
361 mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
362 alive_rx: tokio::sync::oneshot::Receiver<()>,
363 context: Arc<dyn AsyncEngineContext>,
364) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
365 let mut send_sentinel = true;
367
368 loop {
369 let msg = tokio::select! {
370 biased;
371
372 _ = context.killed() => {
373 tracing::trace!("context kill signal received; shutting down");
374 send_sentinel = false;
375 break;
376 }
377
378 _ = context.stopped() => {
379 tracing::trace!("context stop signal received; shutting down");
380 send_sentinel = false;
381 break;
382 }
383
384 msg = bytes_rx.recv() => {
385 match msg {
386 Some(msg) => msg,
387 None => {
388 tracing::trace!("response channel closed; shutting down");
389 break;
390 }
391 }
392 }
393 };
394
395 if let Err(e) = framed_writer.send(msg).await {
396 tracing::trace!(
397 "failed to send message to network; possible disconnect: {:?}",
398 e
399 );
400 send_sentinel = false;
401 break;
402 }
403 }
404
405 if send_sentinel {
407 let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
408 let msg = TwoPartMessage::from_header(message.into());
409 framed_writer.send(msg).await?;
410 }
411
412 drop(alive_rx);
413 Ok(framed_writer)
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::pipeline::context::Controller;
420 use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
421 use bytes::Bytes;
422 use futures::StreamExt;
423 use std::sync::Arc;
424 use tokio::io::{AsyncReadExt, AsyncWriteExt};
425 use tokio::net::TcpStream;
426 use tokio::sync::{mpsc, oneshot};
427 use tokio_util::codec::FramedRead;
428
429 struct WriterHarness {
430 server: tokio::net::TcpStream,
431 framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
432 bytes_tx: mpsc::Sender<TwoPartMessage>,
433 bytes_rx: mpsc::Receiver<TwoPartMessage>,
434 alive_tx: oneshot::Sender<()>,
435 alive_rx: oneshot::Receiver<()>,
436 controller: Arc<Controller>,
437 }
438
439 async fn writer_harness() -> WriterHarness {
441 let (client, server) = create_tcp_pair().await;
442 let (_, write_half) = tokio::io::split(client);
443 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
444
445 let (bytes_tx, bytes_rx) = mpsc::channel(64);
446 let (alive_tx, alive_rx) = oneshot::channel::<()>();
447 let controller = Arc::new(Controller::default());
448
449 WriterHarness {
450 server,
451 framed_writer,
452 bytes_tx,
453 bytes_rx,
454 alive_tx,
455 alive_rx,
456 controller,
457 }
458 }
459
460 async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
461 reader
462 .next()
463 .await
464 .expect("expected message")
465 .expect("failed to decode message")
466 }
467
468 fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
469 let (header, data) = msg.optional_parts();
470 assert!(header.is_none(), "data-only message should not have header");
471 assert_eq!(
472 data.expect("data payload missing").as_ref(),
473 expected,
474 "data payload should match"
475 );
476 }
477
478 fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
479 let (header, data) = msg.optional_parts();
480 assert!(data.is_none(), "header-only message should not carry data");
481 assert_eq!(
482 header.expect("header missing").as_ref(),
483 expected,
484 "header payload should match"
485 );
486 }
487
488 fn assert_header_and_data_message(
489 msg: TwoPartMessage,
490 expected_header: &[u8],
491 expected_data: &[u8],
492 ) {
493 let (header, data) = msg.optional_parts();
494 assert_eq!(
495 header.expect("header missing").as_ref(),
496 expected_header,
497 "header payload should match"
498 );
499 assert_eq!(
500 data.expect("data missing").as_ref(),
501 expected_data,
502 "data payload should match"
503 );
504 }
505
506 fn assert_sentinel_message(msg: TwoPartMessage) {
507 let (header, data) = msg.optional_parts();
508 assert!(data.is_none(), "sentinel should not include a data section");
509 let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
510 assert_eq!(
511 header.expect("sentinel header missing").as_ref(),
512 expected_sentinel.as_slice(),
513 "sentinel header should match serialized ControlMessage::Sentinel"
514 );
515 }
516
517 #[tokio::test]
519 async fn test_handle_writer_forwards_messages() {
520 let WriterHarness {
521 server,
522 framed_writer,
523 bytes_tx,
524 bytes_rx,
525 alive_rx,
526 controller,
527 ..
528 } = writer_harness().await;
529
530 let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
532 bytes_tx.send(test_msg).await.unwrap();
533
534 drop(bytes_tx);
536
537 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
538
539 assert!(result.is_ok());
540
541 let mut reader = FramedRead::new(server, TwoPartCodec::default());
543
544 let msg = recv_msg(&mut reader).await;
545 assert_data_only_message(msg, b"test data");
546
547 let sentinel = recv_msg(&mut reader).await;
548 assert_sentinel_message(sentinel);
549 }
550
551 #[tokio::test]
553 async fn test_handle_writer_sends_sentinel_on_normal_closure() {
554 let WriterHarness {
555 mut server,
556 framed_writer,
557 bytes_tx,
558 bytes_rx,
559 alive_rx,
560 controller,
561 ..
562 } = writer_harness().await;
563
564 drop(bytes_tx);
566
567 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
568
569 assert!(result.is_ok());
570
571 let mut buffer = vec![0u8; 1024];
573 let n = server.read(&mut buffer).await.unwrap();
574
575 assert!(n > 0, "Expected sentinel to be written to the TCP stream");
577
578 let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
580 assert!(
581 buffer[..n]
582 .windows(sentinel_json.len())
583 .any(|w| w == sentinel_json.as_slice()),
584 "Buffer should contain sentinel message. Buffer: {:?}",
585 String::from_utf8_lossy(&buffer[..n])
586 );
587 }
588
589 #[tokio::test]
591 async fn test_handle_writer_no_sentinel_on_context_killed() {
592 let WriterHarness {
593 mut server,
594 framed_writer,
595 bytes_rx,
596 alive_rx,
597 controller,
598 ..
599 } = writer_harness().await;
600
601 controller.kill();
603
604 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
605
606 assert!(result.is_ok());
607
608 drop(result);
611
612 let mut buffer = vec![0u8; 1024];
614 let n = server.read(&mut buffer).await.unwrap();
615
616 let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
618 assert!(
619 n == 0
620 || !buffer[..n]
621 .windows(sentinel_json.len())
622 .any(|w| w == sentinel_json.as_slice()),
623 "Buffer should NOT contain sentinel message when context is killed"
624 );
625 }
626
627 #[tokio::test]
629 async fn test_handle_writer_no_sentinel_on_context_stopped() {
630 let WriterHarness {
631 mut server,
632 framed_writer,
633 bytes_rx,
634 alive_rx,
635 controller,
636 ..
637 } = writer_harness().await;
638
639 controller.stop();
641
642 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
643
644 assert!(result.is_ok());
645
646 drop(result);
649
650 let mut buffer = vec![0u8; 1024];
652 let n = server.read(&mut buffer).await.unwrap();
653
654 let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
656 assert!(
657 n == 0
658 || !buffer[..n]
659 .windows(sentinel_json.len())
660 .any(|w| w == sentinel_json.as_slice()),
661 "Buffer should NOT contain sentinel message when context is stopped"
662 );
663 }
664
665 #[tokio::test]
667 async fn test_handle_writer_multiple_messages() {
668 let WriterHarness {
669 server,
670 framed_writer,
671 bytes_tx,
672 bytes_rx,
673 alive_rx,
674 controller,
675 ..
676 } = writer_harness().await;
677
678 for i in 0..5 {
680 let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
681 bytes_tx.send(test_msg).await.unwrap();
682 }
683
684 drop(bytes_tx);
686
687 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
688
689 assert!(result.is_ok());
690
691 let mut reader = FramedRead::new(server, TwoPartCodec::default());
693 for i in 0..5 {
694 let msg = recv_msg(&mut reader).await;
695 assert_data_only_message(msg, format!("message {}", i).as_bytes());
696 }
697
698 let sentinel = recv_msg(&mut reader).await;
699 assert_sentinel_message(sentinel);
700 }
701
702 #[tokio::test]
704 async fn test_handle_writer_drops_alive_rx() {
705 let WriterHarness {
706 framed_writer,
707 bytes_tx,
708 bytes_rx,
709 alive_tx,
710 alive_rx,
711 controller,
712 ..
713 } = writer_harness().await;
714
715 drop(bytes_tx);
717
718 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
719
720 assert!(result.is_ok());
721
722 assert!(alive_tx.is_closed());
724 }
725
726 #[tokio::test]
728 async fn test_handle_writer_header_only_messages() {
729 let WriterHarness {
730 server,
731 framed_writer,
732 bytes_tx,
733 bytes_rx,
734 alive_rx,
735 controller,
736 ..
737 } = writer_harness().await;
738
739 let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
741 bytes_tx.send(header_msg).await.unwrap();
742
743 drop(bytes_tx);
745
746 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
747
748 assert!(result.is_ok());
749
750 let mut reader = FramedRead::new(server, TwoPartCodec::default());
751
752 let header_msg = recv_msg(&mut reader).await;
753 assert_header_only_message(header_msg, b"header content");
754
755 let sentinel = recv_msg(&mut reader).await;
756 assert_sentinel_message(sentinel);
757 }
758
759 #[tokio::test]
761 async fn test_handle_writer_mixed_messages() {
762 let WriterHarness {
763 server,
764 framed_writer,
765 bytes_tx,
766 bytes_rx,
767 alive_rx,
768 controller,
769 ..
770 } = writer_harness().await;
771
772 bytes_tx
774 .send(TwoPartMessage::from_header(Bytes::from("header1")))
775 .await
776 .unwrap();
777 bytes_tx
778 .send(TwoPartMessage::from_data(Bytes::from("data1")))
779 .await
780 .unwrap();
781 bytes_tx
782 .send(TwoPartMessage::from_parts(
783 Bytes::from("header2"),
784 Bytes::from("data2"),
785 ))
786 .await
787 .unwrap();
788
789 drop(bytes_tx);
791
792 let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
793
794 assert!(result.is_ok());
795
796 let mut reader = FramedRead::new(server, TwoPartCodec::default());
797
798 let first = recv_msg(&mut reader).await;
799 assert_header_only_message(first, b"header1");
800
801 let second = recv_msg(&mut reader).await;
802 assert_data_only_message(second, b"data1");
803
804 let third = recv_msg(&mut reader).await;
805 assert_header_and_data_message(third, b"header2", b"data2");
806
807 let sentinel = recv_msg(&mut reader).await;
808 assert_sentinel_message(sentinel);
809 }
810
811 #[tokio::test]
813 async fn test_wait_for_server_shutdown_skips_terminal_context() {
814 for action in [Controller::kill as fn(&Controller), Controller::stop] {
815 let (client, _server) = create_tcp_pair().await;
816 let controller = Arc::new(Controller::default());
817 action(&controller);
818
819 let context: Arc<dyn AsyncEngineContext> = controller;
820 let result = tokio::time::timeout(
821 std::time::Duration::from_millis(50),
822 wait_for_server_shutdown(client, context),
823 )
824 .await;
825
826 assert!(result.is_ok(), "terminal context should not wait for FIN");
827 assert!(
828 result.unwrap().is_ok(),
829 "terminal context shutdown should succeed"
830 );
831 }
832 }
833
834 #[tokio::test]
836 async fn test_connection_monitor_skips_fin_wait_after_read_error_kills_context() {
837 let (client, mut server) = create_tcp_pair().await;
838 let (read_half, write_half) = tokio::io::split(client);
839 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
840 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
841 let (_bytes_tx, bytes_rx) = mpsc::channel(64);
842 let (alive_tx, alive_rx) = oneshot::channel::<()>();
843 let controller = Arc::new(Controller::default());
844
845 let reader_context = controller.clone();
846 let reader_task = tokio::spawn(async move {
847 handle_reader(framed_reader, reader_context, alive_tx, None).await
848 });
849 let writer_context = controller.clone();
850 let writer_task = tokio::spawn(async move {
851 handle_writer(framed_writer, bytes_rx, alive_rx, writer_context).await
852 });
853
854 server.write_all(&[0xFF; 24]).await.unwrap();
858
859 let monitor_context: Arc<dyn AsyncEngineContext> = controller.clone();
860 let result = tokio::time::timeout(
861 std::time::Duration::from_millis(250),
862 wait_for_connection_tasks(
863 reader_task,
864 writer_task,
865 monitor_context,
866 None,
867 "test-subject".to_string(),
868 ),
869 )
870 .await;
871
872 assert!(
873 result.is_ok(),
874 "connection monitor should not wait for the FIN deadline after read error"
875 );
876 assert!(result.unwrap().is_ok(), "connection monitor should succeed");
877 assert!(
878 controller.is_killed(),
879 "read error should kill the stream context"
880 );
881 }
882
883 struct ReaderHarness {
886 framed_server: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
887 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
888 alive_tx: oneshot::Sender<()>,
889 alive_rx: oneshot::Receiver<()>,
890 controller: Arc<Controller>,
891 }
892
893 async fn reader_harness() -> ReaderHarness {
895 let (client, server) = create_tcp_pair().await;
896 let (read_half, _write_half) = tokio::io::split(client);
897 let (_server_read, server_write) = tokio::io::split(server);
898
899 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
900 let framed_server = FramedWrite::new(server_write, TwoPartCodec::default());
901 let (alive_tx, alive_rx) = oneshot::channel::<()>();
902 let controller = Arc::new(Controller::default());
903
904 ReaderHarness {
905 framed_server,
906 framed_reader,
907 alive_tx,
908 alive_rx,
909 controller,
910 }
911 }
912
913 fn control_message(msg: &ControlMessage) -> TwoPartMessage {
914 let msg_bytes = serde_json::to_vec(msg).unwrap();
915 TwoPartMessage::from_header(Bytes::from(msg_bytes))
916 }
917
918 #[tokio::test]
920 async fn test_handle_reader_stop_control_message() {
921 let ReaderHarness {
922 mut framed_server,
923 framed_reader,
924 alive_tx,
925 alive_rx: _alive_rx,
926 controller,
927 } = reader_harness().await;
928
929 let controller_clone = controller.clone();
931 let reader_handle = tokio::spawn(async move {
932 handle_reader(framed_reader, controller_clone, alive_tx, None).await
933 });
934
935 framed_server
937 .send(control_message(&ControlMessage::Stop))
938 .await
939 .unwrap();
940
941 framed_server.close().await.unwrap();
943
944 let _ = reader_handle.await.unwrap();
946
947 assert!(
949 controller.is_stopped(),
950 "Controller should be stopped after receiving Stop message"
951 );
952 }
953
954 #[tokio::test]
956 async fn test_handle_reader_kill_control_message() {
957 let ReaderHarness {
958 mut framed_server,
959 framed_reader,
960 alive_tx,
961 alive_rx: _alive_rx,
962 controller,
963 } = reader_harness().await;
964
965 let controller_clone = controller.clone();
967 let reader_handle = tokio::spawn(async move {
968 handle_reader(framed_reader, controller_clone, alive_tx, None).await
969 });
970
971 framed_server
973 .send(control_message(&ControlMessage::Kill))
974 .await
975 .unwrap();
976
977 framed_server.close().await.unwrap();
979
980 let _ = reader_handle.await.unwrap();
982
983 assert!(
985 controller.is_killed(),
986 "Controller should be killed after receiving Kill message"
987 );
988 }
989
990 #[tokio::test]
992 async fn test_handle_reader_exits_on_alive_channel_closed() {
993 let ReaderHarness {
994 framed_reader,
995 alive_tx,
996 alive_rx,
997 controller,
998 ..
999 } = reader_harness().await;
1000
1001 let reader_handle =
1003 tokio::spawn(
1004 async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1005 );
1006
1007 drop(alive_rx);
1009
1010 let result = reader_handle.await;
1012
1013 assert!(
1014 result.is_ok(),
1015 "handle_reader should exit when alive channel is closed"
1016 );
1017 }
1018
1019 #[tokio::test]
1021 async fn test_handle_reader_exits_on_stream_closed() {
1022 let ReaderHarness {
1023 mut framed_server,
1024 framed_reader,
1025 alive_tx,
1026 alive_rx: _alive_rx,
1027 controller,
1028 } = reader_harness().await;
1029
1030 let reader_handle =
1032 tokio::spawn(
1033 async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1034 );
1035
1036 framed_server.close().await.unwrap();
1038
1039 let result = tokio::time::timeout(std::time::Duration::from_secs(1), reader_handle).await;
1041
1042 assert!(
1043 result.is_ok(),
1044 "handle_reader should exit when stream is closed"
1045 );
1046 }
1047
1048 #[tokio::test]
1050 async fn test_handle_reader_multiple_control_messages() {
1051 let ReaderHarness {
1052 mut framed_server,
1053 framed_reader,
1054 alive_tx,
1055 alive_rx: _alive_rx,
1056 controller,
1057 } = reader_harness().await;
1058
1059 let controller_clone = controller.clone();
1061 let reader_handle = tokio::spawn(async move {
1062 handle_reader(framed_reader, controller_clone, alive_tx, None).await
1063 });
1064
1065 framed_server
1067 .send(control_message(&ControlMessage::Stop))
1068 .await
1069 .unwrap();
1070 framed_server
1071 .send(control_message(&ControlMessage::Stop))
1072 .await
1073 .unwrap();
1074
1075 framed_server.close().await.unwrap();
1077
1078 let _ = reader_handle.await.unwrap();
1080
1081 assert!(
1083 controller.is_stopped(),
1084 "Controller should be stopped after receiving Stop messages"
1085 );
1086 }
1087
1088 #[tokio::test]
1090 async fn test_handle_reader_stop_then_kill() {
1091 let ReaderHarness {
1092 mut framed_server,
1093 framed_reader,
1094 alive_tx,
1095 alive_rx: _alive_rx,
1096 controller,
1097 } = reader_harness().await;
1098
1099 let controller_clone = controller.clone();
1101 let reader_handle = tokio::spawn(async move {
1102 handle_reader(framed_reader, controller_clone, alive_tx, None).await
1103 });
1104
1105 framed_server
1107 .send(control_message(&ControlMessage::Stop))
1108 .await
1109 .unwrap();
1110 framed_server
1111 .send(control_message(&ControlMessage::Kill))
1112 .await
1113 .unwrap();
1114
1115 framed_server.close().await.unwrap();
1117
1118 let _ = reader_handle.await.unwrap();
1120
1121 assert!(
1123 controller.is_killed(),
1124 "Controller should be killed after receiving Kill message"
1125 );
1126 }
1127
1128 #[tokio::test]
1130 async fn test_handle_reader_increments_cancellation_counter_on_read_error() {
1131 let ReaderHarness {
1132 framed_server,
1133 framed_reader,
1134 alive_tx,
1135 alive_rx: _alive_rx,
1136 controller,
1137 } = reader_harness().await;
1138 let cancellation_counter = IntCounter::new(
1139 "tcp_client_reader_read_error_cancellations_test",
1140 "test cancellation counter",
1141 )
1142 .unwrap();
1143
1144 let counter_clone = cancellation_counter.clone();
1145 let controller_clone = controller.clone();
1146 let reader_handle = tokio::spawn(async move {
1147 handle_reader(
1148 framed_reader,
1149 controller_clone,
1150 alive_tx,
1151 Some(counter_clone),
1152 )
1153 .await
1154 });
1155
1156 let mut raw_writer = framed_server.into_inner();
1157 raw_writer.write_all(&[0u8; 8]).await.unwrap();
1158 raw_writer.shutdown().await.unwrap();
1159
1160 let _ = reader_handle.await.unwrap();
1161
1162 assert!(
1163 controller.is_killed(),
1164 "Controller should be killed after TCP stream read error"
1165 );
1166 assert_eq!(
1167 cancellation_counter.get(),
1168 1,
1169 "read-error close should increment cancellation metric once"
1170 );
1171 }
1172
1173 async fn run_reader_with(
1176 msg: TwoPartMessage,
1177 counter_name: &str,
1178 ) -> (Arc<Controller>, IntCounter) {
1179 let ReaderHarness {
1180 mut framed_server,
1181 framed_reader,
1182 alive_tx,
1183 alive_rx: _alive_rx,
1184 controller,
1185 } = reader_harness().await;
1186 let counter = IntCounter::new(counter_name, "test counter").unwrap();
1187
1188 let counter_clone = counter.clone();
1189 let controller_clone = controller.clone();
1190 let reader_handle = tokio::spawn(async move {
1191 handle_reader(
1192 framed_reader,
1193 controller_clone,
1194 alive_tx,
1195 Some(counter_clone),
1196 )
1197 .await
1198 });
1199
1200 framed_server.send(msg).await.unwrap();
1201 let _ = reader_handle.await.unwrap();
1202
1203 (controller, counter)
1204 }
1205
1206 #[tokio::test]
1212 async fn test_handle_reader_kills_on_protocol_violations() {
1213 let cases: Vec<(&str, TwoPartMessage)> = vec![
1214 (
1215 "invalid control bytes",
1216 TwoPartMessage::from_header(Bytes::from_static(b"not a valid control message")),
1217 ),
1218 (
1219 "sentinel from server",
1220 control_message(&ControlMessage::Sentinel),
1221 ),
1222 (
1223 "non-control (data-only)",
1224 TwoPartMessage::from_data(Bytes::from_static(b"unexpected payload")),
1225 ),
1226 ];
1227
1228 for (i, (label, msg)) in cases.into_iter().enumerate() {
1229 let counter_name = format!("tcp_client_reader_protocol_violation_test_{i}");
1230 let (controller, counter) = run_reader_with(msg, &counter_name).await;
1231 assert!(
1232 controller.is_killed(),
1233 "{label}: should kill stream context"
1234 );
1235 assert_eq!(counter.get(), 1, "{label}: should be counted once");
1236 }
1237 }
1238}