use std::time::Duration;
use super::diff::NftablesDiff;
use super::super::connection::Transaction;
use super::super::types::Chain;
use crate::netlink::{connection::Connection, error::Result, protocol::Nftables};
impl NftablesDiff {
pub async fn apply(&self, conn: &Connection<Nftables>) -> Result<usize> {
let total = self.change_count();
if total == 0 {
return Ok(0);
}
let mut tx: Transaction = conn.transaction();
for (table, family, chain, handle) in &self.rules_to_delete {
tx = tx.del_rule(table, chain, *family, handle.0);
}
for (table, family, name) in &self.chains_to_delete {
tx = tx.del_chain(table, name, *family);
}
for (family, table, name) in &self.flowtables_to_delete {
tx = tx.del_flowtable(*family, table, name);
}
for (family, name) in &self.tables_to_delete {
tx = tx.del_table(name, *family);
}
for table in &self.tables_to_add {
if table.flags() != 0 {
tx = tx.add_table_with_flags(table.name(), table.family(), table.flags());
} else {
tx = tx.add_table(table.name(), table.family());
}
}
for (table_name, family, declared) in &self.chains_to_add {
let mut chain = Chain::new(table_name, declared.name()).family(*family);
if let Some(h) = declared.hook() {
chain = chain.hook(h);
}
if let Some(p) = declared.priority() {
chain = chain.priority(p);
}
if let Some(pol) = declared.policy() {
chain = chain.policy(pol);
}
if let Some(ct) = declared.chain_type() {
chain = chain.chain_type(ct);
}
if let Some(dev) = declared.device() {
chain = chain.device(dev);
}
tx = tx.add_chain(chain);
}
for rule in &self.rules_to_add {
let mut body = rule.body.clone();
if let Some(key) = rule.handle_key()
&& body.comment.is_none()
{
body.comment = Some(key.to_string());
}
tx = tx.add_rule(body);
}
for (_table, _family, _chain, handle, declared) in &self.rules_to_replace {
let mut body = declared.body.clone();
if let Some(key) = declared.handle_key()
&& body.comment.is_none()
{
body.comment = Some(key.to_string());
}
tx = tx.replace_rule(body, handle.0);
}
for ft in &self.flowtables_to_add {
let mut runtime =
super::super::Flowtable::new(ft.family(), ft.table(), ft.name())
.priority(ft.priority());
for dev in ft.devs() {
runtime = runtime.device(dev.clone());
}
if ft.flags() & super::super::NFT_FLOWTABLE_HW_OFFLOAD != 0 {
runtime = runtime.hw_offload(true);
}
if ft.flags() & super::super::NFT_FLOWTABLE_COUNTER != 0 {
runtime = runtime.counter(true);
}
tx = tx.add_flowtable(&runtime);
}
tx.commit(conn).await?;
Ok(total)
}
pub async fn apply_reconcile(
&self,
conn: &Connection<Nftables>,
opts: ReconcileOptions,
) -> Result<ReconcileReport> {
let mut attempt: usize = 0;
loop {
match self.apply(conn).await {
Ok(_) => {
return Ok(ReconcileReport {
attempts: attempt + 1,
change_count: self.change_count(),
});
}
Err(e) if (e.is_busy() || e.is_try_again()) && attempt < opts.max_retries => {
let backoff = opts.backoff.saturating_mul(1u32 << attempt.min(10));
tokio::time::sleep(backoff).await;
attempt += 1;
continue;
}
Err(e) => return Err(e),
}
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ReconcileOptions {
pub max_retries: usize,
pub backoff: Duration,
}
impl Default for ReconcileOptions {
fn default() -> Self {
Self {
max_retries: 3,
backoff: Duration::from_millis(100),
}
}
}
impl ReconcileOptions {
#[must_use]
pub fn max_retries(mut self, retries: usize) -> Self {
self.max_retries = retries;
self
}
#[must_use]
pub fn backoff(mut self, backoff: Duration) -> Self {
self.backoff = backoff;
self
}
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct ReconcileReport {
pub attempts: usize,
pub change_count: usize,
}
#[cfg(test)]
mod reconcile_tests {
use super::*;
#[test]
fn default_reconcile_options_match_plan_spec() {
let opts = ReconcileOptions::default();
assert_eq!(opts.max_retries, 3);
assert_eq!(opts.backoff, Duration::from_millis(100));
}
#[test]
fn reconcile_report_default_is_zero_attempts() {
let r = ReconcileReport::default();
assert_eq!(r.attempts, 0);
assert_eq!(r.change_count, 0);
}
#[test]
fn empty_diff_apply_via_reconcile_returns_one_attempt() {
let d = NftablesDiff::default();
assert!(d.is_empty());
let _ = d;
}
}