Skip to main content

dynamo_runtime/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9    io::AsyncWriteExt,
10    net::TcpStream,
11    time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use prometheus::IntCounter;
16
17use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
18use crate::engine::AsyncEngineContext;
19use crate::pipeline::network::{
20    ConnectionInfo, ResponseStreamPrologue, StreamSender,
21    codec::{TwoPartCodec, TwoPartMessage},
22    tcp::StreamType,
23};
24use anyhow::{Context, Result, anyhow as error}; // Import SinkExt to use the `send` method
25
26#[allow(dead_code)]
27pub struct TcpClient {
28    worker_id: String,
29}
30
31impl Default for TcpClient {
32    fn default() -> Self {
33        TcpClient {
34            worker_id: uuid::Uuid::new_v4().to_string(),
35        }
36    }
37}
38
39impl TcpClient {
40    pub fn new(worker_id: String) -> Self {
41        TcpClient { worker_id }
42    }
43
44    async fn connect(address: &str) -> std::io::Result<TcpStream> {
45        // try to connect to the address; retry with linear backoff if AddrNotAvailable
46        let backoff = std::time::Duration::from_millis(200);
47        loop {
48            match TcpStream::connect(address).await {
49                Ok(socket) => {
50                    socket.set_nodelay(true)?;
51                    return Ok(socket);
52                }
53                Err(e) => {
54                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
55                        tracing::warn!("retry warning: failed to connect: {:?}", e);
56                        tokio::time::sleep(backoff).await;
57                    } else {
58                        return Err(e);
59                    }
60                }
61            }
62        }
63    }
64
65    pub async fn create_response_stream(
66        context: Arc<dyn AsyncEngineContext>,
67        info: ConnectionInfo,
68        cancellation_counter: Option<IntCounter>,
69    ) -> Result<StreamSender> {
70        let info =
71            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
72        tracing::trace!("Creating response stream for {:?}", info);
73
74        if info.stream_type != StreamType::Response {
75            return Err(error!(
76                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
77                info.stream_type
78            ));
79        }
80
81        if info.context != context.id() {
82            return Err(error!(
83                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
84                context.id(),
85                info.context
86            ));
87        }
88
89        let stream = TcpClient::connect(&info.address).await?;
90        let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
91        let (read_half, write_half) = tokio::io::split(stream);
92
93        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
94        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
95
96        // this is a oneshot channel that will be used to signal when the stream is closed
97        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
98        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
99        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
100        // captured by the monitor task
101        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
102
103        let reader_task = tokio::spawn(handle_reader(
104            framed_reader,
105            context.clone(),
106            alive_tx,
107            cancellation_counter,
108        ));
109
110        // transport specific handshake message
111        let handshake = CallHomeHandshake {
112            subject: info.subject.clone(),
113            stream_type: StreamType::Response,
114        };
115
116        let handshake_bytes = match serde_json::to_vec(&handshake) {
117            Ok(hb) => hb,
118            Err(err) => {
119                return Err(error!(
120                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
121                ));
122            }
123        };
124        let msg = TwoPartMessage::from_header(handshake_bytes.into());
125
126        // issue the the first tcp handshake message
127        framed_writer
128            .send(msg)
129            .await
130            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
131
132        // set up the channel to send bytes to the transport layer
133        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
134
135        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
136        let writer_context = context.clone();
137        let writer_task = tokio::spawn(handle_writer(
138            framed_writer,
139            bytes_rx,
140            alive_rx,
141            writer_context,
142        ));
143
144        let subject = info.subject.clone();
145        let monitor_context = context;
146        // Spawn the connection monitor; errors are already logged inside
147        // wait_for_connection_tasks, so the Result is intentionally dropped.
148        tokio::spawn(async move {
149            let _ = wait_for_connection_tasks(
150                reader_task,
151                writer_task,
152                monitor_context,
153                peer_port,
154                subject,
155            )
156            .await;
157        });
158
159        // set up the prologue for the stream
160        // this might have transport specific metadata in the future
161        let prologue = Some(ResponseStreamPrologue { error: None });
162
163        // create the stream sender
164        let stream_sender = StreamSender {
165            tx: bytes_tx,
166            prologue,
167        };
168
169        Ok(stream_sender)
170    }
171}
172
173async fn wait_for_connection_tasks(
174    reader_task: tokio::task::JoinHandle<FramedRead<ReadHalf<TcpStream>, TwoPartCodec>>,
175    writer_task: tokio::task::JoinHandle<Result<FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>>>,
176    context: Arc<dyn AsyncEngineContext>,
177    peer_port: Option<u16>,
178    subject: String,
179) -> Result<()> {
180    // Await the reader first and abort the writer on reader Err — the
181    // writer parks on `bytes_rx.recv()` and won't wake on its own.
182    let reader = match reader_task.await {
183        Ok(reader) => reader,
184        Err(reader_err) => {
185            writer_task.abort();
186            let _ = writer_task.await;
187            tracing::error!(
188                subject = %subject,
189                peer_port = ?peer_port,
190                err = ?reader_err,
191                "reader task failed to join"
192            );
193            return Err(reader_err.into());
194        }
195    };
196
197    let writer = match writer_task.await {
198        Ok(writer) => writer,
199        Err(writer_err) => {
200            tracing::error!(
201                subject = %subject,
202                peer_port = ?peer_port,
203                err = ?writer_err,
204                "writer task failed to join"
205            );
206            return Err(writer_err.into());
207        }
208    };
209
210    let reader = reader.into_inner();
211    let writer = match writer {
212        Ok(writer) => writer.into_inner(),
213        Err(e) => {
214            tracing::error!(
215                subject = %subject,
216                peer_port = ?peer_port,
217                err = ?e,
218                "writer task returned error"
219            );
220            return Err(e);
221        }
222    };
223
224    let stream = reader.unsplit(writer);
225    wait_for_server_shutdown(stream, context).await
226}
227
228async fn wait_for_server_shutdown(
229    mut stream: TcpStream,
230    context: Arc<dyn AsyncEngineContext>,
231) -> Result<()> {
232    // `handle_writer` skips the closing sentinel on both `killed` and
233    // `stopped`, so the server has nothing to react to in either case;
234    // sitting in the read loop until the 10 s deadline would be dead time.
235    if context.is_killed() || context.is_stopped() {
236        tracing::debug!("stream context killed or stopped; skipping server FIN wait");
237        return Ok(());
238    }
239
240    // Await the tcp server to shutdown the socket connection, bounded by a
241    // timeout so normal sentinel shutdown cannot hang indefinitely.
242    let mut buf = [0u8; 1024];
243    let deadline = Instant::now() + Duration::from_secs(10);
244    loop {
245        let n = time::timeout_at(deadline, stream.read(&mut buf))
246            .await
247            .inspect_err(|_| {
248                tracing::debug!("server did not close socket within the deadline");
249            })?
250            .inspect_err(|e| {
251                tracing::debug!(err = ?e, "failed to read from stream");
252            })?;
253        if n == 0 {
254            // Server has closed (FIN)
255            break;
256        }
257    }
258
259    Ok(())
260}
261
262async fn handle_reader(
263    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
264    context: Arc<dyn AsyncEngineContext>,
265    alive_tx: tokio::sync::oneshot::Sender<()>,
266    cancellation_counter: Option<IntCounter>,
267) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
268    let mut framed_reader = framed_reader;
269    let mut alive_tx = alive_tx;
270    // Set on every cancellation arm; counted once after the loop.
271    let mut cancellation_seen = false;
272    loop {
273        tokio::select! {
274            msg = framed_reader.next() => {
275                match msg {
276                    Some(Ok(two_part_msg)) => {
277                        match two_part_msg.optional_parts() {
278                           (Some(bytes), None) => {
279                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
280                                    Ok(msg) => msg,
281                                    Err(e) => {
282                                        tracing::warn!(
283                                            err = ?e,
284                                            "invalid control message, closing connection"
285                                        );
286                                        cancellation_seen = true;
287                                        context.kill();
288                                        break;
289                                    }
290                                };
291
292                                // Stop/Kill intentionally do not `break`: the
293                                // reader keeps running so a later Kill can
294                                // upgrade an earlier Stop (and vice versa).
295                                // The loop still exits promptly via the
296                                // `alive_tx.closed()` arm once `handle_writer`
297                                // reacts to `context.stop()` / `context.kill()`.
298                                match msg {
299                                    ControlMessage::Stop => {
300                                        cancellation_seen = true;
301                                        context.stop();
302                                    }
303                                    ControlMessage::Kill => {
304                                        cancellation_seen = true;
305                                        context.kill();
306                                    }
307                                    ControlMessage::Sentinel => {
308                                        tracing::warn!(
309                                            "unexpected sentinel on client reader, closing connection"
310                                        );
311                                        cancellation_seen = true;
312                                        context.kill();
313                                        break;
314                                    }
315                                }
316                           }
317                           _ => {
318                                tracing::warn!(
319                                    "unexpected non-control message on client reader, closing connection"
320                                );
321                                cancellation_seen = true;
322                                context.kill();
323                                break;
324                           }
325                        }
326                    }
327                    Some(Err(e)) => {
328                        // Kill the engine context so the producer stops
329                        // generating responses that can no longer be delivered.
330                        tracing::warn!(err = ?e, "tcp stream read error, closing connection");
331                        cancellation_seen = true;
332                        context.kill();
333                        break;
334                    }
335                    None => {
336                        tracing::debug!("tcp stream closed by server");
337                        cancellation_seen = true;
338                        break;
339                    }
340                }
341            }
342            _ = alive_tx.closed() => {
343                break;
344            }
345        }
346    }
347    if cancellation_seen && let Some(counter) = &cancellation_counter {
348        counter.inc();
349    }
350    framed_reader
351}
352
353async fn handle_writer(
354    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
355    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
356    alive_rx: tokio::sync::oneshot::Receiver<()>,
357    context: Arc<dyn AsyncEngineContext>,
358) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
359    // Only send sentinel for normal channel closure
360    let mut send_sentinel = true;
361
362    loop {
363        let msg = tokio::select! {
364            biased;
365
366            _ = context.killed() => {
367                tracing::trace!("context kill signal received; shutting down");
368                send_sentinel = false;
369                break;
370            }
371
372            _ = context.stopped() => {
373                tracing::trace!("context stop signal received; shutting down");
374                send_sentinel = false;
375                break;
376            }
377
378            msg = bytes_rx.recv() => {
379                match msg {
380                    Some(msg) => msg,
381                    None => {
382                        tracing::trace!("response channel closed; shutting down");
383                        break;
384                    }
385                }
386            }
387        };
388
389        if let Err(e) = framed_writer.send(msg).await {
390            tracing::trace!(
391                "failed to send message to network; possible disconnect: {:?}",
392                e
393            );
394            send_sentinel = false;
395            break;
396        }
397    }
398
399    // Send sentinel only on normal closure
400    if send_sentinel {
401        let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
402        let msg = TwoPartMessage::from_header(message.into());
403        framed_writer.send(msg).await?;
404    }
405
406    drop(alive_rx);
407    Ok(framed_writer)
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::pipeline::context::Controller;
414    use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
415    use bytes::Bytes;
416    use futures::StreamExt;
417    use std::sync::Arc;
418    use tokio::io::{AsyncReadExt, AsyncWriteExt};
419    use tokio::net::TcpStream;
420    use tokio::sync::{mpsc, oneshot};
421    use tokio_util::codec::FramedRead;
422
423    struct WriterHarness {
424        server: tokio::net::TcpStream,
425        framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
426        bytes_tx: mpsc::Sender<TwoPartMessage>,
427        bytes_rx: mpsc::Receiver<TwoPartMessage>,
428        alive_tx: oneshot::Sender<()>,
429        alive_rx: oneshot::Receiver<()>,
430        controller: Arc<Controller>,
431    }
432
433    /// Creates a reusable writer harness with paired TCP streams and test channels.
434    async fn writer_harness() -> WriterHarness {
435        let (client, server) = create_tcp_pair().await;
436        let (_, write_half) = tokio::io::split(client);
437        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
438
439        let (bytes_tx, bytes_rx) = mpsc::channel(64);
440        let (alive_tx, alive_rx) = oneshot::channel::<()>();
441        let controller = Arc::new(Controller::default());
442
443        WriterHarness {
444            server,
445            framed_writer,
446            bytes_tx,
447            bytes_rx,
448            alive_tx,
449            alive_rx,
450            controller,
451        }
452    }
453
454    async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
455        reader
456            .next()
457            .await
458            .expect("expected message")
459            .expect("failed to decode message")
460    }
461
462    fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
463        let (header, data) = msg.optional_parts();
464        assert!(header.is_none(), "data-only message should not have header");
465        assert_eq!(
466            data.expect("data payload missing").as_ref(),
467            expected,
468            "data payload should match"
469        );
470    }
471
472    fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
473        let (header, data) = msg.optional_parts();
474        assert!(data.is_none(), "header-only message should not carry data");
475        assert_eq!(
476            header.expect("header missing").as_ref(),
477            expected,
478            "header payload should match"
479        );
480    }
481
482    fn assert_header_and_data_message(
483        msg: TwoPartMessage,
484        expected_header: &[u8],
485        expected_data: &[u8],
486    ) {
487        let (header, data) = msg.optional_parts();
488        assert_eq!(
489            header.expect("header missing").as_ref(),
490            expected_header,
491            "header payload should match"
492        );
493        assert_eq!(
494            data.expect("data missing").as_ref(),
495            expected_data,
496            "data payload should match"
497        );
498    }
499
500    fn assert_sentinel_message(msg: TwoPartMessage) {
501        let (header, data) = msg.optional_parts();
502        assert!(data.is_none(), "sentinel should not include a data section");
503        let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
504        assert_eq!(
505            header.expect("sentinel header missing").as_ref(),
506            expected_sentinel.as_slice(),
507            "sentinel header should match serialized ControlMessage::Sentinel"
508        );
509    }
510
511    /// Test that handle_writer forwards messages from the channel to the framed writer
512    #[tokio::test]
513    async fn test_handle_writer_forwards_messages() {
514        let WriterHarness {
515            server,
516            framed_writer,
517            bytes_tx,
518            bytes_rx,
519            alive_rx,
520            controller,
521            ..
522        } = writer_harness().await;
523
524        // Send test messages
525        let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
526        bytes_tx.send(test_msg).await.unwrap();
527
528        // Close the sender to trigger normal termination
529        drop(bytes_tx);
530
531        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
532
533        assert!(result.is_ok());
534
535        // Decode from server side to verify data and sentinel were sent
536        let mut reader = FramedRead::new(server, TwoPartCodec::default());
537
538        let msg = recv_msg(&mut reader).await;
539        assert_data_only_message(msg, b"test data");
540
541        let sentinel = recv_msg(&mut reader).await;
542        assert_sentinel_message(sentinel);
543    }
544
545    /// Test that handle_writer sends sentinel on normal channel closure
546    #[tokio::test]
547    async fn test_handle_writer_sends_sentinel_on_normal_closure() {
548        let WriterHarness {
549            mut server,
550            framed_writer,
551            bytes_tx,
552            bytes_rx,
553            alive_rx,
554            controller,
555            ..
556        } = writer_harness().await;
557
558        // Close the sender immediately to trigger normal termination
559        drop(bytes_tx);
560
561        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
562
563        assert!(result.is_ok());
564
565        // Read from server side to verify sentinel was sent
566        let mut buffer = vec![0u8; 1024];
567        let n = server.read(&mut buffer).await.unwrap();
568
569        // Buffer should contain the sentinel message
570        assert!(n > 0, "Expected sentinel to be written to the TCP stream");
571
572        // Verify it contains the sentinel message by checking for the JSON
573        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
574        assert!(
575            buffer[..n]
576                .windows(sentinel_json.len())
577                .any(|w| w == sentinel_json.as_slice()),
578            "Buffer should contain sentinel message. Buffer: {:?}",
579            String::from_utf8_lossy(&buffer[..n])
580        );
581    }
582
583    /// Test that handle_writer does NOT send sentinel when context is killed
584    #[tokio::test]
585    async fn test_handle_writer_no_sentinel_on_context_killed() {
586        let WriterHarness {
587            mut server,
588            framed_writer,
589            bytes_rx,
590            alive_rx,
591            controller,
592            ..
593        } = writer_harness().await;
594
595        // Kill the context
596        controller.kill();
597
598        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
599
600        assert!(result.is_ok());
601
602        // Drop the writer to close the connection, then try to read. Otherwise,
603        // the test will hang on `server.read()`
604        drop(result);
605
606        // Read from server side - should get no sentinel
607        let mut buffer = vec![0u8; 1024];
608        let n = server.read(&mut buffer).await.unwrap();
609
610        // Buffer should be empty (no sentinel sent)
611        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
612        assert!(
613            n == 0
614                || !buffer[..n]
615                    .windows(sentinel_json.len())
616                    .any(|w| w == sentinel_json.as_slice()),
617            "Buffer should NOT contain sentinel message when context is killed"
618        );
619    }
620
621    /// Test that handle_writer does NOT send sentinel when context is stopped
622    #[tokio::test]
623    async fn test_handle_writer_no_sentinel_on_context_stopped() {
624        let WriterHarness {
625            mut server,
626            framed_writer,
627            bytes_rx,
628            alive_rx,
629            controller,
630            ..
631        } = writer_harness().await;
632
633        // Stop the context
634        controller.stop();
635
636        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
637
638        assert!(result.is_ok());
639
640        // Drop the writer to close the connection, then try to read. Otherwise,
641        // the test will hang on `server.read()`
642        drop(result);
643
644        // Read from server side - should get no sentinel
645        let mut buffer = vec![0u8; 1024];
646        let n = server.read(&mut buffer).await.unwrap();
647
648        // Buffer should be empty (no sentinel sent)
649        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
650        assert!(
651            n == 0
652                || !buffer[..n]
653                    .windows(sentinel_json.len())
654                    .any(|w| w == sentinel_json.as_slice()),
655            "Buffer should NOT contain sentinel message when context is stopped"
656        );
657    }
658
659    /// Test that handle_writer handles multiple messages correctly
660    #[tokio::test]
661    async fn test_handle_writer_multiple_messages() {
662        let WriterHarness {
663            server,
664            framed_writer,
665            bytes_tx,
666            bytes_rx,
667            alive_rx,
668            controller,
669            ..
670        } = writer_harness().await;
671
672        // Send multiple messages
673        for i in 0..5 {
674            let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
675            bytes_tx.send(test_msg).await.unwrap();
676        }
677
678        // Close the sender to trigger normal termination
679        drop(bytes_tx);
680
681        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
682
683        assert!(result.is_ok());
684
685        // Decode from server side to verify all messages plus sentinel
686        let mut reader = FramedRead::new(server, TwoPartCodec::default());
687        for i in 0..5 {
688            let msg = recv_msg(&mut reader).await;
689            assert_data_only_message(msg, format!("message {}", i).as_bytes());
690        }
691
692        let sentinel = recv_msg(&mut reader).await;
693        assert_sentinel_message(sentinel);
694    }
695
696    /// Test that alive_rx is dropped after handle_writer completes
697    #[tokio::test]
698    async fn test_handle_writer_drops_alive_rx() {
699        let WriterHarness {
700            framed_writer,
701            bytes_tx,
702            bytes_rx,
703            alive_tx,
704            alive_rx,
705            controller,
706            ..
707        } = writer_harness().await;
708
709        // Close the sender to trigger normal termination
710        drop(bytes_tx);
711
712        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
713
714        assert!(result.is_ok());
715
716        // alive_tx should now be closed because alive_rx was dropped
717        assert!(alive_tx.is_closed());
718    }
719
720    /// Test handle_writer with header-only messages (control messages)
721    #[tokio::test]
722    async fn test_handle_writer_header_only_messages() {
723        let WriterHarness {
724            server,
725            framed_writer,
726            bytes_tx,
727            bytes_rx,
728            alive_rx,
729            controller,
730            ..
731        } = writer_harness().await;
732
733        // Send a header-only message
734        let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
735        bytes_tx.send(header_msg).await.unwrap();
736
737        // Close the sender
738        drop(bytes_tx);
739
740        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
741
742        assert!(result.is_ok());
743
744        let mut reader = FramedRead::new(server, TwoPartCodec::default());
745
746        let header_msg = recv_msg(&mut reader).await;
747        assert_header_only_message(header_msg, b"header content");
748
749        let sentinel = recv_msg(&mut reader).await;
750        assert_sentinel_message(sentinel);
751    }
752
753    /// Test handle_writer with mixed header and data messages
754    #[tokio::test]
755    async fn test_handle_writer_mixed_messages() {
756        let WriterHarness {
757            server,
758            framed_writer,
759            bytes_tx,
760            bytes_rx,
761            alive_rx,
762            controller,
763            ..
764        } = writer_harness().await;
765
766        // Send mixed messages
767        bytes_tx
768            .send(TwoPartMessage::from_header(Bytes::from("header1")))
769            .await
770            .unwrap();
771        bytes_tx
772            .send(TwoPartMessage::from_data(Bytes::from("data1")))
773            .await
774            .unwrap();
775        bytes_tx
776            .send(TwoPartMessage::from_parts(
777                Bytes::from("header2"),
778                Bytes::from("data2"),
779            ))
780            .await
781            .unwrap();
782
783        // Close the sender
784        drop(bytes_tx);
785
786        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
787
788        assert!(result.is_ok());
789
790        let mut reader = FramedRead::new(server, TwoPartCodec::default());
791
792        let first = recv_msg(&mut reader).await;
793        assert_header_only_message(first, b"header1");
794
795        let second = recv_msg(&mut reader).await;
796        assert_data_only_message(second, b"data1");
797
798        let third = recv_msg(&mut reader).await;
799        assert_header_and_data_message(third, b"header2", b"data2");
800
801        let sentinel = recv_msg(&mut reader).await;
802        assert_sentinel_message(sentinel);
803    }
804
805    /// Killed or stopped contexts skip the server FIN deadline.
806    #[tokio::test]
807    async fn test_wait_for_server_shutdown_skips_terminal_context() {
808        for action in [Controller::kill as fn(&Controller), Controller::stop] {
809            let (client, _server) = create_tcp_pair().await;
810            let controller = Arc::new(Controller::default());
811            action(&controller);
812
813            let context: Arc<dyn AsyncEngineContext> = controller;
814            let result = tokio::time::timeout(
815                std::time::Duration::from_millis(50),
816                wait_for_server_shutdown(client, context),
817            )
818            .await;
819
820            assert!(result.is_ok(), "terminal context should not wait for FIN");
821            assert!(
822                result.unwrap().is_ok(),
823                "terminal context shutdown should succeed"
824            );
825        }
826    }
827
828    /// Read error in the connection monitor kills the context and skips the FIN wait.
829    #[tokio::test]
830    async fn test_connection_monitor_skips_fin_wait_after_read_error_kills_context() {
831        let (client, mut server) = create_tcp_pair().await;
832        let (read_half, write_half) = tokio::io::split(client);
833        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
834        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
835        let (_bytes_tx, bytes_rx) = mpsc::channel(64);
836        let (alive_tx, alive_rx) = oneshot::channel::<()>();
837        let controller = Arc::new(Controller::default());
838
839        let reader_context = controller.clone();
840        let reader_task = tokio::spawn(async move {
841            handle_reader(framed_reader, reader_context, alive_tx, None).await
842        });
843        let writer_context = controller.clone();
844        let writer_task = tokio::spawn(async move {
845            handle_writer(framed_writer, bytes_rx, alive_rx, writer_context).await
846        });
847
848        // Bypass the codec and write a complete but invalid TwoPartCodec
849        // header. This drives the client reader into Some(Err(_)) without
850        // closing the server side of the socket.
851        server.write_all(&[0xFF; 24]).await.unwrap();
852
853        let monitor_context: Arc<dyn AsyncEngineContext> = controller.clone();
854        let result = tokio::time::timeout(
855            std::time::Duration::from_millis(250),
856            wait_for_connection_tasks(
857                reader_task,
858                writer_task,
859                monitor_context,
860                None,
861                "test-subject".to_string(),
862            ),
863        )
864        .await;
865
866        assert!(
867            result.is_ok(),
868            "connection monitor should not wait for the FIN deadline after read error"
869        );
870        assert!(result.unwrap().is_ok(), "connection monitor should succeed");
871        assert!(
872            controller.is_killed(),
873            "read error should kill the stream context"
874        );
875    }
876
877    /// Reader-side panic must abort the writer and return promptly rather than
878    /// hanging on `tokio::join!`. Locks in the fix added with this function's
879    /// sequential-await + writer-abort behavior.
880    ///
881    /// Setup: spawn a reader task that panics immediately (so
882    /// `reader_task.await` yields `Err(JoinError::panic)`), and a writer task
883    /// that parks indefinitely waiting for application bytes (so without the
884    /// abort, `tokio::join!` on the previous implementation would never wake).
885    /// Expect: `wait_for_connection_tasks` returns Err within the timeout.
886    #[tokio::test]
887    async fn test_connection_monitor_aborts_writer_when_reader_panics() {
888        // Reader task that panics immediately. The explicit JoinHandle type
889        // pins the inferred return type to the one wait_for_connection_tasks
890        // expects; `panic!` is type `!`, which coerces to that type.
891        let reader_task: tokio::task::JoinHandle<
892            FramedRead<ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
893        > = tokio::spawn(async {
894            panic!("simulated reader panic to trigger JoinError");
895        });
896
897        // Writer task that would block indefinitely waiting on application
898        // bytes. Under the pre-fix `tokio::join!` implementation, this would
899        // prevent the function from returning when the reader panicked.
900        // After the fix, the abort drives this task to completion promptly.
901        let writer_task: tokio::task::JoinHandle<
902            Result<FramedWrite<WriteHalf<tokio::net::TcpStream>, TwoPartCodec>>,
903        > = tokio::spawn(async {
904            std::future::pending::<()>().await;
905            unreachable!()
906        });
907
908        let controller = Arc::new(Controller::default());
909        let context: Arc<dyn AsyncEngineContext> = controller.clone();
910
911        // 250 ms is generous — the abort + JoinHandle resolution should fire
912        // sub-millisecond. We are checking for "doesn't hang", not "fast".
913        let result = tokio::time::timeout(
914            std::time::Duration::from_millis(250),
915            wait_for_connection_tasks(
916                reader_task,
917                writer_task,
918                context,
919                None,
920                "test-reader-panic".to_string(),
921            ),
922        )
923        .await;
924
925        // Outer timeout must not fire: the abort path must surface the reader
926        // JoinError before the writer would have produced any bytes.
927        assert!(
928            result.is_ok(),
929            "wait_for_connection_tasks must return after reader panic, \
930             not hang waiting on the writer"
931        );
932
933        // The inner result must be Err — the reader's JoinError propagates.
934        assert!(
935            result.unwrap().is_err(),
936            "reader panic should propagate as Err from wait_for_connection_tasks"
937        );
938    }
939
940    // ==================== handle_reader tests ====================
941
942    struct ReaderHarness {
943        framed_server: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
944        framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
945        alive_tx: oneshot::Sender<()>,
946        alive_rx: oneshot::Receiver<()>,
947        controller: Arc<Controller>,
948    }
949
950    /// Creates a reusable reader harness with paired TCP streams and test channels.
951    async fn reader_harness() -> ReaderHarness {
952        let (client, server) = create_tcp_pair().await;
953        let (read_half, _write_half) = tokio::io::split(client);
954        let (_server_read, server_write) = tokio::io::split(server);
955
956        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
957        let framed_server = FramedWrite::new(server_write, TwoPartCodec::default());
958        let (alive_tx, alive_rx) = oneshot::channel::<()>();
959        let controller = Arc::new(Controller::default());
960
961        ReaderHarness {
962            framed_server,
963            framed_reader,
964            alive_tx,
965            alive_rx,
966            controller,
967        }
968    }
969
970    fn control_message(msg: &ControlMessage) -> TwoPartMessage {
971        let msg_bytes = serde_json::to_vec(msg).unwrap();
972        TwoPartMessage::from_header(Bytes::from(msg_bytes))
973    }
974
975    /// Test that handle_reader handles Stop control message by calling context.stop()
976    #[tokio::test]
977    async fn test_handle_reader_stop_control_message() {
978        let ReaderHarness {
979            mut framed_server,
980            framed_reader,
981            alive_tx,
982            alive_rx: _alive_rx,
983            controller,
984        } = reader_harness().await;
985
986        // Spawn the reader task
987        let controller_clone = controller.clone();
988        let reader_handle = tokio::spawn(async move {
989            handle_reader(framed_reader, controller_clone, alive_tx, None).await
990        });
991
992        // Send Stop control message from server
993        framed_server
994            .send(control_message(&ControlMessage::Stop))
995            .await
996            .unwrap();
997
998        // Close the framed server to signal EOF to the client
999        framed_server.close().await.unwrap();
1000
1001        // Wait for reader to finish
1002        let _ = reader_handle.await.unwrap();
1003
1004        // Verify that stop was called on the controller
1005        assert!(
1006            controller.is_stopped(),
1007            "Controller should be stopped after receiving Stop message"
1008        );
1009    }
1010
1011    /// Test that handle_reader handles Kill control message by calling context.kill()
1012    #[tokio::test]
1013    async fn test_handle_reader_kill_control_message() {
1014        let ReaderHarness {
1015            mut framed_server,
1016            framed_reader,
1017            alive_tx,
1018            alive_rx: _alive_rx,
1019            controller,
1020        } = reader_harness().await;
1021
1022        // Spawn the reader task
1023        let controller_clone = controller.clone();
1024        let reader_handle = tokio::spawn(async move {
1025            handle_reader(framed_reader, controller_clone, alive_tx, None).await
1026        });
1027
1028        // Send Kill control message from server
1029        framed_server
1030            .send(control_message(&ControlMessage::Kill))
1031            .await
1032            .unwrap();
1033
1034        // Close the framed server to signal EOF to the client
1035        framed_server.close().await.unwrap();
1036
1037        // Wait for reader to finish
1038        let _ = reader_handle.await.unwrap();
1039
1040        // Verify that kill was called on the controller
1041        assert!(
1042            controller.is_killed(),
1043            "Controller should be killed after receiving Kill message"
1044        );
1045    }
1046
1047    /// Test that handle_reader exits when alive channel is closed
1048    #[tokio::test]
1049    async fn test_handle_reader_exits_on_alive_channel_closed() {
1050        let ReaderHarness {
1051            framed_reader,
1052            alive_tx,
1053            alive_rx,
1054            controller,
1055            ..
1056        } = reader_harness().await;
1057
1058        // Spawn the reader task
1059        let reader_handle =
1060            tokio::spawn(
1061                async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1062            );
1063
1064        // Drop the alive_rx to close the channel (simulating writer finishing)
1065        drop(alive_rx);
1066
1067        // Reader should exit due to alive channel closure
1068        let result = reader_handle.await;
1069
1070        assert!(
1071            result.is_ok(),
1072            "handle_reader should exit when alive channel is closed"
1073        );
1074    }
1075
1076    /// Test that handle_reader exits when TCP stream is closed
1077    #[tokio::test]
1078    async fn test_handle_reader_exits_on_stream_closed() {
1079        let ReaderHarness {
1080            mut framed_server,
1081            framed_reader,
1082            alive_tx,
1083            alive_rx: _alive_rx,
1084            controller,
1085        } = reader_harness().await;
1086
1087        // Spawn the reader task
1088        let reader_handle =
1089            tokio::spawn(
1090                async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1091            );
1092
1093        // Close the framed server to signal EOF to the client
1094        framed_server.close().await.unwrap();
1095
1096        // Reader should exit due to stream closure
1097        let result = tokio::time::timeout(std::time::Duration::from_secs(1), reader_handle).await;
1098
1099        assert!(
1100            result.is_ok(),
1101            "handle_reader should exit when stream is closed"
1102        );
1103    }
1104
1105    /// Test that handle_reader handles multiple control messages in sequence
1106    #[tokio::test]
1107    async fn test_handle_reader_multiple_control_messages() {
1108        let ReaderHarness {
1109            mut framed_server,
1110            framed_reader,
1111            alive_tx,
1112            alive_rx: _alive_rx,
1113            controller,
1114        } = reader_harness().await;
1115
1116        // Spawn the reader task
1117        let controller_clone = controller.clone();
1118        let reader_handle = tokio::spawn(async move {
1119            handle_reader(framed_reader, controller_clone, alive_tx, None).await
1120        });
1121
1122        // Send multiple Stop messages (first one will stop, subsequent ones are no-ops)
1123        framed_server
1124            .send(control_message(&ControlMessage::Stop))
1125            .await
1126            .unwrap();
1127        framed_server
1128            .send(control_message(&ControlMessage::Stop))
1129            .await
1130            .unwrap();
1131
1132        // Close the framed server to signal EOF to the client
1133        framed_server.close().await.unwrap();
1134
1135        // Wait for reader to finish
1136        let _ = reader_handle.await.unwrap();
1137
1138        // Verify that stop was called
1139        assert!(
1140            controller.is_stopped(),
1141            "Controller should be stopped after receiving Stop messages"
1142        );
1143    }
1144
1145    /// Test handle_reader with Stop followed by Kill
1146    #[tokio::test]
1147    async fn test_handle_reader_stop_then_kill() {
1148        let ReaderHarness {
1149            mut framed_server,
1150            framed_reader,
1151            alive_tx,
1152            alive_rx: _alive_rx,
1153            controller,
1154        } = reader_harness().await;
1155
1156        // Spawn the reader task
1157        let controller_clone = controller.clone();
1158        let reader_handle = tokio::spawn(async move {
1159            handle_reader(framed_reader, controller_clone, alive_tx, None).await
1160        });
1161
1162        // Send Stop first, then Kill
1163        framed_server
1164            .send(control_message(&ControlMessage::Stop))
1165            .await
1166            .unwrap();
1167        framed_server
1168            .send(control_message(&ControlMessage::Kill))
1169            .await
1170            .unwrap();
1171
1172        // Close the framed server to signal EOF to the client
1173        framed_server.close().await.unwrap();
1174
1175        // Wait for reader to finish
1176        let _ = reader_handle.await.unwrap();
1177
1178        // Verify that kill was called (which sets killed state)
1179        assert!(
1180            controller.is_killed(),
1181            "Controller should be killed after receiving Kill message"
1182        );
1183    }
1184
1185    /// Read errors kill the context and are counted as cancellations.
1186    #[tokio::test]
1187    async fn test_handle_reader_increments_cancellation_counter_on_read_error() {
1188        let ReaderHarness {
1189            framed_server,
1190            framed_reader,
1191            alive_tx,
1192            alive_rx: _alive_rx,
1193            controller,
1194        } = reader_harness().await;
1195        let cancellation_counter = IntCounter::new(
1196            "tcp_client_reader_read_error_cancellations_test",
1197            "test cancellation counter",
1198        )
1199        .unwrap();
1200
1201        let counter_clone = cancellation_counter.clone();
1202        let controller_clone = controller.clone();
1203        let reader_handle = tokio::spawn(async move {
1204            handle_reader(
1205                framed_reader,
1206                controller_clone,
1207                alive_tx,
1208                Some(counter_clone),
1209            )
1210            .await
1211        });
1212
1213        let mut raw_writer = framed_server.into_inner();
1214        raw_writer.write_all(&[0u8; 8]).await.unwrap();
1215        raw_writer.shutdown().await.unwrap();
1216
1217        let _ = reader_handle.await.unwrap();
1218
1219        assert!(
1220            controller.is_killed(),
1221            "Controller should be killed after TCP stream read error"
1222        );
1223        assert_eq!(
1224            cancellation_counter.get(),
1225            1,
1226            "read-error close should increment cancellation metric once"
1227        );
1228    }
1229
1230    /// Drives `handle_reader` against a single message and returns the
1231    /// controller + cancellation counter for assertions.
1232    async fn run_reader_with(
1233        msg: TwoPartMessage,
1234        counter_name: &str,
1235    ) -> (Arc<Controller>, IntCounter) {
1236        let ReaderHarness {
1237            mut framed_server,
1238            framed_reader,
1239            alive_tx,
1240            alive_rx: _alive_rx,
1241            controller,
1242        } = reader_harness().await;
1243        let counter = IntCounter::new(counter_name, "test counter").unwrap();
1244
1245        let counter_clone = counter.clone();
1246        let controller_clone = controller.clone();
1247        let reader_handle = tokio::spawn(async move {
1248            handle_reader(
1249                framed_reader,
1250                controller_clone,
1251                alive_tx,
1252                Some(counter_clone),
1253            )
1254            .await
1255        });
1256
1257        framed_server.send(msg).await.unwrap();
1258        let _ = reader_handle.await.unwrap();
1259
1260        (controller, counter)
1261    }
1262
1263    /// Each protocol-violating message variant must kill only this stream
1264    /// (controller killed, cancellation counted once) and never panic the
1265    /// worker. Covers the three non-read-error panic arms in `handle_reader`:
1266    /// undecodable control bytes, server-sent Sentinel, and non-control
1267    /// (data-only) messages.
1268    #[tokio::test]
1269    async fn test_handle_reader_kills_on_protocol_violations() {
1270        let cases: Vec<(&str, TwoPartMessage)> = vec![
1271            (
1272                "invalid control bytes",
1273                TwoPartMessage::from_header(Bytes::from_static(b"not a valid control message")),
1274            ),
1275            (
1276                "sentinel from server",
1277                control_message(&ControlMessage::Sentinel),
1278            ),
1279            (
1280                "non-control (data-only)",
1281                TwoPartMessage::from_data(Bytes::from_static(b"unexpected payload")),
1282            ),
1283        ];
1284
1285        for (i, (label, msg)) in cases.into_iter().enumerate() {
1286            let counter_name = format!("tcp_client_reader_protocol_violation_test_{i}");
1287            let (controller, counter) = run_reader_with(msg, &counter_name).await;
1288            assert!(
1289                controller.is_killed(),
1290                "{label}: should kill stream context"
1291            );
1292            assert_eq!(counter.get(), 1, "{label}: should be counted once");
1293        }
1294    }
1295}