Skip to main content

dynamo_runtime/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9    io::AsyncWriteExt,
10    net::TcpStream,
11    time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
16use crate::engine::AsyncEngineContext;
17use crate::pipeline::network::{
18    ConnectionInfo, ResponseStreamPrologue, StreamSender,
19    codec::{TwoPartCodec, TwoPartMessage},
20    tcp::StreamType,
21};
22use anyhow::{Context, Result, anyhow as error}; // Import SinkExt to use the `send` method
23
24#[allow(dead_code)]
25pub struct TcpClient {
26    worker_id: String,
27}
28
29impl Default for TcpClient {
30    fn default() -> Self {
31        TcpClient {
32            worker_id: uuid::Uuid::new_v4().to_string(),
33        }
34    }
35}
36
37impl TcpClient {
38    pub fn new(worker_id: String) -> Self {
39        TcpClient { worker_id }
40    }
41
42    async fn connect(address: &str) -> std::io::Result<TcpStream> {
43        // try to connect to the address; retry with linear backoff if AddrNotAvailable
44        let backoff = std::time::Duration::from_millis(200);
45        loop {
46            match TcpStream::connect(address).await {
47                Ok(socket) => {
48                    socket.set_nodelay(true)?;
49                    return Ok(socket);
50                }
51                Err(e) => {
52                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
53                        tracing::warn!("retry warning: failed to connect: {:?}", e);
54                        tokio::time::sleep(backoff).await;
55                    } else {
56                        return Err(e);
57                    }
58                }
59            }
60        }
61    }
62
63    pub async fn create_response_stream(
64        context: Arc<dyn AsyncEngineContext>,
65        info: ConnectionInfo,
66    ) -> Result<StreamSender> {
67        let info =
68            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
69        tracing::trace!("Creating response stream for {:?}", info);
70
71        if info.stream_type != StreamType::Response {
72            return Err(error!(
73                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
74                info.stream_type
75            ));
76        }
77
78        if info.context != context.id() {
79            return Err(error!(
80                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
81                context.id(),
82                info.context
83            ));
84        }
85
86        let stream = TcpClient::connect(&info.address).await?;
87        let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
88        let (read_half, write_half) = tokio::io::split(stream);
89
90        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
91        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
92
93        // this is a oneshot channel that will be used to signal when the stream is closed
94        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
95        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
96        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
97        // captured by the monitor task
98        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
99
100        let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
101
102        // transport specific handshake message
103        let handshake = CallHomeHandshake {
104            subject: info.subject.clone(),
105            stream_type: StreamType::Response,
106        };
107
108        let handshake_bytes = match serde_json::to_vec(&handshake) {
109            Ok(hb) => hb,
110            Err(err) => {
111                return Err(error!(
112                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
113                ));
114            }
115        };
116        let msg = TwoPartMessage::from_header(handshake_bytes.into());
117
118        // issue the the first tcp handshake message
119        framed_writer
120            .send(msg)
121            .await
122            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
123
124        // set up the channel to send bytes to the transport layer
125        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
126
127        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
128
129        let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
130
131        let subject = info.subject.clone();
132        tokio::spawn(async move {
133            // await both tasks
134            let (reader, writer) = tokio::join!(reader_task, writer_task);
135
136            match (reader, writer) {
137                (Ok(reader), Ok(writer)) => {
138                    let reader = reader.into_inner();
139
140                    let writer = match writer {
141                        Ok(writer) => writer.into_inner(),
142                        Err(e) => {
143                            tracing::error!("failed to join writer task: {:?}", e);
144                            return Err(e);
145                        }
146                    };
147
148                    let mut stream = reader.unsplit(writer);
149
150                    // await the tcp server to shutdown the socket connection
151                    // set a timeout for the server shutdown
152                    let mut buf = vec![0u8; 1024];
153                    let deadline = Instant::now() + Duration::from_secs(10);
154                    loop {
155                        let n = time::timeout_at(deadline, stream.read(&mut buf))
156                            .await
157                            .inspect_err(|_| {
158                                tracing::debug!("server did not close socket within the deadline");
159                            })?
160                            .inspect_err(|e| {
161                                tracing::debug!("failed to read from stream: {:?}", e);
162                            })?;
163                        if n == 0 {
164                            // Server has closed (FIN)
165                            break;
166                        }
167                    }
168
169                    Ok(())
170                }
171                (Err(reader_err), Ok(_)) => {
172                    tracing::error!(
173                        "reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
174                    );
175                    anyhow::bail!(
176                        "reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
177                    );
178                }
179                (Ok(_), Err(writer_err)) => {
180                    tracing::error!(
181                        "writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
182                    );
183                    anyhow::bail!(
184                        "writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
185                    );
186                }
187                (Err(reader_err), Err(writer_err)) => {
188                    tracing::error!(
189                        "both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
190                    );
191                    anyhow::bail!(
192                        "both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
193                    );
194                }
195            }
196        });
197
198        // set up the prologue for the stream
199        // this might have transport specific metadata in the future
200        let prologue = Some(ResponseStreamPrologue { error: None });
201
202        // create the stream sender
203        let stream_sender = StreamSender {
204            tx: bytes_tx,
205            prologue,
206        };
207
208        Ok(stream_sender)
209    }
210}
211
212async fn handle_reader(
213    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
214    context: Arc<dyn AsyncEngineContext>,
215    alive_tx: tokio::sync::oneshot::Sender<()>,
216) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
217    let mut framed_reader = framed_reader;
218    let mut alive_tx = alive_tx;
219    loop {
220        tokio::select! {
221            msg = framed_reader.next() => {
222                match msg {
223                    Some(Ok(two_part_msg)) => {
224                        match two_part_msg.optional_parts() {
225                           (Some(bytes), None) => {
226                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
227                                    Ok(msg) => msg,
228                                    Err(_) => {
229                                        // TODO(#171) - address fatal errors
230                                        panic!("fatal error - invalid control message detected");
231                                    }
232                                };
233
234                                match msg {
235                                    ControlMessage::Stop => {
236                                        context.stop();
237                                    }
238                                    ControlMessage::Kill => {
239                                        context.kill();
240                                    }
241                                    ControlMessage::Sentinel => {
242                                        // TODO(#171) - address fatal errors
243                                        panic!("received a sentinel message; this should never happen");
244                                    }
245                                }
246                           }
247                           _ => {
248                                panic!("received a non-control message; this should never happen");
249                           }
250                        }
251                    }
252                    Some(Err(e)) => {
253                        // TODO(#171) - address fatal errors
254                        // in this case the binary representation of the message is invalid
255                        panic!("fatal error - failed to decode message from stream; invalid line protocol: {e:?}");
256                    }
257                    None => {
258                        tracing::debug!("tcp stream closed by server");
259                        break;
260                    }
261                }
262            }
263            _ = alive_tx.closed() => {
264                break;
265            }
266        }
267    }
268    framed_reader
269}
270
271async fn handle_writer(
272    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
273    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
274    alive_rx: tokio::sync::oneshot::Receiver<()>,
275    context: Arc<dyn AsyncEngineContext>,
276) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
277    // Only send sentinel for normal channel closure
278    let mut send_sentinel = true;
279
280    loop {
281        let msg = tokio::select! {
282            biased;
283
284            _ = context.killed() => {
285                tracing::trace!("context kill signal received; shutting down");
286                send_sentinel = false;
287                break;
288            }
289
290            _ = context.stopped() => {
291                tracing::trace!("context stop signal received; shutting down");
292                send_sentinel = false;
293                break;
294            }
295
296            msg = bytes_rx.recv() => {
297                match msg {
298                    Some(msg) => msg,
299                    None => {
300                        tracing::trace!("response channel closed; shutting down");
301                        break;
302                    }
303                }
304            }
305        };
306
307        if let Err(e) = framed_writer.send(msg).await {
308            tracing::trace!(
309                "failed to send message to network; possible disconnect: {:?}",
310                e
311            );
312            send_sentinel = false;
313            break;
314        }
315    }
316
317    // Send sentinel only on normal closure
318    if send_sentinel {
319        let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
320        let msg = TwoPartMessage::from_header(message.into());
321        framed_writer.send(msg).await?;
322    }
323
324    drop(alive_rx);
325    Ok(framed_writer)
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::pipeline::context::Controller;
332    use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
333    use bytes::Bytes;
334    use futures::StreamExt;
335    use std::sync::Arc;
336    use tokio::io::AsyncReadExt;
337    use tokio::net::TcpStream;
338    use tokio::sync::{mpsc, oneshot};
339    use tokio_util::codec::FramedRead;
340
341    struct WriterHarness {
342        server: tokio::net::TcpStream,
343        framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
344        bytes_tx: mpsc::Sender<TwoPartMessage>,
345        bytes_rx: mpsc::Receiver<TwoPartMessage>,
346        alive_tx: oneshot::Sender<()>,
347        alive_rx: oneshot::Receiver<()>,
348        controller: Arc<Controller>,
349    }
350
351    /// Creates a reusable writer harness with paired TCP streams and test channels.
352    async fn writer_harness() -> WriterHarness {
353        let (client, server) = create_tcp_pair().await;
354        let (_, write_half) = tokio::io::split(client);
355        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
356
357        let (bytes_tx, bytes_rx) = mpsc::channel(64);
358        let (alive_tx, alive_rx) = oneshot::channel::<()>();
359        let controller = Arc::new(Controller::default());
360
361        WriterHarness {
362            server,
363            framed_writer,
364            bytes_tx,
365            bytes_rx,
366            alive_tx,
367            alive_rx,
368            controller,
369        }
370    }
371
372    async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
373        reader
374            .next()
375            .await
376            .expect("expected message")
377            .expect("failed to decode message")
378    }
379
380    fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
381        let (header, data) = msg.optional_parts();
382        assert!(header.is_none(), "data-only message should not have header");
383        assert_eq!(
384            data.expect("data payload missing").as_ref(),
385            expected,
386            "data payload should match"
387        );
388    }
389
390    fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
391        let (header, data) = msg.optional_parts();
392        assert!(data.is_none(), "header-only message should not carry data");
393        assert_eq!(
394            header.expect("header missing").as_ref(),
395            expected,
396            "header payload should match"
397        );
398    }
399
400    fn assert_header_and_data_message(
401        msg: TwoPartMessage,
402        expected_header: &[u8],
403        expected_data: &[u8],
404    ) {
405        let (header, data) = msg.optional_parts();
406        assert_eq!(
407            header.expect("header missing").as_ref(),
408            expected_header,
409            "header payload should match"
410        );
411        assert_eq!(
412            data.expect("data missing").as_ref(),
413            expected_data,
414            "data payload should match"
415        );
416    }
417
418    fn assert_sentinel_message(msg: TwoPartMessage) {
419        let (header, data) = msg.optional_parts();
420        assert!(data.is_none(), "sentinel should not include a data section");
421        let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
422        assert_eq!(
423            header.expect("sentinel header missing").as_ref(),
424            expected_sentinel.as_slice(),
425            "sentinel header should match serialized ControlMessage::Sentinel"
426        );
427    }
428
429    /// Test that handle_writer forwards messages from the channel to the framed writer
430    #[tokio::test]
431    async fn test_handle_writer_forwards_messages() {
432        let WriterHarness {
433            server,
434            framed_writer,
435            bytes_tx,
436            bytes_rx,
437            alive_rx,
438            controller,
439            ..
440        } = writer_harness().await;
441
442        // Send test messages
443        let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
444        bytes_tx.send(test_msg).await.unwrap();
445
446        // Close the sender to trigger normal termination
447        drop(bytes_tx);
448
449        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
450
451        assert!(result.is_ok());
452
453        // Decode from server side to verify data and sentinel were sent
454        let mut reader = FramedRead::new(server, TwoPartCodec::default());
455
456        let msg = recv_msg(&mut reader).await;
457        assert_data_only_message(msg, b"test data");
458
459        let sentinel = recv_msg(&mut reader).await;
460        assert_sentinel_message(sentinel);
461    }
462
463    /// Test that handle_writer sends sentinel on normal channel closure
464    #[tokio::test]
465    async fn test_handle_writer_sends_sentinel_on_normal_closure() {
466        let WriterHarness {
467            mut server,
468            framed_writer,
469            bytes_tx,
470            bytes_rx,
471            alive_rx,
472            controller,
473            ..
474        } = writer_harness().await;
475
476        // Close the sender immediately to trigger normal termination
477        drop(bytes_tx);
478
479        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
480
481        assert!(result.is_ok());
482
483        // Read from server side to verify sentinel was sent
484        let mut buffer = vec![0u8; 1024];
485        let n = server.read(&mut buffer).await.unwrap();
486
487        // Buffer should contain the sentinel message
488        assert!(n > 0, "Expected sentinel to be written to the TCP stream");
489
490        // Verify it contains the sentinel message by checking for the JSON
491        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
492        assert!(
493            buffer[..n]
494                .windows(sentinel_json.len())
495                .any(|w| w == sentinel_json.as_slice()),
496            "Buffer should contain sentinel message. Buffer: {:?}",
497            String::from_utf8_lossy(&buffer[..n])
498        );
499    }
500
501    /// Test that handle_writer does NOT send sentinel when context is killed
502    #[tokio::test]
503    async fn test_handle_writer_no_sentinel_on_context_killed() {
504        let WriterHarness {
505            mut server,
506            framed_writer,
507            bytes_rx,
508            alive_rx,
509            controller,
510            ..
511        } = writer_harness().await;
512
513        // Kill the context
514        controller.kill();
515
516        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
517
518        assert!(result.is_ok());
519
520        // Drop the writer to close the connection, then try to read. Otherwise,
521        // the test will hang on `server.read()`
522        drop(result);
523
524        // Read from server side - should get no sentinel
525        let mut buffer = vec![0u8; 1024];
526        let n = server.read(&mut buffer).await.unwrap();
527
528        // Buffer should be empty (no sentinel sent)
529        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
530        assert!(
531            n == 0
532                || !buffer[..n]
533                    .windows(sentinel_json.len())
534                    .any(|w| w == sentinel_json.as_slice()),
535            "Buffer should NOT contain sentinel message when context is killed"
536        );
537    }
538
539    /// Test that handle_writer does NOT send sentinel when context is stopped
540    #[tokio::test]
541    async fn test_handle_writer_no_sentinel_on_context_stopped() {
542        let WriterHarness {
543            mut server,
544            framed_writer,
545            bytes_rx,
546            alive_rx,
547            controller,
548            ..
549        } = writer_harness().await;
550
551        // Stop the context
552        controller.stop();
553
554        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
555
556        assert!(result.is_ok());
557
558        // Drop the writer to close the connection, then try to read. Otherwise,
559        // the test will hang on `server.read()`
560        drop(result);
561
562        // Read from server side - should get no sentinel
563        let mut buffer = vec![0u8; 1024];
564        let n = server.read(&mut buffer).await.unwrap();
565
566        // Buffer should be empty (no sentinel sent)
567        let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
568        assert!(
569            n == 0
570                || !buffer[..n]
571                    .windows(sentinel_json.len())
572                    .any(|w| w == sentinel_json.as_slice()),
573            "Buffer should NOT contain sentinel message when context is stopped"
574        );
575    }
576
577    /// Test that handle_writer handles multiple messages correctly
578    #[tokio::test]
579    async fn test_handle_writer_multiple_messages() {
580        let WriterHarness {
581            server,
582            framed_writer,
583            bytes_tx,
584            bytes_rx,
585            alive_rx,
586            controller,
587            ..
588        } = writer_harness().await;
589
590        // Send multiple messages
591        for i in 0..5 {
592            let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
593            bytes_tx.send(test_msg).await.unwrap();
594        }
595
596        // Close the sender to trigger normal termination
597        drop(bytes_tx);
598
599        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
600
601        assert!(result.is_ok());
602
603        // Decode from server side to verify all messages plus sentinel
604        let mut reader = FramedRead::new(server, TwoPartCodec::default());
605        for i in 0..5 {
606            let msg = recv_msg(&mut reader).await;
607            assert_data_only_message(msg, format!("message {}", i).as_bytes());
608        }
609
610        let sentinel = recv_msg(&mut reader).await;
611        assert_sentinel_message(sentinel);
612    }
613
614    /// Test that alive_rx is dropped after handle_writer completes
615    #[tokio::test]
616    async fn test_handle_writer_drops_alive_rx() {
617        let WriterHarness {
618            framed_writer,
619            bytes_tx,
620            bytes_rx,
621            alive_tx,
622            alive_rx,
623            controller,
624            ..
625        } = writer_harness().await;
626
627        // Close the sender to trigger normal termination
628        drop(bytes_tx);
629
630        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
631
632        assert!(result.is_ok());
633
634        // alive_tx should now be closed because alive_rx was dropped
635        assert!(alive_tx.is_closed());
636    }
637
638    /// Test handle_writer with header-only messages (control messages)
639    #[tokio::test]
640    async fn test_handle_writer_header_only_messages() {
641        let WriterHarness {
642            server,
643            framed_writer,
644            bytes_tx,
645            bytes_rx,
646            alive_rx,
647            controller,
648            ..
649        } = writer_harness().await;
650
651        // Send a header-only message
652        let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
653        bytes_tx.send(header_msg).await.unwrap();
654
655        // Close the sender
656        drop(bytes_tx);
657
658        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
659
660        assert!(result.is_ok());
661
662        let mut reader = FramedRead::new(server, TwoPartCodec::default());
663
664        let header_msg = recv_msg(&mut reader).await;
665        assert_header_only_message(header_msg, b"header content");
666
667        let sentinel = recv_msg(&mut reader).await;
668        assert_sentinel_message(sentinel);
669    }
670
671    /// Test handle_writer with mixed header and data messages
672    #[tokio::test]
673    async fn test_handle_writer_mixed_messages() {
674        let WriterHarness {
675            server,
676            framed_writer,
677            bytes_tx,
678            bytes_rx,
679            alive_rx,
680            controller,
681            ..
682        } = writer_harness().await;
683
684        // Send mixed messages
685        bytes_tx
686            .send(TwoPartMessage::from_header(Bytes::from("header1")))
687            .await
688            .unwrap();
689        bytes_tx
690            .send(TwoPartMessage::from_data(Bytes::from("data1")))
691            .await
692            .unwrap();
693        bytes_tx
694            .send(TwoPartMessage::from_parts(
695                Bytes::from("header2"),
696                Bytes::from("data2"),
697            ))
698            .await
699            .unwrap();
700
701        // Close the sender
702        drop(bytes_tx);
703
704        let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
705
706        assert!(result.is_ok());
707
708        let mut reader = FramedRead::new(server, TwoPartCodec::default());
709
710        let first = recv_msg(&mut reader).await;
711        assert_header_only_message(first, b"header1");
712
713        let second = recv_msg(&mut reader).await;
714        assert_data_only_message(second, b"data1");
715
716        let third = recv_msg(&mut reader).await;
717        assert_header_and_data_message(third, b"header2", b"data2");
718
719        let sentinel = recv_msg(&mut reader).await;
720        assert_sentinel_message(sentinel);
721    }
722
723    // ==================== handle_reader tests ====================
724
725    struct ReaderHarness {
726        framed_server: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
727        framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
728        alive_tx: oneshot::Sender<()>,
729        alive_rx: oneshot::Receiver<()>,
730        controller: Arc<Controller>,
731    }
732
733    /// Creates a reusable reader harness with paired TCP streams and test channels.
734    async fn reader_harness() -> ReaderHarness {
735        let (client, server) = create_tcp_pair().await;
736        let (read_half, _write_half) = tokio::io::split(client);
737        let (_server_read, server_write) = tokio::io::split(server);
738
739        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
740        let framed_server = FramedWrite::new(server_write, TwoPartCodec::default());
741        let (alive_tx, alive_rx) = oneshot::channel::<()>();
742        let controller = Arc::new(Controller::default());
743
744        ReaderHarness {
745            framed_server,
746            framed_reader,
747            alive_tx,
748            alive_rx,
749            controller,
750        }
751    }
752
753    fn control_message(msg: &ControlMessage) -> TwoPartMessage {
754        let msg_bytes = serde_json::to_vec(msg).unwrap();
755        TwoPartMessage::from_header(Bytes::from(msg_bytes))
756    }
757
758    /// Test that handle_reader handles Stop control message by calling context.stop()
759    #[tokio::test]
760    async fn test_handle_reader_stop_control_message() {
761        let ReaderHarness {
762            mut framed_server,
763            framed_reader,
764            alive_tx,
765            alive_rx: _alive_rx,
766            controller,
767        } = reader_harness().await;
768
769        // Spawn the reader task
770        let controller_clone = controller.clone();
771        let reader_handle =
772            tokio::spawn(
773                async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
774            );
775
776        // Send Stop control message from server
777        framed_server
778            .send(control_message(&ControlMessage::Stop))
779            .await
780            .unwrap();
781
782        // Close the framed server to signal EOF to the client
783        framed_server.close().await.unwrap();
784
785        // Wait for reader to finish
786        let _ = reader_handle.await.unwrap();
787
788        // Verify that stop was called on the controller
789        assert!(
790            controller.is_stopped(),
791            "Controller should be stopped after receiving Stop message"
792        );
793    }
794
795    /// Test that handle_reader handles Kill control message by calling context.kill()
796    #[tokio::test]
797    async fn test_handle_reader_kill_control_message() {
798        let ReaderHarness {
799            mut framed_server,
800            framed_reader,
801            alive_tx,
802            alive_rx: _alive_rx,
803            controller,
804        } = reader_harness().await;
805
806        // Spawn the reader task
807        let controller_clone = controller.clone();
808        let reader_handle =
809            tokio::spawn(
810                async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
811            );
812
813        // Send Kill control message from server
814        framed_server
815            .send(control_message(&ControlMessage::Kill))
816            .await
817            .unwrap();
818
819        // Close the framed server to signal EOF to the client
820        framed_server.close().await.unwrap();
821
822        // Wait for reader to finish
823        let _ = reader_handle.await.unwrap();
824
825        // Verify that kill was called on the controller
826        assert!(
827            controller.is_killed(),
828            "Controller should be killed after receiving Kill message"
829        );
830    }
831
832    /// Test that handle_reader exits when alive channel is closed
833    #[tokio::test]
834    async fn test_handle_reader_exits_on_alive_channel_closed() {
835        let ReaderHarness {
836            framed_reader,
837            alive_tx,
838            alive_rx,
839            controller,
840            ..
841        } = reader_harness().await;
842
843        // Spawn the reader task
844        let reader_handle =
845            tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await });
846
847        // Drop the alive_rx to close the channel (simulating writer finishing)
848        drop(alive_rx);
849
850        // Reader should exit due to alive channel closure
851        let result = reader_handle.await;
852
853        assert!(
854            result.is_ok(),
855            "handle_reader should exit when alive channel is closed"
856        );
857    }
858
859    /// Test that handle_reader exits when TCP stream is closed
860    #[tokio::test]
861    async fn test_handle_reader_exits_on_stream_closed() {
862        let ReaderHarness {
863            mut framed_server,
864            framed_reader,
865            alive_tx,
866            alive_rx: _alive_rx,
867            controller,
868        } = reader_harness().await;
869
870        // Spawn the reader task
871        let reader_handle =
872            tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await });
873
874        // Close the framed server to signal EOF to the client
875        framed_server.close().await.unwrap();
876
877        // Reader should exit due to stream closure
878        let result = tokio::time::timeout(std::time::Duration::from_secs(1), reader_handle).await;
879
880        assert!(
881            result.is_ok(),
882            "handle_reader should exit when stream is closed"
883        );
884    }
885
886    /// Test that handle_reader handles multiple control messages in sequence
887    #[tokio::test]
888    async fn test_handle_reader_multiple_control_messages() {
889        let ReaderHarness {
890            mut framed_server,
891            framed_reader,
892            alive_tx,
893            alive_rx: _alive_rx,
894            controller,
895        } = reader_harness().await;
896
897        // Spawn the reader task
898        let controller_clone = controller.clone();
899        let reader_handle =
900            tokio::spawn(
901                async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
902            );
903
904        // Send multiple Stop messages (first one will stop, subsequent ones are no-ops)
905        framed_server
906            .send(control_message(&ControlMessage::Stop))
907            .await
908            .unwrap();
909        framed_server
910            .send(control_message(&ControlMessage::Stop))
911            .await
912            .unwrap();
913
914        // Close the framed server to signal EOF to the client
915        framed_server.close().await.unwrap();
916
917        // Wait for reader to finish
918        let _ = reader_handle.await.unwrap();
919
920        // Verify that stop was called
921        assert!(
922            controller.is_stopped(),
923            "Controller should be stopped after receiving Stop messages"
924        );
925    }
926
927    /// Test handle_reader with Stop followed by Kill
928    #[tokio::test]
929    async fn test_handle_reader_stop_then_kill() {
930        let ReaderHarness {
931            mut framed_server,
932            framed_reader,
933            alive_tx,
934            alive_rx: _alive_rx,
935            controller,
936        } = reader_harness().await;
937
938        // Spawn the reader task
939        let controller_clone = controller.clone();
940        let reader_handle =
941            tokio::spawn(
942                async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
943            );
944
945        // Send Stop first, then Kill
946        framed_server
947            .send(control_message(&ControlMessage::Stop))
948            .await
949            .unwrap();
950        framed_server
951            .send(control_message(&ControlMessage::Kill))
952            .await
953            .unwrap();
954
955        // Close the framed server to signal EOF to the client
956        framed_server.close().await.unwrap();
957
958        // Wait for reader to finish
959        let _ = reader_handle.await.unwrap();
960
961        // Verify that kill was called (which sets killed state)
962        assert!(
963            controller.is_killed(),
964            "Controller should be killed after receiving Kill message"
965        );
966    }
967}