use std::marker::PhantomData;
use std::sync::Arc;
use crate::backpressure::{BackpressureController, PressureState};
use crate::buffer::{Consumer, Producer, RingBuffer};
use crate::error::{BridgeError, Result};
use crate::eventfd::WakePair;
pub struct BridgeChannel<Req, Rsp> {
pub control: ControlHandle<Req, Rsp>,
pub data: DataHandle<Req, Rsp>,
}
pub struct ControlHandle<Req, Rsp> {
pub producer: Producer<Req>,
pub consumer: Consumer<Rsp>,
pub backpressure: Arc<BackpressureController>,
pub req_wake: Arc<WakePair>,
pub rsp_wake: Arc<WakePair>,
}
pub struct DataHandle<Req, Rsp> {
pub consumer: Consumer<Req>,
pub producer: Producer<Rsp>,
pub backpressure: Arc<BackpressureController>,
pub req_wake: Arc<WakePair>,
pub rsp_wake: Arc<WakePair>,
}
pub struct PinnedDataHandle<Req, Rsp> {
inner: DataHandle<Req, Rsp>,
_not_send: PhantomData<*const ()>,
}
impl<Req, Rsp> BridgeChannel<Req, Rsp> {
pub fn new(req_capacity: usize, rsp_capacity: usize) -> std::io::Result<Self> {
let (req_producer, req_consumer) = RingBuffer::channel::<Req>(req_capacity);
let (rsp_producer, rsp_consumer) = RingBuffer::channel::<Rsp>(rsp_capacity);
let req_wake = Arc::new(WakePair::new()?);
let rsp_wake = Arc::new(WakePair::new()?);
let backpressure = Arc::new(BackpressureController::default());
Ok(Self {
control: ControlHandle {
producer: req_producer,
consumer: rsp_consumer,
backpressure: Arc::clone(&backpressure),
req_wake: Arc::clone(&req_wake),
rsp_wake: Arc::clone(&rsp_wake),
},
data: DataHandle {
consumer: req_consumer,
producer: rsp_producer,
backpressure,
req_wake,
rsp_wake,
},
})
}
}
impl<Req, Rsp> ControlHandle<Req, Rsp> {
pub fn try_send_request(&mut self, req: Req) -> Result<()> {
let result = self.producer.try_push(req);
let util = self.producer.utilization();
if let Some(new_state) = self.backpressure.update(util) {
tracing::info!(
utilization = util,
state = ?new_state,
"bridge backpressure transition"
);
}
match &result {
Ok(()) => {
let _ = self.req_wake.consumer_wake.notify();
}
Err(BridgeError::Full { .. }) => {
}
_ => {}
}
result
}
pub fn try_recv_response(&mut self) -> Result<Rsp> {
let result = self.consumer.try_pop();
if result.is_ok() {
let _ = self.rsp_wake.producer_wake.notify();
}
result
}
pub fn drain_responses(&mut self, buf: &mut Vec<Rsp>, max: usize) -> usize {
let count = self.consumer.drain_into(buf, max);
if count > 0 {
let _ = self.rsp_wake.producer_wake.notify();
}
count
}
pub fn pressure(&self) -> PressureState {
self.backpressure.state()
}
pub fn response_wake_fd(&self) -> std::os::unix::io::RawFd {
self.rsp_wake.consumer_wake.as_fd()
}
pub fn request_space_fd(&self) -> std::os::unix::io::RawFd {
self.req_wake.producer_wake.as_fd()
}
}
impl<Req, Rsp> DataHandle<Req, Rsp> {
pub fn pin(self) -> PinnedDataHandle<Req, Rsp> {
PinnedDataHandle {
inner: self,
_not_send: PhantomData,
}
}
pub fn try_recv_request(&mut self) -> Result<Req> {
let result = self.consumer.try_pop();
if result.is_ok() {
let _ = self.req_wake.producer_wake.notify();
}
result
}
pub fn drain_requests(&mut self, buf: &mut Vec<Req>, max: usize) -> usize {
let count = self.consumer.drain_into(buf, max);
if count > 0 {
let _ = self.req_wake.producer_wake.notify();
}
count
}
pub fn try_send_response(&mut self, rsp: Rsp) -> Result<()> {
let result = self.producer.try_push(rsp);
if result.is_ok() {
let _ = self.rsp_wake.consumer_wake.notify();
}
result
}
pub fn pressure(&self) -> PressureState {
self.backpressure.state()
}
pub fn should_throttle(&self) -> bool {
matches!(
self.pressure(),
PressureState::Throttled | PressureState::Suspended
)
}
pub fn should_suspend(&self) -> bool {
self.pressure() == PressureState::Suspended
}
pub fn request_wake_fd(&self) -> std::os::unix::io::RawFd {
self.req_wake.consumer_wake.as_fd()
}
pub fn response_space_fd(&self) -> std::os::unix::io::RawFd {
self.rsp_wake.producer_wake.as_fd()
}
}
impl<Req, Rsp> PinnedDataHandle<Req, Rsp> {
pub fn try_recv_request(&mut self) -> Result<Req> {
self.inner.try_recv_request()
}
pub fn drain_requests(&mut self, buf: &mut Vec<Req>, max: usize) -> usize {
self.inner.drain_requests(buf, max)
}
pub fn try_send_response(&mut self, rsp: Rsp) -> Result<()> {
self.inner.try_send_response(rsp)
}
pub fn pressure(&self) -> PressureState {
self.inner.pressure()
}
pub fn should_throttle(&self) -> bool {
self.inner.should_throttle()
}
pub fn should_suspend(&self) -> bool {
self.inner.should_suspend()
}
pub fn request_wake_fd(&self) -> std::os::unix::io::RawFd {
self.inner.request_wake_fd()
}
pub fn response_space_fd(&self) -> std::os::unix::io::RawFd {
self.inner.response_space_fd()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bridge_channel_roundtrip() {
let bridge: BridgeChannel<u64, String> = BridgeChannel::new(16, 16).unwrap();
let mut control = bridge.control;
let mut data = bridge.data;
control.try_send_request(42).unwrap();
let req = data.try_recv_request().unwrap();
assert_eq!(req, 42);
data.try_send_response("result".to_string()).unwrap();
let rsp = control.try_recv_response().unwrap();
assert_eq!(rsp, "result");
}
#[test]
fn backpressure_updates_on_send() {
let bridge: BridgeChannel<u64, u64> = BridgeChannel::new(16, 16).unwrap();
let mut control = bridge.control;
for i in 0..14 {
control.try_send_request(i).unwrap();
}
assert_eq!(control.pressure(), PressureState::Throttled);
}
#[test]
fn eventfd_wake_on_push() {
let bridge: BridgeChannel<u64, u64> = BridgeChannel::new(16, 16).unwrap();
let mut control = bridge.control;
let data = bridge.data;
control.try_send_request(1).unwrap();
let count = data.req_wake.consumer_wake.try_read().unwrap();
assert!(count > 0);
}
#[test]
fn drain_responses_signals_producer() {
let bridge: BridgeChannel<u64, u64> = BridgeChannel::new(16, 16).unwrap();
let mut control = bridge.control;
let mut data = bridge.data;
data.try_send_response(10).unwrap();
data.try_send_response(20).unwrap();
data.try_send_response(30).unwrap();
let mut buf = Vec::new();
let count = control.drain_responses(&mut buf, 10);
assert_eq!(count, 3);
assert_eq!(buf, vec![10, 20, 30]);
}
#[test]
fn data_handle_throttle_queries() {
let bridge: BridgeChannel<u64, u64> = BridgeChannel::new(16, 16).unwrap();
let mut control = bridge.control;
let data = bridge.data;
assert!(!data.should_throttle());
assert!(!data.should_suspend());
for i in 0..14 {
control.try_send_request(i).unwrap();
}
assert!(data.should_throttle());
assert!(!data.should_suspend());
control.try_send_request(14).unwrap();
control.try_send_request(15).unwrap();
assert!(data.should_throttle());
assert!(data.should_suspend());
}
}