use std::{net::SocketAddr, sync::Arc};
use bytes::BytesMut;
use super::{
Service, ServiceHandler,
session::{DEFAULT_SESSION_LIFETIME, Identifier, Session, SessionManager},
};
use crate::codec::{
DecodeResult, Decoder,
channel_data::ChannelData,
crypto::Password,
message::{
Message, MessageEncoder,
attributes::{error::ErrorType, *},
methods::*,
},
};
struct Request<'a, 'b, T, M>
where
T: ServiceHandler,
{
id: &'a Identifier,
encode_buffer: &'b mut BytesMut,
state: &'a RouterState<T>,
payload: &'a M,
}
impl<'a, 'b, T> Request<'a, 'b, T, Message<'a>>
where
T: ServiceHandler,
{
#[inline(always)]
fn verify_ip(&self, address: &SocketAddr) -> bool {
self.state
.interfaces
.iter()
.any(|item| item.ip() == address.ip())
}
#[inline(always)]
async fn verify(&self) -> Option<(&str, Password)> {
let username = self.payload.get::<UserName>()?;
let algorithm = self
.payload
.get::<PasswordAlgorithm>()
.unwrap_or(PasswordAlgorithm::Md5);
let password = self
.state
.manager
.get_password(self.id, username, algorithm)
.await?;
if self.payload.verify(&password).is_err() {
return None;
}
Some((username, password))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Target {
pub endpoint: Option<SocketAddr>,
pub relay: Option<SocketAddr>,
}
#[derive(Debug)]
pub struct Response<'a> {
pub method: Option<Method>,
pub bytes: &'a [u8],
pub target: Target,
}
pub(crate) struct RouterState<T>
where
T: ServiceHandler,
{
pub realm: String,
pub software: String,
pub manager: Arc<SessionManager<T>>,
pub endpoint: SocketAddr,
pub interface: SocketAddr,
pub interfaces: Arc<Vec<SocketAddr>>,
pub handler: T,
}
pub struct Router<T>
where
T: ServiceHandler,
{
current_id: Identifier,
state: RouterState<T>,
decoder: Decoder,
bytes: BytesMut,
}
impl<T> Router<T>
where
T: ServiceHandler + Clone,
{
pub fn new(service: &Service<T>, endpoint: SocketAddr, interface: SocketAddr) -> Self {
Self {
bytes: BytesMut::with_capacity(4096),
decoder: Decoder::default(),
current_id: Identifier::new(
"0.0.0.0:0"
.parse()
.expect("Failed to parse placeholder address"),
interface,
),
state: RouterState {
interfaces: service.interfaces.clone(),
software: service.software.clone(),
handler: service.handler.clone(),
manager: service.manager.clone(),
realm: service.realm.clone(),
interface,
endpoint,
},
}
}
pub async fn route<'a, 'b: 'a>(
&'b mut self,
bytes: &'b [u8],
address: SocketAddr,
) -> Result<Option<Response<'a>>, crate::codec::Error> {
{
*self.current_id.source_mut() = address;
}
Ok(match self.decoder.decode(bytes)? {
DecodeResult::ChannelData(channel) => channel_data(
bytes,
Request {
id: &self.current_id,
state: &self.state,
encode_buffer: &mut self.bytes,
payload: &channel,
},
),
DecodeResult::Message(message) => {
let req = Request {
id: &self.current_id,
state: &self.state,
encode_buffer: &mut self.bytes,
payload: &message,
};
match req.payload.method() {
BINDING_REQUEST => binding(req),
ALLOCATE_REQUEST => allocate(req).await,
CREATE_PERMISSION_REQUEST => create_permission(req).await,
CHANNEL_BIND_REQUEST => channel_bind(req).await,
REFRESH_REQUEST => refresh(req).await,
SEND_INDICATION => indication(req),
_ => None,
}
}
})
}
}
fn reject<'a, T>(req: Request<'_, 'a, T, Message<'_>>, error: ErrorType) -> Option<Response<'a>>
where
T: ServiceHandler,
{
let method = req.payload.method().error()?;
{
let mut message = MessageEncoder::extend(method, req.payload, req.encode_buffer);
message.append::<ErrorCode>(ErrorCode::from(error));
if error == ErrorType::Unauthorized {
message.append::<Realm>(&req.state.realm);
message.append::<Nonce>(
req.state
.manager
.get_session_or_default(req.id)
.get_ref()?
.nonce(),
);
message.append::<PasswordAlgorithms>(vec![
PasswordAlgorithm::Md5,
PasswordAlgorithm::Sha256,
]);
}
message.flush(None).ok()?;
}
Some(Response {
target: Target::default(),
bytes: req.encode_buffer,
method: Some(method),
})
}
fn binding<'a, T>(req: Request<'_, 'a, T, Message<'_>>) -> Option<Response<'a>>
where
T: ServiceHandler,
{
{
let mut message = MessageEncoder::extend(BINDING_RESPONSE, req.payload, req.encode_buffer);
message.append::<XorMappedAddress>(req.id.source());
message.append::<MappedAddress>(req.id.source());
message.append::<ResponseOrigin>(req.state.interface);
message.append::<Software>(&req.state.software);
message.flush(None).ok()?;
}
Some(Response {
method: Some(BINDING_RESPONSE),
target: Target::default(),
bytes: req.encode_buffer,
})
}
async fn allocate<'a, T>(req: Request<'_, 'a, T, Message<'_>>) -> Option<Response<'a>>
where
T: ServiceHandler,
{
if req.payload.get::<RequestedTransport>().is_none() {
return reject(req, ErrorType::ServerError);
}
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
let lifetime = req.payload.get::<Lifetime>();
let Some(port) = req.state.manager.allocate(req.id, lifetime) else {
return reject(req, ErrorType::AllocationQuotaReached);
};
req.state.handler.on_allocated(req.id, username, port);
{
let mut message = MessageEncoder::extend(ALLOCATE_RESPONSE, req.payload, req.encode_buffer);
message.append::<XorRelayedAddress>(SocketAddr::new(req.state.interface.ip(), port));
message.append::<XorMappedAddress>(req.id.source());
message.append::<Lifetime>(lifetime.unwrap_or(DEFAULT_SESSION_LIFETIME as u32));
message.append::<Software>(&req.state.software);
message.flush(Some(&password)).ok()?;
}
Some(Response {
target: Target::default(),
method: Some(ALLOCATE_RESPONSE),
bytes: req.encode_buffer,
})
}
async fn channel_bind<'a, T>(req: Request<'_, 'a, T, Message<'_>>) -> Option<Response<'a>>
where
T: ServiceHandler,
{
let Some(peer) = req.payload.get::<XorPeerAddress>() else {
return reject(req, ErrorType::BadRequest);
};
if !req.verify_ip(&peer) {
return reject(req, ErrorType::PeerAddressFamilyMismatch);
}
let Some(number) = req.payload.get::<ChannelNumber>() else {
return reject(req, ErrorType::BadRequest);
};
if !(0x4000..=0x7FFF).contains(&number) {
return reject(req, ErrorType::BadRequest);
}
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
if !req
.state
.manager
.bind_channel(req.id, &req.state.endpoint, peer.port(), number)
{
return reject(req, ErrorType::Forbidden);
}
req.state.handler.on_channel_bind(req.id, username, number);
{
MessageEncoder::extend(CHANNEL_BIND_RESPONSE, req.payload, req.encode_buffer)
.flush(Some(&password))
.ok()?;
}
Some(Response {
target: Target::default(),
method: Some(CHANNEL_BIND_RESPONSE),
bytes: req.encode_buffer,
})
}
async fn create_permission<'a, T>(req: Request<'_, 'a, T, Message<'_>>) -> Option<Response<'a>>
where
T: ServiceHandler,
{
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
let mut ports = Vec::with_capacity(15);
for it in req.payload.get_all::<XorPeerAddress>() {
if !req.verify_ip(&it) {
return reject(req, ErrorType::PeerAddressFamilyMismatch);
}
ports.push(it.port());
}
if !req
.state
.manager
.create_permission(req.id, &req.state.endpoint, &ports)
{
return reject(req, ErrorType::Forbidden);
}
req.state
.handler
.on_create_permission(req.id, username, &ports);
{
MessageEncoder::extend(CREATE_PERMISSION_RESPONSE, req.payload, req.encode_buffer)
.flush(Some(&password))
.ok()?;
}
Some(Response {
method: Some(CREATE_PERMISSION_RESPONSE),
target: Target::default(),
bytes: req.encode_buffer,
})
}
#[rustfmt::skip]
fn indication<'a, T>(req: Request<'_, 'a, T, Message<'_>>) -> Option<Response<'a>>
where
T: ServiceHandler,
{
let peer = req.payload.get::<XorPeerAddress>()?;
let data = req.payload.get::<Data>()?;
if let Some(Session::Authenticated { allocate_port, .. }) =
req.state.manager.get_session(req.id).get_ref() && let Some(local_port) = *allocate_port
{
let relay = req.state.manager.get_relay_address(req.id, peer.port())?;
{
let mut message = MessageEncoder::extend(DATA_INDICATION, req.payload, req.encode_buffer);
message.append::<XorPeerAddress>(SocketAddr::new(req.state.interface.ip(), local_port));
message.append::<Data>(data);
message.flush(None).ok()?;
}
return Some(Response {
method: Some(DATA_INDICATION),
bytes: req.encode_buffer,
target: Target {
relay: Some(relay.source()),
endpoint: if req.state.endpoint != relay.endpoint() {
Some(relay.endpoint())
} else {
None
},
},
});
}
None
}
async fn refresh<'a, T>(req: Request<'_, 'a, T, Message<'_>>) -> Option<Response<'a>>
where
T: ServiceHandler,
{
let Some((username, password)) = req.verify().await else {
return reject(req, ErrorType::Unauthorized);
};
let lifetime = req
.payload
.get::<Lifetime>()
.unwrap_or(DEFAULT_SESSION_LIFETIME as u32);
if !req.state.manager.refresh(req.id, lifetime) {
return reject(req, ErrorType::AllocationMismatch);
}
req.state.handler.on_refresh(req.id, username, lifetime);
{
let mut message = MessageEncoder::extend(REFRESH_RESPONSE, req.payload, req.encode_buffer);
message.append::<Lifetime>(lifetime);
message.flush(Some(&password)).ok()?;
}
Some(Response {
target: Target::default(),
method: Some(REFRESH_RESPONSE),
bytes: req.encode_buffer,
})
}
fn channel_data<'a, T>(
bytes: &'a [u8],
req: Request<'_, 'a, T, ChannelData<'_>>,
) -> Option<Response<'a>>
where
T: ServiceHandler,
{
let relay = req
.state
.manager
.get_channel_relay_address(req.id, req.payload.number())?;
Some(Response {
bytes,
target: Target {
relay: Some(relay.source()),
endpoint: if req.state.endpoint != relay.endpoint() {
Some(relay.endpoint())
} else {
None
},
},
method: None,
})
}