Skip to main content

e2etest_firewall/
lib.rs

1/*
2 * Copyright 2026-present ScyllaDB
3 * SPDX-License-Identifier: MIT OR Apache-2.0
4 */
5
6//! This crate provides a Firewall emulator server for e2etest tests. It provides an actor with
7//! handler using `tokio::sync::mpsc::Sender` over `enum Firewall` message. It provides also a
8//! `trait FirewallExt` with helper methods to send messages to the actor. Currently, the Firewall
9//! actor adds and removes blackhole routes to block traffic to specified IP addresses. It uses
10//! `neli` crate to interact with the Linux kernel's netlink interface to manage routing rules.
11
12use async_backtrace::frame;
13use async_backtrace::framed;
14use neli::consts::nl::*;
15use neli::consts::rtnl::*;
16use neli::consts::socket::*;
17use neli::nl::NlPayload;
18use neli::nl::Nlmsghdr;
19use neli::router::asynchronous::NlRouter;
20use neli::rtnl::*;
21use neli::types::RtBuffer;
22use neli::utils::Groups;
23use std::mem;
24use std::net::IpAddr;
25use std::net::Ipv4Addr;
26use tokio::sync::mpsc;
27use tokio::sync::oneshot;
28use tracing::Instrument;
29use tracing::debug;
30use tracing::error;
31use tracing::error_span;
32use tracing::info;
33
34/// Messages for the Firewall actor.
35pub enum Firewall {
36    DropTraffic {
37        ips: Vec<Ipv4Addr>,
38        tx: oneshot::Sender<()>,
39    },
40    TurnOffRules {
41        tx: oneshot::Sender<()>,
42    },
43}
44
45/// Extension trait for `mpsc::Sender<Firewall>` to provide helper methods to send messages to the
46/// Firewall actor.
47pub trait FirewallExt {
48    /// Drops traffic to the specified IP addresses.
49    fn drop_traffic(&self, ips: Vec<Ipv4Addr>) -> impl Future<Output = ()>;
50
51    /// Turn off all firewall rules.
52    fn turn_off_rules(&self) -> impl Future<Output = ()>;
53}
54
55impl FirewallExt for mpsc::Sender<Firewall> {
56    #[framed]
57    async fn drop_traffic(&self, ips: Vec<Ipv4Addr>) {
58        let (tx, rx) = oneshot::channel();
59        self.send(Firewall::DropTraffic { ips, tx })
60            .await
61            .expect("FirewallExt::drop_traffic: internal actor should receive request");
62        rx.await
63            .expect("FirewallExt::drop_traffic: internal actor should send response")
64    }
65
66    #[framed]
67    async fn turn_off_rules(&self) {
68        let (tx, rx) = oneshot::channel();
69        self.send(Firewall::TurnOffRules { tx })
70            .await
71            .expect("FirewallExt::turn_off_rules: internal actor should receive request");
72        rx.await
73            .expect("FirewallExt::turn_off_rules: internal actor should send response")
74    }
75}
76
77/// Creates a new Firewall actor and returns a sender to send messages to it.
78#[framed]
79pub async fn new() -> mpsc::Sender<Firewall> {
80    let (tx, mut rx) = mpsc::channel(10);
81
82    tokio::spawn(
83        frame!(async move {
84            debug!("starting");
85
86            let (socket, _) = NlRouter::connect(NlFamily::Route, None, Groups::empty())
87                .await
88                .unwrap();
89
90            let mut disabled_ips = Vec::new();
91
92            while let Some(msg) = rx.recv().await {
93                process(msg, &socket, &mut disabled_ips).await;
94            }
95
96            debug!("finished");
97        })
98        .instrument(error_span!("firewall")),
99    );
100
101    tx
102}
103
104#[framed]
105async fn process(msg: Firewall, socket: &NlRouter, disabled_ips: &mut Vec<Ipv4Addr>) {
106    match msg {
107        Firewall::DropTraffic { ips, tx } => {
108            info!("Removing rules for: {disabled_ips:?}");
109            turn_off_rules(socket, mem::take(disabled_ips)).await;
110            *disabled_ips = ips;
111            info!("Adding rules for: {disabled_ips:?}");
112            drop_traffic(socket, disabled_ips).await;
113            if let Err(err) = log_routes(socket).await {
114                error!("Failed to list routes: {err}");
115            }
116            tx.send(())
117                .expect("process Firewall::DropTraffic: failed to send a response");
118        }
119
120        Firewall::TurnOffRules { tx } => {
121            info!("Removing rules for: {disabled_ips:?}");
122            turn_off_rules(socket, mem::take(disabled_ips)).await;
123            if let Err(err) = log_routes(socket).await {
124                error!("Failed to list routes: {err}");
125            }
126            tx.send(())
127                .expect("process Firewall::TurnOffRules: failed to send a response");
128        }
129    }
130}
131
132#[framed]
133async fn drop_traffic(socket: &NlRouter, ips: &[Ipv4Addr]) {
134    for ip in ips.iter() {
135        let Err(err) = add_unreachable_route(socket, ip).await else {
136            continue;
137        };
138        error!("Failed to add unreachable route for ip {ip}: {err}");
139    }
140}
141
142#[framed]
143async fn turn_off_rules(socket: &NlRouter, ips: Vec<Ipv4Addr>) {
144    for ip in ips.into_iter() {
145        let Err(err) = remove_unreachable_route(socket, ip).await else {
146            continue;
147        };
148        error!("Failed to remove unreachable route for ip {ip}: {err}");
149    }
150}
151
152async fn add_unreachable_route(socket: &NlRouter, ip: &Ipv4Addr) -> anyhow::Result<()> {
153    let mut attrs = RtBuffer::new();
154    attrs.push(
155        RtattrBuilder::default()
156            .rta_type(Rta::Dst)
157            .rta_payload(ip.octets())
158            .build()?,
159    );
160    let rtmsg = RtmsgBuilder::default()
161        .rtm_family(RtAddrFamily::Inet)
162        .rtm_dst_len(32)
163        .rtm_src_len(0)
164        .rtm_tos(0)
165        .rtm_table(RtTable::Main)
166        .rtm_protocol(Rtprot::Unspec)
167        .rtm_scope(RtScope::Universe)
168        .rtm_type(Rtn::Blackhole)
169        .rtattrs(attrs)
170        .build()?;
171    socket
172        .send::<Rtm, Rtmsg, NlTypeWrapper, Rtmsg>(
173            Rtm::Newroute,
174            NlmF::REQUEST | NlmF::CREATE | NlmF::REPLACE,
175            NlPayload::Payload(rtmsg),
176        )
177        .await?;
178    Ok(())
179}
180
181async fn remove_unreachable_route(socket: &NlRouter, ip: Ipv4Addr) -> anyhow::Result<()> {
182    let mut attrs = RtBuffer::new();
183    attrs.push(
184        RtattrBuilder::default()
185            .rta_type(Rta::Dst)
186            .rta_payload(ip.octets())
187            .build()?,
188    );
189    let rtmsg = RtmsgBuilder::default()
190        .rtm_family(RtAddrFamily::Inet)
191        .rtm_dst_len(32)
192        .rtm_src_len(0)
193        .rtm_tos(0)
194        .rtm_table(RtTable::Main)
195        .rtm_protocol(Rtprot::Unspec)
196        .rtm_scope(RtScope::Universe)
197        .rtm_type(Rtn::Blackhole)
198        .rtattrs(attrs)
199        .build()?;
200    socket
201        .send::<Rtm, Rtmsg, NlTypeWrapper, Rtmsg>(
202            Rtm::Delroute,
203            NlmF::REQUEST,
204            NlPayload::Payload(rtmsg),
205        )
206        .await?;
207    Ok(())
208}
209
210async fn log_routes(socket: &NlRouter) -> anyhow::Result<()> {
211    let rtmsg = RtmsgBuilder::default()
212        .rtm_family(RtAddrFamily::Inet)
213        .rtm_dst_len(0)
214        .rtm_src_len(0)
215        .rtm_tos(0)
216        .rtm_table(RtTable::Unspec)
217        .rtm_protocol(Rtprot::Unspec)
218        .rtm_scope(RtScope::Universe)
219        .rtm_type(Rtn::Unspec)
220        .build()?;
221    let mut recv = socket
222        .send::<Rtm, Rtmsg, NlTypeWrapper, Rtmsg>(
223            Rtm::Getroute,
224            NlmF::DUMP,
225            NlPayload::Payload(rtmsg),
226        )
227        .await?;
228
229    while let Some(rtm_result) = recv.next().await {
230        let rtm = rtm_result?;
231        if let NlTypeWrapper::Rtm(_) = rtm.nl_type() {
232            parse_route_table(rtm)?;
233        }
234    }
235
236    Ok(())
237}
238
239fn parse_route_table(rtm: Nlmsghdr<NlTypeWrapper, Rtmsg>) -> anyhow::Result<()> {
240    if let Some(payload) = rtm.get_payload() {
241        let mut dst = None;
242
243        for attr in payload.rtattrs().iter() {
244            fn to_addr(b: &[u8]) -> Option<IpAddr> {
245                if let Ok(tup) = <&[u8; 4]>::try_from(b) {
246                    Some(IpAddr::from(*tup))
247                } else if let Ok(tup) = <&[u8; 16]>::try_from(b) {
248                    Some(IpAddr::from(*tup))
249                } else {
250                    None
251                }
252            }
253
254            if attr.rta_type() == &Rta::Dst {
255                dst = to_addr(attr.rta_payload().as_ref())
256            }
257        }
258
259        let dst = if let Some(dst) = dst {
260            format!("{}/{} ", dst, payload.rtm_dst_len())
261        } else {
262            "default".to_string()
263        };
264
265        info!(
266            "active route for {:?}: {dst}: {:?}",
267            payload.rtm_table(),
268            payload.rtm_type()
269        );
270    }
271    Ok(())
272}