use super::errors::WireGuardError;
use crate::noise::Tunn;
use crate::packet::WgKind;
use std::mem;
use std::ops::{Index, IndexMut};
use std::time::Duration;
use bytes::BytesMut;
#[cfg(feature = "mock_instant")]
use mock_instant::thread_local::Instant;
#[cfg(not(feature = "mock_instant"))]
use crate::sleepyinstant::Instant;
pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120);
#[derive(Debug)]
pub enum TimerName {
TimeCurrent,
TimeSessionEstablished,
TimeLastHandshakeStarted,
TimeLastPacketReceived,
TimeLastPacketSent,
TimeLastDataPacketReceived,
TimeLastDataPacketSent,
TimeCookieReceived,
TimePersistentKeepalive,
Top,
}
use self::TimerName::*;
#[derive(Debug)]
pub struct Timers {
is_initiator: bool,
time_started: Instant,
timers: [Duration; TimerName::Top as usize],
pub(super) session_timers: [Duration; super::N_SESSIONS],
want_keepalive: bool,
want_handshake: Option<Duration>,
persistent_keepalive: usize,
}
impl Timers {
pub(super) fn new(persistent_keepalive: Option<u16>) -> Timers {
Timers {
is_initiator: false,
time_started: Instant::now(),
timers: Default::default(),
session_timers: Default::default(),
want_keepalive: Default::default(),
want_handshake: Default::default(),
persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)),
}
}
fn is_initiator(&self) -> bool {
self.is_initiator
}
pub(super) fn clear(&mut self) {
let now = self.now();
for t in &mut self.timers[..] {
*t = now;
}
self.want_handshake = None;
self.want_keepalive = false;
}
fn now(&self) -> Duration {
Instant::now()
.checked_duration_since(self.time_started)
.unwrap_or(Duration::ZERO)
.max(self[TimeCurrent])
}
}
impl Index<TimerName> for Timers {
type Output = Duration;
fn index(&self, index: TimerName) -> &Duration {
&self.timers[index as usize]
}
}
impl IndexMut<TimerName> for Timers {
fn index_mut(&mut self, index: TimerName) -> &mut Duration {
&mut self.timers[index as usize]
}
}
impl Tunn {
pub(super) fn timer_tick(&mut self, timer_name: TimerName) {
let time = self.timers[TimeCurrent];
match timer_name {
TimeLastPacketReceived => {
self.timers.want_keepalive = true;
self.timers.want_handshake = None;
}
TimeLastPacketSent => {
self.timers.want_keepalive = false;
}
TimeLastDataPacketSent => {
self.timers.want_handshake.get_or_insert(time);
}
_ => {}
}
self.timers[timer_name] = time;
}
pub(super) fn timer_tick_session_established(
&mut self,
is_initiator: bool,
session_idx: usize,
) {
self.timer_tick(TimeSessionEstablished);
self.timers.session_timers[session_idx % crate::noise::N_SESSIONS] =
self.timers[TimeCurrent];
self.timers.is_initiator = is_initiator;
}
fn clear_all(&mut self) {
for session in &mut self.sessions {
*session = None;
}
self.packet_queue.clear();
self.timers.clear();
}
fn update_session_timers(&mut self, time_now: Duration) {
let timers = &mut self.timers;
for (i, t) in timers.session_timers.iter_mut().enumerate() {
if time_now - *t > REJECT_AFTER_TIME {
if let Some(session) = self.sessions[i].take() {
log::trace!(
"SESSION_EXPIRED(REJECT_AFTER_TIME): {}",
session.receiving_index
);
}
*t = time_now;
}
}
}
pub fn update_timers(&mut self) -> Result<Option<WgKind>, WireGuardError> {
let mut handshake_initiation_required = false;
let mut keepalive_required = false;
self.rate_limiter.try_reset_count();
let now = self.timers.now();
self.timers[TimeCurrent] = now;
self.update_session_timers(now);
let session_established = self.timers[TimeSessionEstablished];
let handshake_started = self.timers[TimeLastHandshakeStarted];
let aut_packet_sent = self.timers[TimeLastPacketSent];
let data_packet_received = self.timers[TimeLastDataPacketReceived];
let data_packet_sent = self.timers[TimeLastDataPacketSent];
let persistent_keepalive = self.timers.persistent_keepalive;
{
if self.handshake.is_expired() {
return Err(WireGuardError::ConnectionExpired);
}
if self.handshake.has_cookie()
&& now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME
{
self.handshake.clear_cookie();
}
if now - session_established >= REJECT_AFTER_TIME * 3 {
log::trace!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
self.handshake.set_expired();
self.clear_all();
return Err(WireGuardError::ConnectionExpired);
}
if let Some(time_init_sent) = self.handshake.timer() {
if now - handshake_started >= REKEY_ATTEMPT_TIME {
log::debug!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
self.handshake.set_expired();
self.clear_all();
return Err(WireGuardError::ConnectionExpired);
}
if time_init_sent.elapsed() >= REKEY_TIMEOUT {
log::debug!("HANDSHAKE(REKEY_TIMEOUT)");
handshake_initiation_required = true;
}
} else {
if self.timers.is_initiator() {
if session_established < data_packet_sent
&& now - session_established >= REKEY_AFTER_TIME
{
log::trace!("HANDSHAKE(REKEY_AFTER_TIME (on send))");
handshake_initiation_required = true;
}
if session_established < data_packet_received
&& now - session_established
>= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
{
log::trace!(
"HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \
REKEY_TIMEOUT \
(on receive))"
);
handshake_initiation_required = true;
}
}
if let Some(since) = self.timers.want_handshake
&& now.saturating_sub(since) >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT
{
log::trace!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)");
handshake_initiation_required = true;
self.timers.want_handshake = None;
}
if !handshake_initiation_required {
if data_packet_received > aut_packet_sent
&& now - aut_packet_sent >= KEEPALIVE_TIMEOUT
&& mem::replace(&mut self.timers.want_keepalive, false)
{
log::trace!("KEEPALIVE(KEEPALIVE_TIMEOUT)");
keepalive_required = true;
}
if persistent_keepalive > 0
&& (now - self.timers[TimePersistentKeepalive]
>= Duration::from_secs(persistent_keepalive as _))
{
log::trace!("KEEPALIVE(PERSISTENT_KEEPALIVE)");
self.timer_tick(TimePersistentKeepalive);
keepalive_required = true;
}
}
}
}
if handshake_initiation_required {
return Ok(self.format_handshake_initiation(true).map(Into::into));
}
if keepalive_required {
return Ok(self
.handle_outgoing_packet(crate::packet::Packet::from_bytes(BytesMut::new()), None));
}
Ok(None)
}
pub fn time_since_last_handshake(&self) -> Option<Duration> {
let current_session = self.current;
if self.sessions[current_session % super::N_SESSIONS].is_some() {
let duration_since_tun_start = self.timers.now();
let duration_since_session_established = self.timers[TimeSessionEstablished];
Some(duration_since_tun_start.saturating_sub(duration_since_session_established))
} else {
None
}
}
pub fn persistent_keepalive(&self) -> Option<u16> {
let keepalive = self.timers.persistent_keepalive;
if keepalive > 0 {
Some(keepalive as u16)
} else {
None
}
}
pub fn set_persistent_keepalive(&mut self, seconds: Option<u16>) {
self.timers.persistent_keepalive = usize::from(seconds.unwrap_or(0));
if self.timers.persistent_keepalive == 0 {
self.timers[TimePersistentKeepalive] = Duration::ZERO;
}
}
}