nftables_async/
helper.rs

1use std::{ffi::OsStr, future::Future};
2
3use nftables::{
4    helper::{NftablesError, DEFAULT_ARGS, DEFAULT_NFT},
5    schema::Nftables,
6};
7
8use crate::{driver::Driver, util::MapFuture};
9
10/// A helper that provides asynchronous "nft" invocations with nftables-crate types, backed by a [Driver].
11/// As such, the [Helper] trait is implemented for every [Driver].
12pub trait Helper {
13    /// Apply a [Nftables] ruleset.
14    fn apply_ruleset(
15        nftables: &Nftables,
16    ) -> impl Future<Output = Result<(), NftablesError>> + Send {
17        Self::apply_ruleset_with_args(nftables, DEFAULT_NFT, DEFAULT_ARGS)
18    }
19
20    /// Apply a [Nftables] ruleset with, optionally, a custom "nft" program and extra arguments.
21    fn apply_ruleset_with_args<
22        'a,
23        P: AsRef<OsStr> + Sync + ?Sized,
24        A: AsRef<OsStr> + Sync + ?Sized + 'a,
25        I: IntoIterator<Item = &'a A> + Send,
26    >(
27        nftables: &Nftables,
28        program: Option<&P>,
29        args: I,
30    ) -> impl Future<Output = Result<(), NftablesError>> + Send {
31        let payload =
32            serde_json::to_string(nftables).expect("Failed to serialize Nftables struct to JSON");
33        Self::apply_ruleset_raw(payload, program, args)
34    }
35
36    /// Apply a ruleset consisting of an untyped [String] payload with, optionally, a custom "nft" program and
37    /// extra arguments.
38    fn apply_ruleset_raw<
39        'a,
40        P: AsRef<OsStr> + Sync + ?Sized,
41        A: AsRef<OsStr> + Sync + ?Sized + 'a,
42        I: IntoIterator<Item = &'a A> + Send,
43    >(
44        payload: String,
45        program: Option<&P>,
46        args: I,
47    ) -> impl Future<Output = Result<(), NftablesError>> + Send;
48
49    /// Get the current [Nftables] ruleset.
50    fn get_current_ruleset() -> impl Future<Output = Result<Nftables<'static>, NftablesError>> + Send
51    {
52        Self::get_current_ruleset_with_args(DEFAULT_NFT, DEFAULT_ARGS)
53    }
54
55    /// Get the current [Nftables] ruleset with, optionally, a custom "nft" program and extra arguments.
56    fn get_current_ruleset_with_args<
57        'a,
58        P: AsRef<OsStr> + Sync + ?Sized,
59        A: AsRef<OsStr> + Sync + ?Sized + 'a,
60        I: IntoIterator<Item = &'a A> + Send,
61    >(
62        program: Option<&P>,
63        args: I,
64    ) -> impl Future<Output = Result<Nftables<'static>, NftablesError>> + Send {
65        MapFuture::new(
66            Self::get_current_ruleset_raw(program, args),
67            |result: Result<String, NftablesError>| {
68                result.and_then(|output| {
69                    serde_json::from_str(&output).map_err(NftablesError::NftInvalidJson)
70                })
71            },
72        )
73    }
74
75    /// Get the current ruleset as an untyped [String] payload with, optionally, a custom "nft" program and
76    /// extra arguments.
77    fn get_current_ruleset_raw<
78        'a,
79        P: AsRef<OsStr> + Sync + ?Sized,
80        A: AsRef<OsStr> + Sync + ?Sized + 'a,
81        I: IntoIterator<Item = &'a A> + Send,
82    >(
83        program: Option<&P>,
84        args: I,
85    ) -> impl Future<Output = Result<String, NftablesError>> + Send;
86}
87
88impl<D: Driver> Helper for D {
89    async fn apply_ruleset_raw<
90        'a,
91        P: AsRef<OsStr> + Sync + ?Sized,
92        A: AsRef<OsStr> + Sync + ?Sized + 'a,
93        I: IntoIterator<Item = &'a A> + Send,
94    >(
95        payload: String,
96        program: Option<&P>,
97        args: I,
98    ) -> Result<(), NftablesError> {
99        let program = program.map(|v| v.as_ref()).unwrap_or(OsStr::new("nft"));
100        let mut all_args = vec![OsStr::new("-j"), OsStr::new("-f"), OsStr::new("-")];
101
102        all_args.extend(args.into_iter().map(|v| v.as_ref()));
103
104        match D::run_process(&program, all_args.as_slice(), Some(payload.as_bytes())).await {
105            Ok(output) if output.status.success() => Ok(()),
106            Ok(output) => {
107                let stdout = read(&program, output.stdout)?;
108                let stderr = read(&program, output.stderr)?;
109
110                Err(NftablesError::NftFailed {
111                    program: program.into(),
112                    hint: "applying ruleset".to_string(),
113                    stdout,
114                    stderr,
115                })
116            }
117            Err(err) => Err(NftablesError::NftExecution {
118                program: program.into(),
119                inner: err,
120            }),
121        }
122    }
123
124    async fn get_current_ruleset_raw<
125        'a,
126        P: AsRef<OsStr> + Sync + ?Sized,
127        A: AsRef<OsStr> + Sync + ?Sized + 'a,
128        I: IntoIterator<Item = &'a A> + Send,
129    >(
130        program: Option<&P>,
131        args: I,
132    ) -> Result<String, NftablesError> {
133        let program = program.map(|v| v.as_ref()).unwrap_or(OsStr::new("nft"));
134        let mut all_args = vec![OsStr::new("-j"), OsStr::new("list"), OsStr::new("ruleset")];
135
136        all_args.extend(args.into_iter().map(|v| v.as_ref()));
137
138        let output = D::run_process(program, all_args.as_slice(), None)
139            .await
140            .map_err(|err| NftablesError::NftExecution {
141                program: program.into(),
142                inner: err,
143            })?;
144
145        let stdout = read(&program, output.stdout)?;
146
147        if !output.status.success() {
148            let stderr = read(&program, output.stderr)?;
149
150            return Err(NftablesError::NftFailed {
151                program: program.into(),
152                hint: "getting the current ruleset".to_string(),
153                stdout,
154                stderr,
155            });
156        }
157
158        Ok(stdout)
159    }
160}
161
162#[inline]
163fn read(program: &OsStr, stream: Vec<u8>) -> Result<String, NftablesError> {
164    String::from_utf8(stream).map_err(|err| NftablesError::NftOutputEncoding {
165        program: program.into(),
166        inner: err,
167    })
168}