Skip to main content

dynamo_runtime/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9    io::AsyncWriteExt,
10    net::TcpStream,
11    time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use prometheus::IntCounter;
16
17use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
18use crate::engine::AsyncEngineContext;
19use crate::pipeline::network::{
20    ConnectionInfo, ResponseStreamPrologue, StreamSender,
21    codec::{TwoPartCodec, TwoPartMessage},
22    tcp::StreamType,
23};
24use anyhow::{Context, Result, anyhow as error}; // Import SinkExt to use the `send` method
25
26#[allow(dead_code)]
27pub struct TcpClient {
28    worker_id: String,
29}
30
31impl Default for TcpClient {
32    fn default() -> Self {
33        TcpClient {
34            worker_id: uuid::Uuid::new_v4().to_string(),
35        }
36    }
37}
38
39impl TcpClient {
40    pub fn new(worker_id: String) -> Self {
41        TcpClient { worker_id }
42    }
43
44    async fn connect(address: &str) -> std::io::Result<TcpStream> {
45        // try to connect to the address; retry with linear backoff if AddrNotAvailable
46        let backoff = std::time::Duration::from_millis(200);
47        loop {
48            match TcpStream::connect(address).await {
49                Ok(socket) => {
50                    socket.set_nodelay(true)?;
51                    return Ok(socket);
52                }
53                Err(e) => {
54                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
55                        tracing::warn!("retry warning: failed to connect: {:?}", e);
56                        tokio::time::sleep(backoff).await;
57                    } else {
58                        return Err(e);
59                    }
60                }
61            }
62        }
63    }
64
65    pub async fn create_response_stream(
66        context: Arc<dyn AsyncEngineContext>,
67        info: ConnectionInfo,
68        cancellation_counter: Option<IntCounter>,
69    ) -> Result<StreamSender> {
70        let info =
71            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
72        tracing::trace!("Creating response stream for {:?}", info);
73
74        if info.stream_type != StreamType::Response {
75            return Err(error!(
76                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
77                info.stream_type
78            ));
79        }
80
81        if info.context != context.id() {
82            return Err(error!(
83                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
84                context.id(),
85                info.context
86            ));
87        }
88
89        let stream = TcpClient::connect(&info.address).await?;
90        let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
91        let (read_half, write_half) = tokio::io::split(stream);
92
93        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
94        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
95
96        // this is a oneshot channel that will be used to signal when the stream is closed
97        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
98        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
99        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
100        // captured by the monitor task
101        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
102
103        let reader_task = tokio::spawn(handle_reader(
104            framed_reader,
105            context.clone(),
106            alive_tx,
107            cancellation_counter,
108        ));
109
110        // transport specific handshake message
111        let handshake = CallHomeHandshake {
112            subject: info.subject.clone(),
113            stream_type: StreamType::Response,
114        };
115
116        let handshake_bytes = match serde_json::to_vec(&handshake) {
117            Ok(hb) => hb,
118            Err(err) => {
119                return Err(error!(
120                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
121                ));
122            }
123        };
124        let msg = TwoPartMessage::from_header(handshake_bytes.into());
125
126        // issue the the first tcp handshake message
127        framed_writer
128            .send(msg)
129            .await
130            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
131
132        // set up the channel to send bytes to the transport layer
133        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
134
135        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
136        let writer_context = context.clone();
137        let writer_task = tokio::spawn(handle_writer(
138            framed_writer,
139            bytes_rx,
140            alive_rx,
141            writer_context,
142        ));
143
144        let subject = info.subject.clone();
145        let monitor_context = context;
146        // Spawn the connection monitor; errors are already logged inside
147        // wait_for_connection_tasks, so the Result is intentionally dropped.
148        tokio::spawn(async move {
149            let _ = wait_for_connection_tasks(
150                reader_task,
151                writer_task,
152                monitor_context,
153                peer_port,
154                subject,
155            )
156            .await;
157        });
158
159        // set up the prologue for the stream
160        // this might have transport specific metadata in the future
161        let prologue = Some(ResponseStreamPrologue { error: None });
162
163        // create the stream sender
164        let stream_sender = StreamSender {
165            tx: bytes_tx,
166            prologue,
167        };
168
169        Ok(stream_sender)
170    }
171}
172
173async fn wait_for_connection_tasks(
174    reader_task: tokio::task::JoinHandle<FramedRead<ReadHalf<TcpStream>, TwoPartCodec>>,
175    writer_task: tokio::task::JoinHandle<Result<FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>>>,
176    context: Arc<dyn AsyncEngineContext>,
177    peer_port: Option<u16>,
178    subject: String,
179) -> Result<()> {
180    let (reader, writer) = tokio::join!(reader_task, writer_task);
181
182    match (reader, writer) {
183        (Ok(reader), Ok(writer)) => {
184            let reader = reader.into_inner();
185
186            let writer = match writer {
187                Ok(writer) => writer.into_inner(),
188                Err(e) => {
189                    tracing::error!(
190                        subject = %subject,
191                        peer_port = ?peer_port,
192                        err = ?e,
193                        "writer task returned error"
194                    );
195                    return Err(e);
196                }
197            };
198
199            let stream = reader.unsplit(writer);
200            wait_for_server_shutdown(stream, context).await
201        }
202        (Err(reader_err), Ok(_)) => {
203            tracing::error!(
204                subject = %subject,
205                peer_port = ?peer_port,
206                err = ?reader_err,
207                "reader task failed to join"
208            );
209            Err(reader_err.into())
210        }
211        (Ok(_), Err(writer_err)) => {
212            tracing::error!(
213                subject = %subject,
214                peer_port = ?peer_port,
215                err = ?writer_err,
216                "writer task failed to join"
217            );
218            Err(writer_err.into())
219        }
220        (Err(reader_err), Err(writer_err)) => {
221            tracing::error!(
222                subject = %subject,
223                peer_port = ?peer_port,
224                reader_err = ?reader_err,
225                writer_err = ?writer_err,
226                "both reader and writer tasks failed to join"
227            );
228            // Surface the reader error; the writer error is captured above.
229            Err(reader_err.into())
230        }
231    }
232}
233
234async fn wait_for_server_shutdown(
235    mut stream: TcpStream,
236    context: Arc<dyn AsyncEngineContext>,
237) -> Result<()> {
238    // `handle_writer` skips the closing sentinel on both `killed` and
239    // `stopped`, so the server has nothing to react to in either case;
240    // sitting in the read loop until the 10 s deadline would be dead time.
241    if context.is_killed() || context.is_stopped() {
242        tracing::debug!("stream context killed or stopped; skipping server FIN wait");
243        return Ok(());
244    }
245
246    // Await the tcp server to shutdown the socket connection, bounded by a
247    // timeout so normal sentinel shutdown cannot hang indefinitely.
248    let mut buf = [0u8; 1024];
249    let deadline = Instant::now() + Duration::from_secs(10);
250    loop {
251        let n = time::timeout_at(deadline, stream.read(&mut buf))
252            .await
253            .inspect_err(|_| {
254                tracing::debug!("server did not close socket within the deadline");
255            })?
256            .inspect_err(|e| {
257                tracing::debug!(err = ?e, "failed to read from stream");
258            })?;
259        if n == 0 {
260            // Server has closed (FIN)
261            break;
262        }
263    }
264
265    Ok(())
266}
267
268async fn handle_reader(
269    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
270    context: Arc<dyn AsyncEngineContext>,
271    alive_tx: tokio::sync::oneshot::Sender<()>,
272    cancellation_counter: Option<IntCounter>,
273) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
274    let mut framed_reader = framed_reader;
275    let mut alive_tx = alive_tx;
276    // Set on every cancellation arm; counted once after the loop.
277    let mut cancellation_seen = false;
278    loop {
279        tokio::select! {
280            msg = framed_reader.next() => {
281                match msg {
282                    Some(Ok(two_part_msg)) => {
283                        match two_part_msg.optional_parts() {
284                           (Some(bytes), None) => {
285                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
286                                    Ok(msg) => msg,
287                                    Err(e) => {
288                                        tracing::warn!(
289                                            err = ?e,
290                                            "invalid control message, closing connection"
291                                        );
292                                        cancellation_seen = true;
293                                        context.kill();
294                                        break;
295                                    }
296                                };
297
298                                // Stop/Kill intentionally do not `break`: the
299                                // reader keeps running so a later Kill can
300                                // upgrade an earlier Stop (and vice versa).
301                                // The loop still exits promptly via the
302                                // `alive_tx.closed()` arm once `handle_writer`
303                                // reacts to `context.stop()` / `context.kill()`.
304                                match msg {
305                                    ControlMessage::Stop => {
306                                        cancellation_seen = true;
307                                        context.stop();
308                                    }
309                                    ControlMessage::Kill => {
310                                        cancellation_seen = true;
311                                        context.kill();
312                                    }
313                                    ControlMessage::Sentinel => {
314                                        tracing::warn!(
315                                            "unexpected sentinel on client reader, closing connection"
316                                        );
317                                        cancellation_seen = true;
318                                        context.kill();
319                                        break;
320                                    }
321                                }
322                           }
323                           _ => {
324                                tracing::warn!(
325                                    "unexpected non-control message on client reader, closing connection"
326                                );
327                                cancellation_seen = true;
328                                context.kill();
329                                break;
330                           }
331                        }
332                    }
333                    Some(Err(e)) => {
334                        // Kill the engine context so the producer stops
335                        // generating responses that can no longer be delivered.
336                        tracing::warn!(err = ?e, "tcp stream read error, closing connection");
337                        cancellation_seen = true;
338                        context.kill();
339                        break;
340                    }
341                    None => {
342                        tracing::debug!("tcp stream closed by server");
343                        cancellation_seen = true;
344                        break;
345                    }
346                }
347            }
348            _ = alive_tx.closed() => {
349                break;
350            }
351        }
352    }
353    if cancellation_seen && let Some(counter) = &cancellation_counter {
354        counter.inc();
355    }
356    framed_reader
357}
358
359async fn handle_writer(
360    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
361    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
362    alive_rx: tokio::sync::oneshot::Receiver<()>,
363    context: Arc<dyn AsyncEngineContext>,
364) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
365    // Only send sentinel for normal channel closure
366    let mut send_sentinel = true;
367
368    loop {
369        let msg = tokio::select! {
370            biased;
371
372            _ = context.killed() => {
373                tracing::trace!("context kill signal received; shutting down");
374                send_sentinel = false;
375                break;
376            }
377
378            _ = context.stopped() => {
379                tracing::trace!("context stop signal received; shutting down");
380                send_sentinel = false;
381                break;
382            }
383
384            msg = bytes_rx.recv() => {
385                match msg {
386                    Some(msg) => msg,
387                    None => {
388                        tracing::trace!("response channel closed; shutting down");
389                        break;
390                    }
391                }
392            }
393        };
394
395        if let Err(e) = framed_writer.send(msg).await {
396            tracing::trace!(
397                "failed to send message to network; possible disconnect: {:?}",
398                e
399            );
400            send_sentinel = false;
401            break;
402        }
403    }
404
405    // Send sentinel only on normal closure
406    if send_sentinel {
407        let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
408        let msg = TwoPartMessage::from_header(message.into());
409        framed_writer.send(msg).await?;
410    }
411
412    drop(alive_rx);
413    Ok(framed_writer)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::pipeline::context::Controller;
420    use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
421    use bytes::Bytes;
422    use futures::StreamExt;
423    use std::sync::Arc;
424    use tokio::io::{AsyncReadExt, AsyncWriteExt};
425    use tokio::net::TcpStream;
426    use tokio::sync::{mpsc, oneshot};
427    use tokio_util::codec::FramedRead;
428
429    struct WriterHarness {
430        server: tokio::net::TcpStream,
431        framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
432        bytes_tx: mpsc::Sender<TwoPartMessage>,
433        bytes_rx: mpsc::Receiver<TwoPartMessage>,
434        alive_tx: oneshot::Sender<()>,
435        alive_rx: oneshot::Receiver<()>,
436        controller: Arc<Controller>,
437    }
438
439    /// Creates a reusable writer harness with paired TCP streams and test channels.
440    async fn writer_harness() -> WriterHarness {
441        let (client, server) = create_tcp_pair().await;
442        let (_, write_half) = tokio::io::split(client);
443        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
444
445        let (bytes_tx, bytes_rx) = mpsc::channel(64);
446        let (alive_tx, alive_rx) = oneshot::channel::<()>();
447        let controller = Arc::new(Controller::default());
448
449        WriterHarness {
450            server,
451            framed_writer,
452            bytes_tx,
453            bytes_rx,
454            alive_tx,
455            alive_rx,
456            controller,
457        }
458    }
459
460    async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
461        reader
462            .next()
463            .await
464            .expect("expected message")
465            .expect("failed to decode message")
466    }
467
468    fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
469        let (header, data) = msg.optional_parts();
470        assert!(header.is_none(), "data-only message should not have header");
471        assert_eq!(
472            data.expect("data payload missing").as_ref(),
473            expected,
474            "data payload should match"
475        );
476    }
477
478    fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
479        let (header, data) = msg.optional_parts();
480        assert!(data.is_none(), "header-only message should not carry data");
481        assert_eq!(
482            header.expect("header missing").as_ref(),
483            expected,
484            "header payload should match"
485        );
486    }
487
488    fn assert_header_and_data_message(
489        msg: TwoPartMessage,
490        expected_header: &[u8],
491        expected_data: &[u8],
492    ) {
493        let (header, data) = msg.optional_parts();
494        assert_eq!(
495            header.expect("header missing").as_ref(),
496            expected_header,
497            "header payload should match"
498        );
499        assert_eq!(
500            data.expect("data missing").as_ref(),
501            expected_data,
502            "data payload should match"
503        );
504    }
505
506    fn assert_sentinel_message(msg: TwoPartMessage) {
507        let (header, data) = msg.optional_parts();
508        assert!(data.is_none(), "sentinel should not include a data section");
509        let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
510        assert_eq!(
511            header.expect("sentinel header missing").as_ref(),
512            expected_sentinel.as_slice(),
513            "sentinel header should match serialized ControlMessage::Sentinel"
514        );
515    }
516
517    /// Test that handle_writer forwards messages from the channel to the framed writer
518    #[tokio::test]
519    async fn test_handle_writer_forwards_messages() {
520        let WriterHarness {
521            server,
522            framed_writer,
523            bytes_tx,
524            bytes_rx,
525            alive_rx,
526            controller,
527            ..
528        } = writer_harness().await;
529
530        // Send test messages
531        let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
532        bytes_tx.send(test_msg).await.unwrap();
533
534        // Close the sender to trigger normal termination
535        drop(bytes_tx);
536
537        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
538
539        assert!(result.is_ok());
540
541        // Decode from server side to verify data and sentinel were sent
542        let mut reader = FramedRead::new(server, TwoPartCodec::default());
543
544        let msg = recv_msg(&mut reader).await;
545        assert_data_only_message(msg, b"test data");
546
547        let sentinel = recv_msg(&mut reader).await;
548        assert_sentinel_message(sentinel);
549    }
550
551    /// Test that handle_writer sends sentinel on normal channel closure
552    #[tokio::test]
553    async fn test_handle_writer_sends_sentinel_on_normal_closure() {
554        let WriterHarness {
555            mut server,
556            framed_writer,
557            bytes_tx,
558            bytes_rx,
559            alive_rx,
560            controller,
561            ..
562        } = writer_harness().await;
563
564        // Close the sender immediately to trigger normal termination
565        drop(bytes_tx);
566
567        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
568
569        assert!(result.is_ok());
570
571        // Read from server side to verify sentinel was sent
572        let mut buffer = vec![0u8; 1024];
573        let n = server.read(&mut buffer).await.unwrap();
574
575        // Buffer should contain the sentinel message
576        assert!(n > 0, "Expected sentinel to be written to the TCP stream");
577
578        // Verify it contains the sentinel message by checking for the JSON
579        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
580        assert!(
581            buffer[..n]
582                .windows(sentinel_json.len())
583                .any(|w| w == sentinel_json.as_slice()),
584            "Buffer should contain sentinel message. Buffer: {:?}",
585            String::from_utf8_lossy(&buffer[..n])
586        );
587    }
588
589    /// Test that handle_writer does NOT send sentinel when context is killed
590    #[tokio::test]
591    async fn test_handle_writer_no_sentinel_on_context_killed() {
592        let WriterHarness {
593            mut server,
594            framed_writer,
595            bytes_rx,
596            alive_rx,
597            controller,
598            ..
599        } = writer_harness().await;
600
601        // Kill the context
602        controller.kill();
603
604        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
605
606        assert!(result.is_ok());
607
608        // Drop the writer to close the connection, then try to read. Otherwise,
609        // the test will hang on `server.read()`
610        drop(result);
611
612        // Read from server side - should get no sentinel
613        let mut buffer = vec![0u8; 1024];
614        let n = server.read(&mut buffer).await.unwrap();
615
616        // Buffer should be empty (no sentinel sent)
617        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
618        assert!(
619            n == 0
620                || !buffer[..n]
621                    .windows(sentinel_json.len())
622                    .any(|w| w == sentinel_json.as_slice()),
623            "Buffer should NOT contain sentinel message when context is killed"
624        );
625    }
626
627    /// Test that handle_writer does NOT send sentinel when context is stopped
628    #[tokio::test]
629    async fn test_handle_writer_no_sentinel_on_context_stopped() {
630        let WriterHarness {
631            mut server,
632            framed_writer,
633            bytes_rx,
634            alive_rx,
635            controller,
636            ..
637        } = writer_harness().await;
638
639        // Stop the context
640        controller.stop();
641
642        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
643
644        assert!(result.is_ok());
645
646        // Drop the writer to close the connection, then try to read. Otherwise,
647        // the test will hang on `server.read()`
648        drop(result);
649
650        // Read from server side - should get no sentinel
651        let mut buffer = vec![0u8; 1024];
652        let n = server.read(&mut buffer).await.unwrap();
653
654        // Buffer should be empty (no sentinel sent)
655        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
656        assert!(
657            n == 0
658                || !buffer[..n]
659                    .windows(sentinel_json.len())
660                    .any(|w| w == sentinel_json.as_slice()),
661            "Buffer should NOT contain sentinel message when context is stopped"
662        );
663    }
664
665    /// Test that handle_writer handles multiple messages correctly
666    #[tokio::test]
667    async fn test_handle_writer_multiple_messages() {
668        let WriterHarness {
669            server,
670            framed_writer,
671            bytes_tx,
672            bytes_rx,
673            alive_rx,
674            controller,
675            ..
676        } = writer_harness().await;
677
678        // Send multiple messages
679        for i in 0..5 {
680            let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
681            bytes_tx.send(test_msg).await.unwrap();
682        }
683
684        // Close the sender to trigger normal termination
685        drop(bytes_tx);
686
687        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
688
689        assert!(result.is_ok());
690
691        // Decode from server side to verify all messages plus sentinel
692        let mut reader = FramedRead::new(server, TwoPartCodec::default());
693        for i in 0..5 {
694            let msg = recv_msg(&mut reader).await;
695            assert_data_only_message(msg, format!("message {}", i).as_bytes());
696        }
697
698        let sentinel = recv_msg(&mut reader).await;
699        assert_sentinel_message(sentinel);
700    }
701
702    /// Test that alive_rx is dropped after handle_writer completes
703    #[tokio::test]
704    async fn test_handle_writer_drops_alive_rx() {
705        let WriterHarness {
706            framed_writer,
707            bytes_tx,
708            bytes_rx,
709            alive_tx,
710            alive_rx,
711            controller,
712            ..
713        } = writer_harness().await;
714
715        // Close the sender to trigger normal termination
716        drop(bytes_tx);
717
718        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
719
720        assert!(result.is_ok());
721
722        // alive_tx should now be closed because alive_rx was dropped
723        assert!(alive_tx.is_closed());
724    }
725
726    /// Test handle_writer with header-only messages (control messages)
727    #[tokio::test]
728    async fn test_handle_writer_header_only_messages() {
729        let WriterHarness {
730            server,
731            framed_writer,
732            bytes_tx,
733            bytes_rx,
734            alive_rx,
735            controller,
736            ..
737        } = writer_harness().await;
738
739        // Send a header-only message
740        let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
741        bytes_tx.send(header_msg).await.unwrap();
742
743        // Close the sender
744        drop(bytes_tx);
745
746        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
747
748        assert!(result.is_ok());
749
750        let mut reader = FramedRead::new(server, TwoPartCodec::default());
751
752        let header_msg = recv_msg(&mut reader).await;
753        assert_header_only_message(header_msg, b"header content");
754
755        let sentinel = recv_msg(&mut reader).await;
756        assert_sentinel_message(sentinel);
757    }
758
759    /// Test handle_writer with mixed header and data messages
760    #[tokio::test]
761    async fn test_handle_writer_mixed_messages() {
762        let WriterHarness {
763            server,
764            framed_writer,
765            bytes_tx,
766            bytes_rx,
767            alive_rx,
768            controller,
769            ..
770        } = writer_harness().await;
771
772        // Send mixed messages
773        bytes_tx
774            .send(TwoPartMessage::from_header(Bytes::from("header1")))
775            .await
776            .unwrap();
777        bytes_tx
778            .send(TwoPartMessage::from_data(Bytes::from("data1")))
779            .await
780            .unwrap();
781        bytes_tx
782            .send(TwoPartMessage::from_parts(
783                Bytes::from("header2"),
784                Bytes::from("data2"),
785            ))
786            .await
787            .unwrap();
788
789        // Close the sender
790        drop(bytes_tx);
791
792        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
793
794        assert!(result.is_ok());
795
796        let mut reader = FramedRead::new(server, TwoPartCodec::default());
797
798        let first = recv_msg(&mut reader).await;
799        assert_header_only_message(first, b"header1");
800
801        let second = recv_msg(&mut reader).await;
802        assert_data_only_message(second, b"data1");
803
804        let third = recv_msg(&mut reader).await;
805        assert_header_and_data_message(third, b"header2", b"data2");
806
807        let sentinel = recv_msg(&mut reader).await;
808        assert_sentinel_message(sentinel);
809    }
810
811    /// Killed or stopped contexts skip the server FIN deadline.
812    #[tokio::test]
813    async fn test_wait_for_server_shutdown_skips_terminal_context() {
814        for action in [Controller::kill as fn(&Controller), Controller::stop] {
815            let (client, _server) = create_tcp_pair().await;
816            let controller = Arc::new(Controller::default());
817            action(&controller);
818
819            let context: Arc<dyn AsyncEngineContext> = controller;
820            let result = tokio::time::timeout(
821                std::time::Duration::from_millis(50),
822                wait_for_server_shutdown(client, context),
823            )
824            .await;
825
826            assert!(result.is_ok(), "terminal context should not wait for FIN");
827            assert!(
828                result.unwrap().is_ok(),
829                "terminal context shutdown should succeed"
830            );
831        }
832    }
833
834    /// Read error in the connection monitor kills the context and skips the FIN wait.
835    #[tokio::test]
836    async fn test_connection_monitor_skips_fin_wait_after_read_error_kills_context() {
837        let (client, mut server) = create_tcp_pair().await;
838        let (read_half, write_half) = tokio::io::split(client);
839        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
840        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
841        let (_bytes_tx, bytes_rx) = mpsc::channel(64);
842        let (alive_tx, alive_rx) = oneshot::channel::<()>();
843        let controller = Arc::new(Controller::default());
844
845        let reader_context = controller.clone();
846        let reader_task = tokio::spawn(async move {
847            handle_reader(framed_reader, reader_context, alive_tx, None).await
848        });
849        let writer_context = controller.clone();
850        let writer_task = tokio::spawn(async move {
851            handle_writer(framed_writer, bytes_rx, alive_rx, writer_context).await
852        });
853
854        // Bypass the codec and write a complete but invalid TwoPartCodec
855        // header. This drives the client reader into Some(Err(_)) without
856        // closing the server side of the socket.
857        server.write_all(&[0xFF; 24]).await.unwrap();
858
859        let monitor_context: Arc<dyn AsyncEngineContext> = controller.clone();
860        let result = tokio::time::timeout(
861            std::time::Duration::from_millis(250),
862            wait_for_connection_tasks(
863                reader_task,
864                writer_task,
865                monitor_context,
866                None,
867                "test-subject".to_string(),
868            ),
869        )
870        .await;
871
872        assert!(
873            result.is_ok(),
874            "connection monitor should not wait for the FIN deadline after read error"
875        );
876        assert!(result.unwrap().is_ok(), "connection monitor should succeed");
877        assert!(
878            controller.is_killed(),
879            "read error should kill the stream context"
880        );
881    }
882
883    // ==================== handle_reader tests ====================
884
885    struct ReaderHarness {
886        framed_server: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
887        framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
888        alive_tx: oneshot::Sender<()>,
889        alive_rx: oneshot::Receiver<()>,
890        controller: Arc<Controller>,
891    }
892
893    /// Creates a reusable reader harness with paired TCP streams and test channels.
894    async fn reader_harness() -> ReaderHarness {
895        let (client, server) = create_tcp_pair().await;
896        let (read_half, _write_half) = tokio::io::split(client);
897        let (_server_read, server_write) = tokio::io::split(server);
898
899        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
900        let framed_server = FramedWrite::new(server_write, TwoPartCodec::default());
901        let (alive_tx, alive_rx) = oneshot::channel::<()>();
902        let controller = Arc::new(Controller::default());
903
904        ReaderHarness {
905            framed_server,
906            framed_reader,
907            alive_tx,
908            alive_rx,
909            controller,
910        }
911    }
912
913    fn control_message(msg: &ControlMessage) -> TwoPartMessage {
914        let msg_bytes = serde_json::to_vec(msg).unwrap();
915        TwoPartMessage::from_header(Bytes::from(msg_bytes))
916    }
917
918    /// Test that handle_reader handles Stop control message by calling context.stop()
919    #[tokio::test]
920    async fn test_handle_reader_stop_control_message() {
921        let ReaderHarness {
922            mut framed_server,
923            framed_reader,
924            alive_tx,
925            alive_rx: _alive_rx,
926            controller,
927        } = reader_harness().await;
928
929        // Spawn the reader task
930        let controller_clone = controller.clone();
931        let reader_handle = tokio::spawn(async move {
932            handle_reader(framed_reader, controller_clone, alive_tx, None).await
933        });
934
935        // Send Stop control message from server
936        framed_server
937            .send(control_message(&ControlMessage::Stop))
938            .await
939            .unwrap();
940
941        // Close the framed server to signal EOF to the client
942        framed_server.close().await.unwrap();
943
944        // Wait for reader to finish
945        let _ = reader_handle.await.unwrap();
946
947        // Verify that stop was called on the controller
948        assert!(
949            controller.is_stopped(),
950            "Controller should be stopped after receiving Stop message"
951        );
952    }
953
954    /// Test that handle_reader handles Kill control message by calling context.kill()
955    #[tokio::test]
956    async fn test_handle_reader_kill_control_message() {
957        let ReaderHarness {
958            mut framed_server,
959            framed_reader,
960            alive_tx,
961            alive_rx: _alive_rx,
962            controller,
963        } = reader_harness().await;
964
965        // Spawn the reader task
966        let controller_clone = controller.clone();
967        let reader_handle = tokio::spawn(async move {
968            handle_reader(framed_reader, controller_clone, alive_tx, None).await
969        });
970
971        // Send Kill control message from server
972        framed_server
973            .send(control_message(&ControlMessage::Kill))
974            .await
975            .unwrap();
976
977        // Close the framed server to signal EOF to the client
978        framed_server.close().await.unwrap();
979
980        // Wait for reader to finish
981        let _ = reader_handle.await.unwrap();
982
983        // Verify that kill was called on the controller
984        assert!(
985            controller.is_killed(),
986            "Controller should be killed after receiving Kill message"
987        );
988    }
989
990    /// Test that handle_reader exits when alive channel is closed
991    #[tokio::test]
992    async fn test_handle_reader_exits_on_alive_channel_closed() {
993        let ReaderHarness {
994            framed_reader,
995            alive_tx,
996            alive_rx,
997            controller,
998            ..
999        } = reader_harness().await;
1000
1001        // Spawn the reader task
1002        let reader_handle =
1003            tokio::spawn(
1004                async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1005            );
1006
1007        // Drop the alive_rx to close the channel (simulating writer finishing)
1008        drop(alive_rx);
1009
1010        // Reader should exit due to alive channel closure
1011        let result = reader_handle.await;
1012
1013        assert!(
1014            result.is_ok(),
1015            "handle_reader should exit when alive channel is closed"
1016        );
1017    }
1018
1019    /// Test that handle_reader exits when TCP stream is closed
1020    #[tokio::test]
1021    async fn test_handle_reader_exits_on_stream_closed() {
1022        let ReaderHarness {
1023            mut framed_server,
1024            framed_reader,
1025            alive_tx,
1026            alive_rx: _alive_rx,
1027            controller,
1028        } = reader_harness().await;
1029
1030        // Spawn the reader task
1031        let reader_handle =
1032            tokio::spawn(
1033                async move { handle_reader(framed_reader, controller, alive_tx, None).await },
1034            );
1035
1036        // Close the framed server to signal EOF to the client
1037        framed_server.close().await.unwrap();
1038
1039        // Reader should exit due to stream closure
1040        let result = tokio::time::timeout(std::time::Duration::from_secs(1), reader_handle).await;
1041
1042        assert!(
1043            result.is_ok(),
1044            "handle_reader should exit when stream is closed"
1045        );
1046    }
1047
1048    /// Test that handle_reader handles multiple control messages in sequence
1049    #[tokio::test]
1050    async fn test_handle_reader_multiple_control_messages() {
1051        let ReaderHarness {
1052            mut framed_server,
1053            framed_reader,
1054            alive_tx,
1055            alive_rx: _alive_rx,
1056            controller,
1057        } = reader_harness().await;
1058
1059        // Spawn the reader task
1060        let controller_clone = controller.clone();
1061        let reader_handle = tokio::spawn(async move {
1062            handle_reader(framed_reader, controller_clone, alive_tx, None).await
1063        });
1064
1065        // Send multiple Stop messages (first one will stop, subsequent ones are no-ops)
1066        framed_server
1067            .send(control_message(&ControlMessage::Stop))
1068            .await
1069            .unwrap();
1070        framed_server
1071            .send(control_message(&ControlMessage::Stop))
1072            .await
1073            .unwrap();
1074
1075        // Close the framed server to signal EOF to the client
1076        framed_server.close().await.unwrap();
1077
1078        // Wait for reader to finish
1079        let _ = reader_handle.await.unwrap();
1080
1081        // Verify that stop was called
1082        assert!(
1083            controller.is_stopped(),
1084            "Controller should be stopped after receiving Stop messages"
1085        );
1086    }
1087
1088    /// Test handle_reader with Stop followed by Kill
1089    #[tokio::test]
1090    async fn test_handle_reader_stop_then_kill() {
1091        let ReaderHarness {
1092            mut framed_server,
1093            framed_reader,
1094            alive_tx,
1095            alive_rx: _alive_rx,
1096            controller,
1097        } = reader_harness().await;
1098
1099        // Spawn the reader task
1100        let controller_clone = controller.clone();
1101        let reader_handle = tokio::spawn(async move {
1102            handle_reader(framed_reader, controller_clone, alive_tx, None).await
1103        });
1104
1105        // Send Stop first, then Kill
1106        framed_server
1107            .send(control_message(&ControlMessage::Stop))
1108            .await
1109            .unwrap();
1110        framed_server
1111            .send(control_message(&ControlMessage::Kill))
1112            .await
1113            .unwrap();
1114
1115        // Close the framed server to signal EOF to the client
1116        framed_server.close().await.unwrap();
1117
1118        // Wait for reader to finish
1119        let _ = reader_handle.await.unwrap();
1120
1121        // Verify that kill was called (which sets killed state)
1122        assert!(
1123            controller.is_killed(),
1124            "Controller should be killed after receiving Kill message"
1125        );
1126    }
1127
1128    /// Read errors kill the context and are counted as cancellations.
1129    #[tokio::test]
1130    async fn test_handle_reader_increments_cancellation_counter_on_read_error() {
1131        let ReaderHarness {
1132            framed_server,
1133            framed_reader,
1134            alive_tx,
1135            alive_rx: _alive_rx,
1136            controller,
1137        } = reader_harness().await;
1138        let cancellation_counter = IntCounter::new(
1139            "tcp_client_reader_read_error_cancellations_test",
1140            "test cancellation counter",
1141        )
1142        .unwrap();
1143
1144        let counter_clone = cancellation_counter.clone();
1145        let controller_clone = controller.clone();
1146        let reader_handle = tokio::spawn(async move {
1147            handle_reader(
1148                framed_reader,
1149                controller_clone,
1150                alive_tx,
1151                Some(counter_clone),
1152            )
1153            .await
1154        });
1155
1156        let mut raw_writer = framed_server.into_inner();
1157        raw_writer.write_all(&[0u8; 8]).await.unwrap();
1158        raw_writer.shutdown().await.unwrap();
1159
1160        let _ = reader_handle.await.unwrap();
1161
1162        assert!(
1163            controller.is_killed(),
1164            "Controller should be killed after TCP stream read error"
1165        );
1166        assert_eq!(
1167            cancellation_counter.get(),
1168            1,
1169            "read-error close should increment cancellation metric once"
1170        );
1171    }
1172
1173    /// Drives `handle_reader` against a single message and returns the
1174    /// controller + cancellation counter for assertions.
1175    async fn run_reader_with(
1176        msg: TwoPartMessage,
1177        counter_name: &str,
1178    ) -> (Arc<Controller>, IntCounter) {
1179        let ReaderHarness {
1180            mut framed_server,
1181            framed_reader,
1182            alive_tx,
1183            alive_rx: _alive_rx,
1184            controller,
1185        } = reader_harness().await;
1186        let counter = IntCounter::new(counter_name, "test counter").unwrap();
1187
1188        let counter_clone = counter.clone();
1189        let controller_clone = controller.clone();
1190        let reader_handle = tokio::spawn(async move {
1191            handle_reader(
1192                framed_reader,
1193                controller_clone,
1194                alive_tx,
1195                Some(counter_clone),
1196            )
1197            .await
1198        });
1199
1200        framed_server.send(msg).await.unwrap();
1201        let _ = reader_handle.await.unwrap();
1202
1203        (controller, counter)
1204    }
1205
1206    /// Each protocol-violating message variant must kill only this stream
1207    /// (controller killed, cancellation counted once) and never panic the
1208    /// worker. Covers the three non-read-error panic arms in `handle_reader`:
1209    /// undecodable control bytes, server-sent Sentinel, and non-control
1210    /// (data-only) messages.
1211    #[tokio::test]
1212    async fn test_handle_reader_kills_on_protocol_violations() {
1213        let cases: Vec<(&str, TwoPartMessage)> = vec![
1214            (
1215                "invalid control bytes",
1216                TwoPartMessage::from_header(Bytes::from_static(b"not a valid control message")),
1217            ),
1218            (
1219                "sentinel from server",
1220                control_message(&ControlMessage::Sentinel),
1221            ),
1222            (
1223                "non-control (data-only)",
1224                TwoPartMessage::from_data(Bytes::from_static(b"unexpected payload")),
1225            ),
1226        ];
1227
1228        for (i, (label, msg)) in cases.into_iter().enumerate() {
1229            let counter_name = format!("tcp_client_reader_protocol_violation_test_{i}");
1230            let (controller, counter) = run_reader_with(msg, &counter_name).await;
1231            assert!(
1232                controller.is_killed(),
1233                "{label}: should kill stream context"
1234            );
1235            assert_eq!(counter.get(), 1, "{label}: should be counted once");
1236        }
1237    }
1238}