use std::{io, net::SocketAddr, sync::Arc};
use byte_string::ByteStr;
use log::warn;
use crate::{
config::{ReplayAttackPolicy, ServerType},
crypto::{v1::random_iv_or_salt, CipherKind},
dns_resolver::DnsResolver,
security::replay::ReplayProtector,
};
pub struct Context {
replay_protector: ReplayProtector,
replay_policy: ReplayAttackPolicy,
dns_resolver: Arc<DnsResolver>,
ipv6_first: bool,
}
pub type SharedContext = Arc<Context>;
impl Context {
pub fn new(config_type: ServerType) -> Context {
Context {
replay_protector: ReplayProtector::new(config_type),
replay_policy: ReplayAttackPolicy::Default,
dns_resolver: Arc::new(DnsResolver::system_resolver()),
ipv6_first: false,
}
}
pub fn new_shared(config_type: ServerType) -> SharedContext {
SharedContext::new(Context::new(config_type))
}
#[inline(always)]
fn check_nonce_and_set(&self, method: CipherKind, nonce: &[u8]) -> bool {
match self.replay_policy {
ReplayAttackPolicy::Ignore => false,
_ => self.replay_protector.check_nonce_and_set(method, nonce),
}
}
pub fn generate_nonce(&self, method: CipherKind, nonce: &mut [u8], unique: bool) {
if nonce.is_empty() {
return;
}
loop {
random_iv_or_salt(nonce);
if unique && self.check_nonce_and_set(method, nonce) {
continue;
}
break;
}
}
pub fn check_nonce_replay(&self, method: CipherKind, nonce: &[u8]) -> io::Result<()> {
if nonce.is_empty() {
return Ok(());
}
#[allow(unused_mut)]
let mut replay_policy = self.replay_policy;
#[cfg(feature = "aead-cipher-2022")]
if method.is_aead_2022() {
replay_policy = ReplayAttackPolicy::Reject;
}
match replay_policy {
ReplayAttackPolicy::Default | ReplayAttackPolicy::Ignore => Ok(()),
ReplayAttackPolicy::Detect => {
if self.replay_protector.check_nonce_and_set(method, nonce) {
warn!("detected repeated nonce (iv/salt) {:?}", ByteStr::new(nonce));
}
Ok(())
}
ReplayAttackPolicy::Reject => {
if self.replay_protector.check_nonce_and_set(method, nonce) {
let err = io::Error::new(io::ErrorKind::Other, "detected repeated nonce (iv/salt)");
Err(err)
} else {
Ok(())
}
}
}
}
pub fn set_dns_resolver(&mut self, resolver: Arc<DnsResolver>) {
self.dns_resolver = resolver;
}
pub fn dns_resolver(&self) -> &Arc<DnsResolver> {
&self.dns_resolver
}
#[allow(clippy::needless_lifetimes)]
pub async fn dns_resolve<'a>(&self, addr: &'a str, port: u16) -> io::Result<impl Iterator<Item = SocketAddr> + 'a> {
self.dns_resolver.resolve(addr, port).await
}
pub fn set_ipv6_first(&mut self, ipv6_first: bool) {
self.ipv6_first = ipv6_first;
}
pub fn ipv6_first(&self) -> bool {
self.ipv6_first
}
pub fn set_replay_attack_policy(&mut self, replay_policy: ReplayAttackPolicy) {
self.replay_policy = replay_policy;
}
pub fn replay_attack_policy(&self) -> ReplayAttackPolicy {
self.replay_policy
}
}