use std::{os::fd::RawFd, sync::Arc};
use msb_krun::backends::net::{NetBackend, ReadError, WriteError};
use crate::shared::SharedState;
const VIRTIO_NET_HDR_LEN: usize = 12;
pub struct SmoltcpBackend {
shared: Arc<SharedState>,
}
impl SmoltcpBackend {
pub fn new(shared: Arc<SharedState>) -> Self {
Self { shared }
}
}
impl NetBackend for SmoltcpBackend {
fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> {
let ethernet_frame = buf[hdr_len..].to_vec();
self.shared.add_tx_bytes(ethernet_frame.len());
self.shared
.tx_ring
.push(ethernet_frame)
.map_err(|_| WriteError::NothingWritten)?;
self.shared.tx_wake.wake();
Ok(())
}
fn read_frame(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
self.shared.rx_wake.drain();
let frame = self.shared.rx_ring.pop().ok_or(ReadError::NothingRead)?;
let total_len = VIRTIO_NET_HDR_LEN + frame.len();
if total_len > buf.len() {
tracing::debug!(
frame_len = frame.len(),
buf_len = buf.len(),
"dropping oversized frame from rx_ring"
);
return Err(ReadError::NothingRead);
}
buf[..VIRTIO_NET_HDR_LEN].fill(0);
buf[VIRTIO_NET_HDR_LEN..total_len].copy_from_slice(&frame);
Ok(total_len)
}
fn has_unfinished_write(&self) -> bool {
false
}
fn try_finish_write(&mut self, _hdr_len: usize, _buf: &[u8]) -> Result<(), WriteError> {
Ok(())
}
fn raw_socket_fd(&self) -> RawFd {
self.shared.rx_wake.as_raw_fd()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[test]
fn read_frame_drains_rx_wake_pipe() {
let shared = Arc::new(SharedState::new(4));
let mut backend = SmoltcpBackend::new(shared.clone());
let mut buf = [0u8; 64];
assert!(shared.push_rx_frame_and_wake(vec![0xaa, 0xbb]));
assert!(fd_is_readable(backend.raw_socket_fd()));
let n = backend.read_frame(&mut buf).expect("frame should be read");
assert_eq!(n, VIRTIO_NET_HDR_LEN + 2);
assert_eq!(&buf[VIRTIO_NET_HDR_LEN..n], &[0xaa, 0xbb]);
assert!(!fd_is_readable(backend.raw_socket_fd()));
assert!(shared.push_rx_frame_and_wake(vec![0xcc]));
assert!(fd_is_readable(backend.raw_socket_fd()));
}
fn fd_is_readable(fd: RawFd) -> bool {
let mut pfd = libc::pollfd {
fd,
events: libc::POLLIN,
revents: 0,
};
let ret = unsafe { libc::poll(&mut pfd, 1, 0) };
assert!(ret >= 0, "poll failed: {}", std::io::Error::last_os_error());
ret == 1 && pfd.revents & libc::POLLIN != 0
}
}