use std::io;
use prost::Message;
use crate::broker::protocol::{read_frame, write_frame, Endpoint, Frame, FrameKind, FramingError};
pub struct FrameClient {
stream: io::BufReader<interprocess::local_socket::Stream>,
next_request_id: u64,
}
impl FrameClient {
pub fn connect(endpoint: &Endpoint) -> Result<Self, FrameClientError> {
let connection = crate::broker::backend_handle::Connection::connect(endpoint)
.map_err(FrameClientError::Connect)?;
Ok(Self::from_stream(connection.into_inner()))
}
pub fn from_stream(stream: interprocess::local_socket::Stream) -> Self {
Self {
stream: io::BufReader::new(stream),
next_request_id: 1,
}
}
pub fn request(
&mut self,
payload_protocol: u32,
payload: Vec<u8>,
) -> Result<Frame, FrameClientError> {
let request_id = self.next_request_id;
self.next_request_id = self.next_request_id.wrapping_add(1).max(1);
let frame = Frame::request(payload_protocol, payload).with_request_id(request_id);
let mut body = Vec::with_capacity(frame.encoded_len());
frame
.encode(&mut body)
.expect("prost encoding into Vec cannot fail because Vec writes are infallible");
write_frame(self.stream.get_mut(), &body)?;
let response_bytes = read_frame(&mut self.stream)?;
let response =
Frame::decode(response_bytes.as_slice()).map_err(FrameClientError::Decode)?;
if response.request_id != request_id {
return Err(FrameClientError::RequestIdMismatch {
expected: request_id,
got: response.request_id,
});
}
if response.payload_protocol != payload_protocol {
return Err(FrameClientError::PayloadProtocolMismatch {
expected: payload_protocol,
got: response.payload_protocol,
});
}
if FrameKind::try_from(response.kind) != Ok(FrameKind::Response) {
return Err(FrameClientError::NotAResponse {
kind: response.kind,
});
}
Ok(response)
}
pub fn next_request_id(&self) -> u64 {
self.next_request_id
}
pub fn buffered_len(&self) -> usize {
self.stream.buffer().len()
}
pub fn into_stream(self) -> interprocess::local_socket::Stream {
self.stream.into_inner()
}
}
#[derive(Debug, thiserror::Error)]
pub enum FrameClientError {
#[error("frame client connect failed: {0}")]
Connect(io::Error),
#[error(transparent)]
Framing(#[from] FramingError),
#[error("failed to decode response Frame: {0}")]
Decode(prost::DecodeError),
#[error("response request_id {got} does not match request {expected}")]
RequestIdMismatch {
expected: u64,
got: u64,
},
#[error("response payload_protocol {got:#06X} does not match request {expected:#06X}")]
PayloadProtocolMismatch {
expected: u32,
got: u32,
},
#[error("correlated frame kind {kind} is not RESPONSE")]
NotAResponse {
kind: i32,
},
}