use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use parking_lot::Mutex;
pub type RequestId = u64;
pub type StreamId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum MessageType {
Request = 0,
Response = 1,
StreamStart = 2,
StreamData = 3,
StreamEnd = 4,
Error = 5,
FlowPause = 6,
FlowResume = 7,
Ping = 8,
Pong = 9,
Cancel = 10,
}
impl TryFrom<u8> for MessageType {
type Error = IpcError;
fn try_from(value: u8) -> Result<Self, <Self as TryFrom<u8>>::Error> {
match value {
0 => Ok(MessageType::Request),
1 => Ok(MessageType::Response),
2 => Ok(MessageType::StreamStart),
3 => Ok(MessageType::StreamData),
4 => Ok(MessageType::StreamEnd),
5 => Ok(MessageType::Error),
6 => Ok(MessageType::FlowPause),
7 => Ok(MessageType::FlowResume),
8 => Ok(MessageType::Ping),
9 => Ok(MessageType::Pong),
10 => Ok(MessageType::Cancel),
_ => Err(IpcError::InvalidMessageType(value)),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct FrameHeader {
pub length: u32,
pub id: u64,
pub msg_type: MessageType,
pub flags: u8,
}
impl FrameHeader {
pub const SIZE: usize = 14;
pub const MAX_PAYLOAD: u32 = 16 * 1024 * 1024;
pub fn new(id: u64, msg_type: MessageType, payload_len: usize) -> Self {
Self {
length: payload_len as u32,
id,
msg_type,
flags: 0,
}
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut buf = [0u8; Self::SIZE];
buf[0..4].copy_from_slice(&self.length.to_le_bytes());
buf[4..12].copy_from_slice(&self.id.to_le_bytes());
buf[12] = self.msg_type as u8;
buf[13] = self.flags;
buf
}
pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self, IpcError> {
let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
let id = u64::from_le_bytes([
buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
]);
let msg_type = MessageType::try_from(buf[12])?;
let flags = buf[13];
if length > Self::MAX_PAYLOAD {
return Err(IpcError::PayloadTooLarge(length as usize));
}
Ok(Self {
length,
id,
msg_type,
flags,
})
}
}
#[derive(Debug, Clone)]
pub struct Frame {
pub header: FrameHeader,
pub payload: Vec<u8>,
}
impl Frame {
pub fn request(id: RequestId, payload: Vec<u8>) -> Self {
Self {
header: FrameHeader::new(id, MessageType::Request, payload.len()),
payload,
}
}
pub fn response(id: RequestId, payload: Vec<u8>) -> Self {
Self {
header: FrameHeader::new(id, MessageType::Response, payload.len()),
payload,
}
}
pub fn stream_start(id: StreamId, payload: Vec<u8>) -> Self {
Self {
header: FrameHeader::new(id, MessageType::StreamStart, payload.len()),
payload,
}
}
pub fn stream_data(id: StreamId, payload: Vec<u8>) -> Self {
Self {
header: FrameHeader::new(id, MessageType::StreamData, payload.len()),
payload,
}
}
pub fn stream_end(id: StreamId) -> Self {
Self {
header: FrameHeader::new(id, MessageType::StreamEnd, 0),
payload: Vec::new(),
}
}
pub fn error(id: RequestId, error_code: u32, message: &str) -> Self {
let mut payload = Vec::with_capacity(4 + message.len());
payload.extend_from_slice(&error_code.to_le_bytes());
payload.extend_from_slice(message.as_bytes());
Self {
header: FrameHeader::new(id, MessageType::Error, payload.len()),
payload,
}
}
pub fn ping(id: RequestId) -> Self {
Self {
header: FrameHeader::new(id, MessageType::Ping, 0),
payload: Vec::new(),
}
}
pub fn pong(id: RequestId) -> Self {
Self {
header: FrameHeader::new(id, MessageType::Pong, 0),
payload: Vec::new(),
}
}
pub fn cancel(id: RequestId) -> Self {
Self {
header: FrameHeader::new(id, MessageType::Cancel, 0),
payload: Vec::new(),
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(FrameHeader::SIZE + self.payload.len());
buf.extend_from_slice(&self.header.to_bytes());
buf.extend_from_slice(&self.payload);
buf
}
}
#[derive(Debug)]
pub enum IpcError {
Io(io::Error),
InvalidMessageType(u8),
PayloadTooLarge(usize),
UnexpectedEof,
RequestCancelled(RequestId),
StreamClosed(StreamId),
Timeout,
}
impl From<io::Error> for IpcError {
fn from(e: io::Error) -> Self {
IpcError::Io(e)
}
}
impl std::fmt::Display for IpcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IpcError::Io(e) => write!(f, "IO error: {}", e),
IpcError::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
IpcError::PayloadTooLarge(size) => write!(f, "Payload too large: {} bytes", size),
IpcError::UnexpectedEof => write!(f, "Unexpected end of stream"),
IpcError::RequestCancelled(id) => write!(f, "Request {} cancelled", id),
IpcError::StreamClosed(id) => write!(f, "Stream {} closed", id),
IpcError::Timeout => write!(f, "Operation timed out"),
}
}
}
impl std::error::Error for IpcError {}
pub struct FrameReader<R: Read> {
reader: R,
header_buf: [u8; FrameHeader::SIZE],
}
impl<R: Read> FrameReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
header_buf: [0u8; FrameHeader::SIZE],
}
}
pub fn read_frame(&mut self) -> Result<Frame, IpcError> {
self.reader.read_exact(&mut self.header_buf)?;
let header = FrameHeader::from_bytes(&self.header_buf)?;
let mut payload = vec![0u8; header.length as usize];
self.reader.read_exact(&mut payload)?;
Ok(Frame { header, payload })
}
pub fn into_inner(self) -> R {
self.reader
}
}
pub struct FrameWriter<W: Write> {
writer: W,
buffer: Vec<u8>,
max_buffer: usize,
}
impl<W: Write> FrameWriter<W> {
const DEFAULT_BUFFER: usize = 64 * 1024;
pub fn new(writer: W) -> Self {
Self {
writer,
buffer: Vec::with_capacity(Self::DEFAULT_BUFFER),
max_buffer: Self::DEFAULT_BUFFER,
}
}
pub fn write_frame(&mut self, frame: &Frame) -> Result<(), IpcError> {
let bytes = frame.to_bytes();
if bytes.len() > self.max_buffer {
self.flush()?;
self.writer.write_all(&bytes)?;
return Ok(());
}
if self.buffer.len() + bytes.len() > self.max_buffer {
self.flush()?;
}
self.buffer.extend_from_slice(&bytes);
Ok(())
}
pub fn flush(&mut self) -> Result<(), IpcError> {
if !self.buffer.is_empty() {
self.writer.write_all(&self.buffer)?;
self.buffer.clear();
}
self.writer.flush()?;
Ok(())
}
pub fn into_inner(self) -> W {
self.writer
}
}
struct PendingRequest {
callback: Box<dyn FnOnce(Result<Frame, IpcError>) + Send>,
}
pub struct RequestMultiplexer {
next_id: AtomicU64,
pending: Mutex<HashMap<RequestId, PendingRequest>>,
streams: Mutex<HashMap<StreamId, StreamState>>,
}
struct StreamState {
on_data: Box<dyn Fn(Vec<u8>) + Send>,
on_end: Box<dyn FnOnce() + Send>,
#[allow(dead_code)]
paused: bool,
}
impl Default for RequestMultiplexer {
fn default() -> Self {
Self::new()
}
}
impl RequestMultiplexer {
pub fn new() -> Self {
Self {
next_id: AtomicU64::new(1),
pending: Mutex::new(HashMap::new()),
streams: Mutex::new(HashMap::new()),
}
}
pub fn next_id(&self) -> RequestId {
self.next_id.fetch_add(1, Ordering::SeqCst)
}
pub fn register_request<F>(&self, id: RequestId, callback: F)
where
F: FnOnce(Result<Frame, IpcError>) + Send + 'static,
{
self.pending.lock().insert(
id,
PendingRequest {
callback: Box::new(callback),
},
);
}
pub fn register_stream<D, E>(&self, id: StreamId, on_data: D, on_end: E)
where
D: Fn(Vec<u8>) + Send + 'static,
E: FnOnce() + Send + 'static,
{
self.streams.lock().insert(
id,
StreamState {
on_data: Box::new(on_data),
on_end: Box::new(on_end),
paused: false,
},
);
}
pub fn handle_frame(&self, frame: Frame) {
match frame.header.msg_type {
MessageType::Response | MessageType::Error => {
if let Some(pending) = self.pending.lock().remove(&frame.header.id) {
(pending.callback)(Ok(frame));
}
}
MessageType::StreamData => {
if let Some(state) = self.streams.lock().get(&frame.header.id) {
(state.on_data)(frame.payload);
}
}
MessageType::StreamEnd => {
if let Some(state) = self.streams.lock().remove(&frame.header.id) {
(state.on_end)();
}
}
MessageType::Pong => {
}
_ => {
}
}
}
pub fn cancel(&self, id: RequestId) {
if let Some(pending) = self.pending.lock().remove(&id) {
(pending.callback)(Err(IpcError::RequestCancelled(id)));
}
if let Some(state) = self.streams.lock().remove(&id) {
(state.on_end)();
}
}
pub fn pending_count(&self) -> usize {
self.pending.lock().len()
}
}
pub struct BatchRequest {
requests: Vec<(RequestId, Vec<u8>)>,
}
impl Default for BatchRequest {
fn default() -> Self {
Self::new()
}
}
impl BatchRequest {
pub fn new() -> Self {
Self {
requests: Vec::new(),
}
}
pub fn add(&mut self, id: RequestId, payload: Vec<u8>) -> &mut Self {
self.requests.push((id, payload));
self
}
pub fn build(self) -> Vec<Frame> {
self.requests
.into_iter()
.map(|(id, payload)| Frame::request(id, payload))
.collect()
}
pub fn len(&self) -> usize {
self.requests.len()
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct FlowControl {
pub window_size: usize,
pub outstanding: usize,
pub paused: bool,
}
impl Default for FlowControl {
fn default() -> Self {
Self {
window_size: 64 * 1024, outstanding: 0,
paused: false,
}
}
}
impl FlowControl {
pub fn new(window_size: usize) -> Self {
Self {
window_size,
outstanding: 0,
paused: false,
}
}
pub fn can_send(&self) -> bool {
!self.paused && self.outstanding < self.window_size
}
pub fn record_sent(&mut self, bytes: usize) {
self.outstanding += bytes;
if self.outstanding >= self.window_size {
self.paused = true;
}
}
pub fn record_acked(&mut self, bytes: usize) {
self.outstanding = self.outstanding.saturating_sub(bytes);
if self.outstanding < self.window_size / 2 {
self.paused = false;
}
}
pub fn pause(&mut self) {
self.paused = true;
}
pub fn resume(&mut self) {
self.paused = false;
}
}
pub struct StreamWriter<W: Write> {
writer: Arc<Mutex<FrameWriter<W>>>,
stream_id: StreamId,
flow_control: FlowControl,
}
impl<W: Write> StreamWriter<W> {
pub fn new(writer: Arc<Mutex<FrameWriter<W>>>, stream_id: StreamId) -> Self {
Self {
writer,
stream_id,
flow_control: FlowControl::default(),
}
}
pub fn write_chunk(&mut self, data: Vec<u8>) -> Result<(), IpcError> {
while !self.flow_control.can_send() {
std::thread::yield_now();
}
let frame = Frame::stream_data(self.stream_id, data);
let size = frame.payload.len();
self.writer.lock().write_frame(&frame)?;
self.flow_control.record_sent(size);
Ok(())
}
pub fn finish(self) -> Result<(), IpcError> {
let frame = Frame::stream_end(self.stream_id);
let mut writer = self.writer.lock();
writer.write_frame(&frame)?;
writer.flush()
}
}
pub trait RequestHandler: Send + Sync {
fn handle_request(&self, request_id: RequestId, payload: &[u8]) -> Result<Vec<u8>, IpcError>;
fn handle_stream<W: Write>(
&self,
stream_id: StreamId,
payload: &[u8],
writer: StreamWriter<W>,
) -> Result<(), IpcError>;
}
pub struct IpcServer<H: RequestHandler> {
handler: Arc<H>,
}
impl<H: RequestHandler> IpcServer<H> {
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
}
}
pub fn process<R: Read, W: Write>(
&self,
reader: &mut FrameReader<R>,
writer: Arc<Mutex<FrameWriter<W>>>,
) -> Result<(), IpcError> {
loop {
let frame = match reader.read_frame() {
Ok(f) => f,
Err(IpcError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
return Ok(()); }
Err(e) => return Err(e),
};
match frame.header.msg_type {
MessageType::Request => {
let response =
match self.handler.handle_request(frame.header.id, &frame.payload) {
Ok(data) => Frame::response(frame.header.id, data),
Err(e) => Frame::error(frame.header.id, 1, &e.to_string()),
};
writer.lock().write_frame(&response)?;
}
MessageType::StreamStart => {
let stream_writer = StreamWriter::new(Arc::clone(&writer), frame.header.id);
if let Err(e) =
self.handler
.handle_stream(frame.header.id, &frame.payload, stream_writer)
{
let err = Frame::error(frame.header.id, 2, &e.to_string());
writer.lock().write_frame(&err)?;
}
}
MessageType::Ping => {
let pong = Frame::pong(frame.header.id);
writer.lock().write_frame(&pong)?;
}
MessageType::Cancel => {
}
_ => {
}
}
writer.lock().flush()?;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_frame_header_roundtrip() {
let header = FrameHeader::new(12345, MessageType::Request, 100);
let bytes = header.to_bytes();
let parsed = FrameHeader::from_bytes(&bytes).unwrap();
assert_eq!(parsed.id, 12345);
assert_eq!(parsed.msg_type, MessageType::Request);
assert_eq!(parsed.length, 100);
}
#[test]
fn test_frame_roundtrip() {
let original = Frame::request(1, b"hello world".to_vec());
let bytes = original.to_bytes();
let mut reader = FrameReader::new(Cursor::new(bytes));
let parsed = reader.read_frame().unwrap();
assert_eq!(parsed.header.id, 1);
assert_eq!(parsed.header.msg_type, MessageType::Request);
assert_eq!(parsed.payload, b"hello world");
}
#[test]
fn test_batch_request() {
let mut batch = BatchRequest::new();
batch.add(1, b"request1".to_vec());
batch.add(2, b"request2".to_vec());
batch.add(3, b"request3".to_vec());
let frames = batch.build();
assert_eq!(frames.len(), 3);
assert_eq!(frames[0].header.id, 1);
assert_eq!(frames[1].header.id, 2);
assert_eq!(frames[2].header.id, 3);
}
#[test]
fn test_multiplexer() {
let mux = RequestMultiplexer::new();
let id1 = mux.next_id();
let id2 = mux.next_id();
assert_ne!(id1, id2);
use std::sync::atomic::AtomicBool;
let received1 = Arc::new(AtomicBool::new(false));
let received2 = Arc::new(AtomicBool::new(false));
{
let r1 = Arc::clone(&received1);
mux.register_request(id1, move |_| {
r1.store(true, Ordering::SeqCst);
});
}
{
let r2 = Arc::clone(&received2);
mux.register_request(id2, move |_| {
r2.store(true, Ordering::SeqCst);
});
}
mux.handle_frame(Frame::response(id2, b"resp2".to_vec()));
assert!(!received1.load(Ordering::SeqCst));
assert!(received2.load(Ordering::SeqCst));
mux.handle_frame(Frame::response(id1, b"resp1".to_vec()));
assert!(received1.load(Ordering::SeqCst));
}
#[test]
fn test_flow_control() {
let mut fc = FlowControl::new(100);
assert!(fc.can_send());
fc.record_sent(50);
assert!(fc.can_send());
assert_eq!(fc.outstanding, 50);
fc.record_sent(60);
assert!(!fc.can_send()); assert!(fc.paused);
fc.record_acked(80);
assert!(fc.can_send()); assert!(!fc.paused);
}
#[test]
fn test_error_frame() {
let frame = Frame::error(42, 500, "Internal error");
assert_eq!(frame.header.id, 42);
assert_eq!(frame.header.msg_type, MessageType::Error);
let error_code = u32::from_le_bytes([
frame.payload[0],
frame.payload[1],
frame.payload[2],
frame.payload[3],
]);
let message = std::str::from_utf8(&frame.payload[4..]).unwrap();
assert_eq!(error_code, 500);
assert_eq!(message, "Internal error");
}
}