1use std::{ffi::OsStr, future::Future};
2
3use nftables::{
4 helper::{NftablesError, DEFAULT_ARGS, DEFAULT_NFT},
5 schema::Nftables,
6};
7
8use crate::{driver::Driver, util::MapFuture};
9
10pub trait Helper {
13 fn apply_ruleset(
15 nftables: &Nftables,
16 ) -> impl Future<Output = Result<(), NftablesError>> + Send {
17 Self::apply_ruleset_with_args(nftables, DEFAULT_NFT, DEFAULT_ARGS)
18 }
19
20 fn apply_ruleset_with_args<
22 'a,
23 P: AsRef<OsStr> + Sync + ?Sized,
24 A: AsRef<OsStr> + Sync + ?Sized + 'a,
25 I: IntoIterator<Item = &'a A> + Send,
26 >(
27 nftables: &Nftables,
28 program: Option<&P>,
29 args: I,
30 ) -> impl Future<Output = Result<(), NftablesError>> + Send {
31 let payload =
32 serde_json::to_string(nftables).expect("Failed to serialize Nftables struct to JSON");
33 Self::apply_ruleset_raw(payload, program, args)
34 }
35
36 fn apply_ruleset_raw<
39 'a,
40 P: AsRef<OsStr> + Sync + ?Sized,
41 A: AsRef<OsStr> + Sync + ?Sized + 'a,
42 I: IntoIterator<Item = &'a A> + Send,
43 >(
44 payload: String,
45 program: Option<&P>,
46 args: I,
47 ) -> impl Future<Output = Result<(), NftablesError>> + Send;
48
49 fn get_current_ruleset() -> impl Future<Output = Result<Nftables<'static>, NftablesError>> + Send
51 {
52 Self::get_current_ruleset_with_args(DEFAULT_NFT, DEFAULT_ARGS)
53 }
54
55 fn get_current_ruleset_with_args<
57 'a,
58 P: AsRef<OsStr> + Sync + ?Sized,
59 A: AsRef<OsStr> + Sync + ?Sized + 'a,
60 I: IntoIterator<Item = &'a A> + Send,
61 >(
62 program: Option<&P>,
63 args: I,
64 ) -> impl Future<Output = Result<Nftables<'static>, NftablesError>> + Send {
65 MapFuture::new(
66 Self::get_current_ruleset_raw(program, args),
67 |result: Result<String, NftablesError>| {
68 result.and_then(|output| {
69 serde_json::from_str(&output).map_err(NftablesError::NftInvalidJson)
70 })
71 },
72 )
73 }
74
75 fn get_current_ruleset_raw<
78 'a,
79 P: AsRef<OsStr> + Sync + ?Sized,
80 A: AsRef<OsStr> + Sync + ?Sized + 'a,
81 I: IntoIterator<Item = &'a A> + Send,
82 >(
83 program: Option<&P>,
84 args: I,
85 ) -> impl Future<Output = Result<String, NftablesError>> + Send;
86}
87
88impl<D: Driver> Helper for D {
89 async fn apply_ruleset_raw<
90 'a,
91 P: AsRef<OsStr> + Sync + ?Sized,
92 A: AsRef<OsStr> + Sync + ?Sized + 'a,
93 I: IntoIterator<Item = &'a A> + Send,
94 >(
95 payload: String,
96 program: Option<&P>,
97 args: I,
98 ) -> Result<(), NftablesError> {
99 let program = program.map(|v| v.as_ref()).unwrap_or(OsStr::new("nft"));
100 let mut all_args = vec![OsStr::new("-j"), OsStr::new("-f"), OsStr::new("-")];
101
102 all_args.extend(args.into_iter().map(|v| v.as_ref()));
103
104 match D::run_process(&program, all_args.as_slice(), Some(payload.as_bytes())).await {
105 Ok(output) if output.status.success() => Ok(()),
106 Ok(output) => {
107 let stdout = read(&program, output.stdout)?;
108 let stderr = read(&program, output.stderr)?;
109
110 Err(NftablesError::NftFailed {
111 program: program.into(),
112 hint: "applying ruleset".to_string(),
113 stdout,
114 stderr,
115 })
116 }
117 Err(err) => Err(NftablesError::NftExecution {
118 program: program.into(),
119 inner: err,
120 }),
121 }
122 }
123
124 async fn get_current_ruleset_raw<
125 'a,
126 P: AsRef<OsStr> + Sync + ?Sized,
127 A: AsRef<OsStr> + Sync + ?Sized + 'a,
128 I: IntoIterator<Item = &'a A> + Send,
129 >(
130 program: Option<&P>,
131 args: I,
132 ) -> Result<String, NftablesError> {
133 let program = program.map(|v| v.as_ref()).unwrap_or(OsStr::new("nft"));
134 let mut all_args = vec![OsStr::new("-j"), OsStr::new("list"), OsStr::new("ruleset")];
135
136 all_args.extend(args.into_iter().map(|v| v.as_ref()));
137
138 let output = D::run_process(program, all_args.as_slice(), None)
139 .await
140 .map_err(|err| NftablesError::NftExecution {
141 program: program.into(),
142 inner: err,
143 })?;
144
145 let stdout = read(&program, output.stdout)?;
146
147 if !output.status.success() {
148 let stderr = read(&program, output.stderr)?;
149
150 return Err(NftablesError::NftFailed {
151 program: program.into(),
152 hint: "getting the current ruleset".to_string(),
153 stdout,
154 stderr,
155 });
156 }
157
158 Ok(stdout)
159 }
160}
161
162#[inline]
163fn read(program: &OsStr, stream: Vec<u8>) -> Result<String, NftablesError> {
164 String::from_utf8(stream).map_err(|err| NftablesError::NftOutputEncoding {
165 program: program.into(),
166 inner: err,
167 })
168}