use std::{
collections::VecDeque,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tokio_stream::Stream;
use super::{
builder::MessageBuilder,
connection::Connection,
error::Result,
message::{MessageIter, NLM_F_DUMP, NLM_F_REQUEST, NlMsgError},
parse::FromNetlink,
protocol::ProtocolState,
};
#[non_exhaustive]
pub struct DumpStream<'a, P: ProtocolState, T: FromNetlink + Unpin> {
conn: &'a Connection<P>,
expected_seq: u32,
pending: VecDeque<Result<T>>,
done: bool,
errored: bool,
_marker: PhantomData<fn() -> T>,
}
impl<'a, P: ProtocolState, T: FromNetlink + Unpin> DumpStream<'a, P, T> {
pub(crate) async fn send(
conn: &'a Connection<P>,
msg_type: u16,
) -> Result<Self> {
let mut header_buf = Vec::new();
T::write_dump_header(&mut header_buf);
Self::send_with_body_bytes(conn, msg_type, &header_buf).await
}
pub(crate) async fn send_with_body(
conn: &'a Connection<P>,
msg_type: u16,
body: &[u8],
) -> Result<Self> {
Self::send_with_body_bytes(conn, msg_type, body).await
}
async fn send_with_body_bytes(
conn: &'a Connection<P>,
msg_type: u16,
body: &[u8],
) -> Result<Self> {
let mut builder = MessageBuilder::new(msg_type, NLM_F_REQUEST | NLM_F_DUMP);
if !body.is_empty() {
builder.append_bytes(body);
}
let socket = conn.socket();
let seq = socket.next_seq();
builder.set_seq(seq);
builder.set_pid(socket.pid());
let msg = builder.finish();
socket.send(&msg).await?;
Ok(Self {
conn,
expected_seq: seq,
pending: VecDeque::new(),
done: false,
errored: false,
_marker: PhantomData,
})
}
fn drain_into_pending(&mut self, data: &[u8]) {
for result in MessageIter::new(data) {
let (header, payload) = match result {
Ok(p) => p,
Err(e) => {
self.pending.push_back(Err(e));
self.errored = true;
return;
}
};
if header.nlmsg_seq != self.expected_seq {
continue;
}
if header.is_error() {
match NlMsgError::from_bytes(payload) {
Ok(err) => {
if err.is_ack() {
continue;
}
self.pending.push_back(Err(err.into_error(payload)));
self.errored = true;
return;
}
Err(e) => {
self.pending.push_back(Err(e));
self.errored = true;
return;
}
}
}
if header.is_done() {
self.done = true;
return;
}
let _ = header; match T::from_bytes(payload) {
Ok(item) => self.pending.push_back(Ok(item)),
Err(e) => self.pending.push_back(Err(e)),
}
}
}
}
impl<P: ProtocolState, T: FromNetlink + Unpin> Stream for DumpStream<'_, P, T> {
type Item = Result<T>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(item) = this.pending.pop_front() {
return Poll::Ready(Some(item));
}
if this.done || this.errored {
return Poll::Ready(None);
}
loop {
#[cfg(feature = "syscall_batch")]
{
match this
.conn
.socket()
.poll_recv_batch(cx, crate::netlink::socket::NL_BATCH_SIZE)
{
Poll::Ready(Ok(frames)) => {
for data in &frames {
this.drain_into_pending(data);
}
if let Some(item) = this.pending.pop_front() {
return Poll::Ready(Some(item));
}
if this.done || this.errored {
return Poll::Ready(None);
}
continue;
}
Poll::Ready(Err(e)) => {
this.errored = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
#[cfg(not(feature = "syscall_batch"))]
{
match this.conn.socket().poll_recv(cx) {
Poll::Ready(Ok(data)) => {
this.drain_into_pending(&data);
if let Some(item) = this.pending.pop_front() {
return Poll::Ready(Some(item));
}
if this.done || this.errored {
return Poll::Ready(None);
}
continue;
}
Poll::Ready(Err(e)) => {
this.errored = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
}
}
}
impl<P: ProtocolState, T: FromNetlink + Unpin> Unpin for DumpStream<'_, P, T> {}
impl<P: ProtocolState> Connection<P> {
pub async fn dump_stream<T>(&self, msg_type: u16) -> Result<DumpStream<'_, P, T>>
where
T: FromNetlink + Unpin,
{
DumpStream::send(self, msg_type).await
}
pub async fn dump_stream_with_body<T>(
&self,
msg_type: u16,
body: &[u8],
) -> Result<DumpStream<'_, P, T>>
where
T: FromNetlink + Unpin,
{
DumpStream::send_with_body(self, msg_type, body).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::netlink::message::NLMSG_HDRLEN;
#[derive(Debug, PartialEq)]
struct Dummy;
impl FromNetlink for Dummy {
fn parse(_input: &mut &[u8]) -> super::super::parse::PResult<Self> {
Ok(Dummy)
}
}
fn make_stream<'a>(conn: &'a Connection<crate::netlink::Route>) -> DumpStream<'a, crate::netlink::Route, Dummy> {
DumpStream {
conn,
expected_seq: 1,
pending: VecDeque::new(),
done: false,
errored: false,
_marker: PhantomData,
}
}
fn synth_done_frame(seq: u32) -> Vec<u8> {
let mut buf = vec![0u8; NLMSG_HDRLEN];
buf[0..4].copy_from_slice(&(NLMSG_HDRLEN as u32).to_ne_bytes());
buf[4..6].copy_from_slice(&3u16.to_ne_bytes()); buf[6..8].copy_from_slice(&0u16.to_ne_bytes());
buf[8..12].copy_from_slice(&seq.to_ne_bytes());
buf[12..16].copy_from_slice(&0u32.to_ne_bytes());
buf
}
#[tokio::test]
async fn drain_recognizes_nlmsg_done() {
let conn = Connection::<crate::netlink::Route>::new().unwrap();
let mut stream = make_stream(&conn);
let done = synth_done_frame(1);
stream.drain_into_pending(&done);
assert!(stream.done);
assert!(!stream.errored);
assert!(stream.pending.is_empty());
}
#[tokio::test]
async fn drain_skips_mismatched_seq() {
let conn = Connection::<crate::netlink::Route>::new().unwrap();
let mut stream = make_stream(&conn);
let other = synth_done_frame(42);
stream.drain_into_pending(&other);
assert!(!stream.done);
assert!(!stream.errored);
assert!(stream.pending.is_empty());
}
}