use std::{
collections::VecDeque,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tokio_stream::Stream;
use crate::macros::{GenlFamily, GenlMessage};
use crate::netlink::{
connection::Connection,
genl::{GENL_HDRLEN, GenlMsgHdr},
message::{MessageIter, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST, NlMsgError},
MessageBuilder, ProtocolState,
};
use crate::{Error, Result};
impl<F> Connection<F>
where
F: ProtocolState + GenlFamily,
{
pub fn subscribe_group(&self, name: &str) -> Result<()> {
let id = self.state().mcast_group(name).ok_or_else(|| {
crate::Error::FamilyNotFound {
name: ::std::format!("{}::{}", F::NAME, name),
}
})?;
self.socket().add_membership(id)?;
Ok(())
}
pub async fn send_typed<M, R>(&self, request: M) -> Result<R>
where
M: GenlMessage,
R: GenlMessage + Default,
{
let builder =
build_genl_request::<F, M>(self, &request, NLM_F_REQUEST | NLM_F_ACK)?;
let response = self.send_request(builder).await?;
parse_first_genl_reply::<R>(&response)
}
pub async fn dump_typed_stream<M, R>(
&self,
request: M,
) -> Result<GenlTypedDumpStream<'_, F, R>>
where
M: GenlMessage,
R: GenlMessage + Default + Unpin,
{
let mut builder =
build_genl_request::<F, M>(self, &request, NLM_F_REQUEST | NLM_F_DUMP)?;
let socket = self.socket();
let seq = socket.next_seq();
builder.set_seq(seq);
builder.set_pid(socket.pid());
let msg = builder.finish();
socket.send(&msg).await?;
Ok(GenlTypedDumpStream::new(self, seq))
}
}
fn build_genl_request<F, M>(
conn: &Connection<F>,
request: &M,
flags: u16,
) -> Result<MessageBuilder>
where
F: ProtocolState + GenlFamily,
M: GenlMessage,
{
let family_id = conn.state().family_id();
let mut builder = MessageBuilder::new(family_id, flags);
let genl_hdr = GenlMsgHdr::new(M::CMD, F::VERSION);
builder.append(&genl_hdr);
request.to_bytes(&mut builder)?;
Ok(builder)
}
fn parse_first_genl_reply<R>(response: &[u8]) -> Result<R>
where
R: GenlMessage + Default,
{
for result in MessageIter::new(response) {
let (header, payload) = result?;
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(err.into_error(payload));
}
continue;
}
if header.is_done() {
return Ok(R::default());
}
if payload.len() < GENL_HDRLEN {
return Err(Error::InvalidMessage(
"GENL response payload too short for header".into(),
));
}
let attrs = &payload[GENL_HDRLEN..];
return R::from_bytes(attrs);
}
Ok(R::default())
}
#[non_exhaustive]
pub struct GenlTypedDumpStream<'a, F, R>
where
F: ProtocolState + GenlFamily,
R: GenlMessage + Default + Unpin,
{
conn: &'a Connection<F>,
expected_seq: u32,
pending: VecDeque<Result<R>>,
done: bool,
errored: bool,
_marker: PhantomData<fn() -> R>,
}
impl<'a, F, R> GenlTypedDumpStream<'a, F, R>
where
F: ProtocolState + GenlFamily,
R: GenlMessage + Default + Unpin,
{
fn new(conn: &'a Connection<F>, seq: u32) -> Self {
Self {
conn,
expected_seq: seq,
pending: VecDeque::new(),
done: false,
errored: false,
_marker: PhantomData,
}
}
fn drain_into_pending(&mut self, data: &[u8]) {
for result in MessageIter::new(data) {
let (header, payload) = match result {
Ok(p) => p,
Err(e) => {
self.pending.push_back(Err(e));
self.errored = true;
return;
}
};
if header.nlmsg_seq != self.expected_seq {
continue;
}
if header.is_error() {
match NlMsgError::from_bytes(payload) {
Ok(err) => {
if err.is_ack() {
continue;
}
self.pending.push_back(Err(err.into_error(payload)));
self.errored = true;
return;
}
Err(e) => {
self.pending.push_back(Err(e));
self.errored = true;
return;
}
}
}
if header.is_done() {
self.done = true;
return;
}
if payload.len() < GENL_HDRLEN {
self.pending.push_back(Err(Error::InvalidMessage(
"GENL dump frame too short for header".into(),
)));
continue;
}
let attrs = &payload[GENL_HDRLEN..];
self.pending.push_back(R::from_bytes(attrs));
}
}
}
impl<F, R> Stream for GenlTypedDumpStream<'_, F, R>
where
F: ProtocolState + GenlFamily,
R: GenlMessage + Default + Unpin,
{
type Item = Result<R>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(item) = this.pending.pop_front() {
return Poll::Ready(Some(item));
}
if this.done || this.errored {
return Poll::Ready(None);
}
loop {
#[cfg(feature = "syscall_batch")]
{
match this
.conn
.socket()
.poll_recv_batch(cx, crate::netlink::NL_BATCH_SIZE)
{
Poll::Ready(Ok(frames)) => {
for data in &frames {
this.drain_into_pending(data);
}
if let Some(item) = this.pending.pop_front() {
return Poll::Ready(Some(item));
}
if this.done || this.errored {
return Poll::Ready(None);
}
continue;
}
Poll::Ready(Err(e)) => {
this.errored = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
#[cfg(not(feature = "syscall_batch"))]
{
match this.conn.socket().poll_recv(cx) {
Poll::Ready(Ok(data)) => {
this.drain_into_pending(&data);
if let Some(item) = this.pending.pop_front() {
return Poll::Ready(Some(item));
}
if this.done || this.errored {
return Poll::Ready(None);
}
continue;
}
Poll::Ready(Err(e)) => {
this.errored = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
}
}
}
impl<F, R> Unpin for GenlTypedDumpStream<'_, F, R>
where
F: ProtocolState + GenlFamily,
R: GenlMessage + Default + Unpin,
{
}
#[cfg(test)]
mod tests {
use super::*;
use crate::macros::__rt;
use crate::netlink::message::{NlMsgType, NLMSG_HDRLEN};
#[derive(Debug, Default, PartialEq, Eq)]
struct Reply {
id: u32,
label: String,
}
const ATTR_ID: u16 = 1;
const ATTR_LABEL: u16 = 2;
impl GenlMessage for Reply {
const CMD: u8 = 0;
fn to_bytes(&self, b: &mut MessageBuilder) -> Result<()> {
__rt::emit_u32_attr(b, ATTR_ID, self.id);
__rt::emit_str_attr(b, ATTR_LABEL, &self.label);
Ok(())
}
fn from_bytes(payload: &[u8]) -> Result<Self> {
let mut r = Reply::default();
for (ty, p) in __rt::attr_iter(payload) {
match ty {
ATTR_ID => r.id = __rt::parse_u32_attr(p)?,
ATTR_LABEL => r.label = __rt::parse_str_attr(p)?,
_ => {}
}
}
Ok(r)
}
}
fn synth_reply_frame(seq: u32, id: u32, label: &str) -> Vec<u8> {
let mut b = MessageBuilder::new(0x42, 0);
b.append(&GenlMsgHdr::new(Reply::CMD, 1));
__rt::emit_u32_attr(&mut b, ATTR_ID, id);
__rt::emit_str_attr(&mut b, ATTR_LABEL, label);
b.set_seq(seq);
b.finish()
}
fn synth_done_frame(seq: u32) -> Vec<u8> {
let mut buf = vec![0u8; NLMSG_HDRLEN];
buf[0..4].copy_from_slice(&(NLMSG_HDRLEN as u32).to_ne_bytes());
buf[4..6].copy_from_slice(&NlMsgType::DONE.to_ne_bytes());
buf[6..8].copy_from_slice(&0u16.to_ne_bytes());
buf[8..12].copy_from_slice(&seq.to_ne_bytes());
buf[12..16].copy_from_slice(&0u32.to_ne_bytes());
buf
}
fn synth_ack_frame(seq: u32) -> Vec<u8> {
let mut buf = vec![0u8; NLMSG_HDRLEN + 4 + NLMSG_HDRLEN];
let total = buf.len() as u32;
buf[0..4].copy_from_slice(&total.to_ne_bytes());
buf[4..6].copy_from_slice(&NlMsgType::ERROR.to_ne_bytes());
buf[6..8].copy_from_slice(&0u16.to_ne_bytes());
buf[8..12].copy_from_slice(&seq.to_ne_bytes());
buf[12..16].copy_from_slice(&0u32.to_ne_bytes());
buf[16..20].copy_from_slice(&0i32.to_ne_bytes());
buf
}
#[test]
fn parse_first_genl_reply_decodes_real_frame() {
let frame = synth_reply_frame(7, 0xCAFE_BABE, "hello");
let parsed: Reply = parse_first_genl_reply(&frame).expect("parse");
assert_eq!(parsed.id, 0xCAFE_BABE);
assert_eq!(parsed.label, "hello");
}
#[test]
fn parse_first_genl_reply_returns_default_on_nlmsg_done() {
let frame = synth_done_frame(1);
let parsed: Reply = parse_first_genl_reply(&frame).expect("parse");
assert_eq!(parsed, Reply::default());
}
#[test]
fn parse_first_genl_reply_skips_pure_ack_then_returns_default() {
let frame = synth_ack_frame(1);
let parsed: Reply = parse_first_genl_reply(&frame).expect("parse");
assert_eq!(parsed, Reply::default());
}
#[test]
fn parse_first_genl_reply_consumes_typed_reply_after_ack() {
let mut frame = synth_ack_frame(1);
frame.extend_from_slice(&synth_reply_frame(1, 42, "ok"));
let parsed: Reply = parse_first_genl_reply(&frame).expect("parse");
assert_eq!(parsed.id, 42);
assert_eq!(parsed.label, "ok");
}
#[test]
fn parse_first_genl_reply_propagates_kernel_error() {
let mut buf = vec![0u8; NLMSG_HDRLEN + 4 + NLMSG_HDRLEN];
let total = buf.len() as u32;
buf[0..4].copy_from_slice(&total.to_ne_bytes());
buf[4..6].copy_from_slice(&NlMsgType::ERROR.to_ne_bytes());
buf[6..8].copy_from_slice(&0u16.to_ne_bytes());
buf[8..12].copy_from_slice(&1u32.to_ne_bytes());
buf[12..16].copy_from_slice(&0u32.to_ne_bytes());
buf[16..20].copy_from_slice(&(-libc::EINVAL).to_ne_bytes());
let res: Result<Reply> = parse_first_genl_reply(&buf);
assert!(res.is_err(), "expected kernel error to propagate");
}
#[test]
fn build_genl_request_emits_correct_header_layout() {
let family_id: u16 = 0x55;
let version: u8 = 3;
let cmd: u8 = Reply::CMD;
let mut b = MessageBuilder::new(family_id, NLM_F_REQUEST | NLM_F_ACK);
b.append(&GenlMsgHdr::new(cmd, version));
Reply { id: 9, label: "x".into() }
.to_bytes(&mut b)
.expect("emit");
let bytes = b.finish();
assert_eq!(
u16::from_ne_bytes([bytes[4], bytes[5]]),
family_id
);
assert_eq!(
u16::from_ne_bytes([bytes[6], bytes[7]]),
NLM_F_REQUEST | NLM_F_ACK
);
assert_eq!(bytes[NLMSG_HDRLEN], cmd);
assert_eq!(bytes[NLMSG_HDRLEN + 1], version);
}
}