1use alloc::format;
23use alloc::string::ToString;
24use alloc::sync::Arc;
25use alloc::vec;
26use alloc::vec::Vec;
27use core::time::Duration;
28
29use std::io::{Read, Write};
30use std::net::TcpStream;
31use std::sync::Mutex;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
34use std::thread::JoinHandle;
35
36use liminal::protocol::{
37 Frame, ProtocolError, ProtocolVersion, WorkerRegisterOutcome, WorkerRegistration, decode,
38 encode, encoded_len,
39};
40
41use crate::SdkError;
42
43const CLIENT_MIN_VERSION: ProtocolVersion = ProtocolVersion::new(1, 0);
45const CLIENT_MAX_VERSION: ProtocolVersion = ProtocolVersion::new(1, 0);
47const WRITE_TIMEOUT: Duration = Duration::from_secs(5);
49const READER_POLL_TIMEOUT: Duration = Duration::from_millis(100);
53const READ_CHUNK_BYTES: usize = 4096;
55const MAX_FRAME_BYTES: usize = 16 * 1024 * 1024;
57const APPLICATION_STREAM_ID: u32 = 1;
59
60#[derive(Clone, Debug, PartialEq, Eq)]
62pub struct PushedFrame {
63 correlation_id: u64,
65 payload: Vec<u8>,
67}
68
69impl PushedFrame {
70 #[must_use]
72 pub const fn correlation_id(&self) -> u64 {
73 self.correlation_id
74 }
75
76 #[must_use]
78 pub fn payload(&self) -> &[u8] {
79 &self.payload
80 }
81
82 #[must_use]
84 pub fn into_payload(self) -> Vec<u8> {
85 self.payload
86 }
87}
88
89#[derive(Debug)]
95pub struct PushClient {
96 writer: Arc<Mutex<TcpStream>>,
99 inbound: Receiver<PushedFrame>,
101 stop: Arc<AtomicBool>,
103 reader: Option<JoinHandle<()>>,
105}
106
107impl PushClient {
108 pub fn connect(address: &str) -> Result<Self, SdkError> {
117 let mut stream = connect_socket(address)?;
118 handshake(&mut stream)?;
119 Self::start_reader(stream)
120 }
121
122 pub fn connect_with_registration(
141 address: &str,
142 registration: WorkerRegistration,
143 ) -> Result<Self, SdkError> {
144 let mut stream = connect_socket(address)?;
145 handshake(&mut stream)?;
146 register(&mut stream, registration)?;
147 Self::start_reader(stream)
148 }
149
150 fn start_reader(stream: TcpStream) -> Result<Self, SdkError> {
153 let read_stream = stream.try_clone().map_err(|source| SdkError::Protocol {
156 description: format!("failed to clone push socket for reader thread: {source}"),
157 })?;
158
159 let stop = Arc::new(AtomicBool::new(false));
160 let (sender, inbound) = channel();
161 let reader_stop = Arc::clone(&stop);
162 let reader = std::thread::Builder::new()
163 .name("liminal-push-reader".to_string())
164 .spawn(move || run_reader(read_stream, &sender, &reader_stop))
165 .map_err(|source| SdkError::Protocol {
166 description: format!("failed to start push reader thread: {source}"),
167 })?;
168
169 Ok(Self {
170 writer: Arc::new(Mutex::new(stream)),
171 inbound,
172 stop,
173 reader: Some(reader),
174 })
175 }
176
177 pub fn recv_timeout(&self, timeout: Duration) -> Result<PushedFrame, SdkError> {
184 self.inbound.recv_timeout(timeout).map_err(|error| {
185 let detail = match error {
186 RecvTimeoutError::Timeout => "no server push arrived within the timeout",
187 RecvTimeoutError::Disconnected => {
188 "the push reader stopped before a server push arrived"
189 }
190 };
191 SdkError::Connection {
192 description: format!("push receive failed: {detail}"),
193 }
194 })
195 }
196
197 pub fn reply(&self, correlation_id: u64, payload: Vec<u8>) -> Result<(), SdkError> {
206 let frame = Frame::new_push_reply(APPLICATION_STREAM_ID, correlation_id, payload)
207 .map_err(|error| protocol_error(&error))?;
208 let mut writer = self.writer.lock().map_err(|error| SdkError::Connection {
209 description: format!("push writer lock poisoned: {error}"),
210 })?;
211 write_frame(&mut writer, &frame)
212 }
213}
214
215impl Drop for PushClient {
216 fn drop(&mut self) {
217 self.stop.store(true, Ordering::SeqCst);
218 if let Some(reader) = self.reader.take() {
219 reader.join().ok();
222 }
223 }
224}
225
226fn connect_socket(address: &str) -> Result<TcpStream, SdkError> {
229 let stream = TcpStream::connect(address).map_err(|source| SdkError::Connection {
230 description: format!("failed to connect push client to {address}: {source}"),
231 })?;
232 stream
233 .set_nodelay(true)
234 .map_err(|source| SdkError::Connection {
235 description: format!("failed to disable Nagle for {address}: {source}"),
236 })?;
237 stream
241 .set_read_timeout(Some(READER_POLL_TIMEOUT))
242 .map_err(|source| SdkError::Connection {
243 description: format!("failed to set push read timeout for {address}: {source}"),
244 })?;
245 stream
246 .set_write_timeout(Some(WRITE_TIMEOUT))
247 .map_err(|source| SdkError::Connection {
248 description: format!("failed to set push write timeout for {address}: {source}"),
249 })?;
250 Ok(stream)
251}
252
253fn register(stream: &mut TcpStream, registration: WorkerRegistration) -> Result<(), SdkError> {
260 let frame = Frame::WorkerRegister {
261 flags: 0,
262 registration,
263 };
264 write_frame(stream, &frame)?;
265 let mut buffer = Vec::new();
266 match read_one_frame(stream, &mut buffer)? {
267 Frame::WorkerRegisterAck {
268 outcome: WorkerRegisterOutcome::Accepted,
269 ..
270 } => Ok(()),
271 Frame::WorkerRegisterAck {
272 outcome: WorkerRegisterOutcome::Rejected { reason },
273 ..
274 } => Err(SdkError::Protocol {
275 description: format!("server rejected worker registration: {reason}"),
276 }),
277 other => Err(SdkError::Protocol {
278 description: format!(
279 "expected WorkerRegisterAck during registration, received {:?}",
280 other.frame_type()
281 ),
282 }),
283 }
284}
285
286fn handshake(stream: &mut TcpStream) -> Result<(), SdkError> {
288 let connect = Frame::Connect {
289 flags: 0,
290 min_version: CLIENT_MIN_VERSION,
291 max_version: CLIENT_MAX_VERSION,
292 auth_token: Vec::new(),
293 };
294 write_frame(stream, &connect)?;
295 let mut buffer = Vec::new();
296 match read_one_frame(stream, &mut buffer)? {
297 Frame::ConnectAck { .. } => Ok(()),
298 Frame::ConnectError {
299 reason_code,
300 message,
301 ..
302 } => Err(SdkError::Connection {
303 description: format!(
304 "server rejected push connection (reason {reason_code}): {}",
305 message.unwrap_or_else(|| "no detail".to_string())
306 ),
307 }),
308 other => Err(SdkError::Protocol {
309 description: format!(
310 "expected ConnectAck during push handshake, received {:?}",
311 other.frame_type()
312 ),
313 }),
314 }
315}
316
317fn run_reader(mut stream: TcpStream, sender: &Sender<PushedFrame>, stop: &AtomicBool) {
323 let mut buffer = Vec::new();
324 while !stop.load(Ordering::SeqCst) {
325 match next_frame(&mut stream, &mut buffer) {
326 Ok(Some(Frame::Push {
327 correlation_id,
328 payload,
329 ..
330 })) => {
331 if sender
332 .send(PushedFrame {
333 correlation_id,
334 payload,
335 })
336 .is_err()
337 {
338 return;
341 }
342 }
343 Ok(Some(_) | None) => {}
348 Err(_) => return,
351 }
352 }
353}
354
355fn next_frame(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<Option<Frame>, SdkError> {
358 loop {
359 match decode(buffer) {
360 Ok((frame, consumed)) => {
361 buffer.drain(..consumed);
362 return Ok(Some(frame));
363 }
364 Err(
365 ProtocolError::IncompleteHeader { .. } | ProtocolError::TruncatedPayload { .. },
366 ) => match fill_buffer(stream, buffer)? {
367 FillOutcome::Read => {}
368 FillOutcome::TimedOut => return Ok(None),
369 },
370 Err(error) => return Err(protocol_error(&error)),
371 }
372 }
373}
374
375fn read_one_frame(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<Frame, SdkError> {
379 loop {
380 match decode(buffer) {
381 Ok((frame, consumed)) => {
382 buffer.drain(..consumed);
383 return Ok(frame);
384 }
385 Err(
386 ProtocolError::IncompleteHeader { .. } | ProtocolError::TruncatedPayload { .. },
387 ) => match fill_buffer(stream, buffer)? {
388 FillOutcome::Read => {}
389 FillOutcome::TimedOut => {
390 return Err(SdkError::Connection {
391 description: "push connection timed out waiting for a control-frame reply"
392 .to_string(),
393 });
394 }
395 },
396 Err(error) => return Err(protocol_error(&error)),
397 }
398 }
399}
400
401fn fill_buffer(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<FillOutcome, SdkError> {
404 if buffer.len() > MAX_FRAME_BYTES {
405 return Err(SdkError::Protocol {
406 description: format!(
407 "push frame exceeded {MAX_FRAME_BYTES} bytes without a complete frame"
408 ),
409 });
410 }
411 let mut chunk = [0_u8; READ_CHUNK_BYTES];
412 match stream.read(&mut chunk) {
413 Ok(0) => Err(SdkError::Connection {
414 description: "server closed the push connection".to_string(),
415 }),
416 Ok(read) => {
417 let Some(received) = chunk.get(..read) else {
418 return Err(SdkError::Protocol {
419 description: "push socket read reported more bytes than the buffer holds"
420 .to_string(),
421 });
422 };
423 buffer.extend_from_slice(received);
424 Ok(FillOutcome::Read)
425 }
426 Err(error)
427 if matches!(
428 error.kind(),
429 std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
430 ) =>
431 {
432 Ok(FillOutcome::TimedOut)
433 }
434 Err(error) => Err(SdkError::Connection {
435 description: format!("failed to read from push connection: {error}"),
436 }),
437 }
438}
439
440#[derive(Debug, Clone, Copy, PartialEq, Eq)]
442enum FillOutcome {
443 Read,
444 TimedOut,
445}
446
447fn write_frame(stream: &mut TcpStream, frame: &Frame) -> Result<(), SdkError> {
449 let len = encoded_len(frame).map_err(|error| protocol_error(&error))?;
450 let mut bytes = vec![0_u8; len];
451 let written = encode(frame, &mut bytes).map_err(|error| protocol_error(&error))?;
452 let encoded = bytes.get(..written).ok_or_else(|| SdkError::Protocol {
453 description: "push wire encoder reported an invalid byte count".to_string(),
454 })?;
455 stream
456 .write_all(encoded)
457 .map_err(|source| SdkError::Connection {
458 description: format!("failed to write push frame: {source}"),
459 })?;
460 stream.flush().map_err(|source| SdkError::Connection {
461 description: format!("failed to flush push frame: {source}"),
462 })
463}
464
465fn protocol_error(error: &ProtocolError) -> SdkError {
467 SdkError::Protocol {
468 description: format!("push wire codec error: {error}"),
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use liminal::protocol::FrameType;
476
477 #[test]
478 fn pushed_frame_exposes_correlation_and_payload() {
479 let frame = PushedFrame {
480 correlation_id: 7,
481 payload: vec![1, 2, 3],
482 };
483 assert_eq!(frame.correlation_id(), 7);
484 assert_eq!(frame.payload(), &[1, 2, 3]);
485 assert_eq!(frame.into_payload(), vec![1, 2, 3]);
486 }
487
488 #[test]
489 fn reply_frame_round_trips_through_codec() -> Result<(), SdkError> {
490 let frame = Frame::new_push_reply(APPLICATION_STREAM_ID, 9, vec![4, 5])
491 .map_err(|error| protocol_error(&error))?;
492 let len = encoded_len(&frame).map_err(|error| protocol_error(&error))?;
493 let mut bytes = vec![0_u8; len];
494 let written = encode(&frame, &mut bytes).map_err(|error| protocol_error(&error))?;
495 let (decoded, consumed) =
496 decode(&bytes[..written]).map_err(|error| protocol_error(&error))?;
497 assert_eq!(consumed, written);
498 assert_eq!(decoded.frame_type(), FrameType::PushReply);
499 Ok(())
500 }
501}