use std::pin::Pin;
use tokio_stream::Stream;
use super::events::NftablesEvent;
use super::types::Family;
use crate::netlink::protocol::Nftables;
use crate::netlink::resync::{ConnectionFactory, ResyncStream, events_with_resync};
use crate::netlink::stream::{EventSubscription, OwnedEventStream};
use crate::{Connection, Result};
pub async fn nftables_snapshot(conn: &Connection<Nftables>) -> Result<Vec<NftablesEvent>> {
let mut out = Vec::new();
let tables = conn.list_tables().await?;
for t in &tables {
out.push(NftablesEvent::NewTable(t.clone()));
}
for t in &tables {
for c in conn.list_chains_in(&t.name, t.family).await? {
out.push(NftablesEvent::NewChain(c));
}
for f in conn.list_flowtables_in(&t.name, t.family).await? {
out.push(NftablesEvent::NewFlowtable(f));
}
for s in conn.list_sets_in(&t.name, t.family).await? {
out.push(NftablesEvent::NewSet(s));
}
let _: Family = t.family;
for r in conn.list_rules(&t.name, t.family).await? {
out.push(NftablesEvent::NewRule(r));
}
}
Ok(out)
}
type SnapshotFuture =
Pin<Box<dyn Future<Output = Result<Vec<NftablesEvent>>> + Send + 'static>>;
type SnapshotFn = Box<dyn FnMut() -> SnapshotFuture + Send + Unpin + 'static>;
fn make_snapshot_fn(factory: ConnectionFactory<Nftables>) -> SnapshotFn {
Box::new(move || {
let factory = factory.clone();
Box::pin(async move {
let conn = (factory)().await?;
nftables_snapshot(&conn).await
}) as SnapshotFuture
})
}
pub type OwnedResyncStream =
ResyncStream<'static, OwnedEventStream<Nftables>, NftablesEvent, SnapshotFn>;
pub type BorrowedResyncStream<'a> =
ResyncStream<'static, EventSubscription<'a, Nftables>, NftablesEvent, SnapshotFn>;
impl Connection<Nftables> {
#[tracing::instrument(level = "info", skip_all)]
pub fn into_events_with_resync(
mut self,
factory: ConnectionFactory<Nftables>,
) -> Result<OwnedResyncStream> {
self.subscribe_all()?;
let stream = self.into_events();
Ok(events_with_resync(stream, make_snapshot_fn(factory)))
}
#[tracing::instrument(level = "info", skip_all)]
pub fn subscribe_all_with_resync(
&mut self,
factory: ConnectionFactory<Nftables>,
) -> Result<BorrowedResyncStream<'_>> {
self.subscribe_all()?;
let stream = self.events();
Ok(events_with_resync(stream, make_snapshot_fn(factory)))
}
}
#[allow(dead_code)]
fn _streams_are_streams() {
fn assert_stream<S: Stream + ?Sized>() {}
assert_stream::<OwnedResyncStream>();
assert_stream::<BorrowedResyncStream<'static>>();
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::netlink::resync::ConnectionFuture;
#[test]
fn factory_is_clone_and_send() {
let factory: ConnectionFactory<Nftables> = Arc::new(|| {
Box::pin(async { Connection::<Nftables>::new() })
as Pin<Box<dyn Future<Output = Result<Connection<Nftables>>> + Send + 'static>>
});
let _f2 = factory.clone();
fn assert_send_sync<T: Send + Sync>() {}
fn assert_send<T: Send>() {}
assert_send_sync::<ConnectionFactory<Nftables>>();
assert_send::<ConnectionFuture<Nftables>>();
}
}