use alloc::format;
use alloc::string::ToString;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::time::Duration;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
use std::thread::JoinHandle;
use liminal::protocol::{
Frame, ProtocolError, ProtocolVersion, WorkerRegisterOutcome, WorkerRegistration, decode,
encode, encoded_len,
};
use crate::SdkError;
const CLIENT_MIN_VERSION: ProtocolVersion = ProtocolVersion::new(1, 0);
const CLIENT_MAX_VERSION: ProtocolVersion = ProtocolVersion::new(1, 0);
const WRITE_TIMEOUT: Duration = Duration::from_secs(5);
const READER_POLL_TIMEOUT: Duration = Duration::from_millis(100);
const READ_CHUNK_BYTES: usize = 4096;
const MAX_FRAME_BYTES: usize = 16 * 1024 * 1024;
const APPLICATION_STREAM_ID: u32 = 1;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PushedFrame {
correlation_id: u64,
payload: Vec<u8>,
}
impl PushedFrame {
#[must_use]
pub const fn correlation_id(&self) -> u64 {
self.correlation_id
}
#[must_use]
pub fn payload(&self) -> &[u8] {
&self.payload
}
#[must_use]
pub fn into_payload(self) -> Vec<u8> {
self.payload
}
}
#[derive(Debug)]
pub struct PushClient {
writer: Arc<Mutex<TcpStream>>,
inbound: Receiver<PushedFrame>,
stop: Arc<AtomicBool>,
reader: Option<JoinHandle<()>>,
}
impl PushClient {
pub fn connect(address: &str) -> Result<Self, SdkError> {
let mut stream = connect_socket(address)?;
handshake(&mut stream)?;
Self::start_reader(stream)
}
pub fn connect_with_registration(
address: &str,
registration: WorkerRegistration,
) -> Result<Self, SdkError> {
let mut stream = connect_socket(address)?;
handshake(&mut stream)?;
register(&mut stream, registration)?;
Self::start_reader(stream)
}
fn start_reader(stream: TcpStream) -> Result<Self, SdkError> {
let read_stream = stream.try_clone().map_err(|source| SdkError::Protocol {
description: format!("failed to clone push socket for reader thread: {source}"),
})?;
let stop = Arc::new(AtomicBool::new(false));
let (sender, inbound) = channel();
let reader_stop = Arc::clone(&stop);
let reader = std::thread::Builder::new()
.name("liminal-push-reader".to_string())
.spawn(move || run_reader(read_stream, &sender, &reader_stop))
.map_err(|source| SdkError::Protocol {
description: format!("failed to start push reader thread: {source}"),
})?;
Ok(Self {
writer: Arc::new(Mutex::new(stream)),
inbound,
stop,
reader: Some(reader),
})
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<PushedFrame, SdkError> {
self.inbound.recv_timeout(timeout).map_err(|error| {
let detail = match error {
RecvTimeoutError::Timeout => "no server push arrived within the timeout",
RecvTimeoutError::Disconnected => {
"the push reader stopped before a server push arrived"
}
};
SdkError::Connection {
description: format!("push receive failed: {detail}"),
}
})
}
pub fn reply(&self, correlation_id: u64, payload: Vec<u8>) -> Result<(), SdkError> {
let frame = Frame::new_push_reply(APPLICATION_STREAM_ID, correlation_id, payload)
.map_err(|error| protocol_error(&error))?;
let mut writer = self.writer.lock().map_err(|error| SdkError::Connection {
description: format!("push writer lock poisoned: {error}"),
})?;
write_frame(&mut writer, &frame)
}
}
impl Drop for PushClient {
fn drop(&mut self) {
self.stop.store(true, Ordering::SeqCst);
if let Some(reader) = self.reader.take() {
reader.join().ok();
}
}
}
fn connect_socket(address: &str) -> Result<TcpStream, SdkError> {
let stream = TcpStream::connect(address).map_err(|source| SdkError::Connection {
description: format!("failed to connect push client to {address}: {source}"),
})?;
stream
.set_nodelay(true)
.map_err(|source| SdkError::Connection {
description: format!("failed to disable Nagle for {address}: {source}"),
})?;
stream
.set_read_timeout(Some(READER_POLL_TIMEOUT))
.map_err(|source| SdkError::Connection {
description: format!("failed to set push read timeout for {address}: {source}"),
})?;
stream
.set_write_timeout(Some(WRITE_TIMEOUT))
.map_err(|source| SdkError::Connection {
description: format!("failed to set push write timeout for {address}: {source}"),
})?;
Ok(stream)
}
fn register(stream: &mut TcpStream, registration: WorkerRegistration) -> Result<(), SdkError> {
let frame = Frame::WorkerRegister {
flags: 0,
registration,
};
write_frame(stream, &frame)?;
let mut buffer = Vec::new();
match read_one_frame(stream, &mut buffer)? {
Frame::WorkerRegisterAck {
outcome: WorkerRegisterOutcome::Accepted,
..
} => Ok(()),
Frame::WorkerRegisterAck {
outcome: WorkerRegisterOutcome::Rejected { reason },
..
} => Err(SdkError::Protocol {
description: format!("server rejected worker registration: {reason}"),
}),
other => Err(SdkError::Protocol {
description: format!(
"expected WorkerRegisterAck during registration, received {:?}",
other.frame_type()
),
}),
}
}
fn handshake(stream: &mut TcpStream) -> Result<(), SdkError> {
let connect = Frame::Connect {
flags: 0,
min_version: CLIENT_MIN_VERSION,
max_version: CLIENT_MAX_VERSION,
auth_token: Vec::new(),
};
write_frame(stream, &connect)?;
let mut buffer = Vec::new();
match read_one_frame(stream, &mut buffer)? {
Frame::ConnectAck { .. } => Ok(()),
Frame::ConnectError {
reason_code,
message,
..
} => Err(SdkError::Connection {
description: format!(
"server rejected push connection (reason {reason_code}): {}",
message.unwrap_or_else(|| "no detail".to_string())
),
}),
other => Err(SdkError::Protocol {
description: format!(
"expected ConnectAck during push handshake, received {:?}",
other.frame_type()
),
}),
}
}
fn run_reader(mut stream: TcpStream, sender: &Sender<PushedFrame>, stop: &AtomicBool) {
let mut buffer = Vec::new();
while !stop.load(Ordering::SeqCst) {
match next_frame(&mut stream, &mut buffer) {
Ok(Some(Frame::Push {
correlation_id,
payload,
..
})) => {
if sender
.send(PushedFrame {
correlation_id,
payload,
})
.is_err()
{
return;
}
}
Ok(Some(_) | None) => {}
Err(_) => return,
}
}
}
fn next_frame(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<Option<Frame>, SdkError> {
loop {
match decode(buffer) {
Ok((frame, consumed)) => {
buffer.drain(..consumed);
return Ok(Some(frame));
}
Err(
ProtocolError::IncompleteHeader { .. } | ProtocolError::TruncatedPayload { .. },
) => match fill_buffer(stream, buffer)? {
FillOutcome::Read => {}
FillOutcome::TimedOut => return Ok(None),
},
Err(error) => return Err(protocol_error(&error)),
}
}
}
fn read_one_frame(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<Frame, SdkError> {
loop {
match decode(buffer) {
Ok((frame, consumed)) => {
buffer.drain(..consumed);
return Ok(frame);
}
Err(
ProtocolError::IncompleteHeader { .. } | ProtocolError::TruncatedPayload { .. },
) => match fill_buffer(stream, buffer)? {
FillOutcome::Read => {}
FillOutcome::TimedOut => {
return Err(SdkError::Connection {
description: "push connection timed out waiting for a control-frame reply"
.to_string(),
});
}
},
Err(error) => return Err(protocol_error(&error)),
}
}
}
fn fill_buffer(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<FillOutcome, SdkError> {
if buffer.len() > MAX_FRAME_BYTES {
return Err(SdkError::Protocol {
description: format!(
"push frame exceeded {MAX_FRAME_BYTES} bytes without a complete frame"
),
});
}
let mut chunk = [0_u8; READ_CHUNK_BYTES];
match stream.read(&mut chunk) {
Ok(0) => Err(SdkError::Connection {
description: "server closed the push connection".to_string(),
}),
Ok(read) => {
let Some(received) = chunk.get(..read) else {
return Err(SdkError::Protocol {
description: "push socket read reported more bytes than the buffer holds"
.to_string(),
});
};
buffer.extend_from_slice(received);
Ok(FillOutcome::Read)
}
Err(error)
if matches!(
error.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
Ok(FillOutcome::TimedOut)
}
Err(error) => Err(SdkError::Connection {
description: format!("failed to read from push connection: {error}"),
}),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FillOutcome {
Read,
TimedOut,
}
fn write_frame(stream: &mut TcpStream, frame: &Frame) -> Result<(), SdkError> {
let len = encoded_len(frame).map_err(|error| protocol_error(&error))?;
let mut bytes = vec![0_u8; len];
let written = encode(frame, &mut bytes).map_err(|error| protocol_error(&error))?;
let encoded = bytes.get(..written).ok_or_else(|| SdkError::Protocol {
description: "push wire encoder reported an invalid byte count".to_string(),
})?;
stream
.write_all(encoded)
.map_err(|source| SdkError::Connection {
description: format!("failed to write push frame: {source}"),
})?;
stream.flush().map_err(|source| SdkError::Connection {
description: format!("failed to flush push frame: {source}"),
})
}
fn protocol_error(error: &ProtocolError) -> SdkError {
SdkError::Protocol {
description: format!("push wire codec error: {error}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use liminal::protocol::FrameType;
#[test]
fn pushed_frame_exposes_correlation_and_payload() {
let frame = PushedFrame {
correlation_id: 7,
payload: vec![1, 2, 3],
};
assert_eq!(frame.correlation_id(), 7);
assert_eq!(frame.payload(), &[1, 2, 3]);
assert_eq!(frame.into_payload(), vec![1, 2, 3]);
}
#[test]
fn reply_frame_round_trips_through_codec() -> Result<(), SdkError> {
let frame = Frame::new_push_reply(APPLICATION_STREAM_ID, 9, vec![4, 5])
.map_err(|error| protocol_error(&error))?;
let len = encoded_len(&frame).map_err(|error| protocol_error(&error))?;
let mut bytes = vec![0_u8; len];
let written = encode(&frame, &mut bytes).map_err(|error| protocol_error(&error))?;
let (decoded, consumed) =
decode(&bytes[..written]).map_err(|error| protocol_error(&error))?;
assert_eq!(consumed, written);
assert_eq!(decoded.frame_type(), FrameType::PushReply);
Ok(())
}
}