Skip to main content

liminal_sdk/remote/tcp/
push_client.rs

1//! Client-side background reader for server-initiated pushes.
2//!
3//! Every other SDK transport call is request/response: the client writes a frame
4//! and reads exactly one reply to its own request ([`Connection::round_trip`]). A
5//! server PUSH inverts that — the server writes a [`Frame::Push`] on the client's
6//! existing connection at a time of the server's choosing, with no outstanding
7//! client request to read it. [`PushClient`] is the piece that consumes those
8//! inbound frames: it owns a connection whose socket is drained by a dedicated
9//! background reader thread, surfaces each pushed frame on a channel, and lets the
10//! caller send back a correlated [`Frame::PushReply`] on the same socket.
11//!
12//! # Read/write split
13//!
14//! A push connection is read concurrently (the background thread blocks on the
15//! socket) and written concurrently (the caller replies). `TcpStream` is cloned so
16//! the reader thread owns one handle and the writer holds the other behind a
17//! `Mutex`; the two handles share the same underlying socket, so a reply written
18//! by the caller travels the connection the server is pushing on. This keeps the
19//! request/reply [`Connection`] (which couples a single read to a single write)
20//! completely untouched — the push path is additive, not a rewrite.
21
22use alloc::format;
23use alloc::string::ToString;
24use alloc::sync::Arc;
25use alloc::vec;
26use alloc::vec::Vec;
27use core::time::Duration;
28
29use std::io::{Read, Write};
30use std::net::TcpStream;
31use std::sync::Mutex;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
34use std::thread::JoinHandle;
35
36use liminal::protocol::{
37    Frame, ProtocolError, ProtocolVersion, WorkerRegisterOutcome, WorkerRegistration, decode,
38    encode, encoded_len,
39};
40
41use crate::SdkError;
42
43/// Minimum protocol version this client advertises during the handshake.
44const CLIENT_MIN_VERSION: ProtocolVersion = ProtocolVersion::new(1, 0);
45/// Maximum protocol version this client advertises during the handshake.
46const CLIENT_MAX_VERSION: ProtocolVersion = ProtocolVersion::new(1, 0);
47/// Bound on a single socket write.
48const WRITE_TIMEOUT: Duration = Duration::from_secs(5);
49/// Poll cadence the reader thread uses so it can observe the stop flag promptly
50/// between reads while still blocking efficiently on the socket the rest of the
51/// time.
52const READER_POLL_TIMEOUT: Duration = Duration::from_millis(100);
53/// Read chunk size used when draining the socket into the frame buffer.
54const READ_CHUNK_BYTES: usize = 4096;
55/// Upper bound on a single buffered frame, guarding against runaway buffering.
56const MAX_FRAME_BYTES: usize = 16 * 1024 * 1024;
57/// Application stream id used for the client's push reply frames.
58const APPLICATION_STREAM_ID: u32 = 1;
59
60/// A frame the server pushed to this client.
61#[derive(Clone, Debug, PartialEq, Eq)]
62pub struct PushedFrame {
63    /// Correlation id the server assigned; echo it on the reply.
64    correlation_id: u64,
65    /// Opaque payload bytes the server pushed.
66    payload: Vec<u8>,
67}
68
69impl PushedFrame {
70    /// Correlation id to echo back on the reply so the server matches it.
71    #[must_use]
72    pub const fn correlation_id(&self) -> u64 {
73        self.correlation_id
74    }
75
76    /// Opaque payload bytes the server pushed.
77    #[must_use]
78    pub fn payload(&self) -> &[u8] {
79        &self.payload
80    }
81
82    /// Consumes the frame, returning the owned payload bytes.
83    #[must_use]
84    pub fn into_payload(self) -> Vec<u8> {
85        self.payload
86    }
87}
88
89/// A connected client that consumes server pushes and sends correlated replies.
90///
91/// Construct with [`PushClient::connect`]; the background reader starts
92/// immediately and runs until the client is dropped. Pull pushed frames with
93/// [`PushClient::recv_timeout`] and answer them with [`PushClient::reply`].
94#[derive(Debug)]
95pub struct PushClient {
96    /// Write half of the shared socket, guarded so the caller's reply does not
97    /// interleave bytes with any other writer.
98    writer: Arc<Mutex<TcpStream>>,
99    /// Inbound pushed frames surfaced by the background reader.
100    inbound: Receiver<PushedFrame>,
101    /// Signals the reader thread to stop; set on drop.
102    stop: Arc<AtomicBool>,
103    /// Background reader handle, joined on drop.
104    reader: Option<JoinHandle<()>>,
105}
106
107impl PushClient {
108    /// Connects to `address`, performs the protocol handshake, and starts the
109    /// background reader that drains inbound server pushes.
110    ///
111    /// # Errors
112    ///
113    /// Returns [`SdkError::Connection`] when the TCP connection or socket
114    /// configuration fails, and [`SdkError::Protocol`] when the handshake is
115    /// rejected or the socket cannot be cloned for the reader thread.
116    pub fn connect(address: &str) -> Result<Self, SdkError> {
117        let mut stream = connect_socket(address)?;
118        handshake(&mut stream)?;
119        Self::start_reader(stream)
120    }
121
122    /// Connects, performs the handshake, then synchronously registers this client
123    /// as a worker before starting the background reader.
124    ///
125    /// This mirrors the synchronous `Connect`/`ConnectAck` pattern: the
126    /// `WorkerRegister` frame is written and its [`Frame::WorkerRegisterAck`] read
127    /// on the calling thread, BEFORE the Push-only background reader is spawned, so
128    /// the ack is never swallowed by the reader. A connect-variant (rather than a
129    /// `register()` method on a connected client) is the cleanest fit: `connect`
130    /// spawns the reader as its last step, so registration must be threaded into
131    /// the connect sequence to land before that spawn; a post-connect method would
132    /// race the already-running reader for the ack frame.
133    ///
134    /// # Errors
135    ///
136    /// Returns [`SdkError::Connection`] when the TCP connection or socket
137    /// configuration fails, and [`SdkError::Protocol`] when the handshake is
138    /// rejected, the server rejects the registration (the rejection reason is
139    /// carried in the error), or the socket cannot be cloned for the reader thread.
140    pub fn connect_with_registration(
141        address: &str,
142        registration: WorkerRegistration,
143    ) -> Result<Self, SdkError> {
144        let mut stream = connect_socket(address)?;
145        handshake(&mut stream)?;
146        register(&mut stream, registration)?;
147        Self::start_reader(stream)
148    }
149
150    /// Spawns the Push-only background reader over a handshaken (and, for a worker,
151    /// already-registered) stream and returns the running client.
152    fn start_reader(stream: TcpStream) -> Result<Self, SdkError> {
153        // Clone the socket so the reader thread owns one handle and the writer
154        // holds the other; both refer to the same underlying connection.
155        let read_stream = stream.try_clone().map_err(|source| SdkError::Protocol {
156            description: format!("failed to clone push socket for reader thread: {source}"),
157        })?;
158
159        let stop = Arc::new(AtomicBool::new(false));
160        let (sender, inbound) = channel();
161        let reader_stop = Arc::clone(&stop);
162        let reader = std::thread::Builder::new()
163            .name("liminal-push-reader".to_string())
164            .spawn(move || run_reader(read_stream, &sender, &reader_stop))
165            .map_err(|source| SdkError::Protocol {
166                description: format!("failed to start push reader thread: {source}"),
167            })?;
168
169        Ok(Self {
170            writer: Arc::new(Mutex::new(stream)),
171            inbound,
172            stop,
173            reader: Some(reader),
174        })
175    }
176
177    /// Blocks up to `timeout` for the next pushed frame from the server.
178    ///
179    /// # Errors
180    ///
181    /// Returns [`SdkError::Connection`] when no push arrives within `timeout` or
182    /// the background reader has stopped (e.g. the server closed the connection).
183    pub fn recv_timeout(&self, timeout: Duration) -> Result<PushedFrame, SdkError> {
184        self.inbound.recv_timeout(timeout).map_err(|error| {
185            let detail = match error {
186                RecvTimeoutError::Timeout => "no server push arrived within the timeout",
187                RecvTimeoutError::Disconnected => {
188                    "the push reader stopped before a server push arrived"
189                }
190            };
191            SdkError::Connection {
192                description: format!("push receive failed: {detail}"),
193            }
194        })
195    }
196
197    /// Sends a correlated reply to a pushed frame, echoing its correlation id so
198    /// the server matches the reply back to the originating push.
199    ///
200    /// # Errors
201    ///
202    /// Returns [`SdkError::Protocol`] when the reply frame cannot be encoded and
203    /// [`SdkError::Connection`] when it cannot be written to the socket or the
204    /// writer lock is poisoned.
205    pub fn reply(&self, correlation_id: u64, payload: Vec<u8>) -> Result<(), SdkError> {
206        let frame = Frame::new_push_reply(APPLICATION_STREAM_ID, correlation_id, payload)
207            .map_err(|error| protocol_error(&error))?;
208        let mut writer = self.writer.lock().map_err(|error| SdkError::Connection {
209            description: format!("push writer lock poisoned: {error}"),
210        })?;
211        write_frame(&mut writer, &frame)
212    }
213}
214
215impl Drop for PushClient {
216    fn drop(&mut self) {
217        self.stop.store(true, Ordering::SeqCst);
218        if let Some(reader) = self.reader.take() {
219            // The reader wakes within READER_POLL_TIMEOUT to observe the stop flag,
220            // so this join does not hang on a quiet connection.
221            reader.join().ok();
222        }
223    }
224}
225
226/// Opens and configures the push-client socket (Nagle off, bounded read/write
227/// timeouts) before any framing.
228fn connect_socket(address: &str) -> Result<TcpStream, SdkError> {
229    let stream = TcpStream::connect(address).map_err(|source| SdkError::Connection {
230        description: format!("failed to connect push client to {address}: {source}"),
231    })?;
232    stream
233        .set_nodelay(true)
234        .map_err(|source| SdkError::Connection {
235            description: format!("failed to disable Nagle for {address}: {source}"),
236        })?;
237    // A bounded read timeout lets the reader thread wake to check the stop flag
238    // even when the server is silent; without it the thread would block forever
239    // on a quiet connection and never observe drop.
240    stream
241        .set_read_timeout(Some(READER_POLL_TIMEOUT))
242        .map_err(|source| SdkError::Connection {
243            description: format!("failed to set push read timeout for {address}: {source}"),
244        })?;
245    stream
246        .set_write_timeout(Some(WRITE_TIMEOUT))
247        .map_err(|source| SdkError::Connection {
248            description: format!("failed to set push write timeout for {address}: {source}"),
249        })?;
250    Ok(stream)
251}
252
253/// Drives the synchronous worker-registration round trip
254/// (`WorkerRegister` -> `WorkerRegisterAck`) on a handshaken socket, before the
255/// background reader is spawned.
256///
257/// A `Rejected` ack maps to a typed [`SdkError::Protocol`] carrying the server's
258/// reason; any non-ack reply is a protocol error.
259fn register(stream: &mut TcpStream, registration: WorkerRegistration) -> Result<(), SdkError> {
260    let frame = Frame::WorkerRegister {
261        flags: 0,
262        registration,
263    };
264    write_frame(stream, &frame)?;
265    let mut buffer = Vec::new();
266    match read_one_frame(stream, &mut buffer)? {
267        Frame::WorkerRegisterAck {
268            outcome: WorkerRegisterOutcome::Accepted,
269            ..
270        } => Ok(()),
271        Frame::WorkerRegisterAck {
272            outcome: WorkerRegisterOutcome::Rejected { reason },
273            ..
274        } => Err(SdkError::Protocol {
275            description: format!("server rejected worker registration: {reason}"),
276        }),
277        other => Err(SdkError::Protocol {
278            description: format!(
279                "expected WorkerRegisterAck during registration, received {:?}",
280                other.frame_type()
281            ),
282        }),
283    }
284}
285
286/// Drives the client handshake (`Connect` -> `ConnectAck`) on a fresh socket.
287fn handshake(stream: &mut TcpStream) -> Result<(), SdkError> {
288    let connect = Frame::Connect {
289        flags: 0,
290        min_version: CLIENT_MIN_VERSION,
291        max_version: CLIENT_MAX_VERSION,
292        auth_token: Vec::new(),
293    };
294    write_frame(stream, &connect)?;
295    let mut buffer = Vec::new();
296    match read_one_frame(stream, &mut buffer)? {
297        Frame::ConnectAck { .. } => Ok(()),
298        Frame::ConnectError {
299            reason_code,
300            message,
301            ..
302        } => Err(SdkError::Connection {
303            description: format!(
304                "server rejected push connection (reason {reason_code}): {}",
305                message.unwrap_or_else(|| "no detail".to_string())
306            ),
307        }),
308        other => Err(SdkError::Protocol {
309            description: format!(
310                "expected ConnectAck during push handshake, received {:?}",
311                other.frame_type()
312            ),
313        }),
314    }
315}
316
317/// Background loop: drains the socket, surfacing each `Push` frame on `sender`.
318///
319/// Returns (ending the thread) when the stop flag is set, the connection closes,
320/// or a fatal decode/IO error occurs. A read timeout is non-fatal: it just lets
321/// the loop re-check the stop flag.
322fn run_reader(mut stream: TcpStream, sender: &Sender<PushedFrame>, stop: &AtomicBool) {
323    let mut buffer = Vec::new();
324    while !stop.load(Ordering::SeqCst) {
325        match next_frame(&mut stream, &mut buffer) {
326            Ok(Some(Frame::Push {
327                correlation_id,
328                payload,
329                ..
330            })) => {
331                if sender
332                    .send(PushedFrame {
333                        correlation_id,
334                        payload,
335                    })
336                    .is_err()
337                {
338                    // The receiver was dropped; nothing will consume further
339                    // pushes, so stop reading.
340                    return;
341                }
342            }
343            // `Some(_)`: any non-Push frame on a push connection is unexpected for
344            // this spike — ignore it rather than tearing the reader down so a stray
345            // frame cannot silently drop subsequent pushes. `None`: a read timeout
346            // with no complete frame. Both just loop to re-check the stop flag.
347            Ok(Some(_) | None) => {}
348            // Connection closed or a fatal read/decode error: end the thread. The
349            // dropped `sender` surfaces as a `Disconnected` on the receiver side.
350            Err(_) => return,
351        }
352    }
353}
354
355/// Reads until one complete frame decodes, treating a read timeout as
356/// `Ok(None)` so the caller can re-check the stop flag without ending the loop.
357fn next_frame(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<Option<Frame>, SdkError> {
358    loop {
359        match decode(buffer) {
360            Ok((frame, consumed)) => {
361                buffer.drain(..consumed);
362                return Ok(Some(frame));
363            }
364            Err(
365                ProtocolError::IncompleteHeader { .. } | ProtocolError::TruncatedPayload { .. },
366            ) => match fill_buffer(stream, buffer)? {
367                FillOutcome::Read => {}
368                FillOutcome::TimedOut => return Ok(None),
369            },
370            Err(error) => return Err(protocol_error(&error)),
371        }
372    }
373}
374
375/// Reads one complete frame, blocking (no timeout tolerance) — used for the
376/// synchronous handshake and worker-registration replies, before the background
377/// reader starts.
378fn read_one_frame(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<Frame, SdkError> {
379    loop {
380        match decode(buffer) {
381            Ok((frame, consumed)) => {
382                buffer.drain(..consumed);
383                return Ok(frame);
384            }
385            Err(
386                ProtocolError::IncompleteHeader { .. } | ProtocolError::TruncatedPayload { .. },
387            ) => match fill_buffer(stream, buffer)? {
388                FillOutcome::Read => {}
389                FillOutcome::TimedOut => {
390                    return Err(SdkError::Connection {
391                        description: "push connection timed out waiting for a control-frame reply"
392                            .to_string(),
393                    });
394                }
395            },
396            Err(error) => return Err(protocol_error(&error)),
397        }
398    }
399}
400
401/// Appends one socket read into `buffer`, mapping a read timeout to a non-fatal
402/// [`FillOutcome::TimedOut`] so the reader can poll the stop flag.
403fn fill_buffer(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<FillOutcome, SdkError> {
404    if buffer.len() > MAX_FRAME_BYTES {
405        return Err(SdkError::Protocol {
406            description: format!(
407                "push frame exceeded {MAX_FRAME_BYTES} bytes without a complete frame"
408            ),
409        });
410    }
411    let mut chunk = [0_u8; READ_CHUNK_BYTES];
412    match stream.read(&mut chunk) {
413        Ok(0) => Err(SdkError::Connection {
414            description: "server closed the push connection".to_string(),
415        }),
416        Ok(read) => {
417            let Some(received) = chunk.get(..read) else {
418                return Err(SdkError::Protocol {
419                    description: "push socket read reported more bytes than the buffer holds"
420                        .to_string(),
421                });
422            };
423            buffer.extend_from_slice(received);
424            Ok(FillOutcome::Read)
425        }
426        Err(error)
427            if matches!(
428                error.kind(),
429                std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
430            ) =>
431        {
432            Ok(FillOutcome::TimedOut)
433        }
434        Err(error) => Err(SdkError::Connection {
435            description: format!("failed to read from push connection: {error}"),
436        }),
437    }
438}
439
440/// Outcome of one non-fatal socket read attempt.
441#[derive(Debug, Clone, Copy, PartialEq, Eq)]
442enum FillOutcome {
443    Read,
444    TimedOut,
445}
446
447/// Encodes and writes one frame to the socket, flushing it.
448fn write_frame(stream: &mut TcpStream, frame: &Frame) -> Result<(), SdkError> {
449    let len = encoded_len(frame).map_err(|error| protocol_error(&error))?;
450    let mut bytes = vec![0_u8; len];
451    let written = encode(frame, &mut bytes).map_err(|error| protocol_error(&error))?;
452    let encoded = bytes.get(..written).ok_or_else(|| SdkError::Protocol {
453        description: "push wire encoder reported an invalid byte count".to_string(),
454    })?;
455    stream
456        .write_all(encoded)
457        .map_err(|source| SdkError::Connection {
458            description: format!("failed to write push frame: {source}"),
459        })?;
460    stream.flush().map_err(|source| SdkError::Connection {
461        description: format!("failed to flush push frame: {source}"),
462    })
463}
464
465/// Maps a wire codec error into the SDK error taxonomy.
466fn protocol_error(error: &ProtocolError) -> SdkError {
467    SdkError::Protocol {
468        description: format!("push wire codec error: {error}"),
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use liminal::protocol::FrameType;
476
477    #[test]
478    fn pushed_frame_exposes_correlation_and_payload() {
479        let frame = PushedFrame {
480            correlation_id: 7,
481            payload: vec![1, 2, 3],
482        };
483        assert_eq!(frame.correlation_id(), 7);
484        assert_eq!(frame.payload(), &[1, 2, 3]);
485        assert_eq!(frame.into_payload(), vec![1, 2, 3]);
486    }
487
488    #[test]
489    fn reply_frame_round_trips_through_codec() -> Result<(), SdkError> {
490        let frame = Frame::new_push_reply(APPLICATION_STREAM_ID, 9, vec![4, 5])
491            .map_err(|error| protocol_error(&error))?;
492        let len = encoded_len(&frame).map_err(|error| protocol_error(&error))?;
493        let mut bytes = vec![0_u8; len];
494        let written = encode(&frame, &mut bytes).map_err(|error| protocol_error(&error))?;
495        let (decoded, consumed) =
496            decode(&bytes[..written]).map_err(|error| protocol_error(&error))?;
497        assert_eq!(consumed, written);
498        assert_eq!(decoded.frame_type(), FrameType::PushReply);
499        Ok(())
500    }
501}