pub mod ports;
use super::{
ServiceHandler, Transport,
session::ports::{PortAllocator, PortRange},
};
use crate::codec::{crypto::Password, message::attributes::PasswordAlgorithm};
use std::{
hash::Hash,
net::SocketAddr,
ops::{Deref, DerefMut},
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
thread::{self, sleep},
time::Duration,
};
use ahash::{HashMap, HashMapExt};
use parking_lot::{Mutex, RwLock, RwLockReadGuard};
use rand::{Rng, distr::Alphanumeric};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Identifier {
pub source: SocketAddr,
pub external: SocketAddr,
pub interface: SocketAddr,
pub transport: Transport,
}
pub struct Table<K, V>(HashMap<K, V>);
impl<K, V> Default for Table<K, V> {
fn default() -> Self {
Self(HashMap::with_capacity(PortRange::default().size()))
}
}
impl<K, V> AsRef<HashMap<K, V>> for Table<K, V> {
fn as_ref(&self) -> &HashMap<K, V> {
&self.0
}
}
impl<K, V> Deref for Table<K, V> {
type Target = HashMap<K, V>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<K, V> DerefMut for Table<K, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub struct ReadLock<'a, 'b, K, R> {
pub key: &'a K,
pub lock: RwLockReadGuard<'b, R>,
}
impl<'a, 'b, K, V> ReadLock<'a, 'b, K, Table<K, V>>
where
K: Eq + Hash,
{
pub fn get_ref(&self) -> Option<&V> {
self.lock.get(self.key)
}
}
#[derive(Default)]
pub struct Timer(AtomicU64);
impl Timer {
pub fn get(&self) -> u64 {
self.0.load(Ordering::Relaxed)
}
pub fn add(&self) -> u64 {
self.0.fetch_add(1, Ordering::Relaxed) + 1
}
}
pub const DEFAULT_SESSION_LIFETIME: u64 = 600;
#[derive(Debug, Clone)]
pub enum Session {
New {
nonce: String,
expires: u64,
},
Authenticated {
nonce: String,
username: String,
password: Password,
allocated_port: Option<u16>,
port_relay_table: HashMap< u16, Identifier>,
channel_relay_table: HashMap< u16, Identifier>,
expires: u64,
},
}
impl Session {
pub fn nonce(&self) -> &str {
match self {
Session::New { nonce, .. } | Session::Authenticated { nonce, .. } => nonce,
}
}
pub fn is_new(&self) -> bool {
matches!(self, Session::New { .. })
}
pub fn is_authenticated(&self) -> bool {
matches!(self, Session::Authenticated { .. })
}
}
pub struct SessionManagerOptions<T> {
pub port_range: PortRange,
pub handler: T,
}
pub struct SessionManager<T> {
sessions: RwLock<Table<Identifier, Session>>,
port_allocator: Mutex<PortAllocator>,
port_mapping_table: RwLock<Table< u16, Identifier>>,
timer: Timer,
handler: T,
}
impl<T> SessionManager<T>
where
T: ServiceHandler,
{
pub fn new(options: SessionManagerOptions<T>) -> Arc<Self> {
let this = Arc::new(Self {
port_allocator: Mutex::new(PortAllocator::new(options.port_range)),
port_mapping_table: RwLock::new(Table::default()),
sessions: RwLock::new(Table::default()),
timer: Timer::default(),
handler: options.handler,
});
let this_ = Arc::downgrade(&this);
thread::spawn(move || {
let mut identifiers = Vec::with_capacity(255);
while let Some(this) = this_.upgrade() {
let now = this.timer.add();
{
{
this.sessions
.read()
.iter()
.filter(|(_, v)| match v {
Session::New { expires, .. }
| Session::Authenticated { expires, .. } => *expires <= now,
})
.for_each(|(k, _)| identifiers.push(*k));
}
if !identifiers.is_empty() {
this.remove_session(&identifiers);
identifiers.clear();
}
}
sleep(Duration::from_secs(1));
}
});
this
}
fn remove_session(&self, identifiers: &[Identifier]) {
let mut sessions = self.sessions.write();
let mut port_allocator = self.port_allocator.lock();
let mut port_mapping_table = self.port_mapping_table.write();
identifiers.iter().for_each(|k| {
if let Some(Session::Authenticated {
allocated_port,
username,
..
}) = sessions.remove(k)
{
if let Some(port) = allocated_port {
port_mapping_table.remove(&port);
port_allocator.deallocate(port);
}
self.handler.on_destroy(k, &username);
}
});
}
pub fn get_session_or_default<'a, 'b>(
&'a self,
key: &'b Identifier,
) -> ReadLock<'b, 'a, Identifier, Table<Identifier, Session>> {
{
let lock = self.sessions.read();
if lock.contains_key(key) {
return ReadLock { lock, key };
}
}
{
self.sessions.write().insert(
*key,
Session::New {
nonce: generate_nonce(),
expires: self.timer.get() + DEFAULT_SESSION_LIFETIME,
},
);
}
ReadLock {
lock: self.sessions.read(),
key,
}
}
pub fn get_session<'a, 'b>(
&'a self,
key: &'b Identifier,
) -> ReadLock<'b, 'a, Identifier, Table<Identifier, Session>> {
ReadLock {
lock: self.sessions.read(),
key,
}
}
pub async fn get_password(
&self,
identifier: &Identifier,
username: &str,
algorithm: PasswordAlgorithm,
) -> Option<Password> {
{
if let Some(Session::Authenticated { password, .. }) =
self.sessions.read().get(identifier)
{
return Some(*password);
}
}
let password = self
.handler
.get_password(identifier, username, algorithm)
.await?;
{
let mut lock = self.sessions.write();
let nonce = if let Some(Session::New { nonce, .. }) = lock.remove(identifier) {
nonce
} else {
generate_nonce()
};
lock.insert(
*identifier,
Session::Authenticated {
port_relay_table: HashMap::with_capacity(10),
channel_relay_table: HashMap::with_capacity(10),
expires: self.timer.get() + DEFAULT_SESSION_LIFETIME,
username: username.to_string(),
allocated_port: None,
password,
nonce,
},
);
}
Some(password)
}
pub fn allocated(&self) -> usize {
self.port_allocator.lock().len()
}
pub fn allocate(&self, identifier: &Identifier, lifetime: Option<u32>) -> Option<u16> {
let mut lock = self.sessions.write();
if let Some(Session::Authenticated {
allocated_port,
expires,
..
}) = lock.get_mut(identifier)
{
if let Some(port) = allocated_port {
return Some(*port);
}
let port = self.port_allocator.lock().allocate(None)?;
*allocated_port = Some(port);
*expires =
self.timer.get() + (lifetime.unwrap_or(DEFAULT_SESSION_LIFETIME as u32) as u64);
self.port_mapping_table.write().insert(port, *identifier);
Some(port)
} else {
None
}
}
pub fn create_permission(&self, identifier: &Identifier, ports: &[u16]) -> bool {
if let Some(Session::Authenticated {
allocated_port,
port_relay_table,
..
}) = self.sessions.write().get_mut(identifier)
{
let Some(local_port) = *allocated_port else {
return false;
};
if ports.contains(&local_port) {
return false;
}
for port in ports {
if let Some(peer) = self.port_mapping_table.read().get(port) {
if let Some(relay) = port_relay_table.get(&port)
&& relay != peer
{
return false;
}
port_relay_table.insert(*port, *peer);
} else {
return false;
}
}
true
} else {
false
}
}
pub fn bind_channel(&self, identifier: &Identifier, port: u16, channel: u16) -> bool {
{
if let Some(Session::Authenticated {
channel_relay_table,
..
}) = self.sessions.write().get_mut(identifier)
{
if let Some(peer) = self.port_mapping_table.read().get(&port) {
if let Some(relay) = channel_relay_table.get(&channel)
&& relay != peer
{
return false;
}
channel_relay_table.insert(channel, *peer);
} else {
return false;
};
} else {
return false;
};
}
if !self.create_permission(identifier, &[port]) {
return false;
}
true
}
pub fn get_channel_relay_address(
&self,
identifier: &Identifier,
channel: u16,
) -> Option<(/* peer channel */ u16, Identifier)> {
let session = self.sessions.read();
if let Session::Authenticated {
channel_relay_table,
..
} = session.get(identifier)?
{
let peer = channel_relay_table.get(&channel)?;
if let Session::Authenticated {
channel_relay_table,
..
} = session.get(peer)?
{
let (peer_channel, _) =
channel_relay_table.iter().find(|(_, v)| *v == identifier)?;
Some((*peer_channel, *peer))
} else {
None
}
} else {
None
}
}
pub fn get_port_relay_address(
&self,
identifier: &Identifier,
port: u16,
) -> Option<(/* local port */ u16, Identifier)> {
if let Session::Authenticated {
port_relay_table,
allocated_port,
..
} = self.sessions.read().get(identifier)?
{
Some(((*allocated_port)?, port_relay_table.get(&port).copied()?))
} else {
None
}
}
pub fn refresh(&self, identifier: &Identifier, lifetime: u32) -> bool {
if lifetime > 3600 {
return false;
}
if lifetime == 0 {
self.remove_session(&[*identifier]);
} else if let Some(Session::Authenticated { expires, .. }) =
self.sessions.write().get_mut(identifier)
{
*expires = self.timer.get() + lifetime as u64;
} else {
return false;
}
true
}
}
fn generate_nonce() -> String {
rand::rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect()
}