use crate::errors::rpc_invalid_request;
use crate::metrics::Metrics;
use anyhow::Result;
use prometheus::{IntCounter, IntGauge};
use std::convert::TryInto;
use std::net::IpAddr;
use std::sync::atomic::AtomicI32;
use std::sync::atomic::Ordering;
use std::sync::Mutex;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
struct ConnectionMetrics {
connections: IntGauge,
connections_rejected_global: IntCounter,
connections_rejected_prefix: IntCounter,
connections_total: IntCounter,
}
pub struct GlobalLimits {
max_connections_total: i32,
max_connections_shared_prefix: u32,
total_connections: AtomicI32,
total_prefixed_connections: Mutex<HashMap<[u8; 2], u32>>,
metrics: ConnectionMetrics,
}
fn get_prefix(addr: &IpAddr) -> [u8; 2] {
match addr {
IpAddr::V4(ipv4) => ipv4.octets()[..2].try_into().unwrap(),
IpAddr::V6(ipv6) => ipv6.octets()[..2].try_into().unwrap(),
}
}
impl GlobalLimits {
pub fn new(
max_connections_total: u32,
max_connections_shared_prefix: u32,
metric: &Metrics,
) -> GlobalLimits {
GlobalLimits {
max_connections_total: max_connections_total as i32,
max_connections_shared_prefix,
total_connections: AtomicI32::new(0),
total_prefixed_connections: Mutex::new(HashMap::new()),
metrics: ConnectionMetrics {
connections: metric.gauge_int(prometheus::Opts::new(
"rostrum_rpc_connections",
"# of RPC connections",
)),
connections_rejected_global: metric.counter_int(prometheus::Opts::new(
"rostrum_rpc_connections_rejected_global",
"# of rejected RPC connections due to global slot limits",
)),
connections_rejected_prefix: metric.counter_int(prometheus::Opts::new(
"rostrum_rpc_connections_rejected_prefix",
"# of rejected RPC connections due to prefix slot limits",
)),
connections_total: metric.counter_int(prometheus::Opts::new(
"rostrum_rpc_connections_total",
"# of RPC connections since server start",
)),
},
}
}
pub fn inc_connection(&self, addr: &IpAddr) -> Result<(u32, u32)> {
self.metrics.connections_total.inc();
let mut prefix_table = self.total_prefixed_connections.lock().unwrap();
let prefix_count = match prefix_table.entry(get_prefix(addr)) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(0),
};
if *prefix_count >= self.max_connections_shared_prefix {
self.metrics.connections_rejected_prefix.inc();
bail!(format!(
"Maximum connection limit of {} reached for IP prefix {:?}.",
self.max_connections_shared_prefix,
get_prefix(addr)
))
}
let c =
self.total_connections
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
if current < self.max_connections_total {
Some(current + 1)
} else {
None
}
});
if c.is_err() {
self.metrics.connections_rejected_global.inc();
bail!(format!(
"Maximum connection limit of {} reached.",
self.max_connections_total
))
};
*prefix_count += 1;
let c = c.unwrap() + 1;
self.metrics.connections.set(c as i64);
Ok((c as u32, *prefix_count))
}
pub fn dec_connection(&self, addr: &IpAddr) -> Result<(u32, u32)> {
let mut prefix_table = self.total_prefixed_connections.lock().unwrap();
let prefix_count = match prefix_table.get_mut(&get_prefix(addr)) {
Some(count) => {
*count -= 1;
*count
}
None => {
warn!("IP not found in prefix table");
0
}
};
let c =
self.total_connections
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
if current <= 0 {
None
} else {
Some(current - 1)
}
});
if c.is_err() {
bail!("Cannot decrease connection counter. Already at zero.");
}
let c = c.unwrap() - 1;
self.metrics.connections.set(c as i64);
Ok((c as u32, prefix_count))
}
pub fn connection_limits(&self) -> (u32, u32) {
(
self.max_connections_total as u32,
self.max_connections_shared_prefix,
)
}
}
#[derive(Clone, Copy)]
pub struct ConnectionLimits {
pub rpc_timeout: u16,
pub max_subscriptions: u32,
pub max_alias_bytes: u32,
}
impl ConnectionLimits {
pub fn new(rpc_timeout: u16, max_subscriptions: u32, max_alias_bytes: u32) -> ConnectionLimits {
ConnectionLimits {
rpc_timeout,
max_subscriptions,
max_alias_bytes,
}
}
pub fn check_subscriptions(&self, num_subscriptions: u32) -> Result<()> {
if num_subscriptions <= self.max_subscriptions {
return Ok(());
}
Err(rpc_invalid_request(format!(
"Scripthash subscriptions limit reached (max {})",
self.max_subscriptions
)))
}
pub fn check_alias_usage(&self, bytes_used: usize) -> Result<()> {
if bytes_used <= self.max_alias_bytes as usize {
return Ok(());
}
Err(rpc_invalid_request(format!(
"Address/alias subscriptions limit reached (max {} bytes) \
Use scripthash subscriptions for more subscriptions or increase this limit.",
self.max_alias_bytes
)))
}
}
#[cfg(test)]
mod test {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_ip_shared_prefix() {
let metrics = Metrics::dummy();
let prefix_limit = 2;
let limits = GlobalLimits::new(100, prefix_limit, &metrics);
let ipv4_addr1 = Ipv4Addr::new(1, 2, 0, 4);
let ipv4_addr2 = Ipv4Addr::new(1, 2, 100, 5);
let ipv4_addr3 = Ipv4Addr::new(1, 2, 254, 6);
let ipv6_addr1 = Ipv6Addr::new(1, 2, 1, 0, 0, 0, 0, 0);
let ipv6_addr2 = Ipv6Addr::new(1, 2, 2, 0, 0, 0, 0, 0);
let ipv6_addr3 = Ipv6Addr::new(1, 2, 3, 0, 0, 0, 0, 0);
let ipv4_addr4 = Ipv4Addr::new(1, 3, 0, 4);
let ipv6_addr4 = Ipv6Addr::new(0xf00d, 2, 1, 0, 0, 0, 0, 0);
assert_eq!(limits.inc_connection(&ipv4_addr1.into()).unwrap(), (1, 1));
assert_eq!(limits.inc_connection(&ipv4_addr2.into()).unwrap(), (2, 2));
assert!(limits.inc_connection(&ipv4_addr3.into()).is_err());
assert_eq!(limits.inc_connection(&ipv4_addr4.into()).unwrap(), (3, 1));
assert_eq!(limits.dec_connection(&ipv4_addr1.into()).unwrap(), (2, 1));
assert_eq!(limits.inc_connection(&ipv4_addr3.into()).unwrap(), (3, 2));
assert_eq!(limits.inc_connection(&ipv6_addr1.into()).unwrap(), (4, 1));
assert_eq!(limits.inc_connection(&ipv6_addr2.into()).unwrap(), (5, 2));
assert!(limits.inc_connection(&ipv6_addr3.into()).is_err());
assert_eq!(limits.inc_connection(&ipv6_addr4.into()).unwrap(), (6, 1));
assert_eq!(limits.dec_connection(&ipv6_addr1.into()).unwrap(), (5, 1));
assert_eq!(limits.inc_connection(&ipv6_addr3.into()).unwrap(), (6, 2));
}
}