pub mod ports;
use super::{
ServiceHandler,
session::ports::{PortAllocator, PortRange},
};
use crate::codec::{crypto::Password, message::attributes::PasswordAlgorithm};
use std::{
hash::Hash,
net::SocketAddr,
ops::{Deref, DerefMut},
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
thread::{self, sleep},
time::Duration,
};
use ahash::{HashMap, HashMapExt};
use parking_lot::{Mutex, RwLock, RwLockReadGuard};
use rand::{Rng, distr::Alphanumeric};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Identifier {
source: SocketAddr,
interface: SocketAddr,
}
impl Identifier {
#[inline]
pub fn new(source: SocketAddr, interface: SocketAddr) -> Self {
Self { source, interface }
}
#[inline]
pub fn source(&self) -> SocketAddr {
self.source
}
#[inline]
pub fn interface(&self) -> SocketAddr {
self.interface
}
#[inline]
pub fn source_mut(&mut self) -> &mut SocketAddr {
&mut self.source
}
#[inline]
pub fn interface_mut(&mut self) -> &mut SocketAddr {
&mut self.interface
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Endpoint {
source: SocketAddr,
endpoint: SocketAddr,
}
impl Endpoint {
#[inline]
pub fn new(source: SocketAddr, endpoint: SocketAddr) -> Self {
Self { source, endpoint }
}
#[inline]
pub fn source(&self) -> SocketAddr {
self.source
}
#[inline]
pub fn endpoint(&self) -> SocketAddr {
self.endpoint
}
#[inline]
pub fn source_mut(&mut self) -> &mut SocketAddr {
&mut self.source
}
#[inline]
pub fn endpoint_mut(&mut self) -> &mut SocketAddr {
&mut self.endpoint
}
}
pub struct Table<K, V>(HashMap<K, V>);
impl<K, V> Default for Table<K, V> {
fn default() -> Self {
Self(HashMap::with_capacity(PortRange::default().size()))
}
}
impl<K, V> AsRef<HashMap<K, V>> for Table<K, V> {
fn as_ref(&self) -> &HashMap<K, V> {
&self.0
}
}
impl<K, V> Deref for Table<K, V> {
type Target = HashMap<K, V>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<K, V> DerefMut for Table<K, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub struct ReadLock<'a, 'b, K, R> {
pub key: &'a K,
pub lock: RwLockReadGuard<'b, R>,
}
impl<'a, 'b, K, V> ReadLock<'a, 'b, K, Table<K, V>>
where
K: Eq + Hash,
{
pub fn get_ref(&self) -> Option<&V> {
self.lock.get(self.key)
}
}
#[derive(Default)]
pub struct Timer(AtomicU64);
impl Timer {
pub fn get(&self) -> u64 {
self.0.load(Ordering::Relaxed)
}
pub fn add(&self) -> u64 {
self.0.fetch_add(1, Ordering::Relaxed) + 1
}
}
pub const DEFAULT_SESSION_LIFETIME: u64 = 600;
#[derive(Debug, Clone)]
pub enum Session {
New {
nonce: String,
expires: u64,
},
Authenticated {
nonce: String,
username: String,
password: Password,
allocate_port: Option<u16>,
allocate_channels: Vec<u16>,
permissions: Vec<u16>,
expires: u64,
},
}
impl Session {
pub fn nonce(&self) -> &str {
match self {
Session::New { nonce, .. } | Session::Authenticated { nonce, .. } => nonce,
}
}
pub fn is_new(&self) -> bool {
matches!(self, Session::New { .. })
}
pub fn is_authenticated(&self) -> bool {
matches!(self, Session::Authenticated { .. })
}
}
pub struct SessionManagerOptions<T> {
pub port_range: PortRange,
pub handler: T,
}
pub struct SessionManager<T> {
sessions: RwLock<Table<Identifier, Session>>,
port_allocator: Mutex<PortAllocator>,
port_mapping_table: RwLock<Table< u16, Identifier>>,
port_relay_table: RwLock<Table<Identifier, HashMap< u16, Endpoint>>>,
channel_relay_table: RwLock<Table<Identifier, HashMap< u16, Endpoint>>>,
timer: Timer,
handler: T,
}
impl<T> SessionManager<T>
where
T: ServiceHandler,
{
pub fn new(options: SessionManagerOptions<T>) -> Arc<Self> {
let this = Arc::new(Self {
port_allocator: Mutex::new(PortAllocator::new(options.port_range)),
channel_relay_table: RwLock::new(Table::default()),
port_mapping_table: RwLock::new(Table::default()),
port_relay_table: RwLock::new(Table::default()),
sessions: RwLock::new(Table::default()),
timer: Timer::default(),
handler: options.handler,
});
let this_ = Arc::downgrade(&this);
thread::spawn(move || {
let mut identifiers = Vec::with_capacity(255);
while let Some(this) = this_.upgrade() {
let now = this.timer.add();
{
{
this.sessions
.read()
.iter()
.filter(|(_, v)| match v {
Session::New { expires, .. }
| Session::Authenticated { expires, .. } => *expires <= now,
})
.for_each(|(k, _)| identifiers.push(*k));
}
if !identifiers.is_empty() {
this.remove_session(&identifiers);
identifiers.clear();
}
}
sleep(Duration::from_secs(1));
}
});
this
}
fn remove_session(&self, identifiers: &[Identifier]) {
let mut sessions = self.sessions.write();
let mut port_allocator = self.port_allocator.lock();
let mut port_mapping_table = self.port_mapping_table.write();
let mut port_relay_table = self.port_relay_table.write();
let mut channel_relay_table = self.channel_relay_table.write();
identifiers.iter().for_each(|k| {
port_relay_table.remove(k);
channel_relay_table.remove(k);
if let Some(Session::Authenticated {
allocate_port,
username,
..
}) = sessions.remove(k)
{
if let Some(port) = allocate_port {
port_mapping_table.remove(&port);
port_allocator.deallocate(port);
}
self.handler.on_destroy(k, &username);
}
});
}
pub fn get_session_or_default<'a, 'b>(
&'a self,
key: &'b Identifier,
) -> ReadLock<'b, 'a, Identifier, Table<Identifier, Session>> {
{
let lock = self.sessions.read();
if lock.contains_key(key) {
return ReadLock {
lock: self.sessions.read(),
key,
};
}
}
{
self.sessions.write().insert(
*key,
Session::New {
nonce: generate_nonce(),
expires: self.timer.get() + DEFAULT_SESSION_LIFETIME,
},
);
}
ReadLock {
lock: self.sessions.read(),
key,
}
}
pub fn get_session<'a, 'b>(
&'a self,
key: &'b Identifier,
) -> ReadLock<'b, 'a, Identifier, Table<Identifier, Session>> {
ReadLock {
lock: self.sessions.read(),
key,
}
}
pub async fn get_password(
&self,
identifier: &Identifier,
username: &str,
algorithm: PasswordAlgorithm,
) -> Option<Password> {
{
if let Some(Session::Authenticated { password, .. }) =
self.sessions.read().get(identifier)
{
return Some(*password);
}
}
let password = self.handler.get_password(username, algorithm).await?;
{
let mut lock = self.sessions.write();
let nonce = if let Some(Session::New { nonce, .. }) = lock.remove(identifier) {
nonce
} else {
generate_nonce()
};
lock.insert(
*identifier,
Session::Authenticated {
allocate_channels: Vec::with_capacity(10),
permissions: Vec::with_capacity(10),
expires: self.timer.get() + DEFAULT_SESSION_LIFETIME,
username: username.to_string(),
allocate_port: None,
password,
nonce,
},
);
}
Some(password)
}
pub fn allocated(&self) -> usize {
self.port_allocator.lock().len()
}
pub fn allocate(&self, identifier: &Identifier, lifetime: Option<u32>) -> Option<u16> {
let mut lock = self.sessions.write();
if let Some(Session::Authenticated {
allocate_port,
expires,
..
}) = lock.get_mut(identifier)
{
if let Some(port) = allocate_port {
return Some(*port);
}
let port = self.port_allocator.lock().allocate(None)?;
*allocate_port = Some(port);
*expires =
self.timer.get() + (lifetime.unwrap_or(DEFAULT_SESSION_LIFETIME as u32) as u64);
self.port_mapping_table.write().insert(port, *identifier);
Some(port)
} else {
None
}
}
pub fn create_permission(
&self,
identifier: &Identifier,
endpoint: &SocketAddr,
ports: &[u16],
) -> bool {
let mut sessions = self.sessions.write();
let mut port_relay_table = self.port_relay_table.write();
let port_mapping_table = self.port_mapping_table.read();
if let Some(Session::Authenticated {
allocate_port,
permissions,
..
}) = sessions.get_mut(identifier)
{
let local_port = if let Some(it) = allocate_port {
*it
} else {
return false;
};
if ports.contains(&local_port) {
return false;
}
let mut peers = Vec::with_capacity(15);
for port in ports {
if let Some(it) = port_mapping_table.get(port) {
peers.push((it, *port));
} else {
return false;
}
}
for (peer, port) in peers {
port_relay_table
.entry(*peer)
.or_insert_with(|| HashMap::with_capacity(20))
.insert(
local_port,
Endpoint {
source: identifier.source,
endpoint: *endpoint,
},
);
if !permissions.contains(&port) {
permissions.push(port);
}
}
true
} else {
false
}
}
pub fn bind_channel(
&self,
identifier: &Identifier,
endpoint: &SocketAddr,
port: u16,
channel: u16,
) -> bool {
let peer = if let Some(it) = self.port_mapping_table.read().get(&port) {
*it
} else {
return false;
};
{
let mut lock = self.sessions.write();
if let Some(Session::Authenticated {
allocate_channels, ..
}) = lock.get_mut(identifier)
{
if !allocate_channels.contains(&channel) {
allocate_channels.push(channel);
}
} else {
return false;
};
}
if !self.create_permission(identifier, endpoint, &[port]) {
return false;
}
self.channel_relay_table
.write()
.entry(peer)
.or_insert_with(|| HashMap::with_capacity(10))
.insert(
channel,
Endpoint {
source: identifier.source,
endpoint: *endpoint,
},
);
true
}
pub fn get_channel_relay_address(
&self,
identifier: &Identifier,
channel: u16,
) -> Option<Endpoint> {
self.channel_relay_table
.read()
.get(identifier)?
.get(&channel)
.copied()
}
pub fn get_relay_address(&self, identifier: &Identifier, port: u16) -> Option<Endpoint> {
self.port_relay_table
.read()
.get(identifier)?
.get(&port)
.copied()
}
pub fn refresh(&self, identifier: &Identifier, lifetime: u32) -> bool {
if lifetime > 3600 {
return false;
}
if lifetime == 0 {
self.remove_session(&[*identifier]);
} else if let Some(Session::Authenticated { expires, .. }) =
self.sessions.write().get_mut(identifier)
{
*expires = self.timer.get() + lifetime as u64;
} else {
return false;
}
true
}
}
fn generate_nonce() -> String {
rand::rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect()
}