use std::{net::SocketAddr, sync::Arc};
use bytes::BytesMut;
use rand::seq::IteratorRandom;
use super::{
InterfaceAddr, Service, ServiceHandler,
session::{DEFAULT_SESSION_LIFETIME, Identifier, SessionManager},
};
use crate::{
codec::{
DecodeResult, Decoder,
channel_data::ChannelData,
crypto::Password,
message::{
Message, MessageEncoder,
attributes::{address::IpAddrExt, error::ErrorType, *},
methods::*,
},
},
service::Transport,
};
struct Request<'a, 'b, T, M>
where
T: ServiceHandler,
{
response_buffer: &'b mut BytesMut,
state: &'a RouterState<T>,
payload: &'a M,
}
impl<'a, 'b, T> Request<'a, 'b, T, Message<'a>>
where
T: ServiceHandler,
{
#[inline(always)]
fn verify_ip(&self, address: &SocketAddr) -> bool {
self.state
.interfaces
.iter()
.any(|item| item.external.ip() == address.ip())
}
#[inline(always)]
async fn verify(&self) -> Option<(&str, Password)> {
let username = self.payload.get::<UserName>()?;
let algorithm = self
.payload
.get::<PasswordAlgorithm>()
.unwrap_or(PasswordAlgorithm::Md5);
let password = self
.state
.manager
.get_password(&self.state.id, username, algorithm)
.await?;
if self.payload.verify(&password).is_err() {
return None;
}
Some((username, password))
}
}
#[derive(Debug)]
pub struct RouteResult {
pub method: Option<Method>,
pub relay: Option<Identifier>,
}
pub(crate) struct RouterState<T>
where
T: ServiceHandler,
{
pub id: Identifier,
pub realm: String,
pub software: String,
pub manager: Arc<SessionManager<T>>,
pub interfaces: Arc<Vec<InterfaceAddr>>,
pub handler: T,
}
pub struct Router<T>
where
T: ServiceHandler,
{
state: RouterState<T>,
decoder: Decoder,
}
impl<T> Router<T>
where
T: ServiceHandler + Clone,
{
pub fn new(service: &Service<T>, id: Identifier) -> Self {
Self {
decoder: Decoder::default(),
state: RouterState {
interfaces: service.interfaces.clone(),
software: service.software.clone(),
handler: service.handler.clone(),
manager: service.manager.clone(),
realm: service.realm.clone(),
id,
},
}
}
pub async fn route(
&mut self,
bytes: &[u8],
response_buffer: &mut BytesMut,
) -> Result<Option<RouteResult>, crate::codec::Error> {
Ok(match self.decoder.decode(bytes)? {
DecodeResult::ChannelData(channel) => channel_data(Request {
state: &self.state,
payload: &channel,
response_buffer,
}),
DecodeResult::Message(message) => {
let req = Request {
state: &self.state,
payload: &message,
response_buffer,
};
match req.payload.method() {
BINDING_REQUEST => binding(req),
ALLOCATE_REQUEST => allocate(req).await,
CREATE_PERMISSION_REQUEST => create_permission(req).await,
CHANNEL_BIND_REQUEST => channel_bind(req).await,
REFRESH_REQUEST => refresh(req).await,
SEND_INDICATION => indication(req),
_ => None,
}
}
})
}
}
fn reject<T>(req: Request<'_, '_, T, Message<'_>>, error: ErrorType) -> Option<RouteResult>
where
T: ServiceHandler,
{
let method = req.payload.method().error()?;
{
let mut message = MessageEncoder::extend(method, req.payload, req.response_buffer);
message.append::<ErrorCode>(ErrorCode::from(error));
if error == ErrorType::Unauthorized {
message.append::<Realm>(&req.state.realm);
message.append::<Nonce>(
req.state
.manager
.get_session_or_default(&req.state.id)
.get_ref()?
.nonce(),
);
message.append::<PasswordAlgorithms>(vec![
PasswordAlgorithm::Md5,
PasswordAlgorithm::Sha256,
]);
}
message.flush(None).ok()?;
}
Some(RouteResult {
method: Some(method),
relay: None,
})
}
fn binding<T>(req: Request<'_, '_, T, Message<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
{
let mut message =
MessageEncoder::extend(BINDING_RESPONSE, req.payload, req.response_buffer);
message.append::<XorMappedAddress>(req.state.id.source);
message.append::<MappedAddress>(req.state.id.source);
message.append::<ResponseOrigin>(req.state.id.external);
message.append::<Software>(&req.state.software);
message.flush(None).ok()?;
}
Some(RouteResult {
method: Some(BINDING_RESPONSE),
relay: None,
})
}
async fn allocate<T>(req: Request<'_, '_, T, Message<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
let xor_relayed_ip = {
let mut ip = req.state.id.external.ip();
let request_transport = if let Some(it) = req.payload.get::<RequestedTransport>() {
match it {
RequestedTransport::Tcp => Transport::Tcp,
RequestedTransport::Udp => Transport::Udp,
}
} else {
return reject(req, ErrorType::BadRequest);
};
let request_family = req
.payload
.get::<RequestedAddressFamily>()
.unwrap_or_else(|| ip.family());
if request_transport != req.state.id.transport || request_family != ip.family() {
if let Some(addr) = req
.state
.interfaces
.iter()
.filter(|addr| {
addr.transport == request_transport
&& addr.external.ip().family() == request_family
})
.choose(&mut rand::rng())
{
ip = addr.external.ip();
} else {
return reject(
req,
if request_family != ip.family() {
ErrorType::AddressFamilyNotSupported
} else {
ErrorType::UnsupportedTransportAddress
},
);
}
}
ip
};
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
let lifetime = req.payload.get::<Lifetime>();
let Some(port) = req.state.manager.allocate(&req.state.id, lifetime) else {
return reject(req, ErrorType::AllocationQuotaReached);
};
req.state
.handler
.on_allocated(&req.state.id, username, port);
{
let mut message =
MessageEncoder::extend(ALLOCATE_RESPONSE, req.payload, req.response_buffer);
message.append::<XorRelayedAddress>(SocketAddr::new(xor_relayed_ip, port));
message.append::<XorMappedAddress>(req.state.id.source);
message.append::<Lifetime>(lifetime.unwrap_or(DEFAULT_SESSION_LIFETIME as u32));
message.append::<Software>(&req.state.software);
message.flush(Some(&password)).ok()?;
}
Some(RouteResult {
method: Some(ALLOCATE_RESPONSE),
relay: None,
})
}
async fn create_permission<T>(req: Request<'_, '_, T, Message<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
let mut ports = Vec::with_capacity(15);
for it in req.payload.get_all::<XorPeerAddress>() {
if !req.verify_ip(&it) {
return reject(req, ErrorType::PeerAddressFamilyMismatch);
}
ports.push(it.port());
}
if !req.state.manager.create_permission(&req.state.id, &ports) {
return reject(req, ErrorType::Forbidden);
}
req.state
.handler
.on_create_permission(&req.state.id, username, &ports);
{
MessageEncoder::extend(CREATE_PERMISSION_RESPONSE, req.payload, req.response_buffer)
.flush(Some(&password))
.ok()?;
}
Some(RouteResult {
method: Some(CREATE_PERMISSION_RESPONSE),
relay: None,
})
}
async fn channel_bind<T>(req: Request<'_, '_, T, Message<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
let Some(peer) = req.payload.get::<XorPeerAddress>() else {
return reject(req, ErrorType::BadRequest);
};
if !req.verify_ip(&peer) {
return reject(req, ErrorType::PeerAddressFamilyMismatch);
}
let Some(number) = req.payload.get::<ChannelNumber>() else {
return reject(req, ErrorType::BadRequest);
};
if !(0x4000..=0xFFFF).contains(&number) {
return reject(req, ErrorType::BadRequest);
}
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
if !req
.state
.manager
.bind_channel(&req.state.id, peer.port(), number)
{
return reject(req, ErrorType::Forbidden);
}
req.state
.handler
.on_channel_bind(&req.state.id, username, number);
{
MessageEncoder::extend(CHANNEL_BIND_RESPONSE, req.payload, req.response_buffer)
.flush(Some(&password))
.ok()?;
}
Some(RouteResult {
method: Some(CHANNEL_BIND_RESPONSE),
relay: None,
})
}
#[rustfmt::skip]
fn indication<T>(req: Request<'_, '_, T, Message<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
let peer = req.payload.get::<XorPeerAddress>()?;
let data = req.payload.get::<Data>()?;
let (local_port, relay) = req.state.manager.get_port_relay_address(&req.state.id, peer.port())?;
{
let mut message = MessageEncoder::extend(DATA_INDICATION, req.payload, req.response_buffer);
message.append::<XorPeerAddress>(SocketAddr::new(req.state.id.external.ip(), local_port));
message.append::<Data>(data);
message.flush(None).ok()?;
}
Some(RouteResult {
method: Some(DATA_INDICATION),
relay: Some(relay),
})
}
async fn refresh<T>(req: Request<'_, '_, T, Message<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
let lifetime = req
.payload
.get::<Lifetime>()
.unwrap_or(DEFAULT_SESSION_LIFETIME as u32);
if !req.state.manager.refresh(&req.state.id, lifetime) {
return reject(req, ErrorType::AllocationMismatch);
}
req.state
.handler
.on_refresh(&req.state.id, username, lifetime);
{
let mut message =
MessageEncoder::extend(REFRESH_RESPONSE, req.payload, req.response_buffer);
message.append::<Lifetime>(lifetime);
message.flush(Some(&password)).ok()?;
}
Some(RouteResult {
method: Some(REFRESH_RESPONSE),
relay: None,
})
}
fn channel_data<T>(req: Request<'_, '_, T, ChannelData<'_>>) -> Option<RouteResult>
where
T: ServiceHandler,
{
let (relay_channel, relay) = req
.state
.manager
.get_channel_relay_address(&req.state.id, req.payload.number())?;
{
ChannelData::new(relay_channel, req.payload.bytes()).encode(req.response_buffer);
}
Some(RouteResult {
relay: Some(relay),
method: None,
})
}