use crate::runtime::graph::Graph;
use crate::runtime::scheduler::Scheduler;
pub use mio::{Interest, event::Source};
use petgraph::prelude::NodeIndex;
use slab::Slab;
use std::io;
use std::sync::Arc;
pub struct IoSource<S: Source> {
source: S,
token: mio::Token,
}
impl<S: Source> IoSource<S> {
const fn new(source: S, token: mio::Token) -> Self {
IoSource { source, token }
}
pub const fn source(&self) -> &S {
&self.source
}
pub const fn source_mut(&mut self) -> &mut S {
&mut self.source
}
}
pub struct IoDriver {
poller: mio::Poll,
events: mio::Events,
indices: Slab<NodeIndex>,
waker: Arc<mio::Waker>,
}
impl IoDriver {
pub(crate) fn with_capacity(capacity: usize) -> Self {
let poller = mio::Poll::new().expect("failed to create mio poll");
let waker = Arc::new(
mio::Waker::new(poller.registry(), mio::Token(usize::MAX))
.expect("failed to create waker"),
);
Self {
poller,
events: mio::Events::with_capacity(capacity),
indices: Slab::with_capacity(capacity),
waker,
}
}
#[inline(always)]
pub fn register_source<S: Source>(
&mut self,
mut source: S,
idx: NodeIndex,
interest: Interest,
) -> io::Result<IoSource<S>> {
let entry = self.indices.vacant_entry();
let token = mio::Token(entry.key());
self.poller
.registry()
.register(&mut source, token, interest)?;
entry.insert(idx);
Ok(IoSource::new(source, token))
}
#[inline(always)]
pub fn deregister_source<S: Source>(
&mut self,
mut source: IoSource<S>,
) -> io::Result<NodeIndex> {
self.poller.registry().deregister(&mut source.source)?;
Ok(self.indices.remove(source.token.0))
}
#[inline(always)]
pub fn reregister_source<S: Source>(
&mut self,
source: &mut IoSource<S>,
interest: Interest,
) -> io::Result<()> {
self.poller
.registry()
.reregister(&mut source.source, source.token, interest)
}
#[inline(always)]
pub(crate) fn waker(&self) -> Arc<mio::Waker> {
self.waker.clone()
}
#[inline(always)]
pub(crate) fn poll(
&mut self,
graph: &mut Graph,
scheduler: &mut Scheduler,
timeout: Option<std::time::Duration>,
epoch: usize,
) -> io::Result<()> {
self.events.clear();
self.poller.poll(&mut self.events, timeout)?;
self.events.iter().for_each(|event| {
if event.token().0 != usize::MAX {
let node_index = self.indices[event.token().0];
if let Some(depth) = graph.can_schedule(node_index, epoch) {
let _ = scheduler.schedule(node_index, depth).unwrap();
}
}
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Control;
use crate::runtime::graph::NodeContext;
use mio::net::TcpListener;
use std::net::SocketAddr;
fn create_test_listener() -> io::Result<TcpListener> {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
TcpListener::bind(addr)
}
fn create_unique_source(port_hint: u16) -> io::Result<TcpListener> {
match TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], port_hint))) {
Ok(listener) => Ok(listener),
Err(_) => TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))),
}
}
#[test]
fn test_io_driver_creation() {
let driver = IoDriver::with_capacity(64);
assert_eq!(driver.indices.capacity(), 64);
}
#[test]
fn test_register_source() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let node_idx = NodeIndex::from(100);
let listener = create_test_listener()?;
let io_source = driver.register_source(listener, node_idx, Interest::READABLE)?;
assert_eq!(driver.indices[io_source.token.0], node_idx);
assert!(io_source.source().local_addr().is_ok());
Ok(())
}
#[test]
fn test_deregister_source() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let node_idx = NodeIndex::from(100);
let listener = create_test_listener()?;
let io_source = driver.register_source(listener, node_idx, Interest::READABLE)?;
let token_key = io_source.token.0;
assert!(driver.indices.contains(token_key));
let returned_idx = driver.deregister_source(io_source)?;
assert_eq!(returned_idx, node_idx);
assert!(!driver.indices.contains(token_key));
Ok(())
}
#[test]
fn test_reregister_source() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let node_idx = NodeIndex::from(100);
let listener = create_test_listener()?;
let mut io_source = driver.register_source(listener, node_idx, Interest::READABLE)?;
driver.reregister_source(&mut io_source, Interest::WRITABLE)?;
assert_eq!(driver.indices[io_source.token.0], node_idx);
Ok(())
}
#[test]
fn test_multiple_registrations() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let listener1 = create_test_listener()?;
let listener2 = create_test_listener()?;
let source1 = driver.register_source(listener1, NodeIndex::from(1), Interest::READABLE)?;
let source2 = driver.register_source(listener2, NodeIndex::from(2), Interest::READABLE)?;
assert_ne!(source1.token.0, source2.token.0);
assert_eq!(driver.indices[source1.token.0], NodeIndex::from(1));
assert_eq!(driver.indices[source2.token.0], NodeIndex::from(2));
Ok(())
}
#[test]
fn test_notifier_notify() -> io::Result<()> {
let driver = IoDriver::with_capacity(64);
let notifier = driver.waker();
notifier.wake()?;
Ok(())
}
#[test]
fn test_io_source_access() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let node_idx = NodeIndex::from(100);
let listener = create_test_listener()?;
let original_addr = listener.local_addr()?;
let mut io_source = driver.register_source(listener, node_idx, Interest::READABLE)?;
assert_eq!(io_source.source().local_addr()?, original_addr);
let _source_mut = io_source.source_mut();
Ok(())
}
#[test]
fn test_poll_schedules_ready_nodes() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let mut graph = Graph::new();
let mut scheduler = Scheduler::new();
scheduler.resize(5);
let node1_ctx = NodeContext::new(
Box::new(|_| Control::Unchanged), 1, );
let node2_ctx = NodeContext::new(
Box::new(|_| Control::Unchanged),
3, );
let node1_idx = graph.add_node(node1_ctx);
let node2_idx = graph.add_node(node2_ctx);
let listener1 = create_unique_source(8010)?;
let listener2 = create_unique_source(8011)?;
let _source1 = driver.register_source(listener1, node1_idx, Interest::READABLE)?;
let _source2 = driver.register_source(listener2, node2_idx, Interest::READABLE)?;
driver.poll(
&mut graph,
&mut scheduler,
Some(std::time::Duration::ZERO),
1,
)?;
assert!(scheduler.pop().is_none());
Ok(())
}
#[test]
fn test_poll_respects_epoch_deduplication() -> io::Result<()> {
let _driver = IoDriver::with_capacity(64);
let mut graph = Graph::new();
let mut scheduler = Scheduler::new();
scheduler.resize(5);
let node_ctx = NodeContext::new(Box::new(|_| Control::Unchanged), 2);
let node_idx = graph.add_node(node_ctx);
assert_eq!(graph.can_schedule(node_idx, 1), Some(2)); assert_eq!(graph.can_schedule(node_idx, 1), None);
Ok(())
}
#[test]
fn test_slab_token_consistency() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let listener1 = create_unique_source(8012)?;
let listener2 = create_unique_source(8013)?;
let io_source1 =
driver.register_source(listener1, NodeIndex::from(2), Interest::READABLE)?;
let io_source2 =
driver.register_source(listener2, NodeIndex::from(3), Interest::READABLE)?;
assert_eq!(driver.indices[io_source1.token.0], NodeIndex::from(2));
assert_eq!(driver.indices[io_source2.token.0], NodeIndex::from(3));
let returned_idx1 = driver.deregister_source(io_source1)?;
let returned_idx2 = driver.deregister_source(io_source2)?;
assert_eq!(returned_idx1, NodeIndex::from(2));
assert_eq!(returned_idx2, NodeIndex::from(3));
Ok(())
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::Control;
use crate::runtime::graph::NodeContext;
use std::io::Write;
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::time::Duration;
#[test]
fn test_tcp_stream_polling_integration() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let mut graph = Graph::new();
let mut scheduler = Scheduler::new();
scheduler.resize(5);
let listener = TcpListener::bind("127.0.0.1:0")?;
let listener = mio::net::TcpListener::from_std(listener);
let listener_addr = listener.local_addr()?;
let node_ctx = NodeContext::new(Box::new(|_| Control::Unchanged), 1);
let node_idx = graph.add_node(node_ctx);
let _io_source = driver.register_source(listener, node_idx, Interest::READABLE)?;
driver.poll(
&mut graph,
&mut scheduler,
Some(Duration::from_millis(1)),
1,
)?;
assert!(
scheduler.pop().is_none(),
"No events should be ready initially"
);
thread::spawn(move || {
thread::sleep(Duration::from_millis(10)); if let Ok(mut stream) = TcpStream::connect(listener_addr) {
let _ = stream.write_all(b"test data");
let _ = stream.flush();
thread::sleep(Duration::from_millis(50)); }
});
thread::sleep(Duration::from_millis(20));
driver.poll(
&mut graph,
&mut scheduler,
Some(Duration::from_millis(100)),
2,
)?;
let scheduled_node = scheduler.pop();
assert!(
scheduled_node.is_some(),
"Node should be scheduled after TCP connection"
);
assert_eq!(scheduled_node.unwrap(), node_idx);
assert!(scheduler.pop().is_none());
Ok(())
}
#[test]
fn test_multiple_tcp_events() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let mut graph = Graph::new();
let mut scheduler = Scheduler::new();
scheduler.resize(5);
let listener1 = TcpListener::bind("127.0.0.1:0")?;
let listener1 = mio::net::TcpListener::from_std(listener1);
let listener2 = TcpListener::bind("127.0.0.1:0")?;
let listener2 = mio::net::TcpListener::from_std(listener2);
let addr1 = listener1.local_addr()?;
let addr2 = listener2.local_addr()?;
let node1_ctx = NodeContext::new(Box::new(|_| Control::Unchanged), 1);
let node2_ctx = NodeContext::new(Box::new(|_| Control::Unchanged), 2);
let node1_idx = graph.add_node(node1_ctx);
let node2_idx = graph.add_node(node2_ctx);
let _source1 = driver.register_source(listener1, node1_idx, Interest::READABLE)?;
let _source2 = driver.register_source(listener2, node2_idx, Interest::READABLE)?;
thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
let _ = TcpStream::connect(addr1);
let _ = TcpStream::connect(addr2);
thread::sleep(Duration::from_millis(50));
});
thread::sleep(Duration::from_millis(30));
driver.poll(
&mut graph,
&mut scheduler,
Some(Duration::from_millis(100)),
1,
)?;
let mut scheduled = vec![];
while let Some(node) = scheduler.pop() {
scheduled.push(node);
}
assert_eq!(scheduled.len(), 2);
assert!(scheduled.contains(&node1_idx));
assert!(scheduled.contains(&node2_idx));
Ok(())
}
#[test]
fn test_notifier_polling() -> io::Result<()> {
let mut driver = IoDriver::with_capacity(64);
let mut graph = Graph::new();
let mut scheduler = Scheduler::new();
scheduler.resize(5);
let _node_ctx = NodeContext::new(Box::new(|_| Control::Unchanged), 1);
let notifier = driver.waker();
driver.poll(&mut graph, &mut scheduler, Some(Duration::ZERO), 1)?;
assert!(scheduler.pop().is_none());
notifier.wake()?;
driver.poll(
&mut graph,
&mut scheduler,
Some(Duration::from_millis(10)),
2,
)?;
let scheduled_node = scheduler.pop();
assert!(scheduled_node.is_none());
Ok(())
}
}