use std::{net::IpAddr, time::Duration};
use super::{
Connection,
error::{Error, Result},
link::IfbLink,
protocol::Route,
tc::{FqCodelConfig, HtbClassConfig, HtbQdiscConfig, IngressConfig},
tc_handle::TcHandle,
};
#[derive(Debug, Clone)]
pub struct RateLimit {
pub rate: crate::util::Rate,
pub ceil: Option<crate::util::Rate>,
pub burst: Option<crate::util::Bytes>,
pub latency: Option<Duration>,
}
impl RateLimit {
pub fn new(rate: crate::util::Rate) -> Self {
Self {
rate,
ceil: None,
burst: None,
latency: None,
}
}
pub fn ceil(mut self, ceil: crate::util::Rate) -> Self {
self.ceil = Some(ceil);
self
}
pub fn burst(mut self, burst: crate::util::Bytes) -> Self {
self.burst = Some(burst);
self
}
pub fn latency(mut self, latency: Duration) -> Self {
self.latency = Some(latency);
self
}
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
dev: String,
egress: Option<RateLimit>,
ingress: Option<RateLimit>,
}
impl RateLimiter {
pub fn new(dev: &str) -> Self {
Self {
dev: dev.to_string(),
egress: None,
ingress: None,
}
}
pub fn egress(mut self, rate: crate::util::Rate) -> Self {
self.egress = Some(RateLimit::new(rate));
self
}
pub fn ingress(mut self, rate: crate::util::Rate) -> Self {
self.ingress = Some(RateLimit::new(rate));
self
}
pub fn burst_to(mut self, ceil: crate::util::Rate) -> Self {
if let Some(ref mut egress) = self.egress {
egress.ceil = Some(ceil);
}
if let Some(ref mut ingress) = self.ingress {
ingress.ceil = Some(ceil);
}
self
}
pub fn burst_size(mut self, size: crate::util::Bytes) -> Self {
if let Some(ref mut egress) = self.egress {
egress.burst = Some(size);
}
if let Some(ref mut ingress) = self.ingress {
ingress.burst = Some(size);
}
self
}
pub fn latency(mut self, latency: Duration) -> Self {
if let Some(ref mut egress) = self.egress {
egress.latency = Some(latency);
}
if let Some(ref mut ingress) = self.ingress {
ingress.latency = Some(latency);
}
self
}
#[tracing::instrument(level = "info", skip_all, fields(dev = %self.dev, egress = self.egress.is_some(), ingress = self.ingress.is_some()))]
pub async fn apply(&self, conn: &Connection<Route>) -> Result<()> {
if let Some(ref egress) = self.egress {
self.apply_egress(conn, egress).await?;
}
if let Some(ref ingress) = self.ingress {
self.apply_ingress(conn, ingress).await?;
}
Ok(())
}
#[tracing::instrument(level = "info", skip_all, fields(dev = %self.dev))]
pub async fn remove(&self, conn: &Connection<Route>) -> Result<()> {
let _ = conn.del_qdisc(&self.dev, TcHandle::ROOT).await;
let _ = conn.del_qdisc(&self.dev, TcHandle::INGRESS).await;
let ifb_name = self.ifb_name();
let _ = conn.del_link(&ifb_name).await;
Ok(())
}
fn ifb_name(&self) -> String {
let prefix = "ifb_";
let max_dev_len = 15 - prefix.len();
let dev_part = if self.dev.len() > max_dev_len {
&self.dev[..max_dev_len]
} else {
&self.dev
};
format!("{}{}", prefix, dev_part)
}
async fn apply_egress(&self, conn: &Connection<Route>, limit: &RateLimit) -> Result<()> {
let _ = conn.del_qdisc(&self.dev, TcHandle::ROOT).await;
let htb = HtbQdiscConfig::new().default_class(0x10).build();
conn.add_qdisc_full(
&self.dev,
TcHandle::ROOT,
Some(TcHandle::major_only(1)),
htb,
)
.await?;
let mut class_config = HtbClassConfig::new(limit.rate);
if let Some(ceil) = limit.ceil {
class_config = class_config.ceil(ceil);
}
if let Some(burst) = limit.burst {
class_config = class_config.burst(burst);
}
conn.add_class_config(
&self.dev,
TcHandle::major_only(1),
TcHandle::new(1, 1),
class_config.build(),
)
.await?;
let mut default_config = HtbClassConfig::new(limit.rate);
if let Some(ceil) = limit.ceil {
default_config = default_config.ceil(ceil);
}
if let Some(burst) = limit.burst {
default_config = default_config.burst(burst);
}
conn.add_class_config(
&self.dev,
TcHandle::new(1, 1),
TcHandle::new(1, 10),
default_config.build(),
)
.await?;
let mut fq_codel = FqCodelConfig::new();
if let Some(latency) = limit.latency {
fq_codel = fq_codel.target(latency);
}
conn.add_qdisc_full(
&self.dev,
TcHandle::new(1, 10),
Some(TcHandle::major_only(10)),
fq_codel.build(),
)
.await?;
Ok(())
}
async fn apply_ingress(&self, conn: &Connection<Route>, limit: &RateLimit) -> Result<()> {
let ifb_name = self.ifb_name();
if conn.get_link_by_name(&ifb_name).await?.is_none() {
conn.add_link(IfbLink::new(&ifb_name)).await?;
}
conn.set_link_up(&ifb_name).await?;
let _ = conn.del_qdisc(&self.dev, TcHandle::INGRESS).await;
conn.add_qdisc_full(&self.dev, TcHandle::INGRESS, None, IngressConfig::new())
.await?;
self.add_ingress_redirect(conn, &ifb_name).await?;
let _ = conn.del_qdisc(&ifb_name, TcHandle::ROOT).await;
let htb = HtbQdiscConfig::new().default_class(0x10).build();
conn.add_qdisc_full(
&ifb_name,
TcHandle::ROOT,
Some(TcHandle::major_only(1)),
htb,
)
.await?;
let mut class_config = HtbClassConfig::new(limit.rate);
if let Some(ceil) = limit.ceil {
class_config = class_config.ceil(ceil);
}
if let Some(burst) = limit.burst {
class_config = class_config.burst(burst);
}
conn.add_class_config(
&ifb_name,
TcHandle::major_only(1),
TcHandle::new(1, 1),
class_config.build(),
)
.await?;
let mut default_config = HtbClassConfig::new(limit.rate);
if let Some(ceil) = limit.ceil {
default_config = default_config.ceil(ceil);
}
if let Some(burst) = limit.burst {
default_config = default_config.burst(burst);
}
conn.add_class_config(
&ifb_name,
TcHandle::new(1, 1),
TcHandle::new(1, 10),
default_config.build(),
)
.await?;
let mut fq_codel = FqCodelConfig::new();
if let Some(latency) = limit.latency {
fq_codel = fq_codel.target(latency);
}
conn.add_qdisc_full(
&ifb_name,
TcHandle::new(1, 10),
Some(TcHandle::major_only(10)),
fq_codel.build(),
)
.await?;
Ok(())
}
async fn add_ingress_redirect(&self, conn: &Connection<Route>, ifb_name: &str) -> Result<()> {
let ifb_link = conn
.get_link_by_name(ifb_name)
.await?
.ok_or_else(|| Error::InvalidMessage(format!("IFB device not found: {}", ifb_name)))?;
let ifb_ifindex = ifb_link.ifindex();
self.add_u32_redirect_filter(conn, ifb_ifindex).await
}
async fn add_u32_redirect_filter(
&self,
conn: &Connection<Route>,
ifb_ifindex: u32,
) -> Result<()> {
use super::{
connection::ack_request,
message::NlMsgType,
types::tc::{
TcMsg, TcaAttr,
action::{self, mirred},
filter::u32 as u32_mod,
tc_handle,
},
};
let link = conn
.get_link_by_name(&self.dev)
.await?
.ok_or_else(|| Error::InvalidMessage(format!("interface not found: {}", self.dev)))?;
let ifindex = link.ifindex();
let tcmsg = TcMsg::new()
.with_ifindex(ifindex as i32)
.with_parent(tc_handle::INGRESS)
.with_info((0x0003u16 as u32) << 16 | 1);
let mut builder = ack_request(NlMsgType::RTM_NEWTFILTER);
builder.append(&tcmsg);
builder.append_attr_str(TcaAttr::Kind as u16, "u32");
let opt_token = builder.nest_start(TcaAttr::Options as u16);
let sel_token = builder.nest_start(u32_mod::TCA_U32_SEL);
let sel_data: [u8; 28] = [
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ];
builder.append_bytes(&sel_data);
builder.nest_end(sel_token);
let act_token = builder.nest_start(u32_mod::TCA_U32_ACT);
let act1_token = builder.nest_start(1);
builder.append_attr_str(action::TCA_ACT_KIND, "mirred");
let mirred_opt_token = builder.nest_start(action::TCA_ACT_OPTIONS);
let mirred_parms = mirred::TcMirred::new(
mirred::TCA_INGRESS_REDIR,
ifb_ifindex,
action::TC_ACT_STOLEN,
);
builder.append_attr(mirred::TCA_MIRRED_PARMS, mirred_parms.as_bytes());
builder.nest_end(mirred_opt_token);
builder.nest_end(act1_token);
builder.nest_end(act_token);
builder.nest_end(opt_token);
conn.send_ack(builder).await?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PerHostLimiter {
dev: String,
default_rate: crate::util::Rate,
rules: Vec<HostRule>,
latency: Option<Duration>,
}
#[derive(Debug, Clone)]
pub struct HostRule {
match_: HostMatch,
rate: crate::util::Rate,
ceil: Option<crate::util::Rate>,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum HostMatch {
Ip(IpAddr),
Subnet(IpAddr, u8),
Port(u16),
PortRange(u16, u16),
SrcIp(IpAddr),
SrcSubnet(IpAddr, u8),
}
impl PerHostLimiter {
pub fn new(dev: &str, default_rate: crate::util::Rate) -> Self {
Self {
dev: dev.to_string(),
default_rate,
rules: Vec::new(),
latency: None,
}
}
pub fn limit_ip(mut self, ip: IpAddr, rate: crate::util::Rate) -> Self {
self.rules.push(HostRule {
match_: HostMatch::Ip(ip),
rate,
ceil: None,
});
self
}
pub fn limit_ip_with_ceil(
mut self,
ip: IpAddr,
rate: crate::util::Rate,
ceil: crate::util::Rate,
) -> Self {
self.rules.push(HostRule {
match_: HostMatch::Ip(ip),
rate,
ceil: Some(ceil),
});
self
}
pub fn limit_subnet(mut self, subnet: &str, rate: crate::util::Rate) -> Result<Self> {
let (addr, prefix) = parse_subnet(subnet)?;
self.rules.push(HostRule {
match_: HostMatch::Subnet(addr, prefix),
rate,
ceil: None,
});
Ok(self)
}
pub fn limit_src_ip(mut self, ip: IpAddr, rate: crate::util::Rate) -> Self {
self.rules.push(HostRule {
match_: HostMatch::SrcIp(ip),
rate,
ceil: None,
});
self
}
pub fn limit_src_subnet(mut self, subnet: &str, rate: crate::util::Rate) -> Result<Self> {
let (addr, prefix) = parse_subnet(subnet)?;
self.rules.push(HostRule {
match_: HostMatch::SrcSubnet(addr, prefix),
rate,
ceil: None,
});
Ok(self)
}
pub fn limit_port(mut self, port: u16, rate: crate::util::Rate) -> Self {
self.rules.push(HostRule {
match_: HostMatch::Port(port),
rate,
ceil: None,
});
self
}
pub fn limit_port_range(mut self, start: u16, end: u16, rate: crate::util::Rate) -> Self {
self.rules.push(HostRule {
match_: HostMatch::PortRange(start, end),
rate,
ceil: None,
});
self
}
pub fn latency(mut self, latency: Duration) -> Self {
self.latency = Some(latency);
self
}
#[tracing::instrument(level = "info", skip_all, fields(dev = %self.dev, rules = self.rules.len()))]
pub async fn apply(&self, conn: &Connection<Route>) -> Result<()> {
let _ = conn.del_qdisc(&self.dev, TcHandle::ROOT).await;
let default_classid = (self.rules.len() + 1) as u32;
let htb = HtbQdiscConfig::new().default_class(default_classid).build();
conn.add_qdisc_full(
&self.dev,
TcHandle::ROOT,
Some(TcHandle::major_only(1)),
htb,
)
.await?;
let parent_classid = TcHandle::new(1, 1);
let major_only_1 = TcHandle::major_only(1);
let total_rate: crate::util::Rate =
self.default_rate + self.rules.iter().map(|r| r.rate).sum::<crate::util::Rate>();
let root_config = HtbClassConfig::new(total_rate).ceil(total_rate).build();
conn.add_class_config(&self.dev, major_only_1, parent_classid, root_config)
.await?;
for (i, rule) in self.rules.iter().enumerate() {
let classid = TcHandle::new(1, (i + 2) as u16);
let leaf_handle = TcHandle::major_only((i + 10) as u16);
let class_config = HtbClassConfig::new(rule.rate).ceil(rule.ceil.unwrap_or(rule.rate));
conn.add_class_config(&self.dev, parent_classid, classid, class_config.build())
.await?;
let mut fq_codel = FqCodelConfig::new();
if let Some(latency) = self.latency {
fq_codel = fq_codel.target(latency);
}
conn.add_qdisc_full(&self.dev, classid, Some(leaf_handle), fq_codel.build())
.await?;
self.add_filter_for_rule(conn, i, rule).await?;
}
let default_classid = TcHandle::new(1, (self.rules.len() + 2) as u16);
let default_handle = TcHandle::major_only((self.rules.len() + 10) as u16);
let default_config = HtbClassConfig::new(self.default_rate)
.ceil(self.default_rate)
.build();
conn.add_class_config(&self.dev, parent_classid, default_classid, default_config)
.await?;
let mut fq_codel = FqCodelConfig::new();
if let Some(latency) = self.latency {
fq_codel = fq_codel.target(latency);
}
conn.add_qdisc_full(
&self.dev,
default_classid,
Some(default_handle),
fq_codel.build(),
)
.await?;
Ok(())
}
#[tracing::instrument(level = "info", skip_all, fields(dev = %self.dev))]
pub async fn remove(&self, conn: &Connection<Route>) -> Result<()> {
let _ = conn.del_qdisc(&self.dev, TcHandle::ROOT).await;
Ok(())
}
async fn add_filter_for_rule(
&self,
conn: &Connection<Route>,
index: usize,
rule: &HostRule,
) -> Result<()> {
use super::filter::FlowerFilter;
const ETH_P_IP: u16 = 0x0800;
const ETH_P_IPV6: u16 = 0x86DD;
let classid = TcHandle::new(1, (index + 2) as u16);
let priority = (index + 1) as u16;
match &rule.match_ {
HostMatch::Ip(ip) | HostMatch::Subnet(ip, _) => {
let prefix = match &rule.match_ {
HostMatch::Subnet(_, p) => *p,
_ => {
if ip.is_ipv4() {
32
} else {
128
}
}
};
match ip {
IpAddr::V4(addr) => {
let filter = FlowerFilter::new()
.classid(classid)
.priority(priority)
.dst_ipv4(*addr, prefix)
.build();
conn.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IP,
priority,
filter,
)
.await?;
}
IpAddr::V6(addr) => {
let filter = FlowerFilter::new()
.classid(classid)
.priority(priority)
.dst_ipv6(*addr, prefix)
.build();
conn.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IPV6,
priority,
filter,
)
.await?;
}
}
}
HostMatch::SrcIp(ip) | HostMatch::SrcSubnet(ip, _) => {
let prefix = match &rule.match_ {
HostMatch::SrcSubnet(_, p) => *p,
_ => {
if ip.is_ipv4() {
32
} else {
128
}
}
};
match ip {
IpAddr::V4(addr) => {
let filter = FlowerFilter::new()
.classid(classid)
.priority(priority)
.src_ipv4(*addr, prefix)
.build();
conn.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IP,
priority,
filter,
)
.await?;
}
IpAddr::V6(addr) => {
let filter = FlowerFilter::new()
.classid(classid)
.priority(priority)
.src_ipv6(*addr, prefix)
.build();
conn.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IPV6,
priority,
filter,
)
.await?;
}
}
}
HostMatch::Port(port) => {
let tcp_filter = FlowerFilter::new()
.classid(classid)
.priority(priority)
.ip_proto_tcp()
.dst_port(*port)
.build();
conn.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IP,
priority,
tcp_filter,
)
.await?;
let udp_filter = FlowerFilter::new()
.classid(classid)
.priority(priority + 100) .ip_proto_udp()
.dst_port(*port)
.build();
conn.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IP,
priority + 100,
udp_filter,
)
.await?;
}
HostMatch::PortRange(start, end) => {
if *end - *start <= 10 {
for port in *start..=*end {
let filter = FlowerFilter::new()
.classid(classid)
.priority(priority)
.ip_proto_tcp()
.dst_port(port)
.build();
let _ = conn
.add_filter_full(
&self.dev,
TcHandle::major_only(1),
None,
ETH_P_IP,
priority,
filter,
)
.await;
}
}
}
}
Ok(())
}
}
fn parse_subnet(subnet: &str) -> Result<(IpAddr, u8)> {
let parts: Vec<&str> = subnet.split('/').collect();
if parts.len() != 2 {
return Err(Error::InvalidMessage(format!(
"invalid subnet format: {}",
subnet
)));
}
let addr: IpAddr = parts[0]
.parse()
.map_err(|_| Error::InvalidMessage(format!("invalid IP address: {}", parts[0])))?;
let prefix: u8 = parts[1]
.parse()
.map_err(|_| Error::InvalidMessage(format!("invalid prefix length: {}", parts[1])))?;
let max_prefix = if addr.is_ipv4() { 32 } else { 128 };
if prefix > max_prefix {
return Err(Error::InvalidMessage(format!(
"prefix length {} exceeds maximum {} for address type",
prefix, max_prefix
)));
}
Ok((addr, prefix))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_new() {
use crate::util::Rate;
let limit = RateLimit::new(Rate::bytes_per_sec(1_000_000));
assert_eq!(limit.rate, Rate::bytes_per_sec(1_000_000));
assert!(limit.ceil.is_none());
assert!(limit.burst.is_none());
}
#[test]
fn test_rate_limit_typed_units() {
use crate::util::Rate;
let limit = RateLimit::new(Rate::mbit(100));
assert_eq!(limit.rate.as_bytes_per_sec(), 12_500_000);
let limit = RateLimit::new(Rate::gbit(1));
assert_eq!(limit.rate.as_bytes_per_sec(), 125_000_000);
}
#[test]
fn test_rate_limiter_builder() {
use crate::util::Rate;
let limiter = RateLimiter::new("eth0")
.egress(Rate::bytes_per_sec(1_000_000))
.ingress(Rate::bytes_per_sec(2_000_000))
.burst_to(Rate::bytes_per_sec(3_000_000));
assert_eq!(limiter.dev, "eth0");
assert!(limiter.egress.is_some());
assert!(limiter.ingress.is_some());
assert_eq!(
limiter.egress.as_ref().unwrap().rate,
Rate::bytes_per_sec(1_000_000)
);
assert_eq!(
limiter.egress.as_ref().unwrap().ceil,
Some(Rate::bytes_per_sec(3_000_000))
);
assert_eq!(
limiter.ingress.as_ref().unwrap().rate,
Rate::bytes_per_sec(2_000_000)
);
assert_eq!(
limiter.ingress.as_ref().unwrap().ceil,
Some(Rate::bytes_per_sec(3_000_000))
);
}
#[test]
fn test_ifb_name_generation() {
let limiter = RateLimiter::new("eth0");
assert_eq!(limiter.ifb_name(), "ifb_eth0");
let limiter = RateLimiter::new("verylonginterfacename");
assert!(limiter.ifb_name().len() <= 15);
}
#[test]
fn test_parse_subnet() {
let (addr, prefix) = parse_subnet("10.0.0.0/8").unwrap();
assert_eq!(addr, "10.0.0.0".parse::<IpAddr>().unwrap());
assert_eq!(prefix, 8);
let (addr, prefix) = parse_subnet("192.168.1.0/24").unwrap();
assert_eq!(addr, "192.168.1.0".parse::<IpAddr>().unwrap());
assert_eq!(prefix, 24);
let (addr, prefix) = parse_subnet("2001:db8::/32").unwrap();
assert!(addr.is_ipv6());
assert_eq!(prefix, 32);
assert!(parse_subnet("10.0.0.0").is_err());
assert!(parse_subnet("10.0.0.0/33").is_err());
}
#[test]
fn test_per_host_limiter_builder() {
use crate::util::Rate;
let limiter = PerHostLimiter::new("eth0", Rate::mbit(10));
assert_eq!(limiter.dev, "eth0");
assert_eq!(limiter.default_rate, Rate::mbit(10));
assert!(limiter.rules.is_empty());
}
#[test]
fn test_per_host_limiter_with_rules() {
use crate::util::Rate;
let limiter = PerHostLimiter::new("eth0", Rate::mbit(10))
.limit_ip("192.168.1.100".parse().unwrap(), Rate::mbit(100))
.limit_subnet("10.0.0.0/8", Rate::mbit(50))
.unwrap()
.limit_port(80, Rate::mbit(500));
assert_eq!(limiter.rules.len(), 3);
}
}