use super::{
addr::AddressConfig,
builder::MessageBuilder,
connection::Connection,
error::{Error, Result},
fdb::FdbEntryBuilder,
link::LinkConfig,
message::{
MessageIter, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, NlMsgError, NlMsgType,
},
neigh::NeighborConfig,
protocol::Route,
route::RouteConfig,
tc::QdiscConfig,
types::{
link::IfInfoMsg,
tc::{TcMsg, TcaAttr, tc_handle},
},
};
const MAX_BATCH_SIZE: usize = 200 * 1024;
pub struct Batch<'a> {
conn: &'a Connection<Route>,
ops: Vec<BatchOp>,
}
struct BatchOp {
seq: u32,
msg: Vec<u8>,
}
impl<'a> Batch<'a> {
pub(crate) fn new(conn: &'a Connection<Route>) -> Self {
Self {
conn,
ops: Vec::new(),
}
}
pub fn add_route<R: RouteConfig>(mut self, config: R) -> Self {
let mut builder = MessageBuilder::new(
NlMsgType::RTM_NEWROUTE,
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
config.write_add(&mut builder, &Default::default());
self.push(builder);
self
}
pub fn del_route<R: RouteConfig>(mut self, config: R) -> Self {
let mut builder = MessageBuilder::new(NlMsgType::RTM_DELROUTE, NLM_F_REQUEST | NLM_F_ACK);
config.write_delete(&mut builder);
self.push(builder);
self
}
pub fn add_link<L: LinkConfig>(mut self, config: L) -> Self {
let mut builder = MessageBuilder::new(
NlMsgType::RTM_NEWLINK,
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
let ifinfo = IfInfoMsg::new();
builder.append(&ifinfo);
config.write_to(&mut builder, None);
self.push(builder);
self
}
pub fn del_link_by_index(mut self, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(NlMsgType::RTM_DELLINK, NLM_F_REQUEST | NLM_F_ACK);
let mut ifinfo = IfInfoMsg::new();
ifinfo.ifi_index = ifindex as i32;
builder.append(&ifinfo);
self.push(builder);
self
}
pub fn add_address<A: AddressConfig>(mut self, config: A, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(
NlMsgType::RTM_NEWADDR,
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
if config.write_add(&mut builder, ifindex).is_ok() {
self.push(builder);
}
self
}
pub fn del_address<A: AddressConfig>(mut self, config: A, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(NlMsgType::RTM_DELADDR, NLM_F_REQUEST | NLM_F_ACK);
if config.write_delete(&mut builder, ifindex).is_ok() {
self.push(builder);
}
self
}
pub fn add_neighbor<N: NeighborConfig>(mut self, config: N, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(
NlMsgType::RTM_NEWNEIGH,
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
if config.write_add(&mut builder, ifindex).is_ok() {
self.push(builder);
}
self
}
pub fn del_neighbor<N: NeighborConfig>(mut self, config: N, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(NlMsgType::RTM_DELNEIGH, NLM_F_REQUEST | NLM_F_ACK);
if config.write_delete(&mut builder, ifindex).is_ok() {
self.push(builder);
}
self
}
pub fn add_fdb(
mut self,
entry: FdbEntryBuilder,
ifindex: u32,
master_idx: Option<u32>,
) -> Self {
let mut builder = MessageBuilder::new(
NlMsgType::RTM_NEWNEIGH,
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
entry.write_add(&mut builder, ifindex, master_idx);
self.push(builder);
self
}
pub fn del_fdb(mut self, entry: FdbEntryBuilder, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(NlMsgType::RTM_DELNEIGH, NLM_F_REQUEST | NLM_F_ACK);
entry.write_delete(&mut builder, ifindex);
self.push(builder);
self
}
pub fn add_qdisc(mut self, ifindex: u32, config: impl QdiscConfig) -> Self {
let mut builder = MessageBuilder::new(
NlMsgType::RTM_NEWQDISC,
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE,
);
let tcmsg = TcMsg::new()
.with_ifindex(ifindex as i32)
.with_parent(tc_handle::ROOT)
.with_handle(config.default_handle().unwrap_or(0));
builder.append(&tcmsg);
builder.append_attr_str(TcaAttr::Kind as u16, config.kind());
let options_token = builder.nest_start(TcaAttr::Options as u16);
if config.write_options(&mut builder).is_ok() {
builder.nest_end(options_token);
self.push(builder);
}
self
}
pub fn del_qdisc(mut self, ifindex: u32) -> Self {
let mut builder = MessageBuilder::new(NlMsgType::RTM_DELQDISC, NLM_F_REQUEST | NLM_F_ACK);
let tcmsg = TcMsg::new()
.with_ifindex(ifindex as i32)
.with_parent(tc_handle::ROOT);
builder.append(&tcmsg);
self.push(builder);
self
}
fn push(&mut self, mut builder: MessageBuilder) {
let seq = self.conn.socket().next_seq();
builder.set_seq(seq);
builder.set_pid(self.conn.socket().pid());
let msg = builder.finish();
self.ops.push(BatchOp { seq, msg });
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
#[tracing::instrument(level = "debug", skip_all, fields(ops = self.ops.len()))]
pub async fn execute(self) -> Result<BatchResults> {
if self.ops.is_empty() {
return Ok(BatchResults {
results: Vec::new(),
});
}
let mut all_results = Vec::with_capacity(self.ops.len());
let mut chunk_start = 0;
let mut chunk_size = 0;
for (i, op) in self.ops.iter().enumerate() {
if chunk_size + op.msg.len() > MAX_BATCH_SIZE && chunk_size > 0 {
let chunk_results = self.send_chunk(&self.ops[chunk_start..i]).await?;
all_results.extend(chunk_results);
chunk_start = i;
chunk_size = 0;
}
chunk_size += op.msg.len();
}
if chunk_start < self.ops.len() {
let chunk_results = self.send_chunk(&self.ops[chunk_start..]).await?;
all_results.extend(chunk_results);
}
Ok(BatchResults {
results: all_results,
})
}
pub async fn execute_all(self) -> Result<()> {
let results = self.execute().await?;
for result in &results.results {
if let Err(e) = result {
return Err(Error::InvalidMessage(format!(
"batch operation failed: {e}"
)));
}
}
Ok(())
}
async fn send_chunk(&self, ops: &[BatchOp]) -> Result<Vec<std::result::Result<(), Error>>> {
let total_size: usize = ops.iter().map(|o| o.msg.len()).sum();
let mut buf = Vec::with_capacity(total_size);
for op in ops {
buf.extend_from_slice(&op.msg);
}
self.conn.socket().send(&buf).await?;
let mut results: Vec<Option<std::result::Result<(), Error>>> =
(0..ops.len()).map(|_| None).collect();
let mut remaining = ops.len();
while remaining > 0 {
let response = self.conn.socket().recv_msg().await?;
for result in MessageIter::new(&response) {
let (header, payload) = result?;
if let Some(idx) = ops.iter().position(|op| op.seq == header.nlmsg_seq) {
if results[idx].is_some() {
continue; }
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if err.is_ack() {
results[idx] = Some(Ok(()));
} else {
results[idx] = Some(Err(Error::from_errno(err.error)));
}
remaining -= 1;
}
}
}
}
Ok(results.into_iter().map(|r| r.unwrap_or(Ok(()))).collect())
}
}
pub struct BatchResults {
results: Vec<std::result::Result<(), Error>>,
}
impl BatchResults {
pub fn iter(&self) -> impl Iterator<Item = &std::result::Result<(), Error>> {
self.results.iter()
}
pub fn errors(&self) -> impl Iterator<Item = (usize, &Error)> {
self.results
.iter()
.enumerate()
.filter_map(|(i, r)| r.as_ref().err().map(|e| (i, e)))
}
pub fn success_count(&self) -> usize {
self.results.iter().filter(|r| r.is_ok()).count()
}
pub fn error_count(&self) -> usize {
self.results.iter().filter(|r| r.is_err()).count()
}
pub fn all_ok(&self) -> bool {
self.results.iter().all(|r| r.is_ok())
}
pub fn len(&self) -> usize {
self.results.len()
}
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_results(results: Vec<std::result::Result<(), Error>>) -> BatchResults {
BatchResults { results }
}
#[test]
fn test_empty_results() {
let r = make_results(vec![]);
assert!(r.is_empty());
assert!(r.all_ok());
assert_eq!(r.len(), 0);
assert_eq!(r.success_count(), 0);
assert_eq!(r.error_count(), 0);
assert_eq!(r.errors().count(), 0);
}
#[test]
fn test_all_success() {
let r = make_results(vec![Ok(()), Ok(()), Ok(())]);
assert!(r.all_ok());
assert_eq!(r.len(), 3);
assert_eq!(r.success_count(), 3);
assert_eq!(r.error_count(), 0);
assert_eq!(r.errors().count(), 0);
}
#[test]
fn test_mixed_results() {
let r = make_results(vec![
Ok(()),
Err(Error::from_errno(-2)), Ok(()),
Err(Error::from_errno(-1)), ]);
assert!(!r.all_ok());
assert_eq!(r.len(), 4);
assert_eq!(r.success_count(), 2);
assert_eq!(r.error_count(), 2);
let errors: Vec<_> = r.errors().collect();
assert_eq!(errors.len(), 2);
assert_eq!(errors[0].0, 1); assert!(errors[0].1.is_not_found());
assert_eq!(errors[1].0, 3); assert!(errors[1].1.is_permission_denied());
}
#[test]
fn test_all_errors() {
let r = make_results(vec![
Err(Error::from_errno(-17)), Err(Error::from_errno(-16)), ]);
assert!(!r.all_ok());
assert_eq!(r.success_count(), 0);
assert_eq!(r.error_count(), 2);
}
#[test]
fn test_iter() {
let r = make_results(vec![Ok(()), Err(Error::from_errno(-1))]);
let items: Vec<_> = r.iter().collect();
assert_eq!(items.len(), 2);
assert!(items[0].is_ok());
assert!(items[1].is_err());
}
}