#[cfg(feature = "std")]
use std::time::{Duration, Instant};
use crate::transport::packet::PacketCodec;
#[derive(Debug, Clone, Copy)]
pub struct RekeyPolicy {
pub max_bytes: u64,
#[cfg(feature = "std")]
pub max_duration: Duration,
pub max_seq: u32,
}
impl Default for RekeyPolicy {
fn default() -> Self {
Self {
max_bytes: 1u64 << 30,
#[cfg(feature = "std")]
max_duration: Duration::from_secs(60 * 60),
max_seq: 1u32 << 31,
}
}
}
impl RekeyPolicy {
#[cfg(feature = "std")]
pub fn should_rekey(&self, codec: &PacketCodec, last_kex: Instant, now: Instant) -> bool {
if self.bytes_exceeded(codec) {
return true;
}
if self.seq_exceeded(codec) {
return true;
}
now.saturating_duration_since(last_kex) >= self.max_duration
}
#[cfg(not(feature = "std"))]
pub fn should_rekey(&self, codec: &PacketCodec) -> bool {
self.bytes_exceeded(codec) || self.seq_exceeded(codec)
}
fn bytes_exceeded(&self, codec: &PacketCodec) -> bool {
codec.bytes_in >= self.max_bytes || codec.bytes_out >= self.max_bytes
}
fn seq_exceeded(&self, codec: &PacketCodec) -> bool {
codec.seq_in >= self.max_seq || codec.seq_out >= self.max_seq
}
}
pub fn is_kex_msg(b: u8) -> bool {
matches!(b, 20 | 21 | 30..=49)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "std")]
#[test]
fn defaults_are_one_gib_and_one_hour() {
let p = RekeyPolicy::default();
assert_eq!(p.max_bytes, 1u64 << 30);
assert_eq!(p.max_duration, Duration::from_secs(3600));
assert_eq!(p.max_seq, 1u32 << 31);
}
#[cfg(feature = "std")]
#[test]
fn fresh_codec_is_not_due_for_rekey() {
let codec = PacketCodec::new();
let p = RekeyPolicy::default();
let now = Instant::now();
assert!(!p.should_rekey(&codec, now, now));
}
#[cfg(feature = "std")]
#[test]
fn byte_threshold_triggers() {
let mut codec = PacketCodec::new();
let p = RekeyPolicy {
max_bytes: 1024,
max_duration: Duration::from_secs(60 * 60),
max_seq: 1u32 << 31,
};
let now = Instant::now();
codec.bytes_out = 1023;
assert!(!p.should_rekey(&codec, now, now));
codec.bytes_out = 1024;
assert!(p.should_rekey(&codec, now, now));
}
#[cfg(feature = "std")]
#[test]
fn duration_threshold_triggers() {
let codec = PacketCodec::new();
let p = RekeyPolicy {
max_bytes: 1u64 << 30,
max_duration: Duration::from_secs(1),
max_seq: 1u32 << 31,
};
let now = Instant::now();
let then = now - Duration::from_secs(2);
assert!(p.should_rekey(&codec, then, now));
assert!(!p.should_rekey(&codec, now, now));
}
#[cfg(feature = "std")]
#[test]
fn seq_threshold_triggers() {
let mut codec = PacketCodec::new();
let p = RekeyPolicy {
max_bytes: 1u64 << 30,
max_duration: Duration::from_secs(60 * 60),
max_seq: 16,
};
let now = Instant::now();
codec.seq_in = 16;
assert!(p.should_rekey(&codec, now, now));
}
#[test]
fn is_kex_msg_basic() {
assert!(is_kex_msg(20));
assert!(is_kex_msg(21));
assert!(is_kex_msg(30));
assert!(is_kex_msg(49));
assert!(!is_kex_msg(50));
assert!(!is_kex_msg(0));
assert!(!is_kex_msg(19));
}
}