use std::{fmt::Display, future::Future, net::SocketAddr, sync::Arc, time::Duration};
use async_trait::async_trait;
use axum::{
extract::{Json as Body, Query, State},
http::HeaderMap,
response::IntoResponse,
routing::{get, post},
Router,
};
use reqwest::{Client, ClientBuilder, Response, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
#[derive(Deserialize, Serialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Transport {
TCP,
UDP,
}
#[derive(Deserialize, Serialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct SessionAddr {
pub address: SocketAddr,
pub interface: SocketAddr,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Interface {
pub transport: Transport,
pub bind: SocketAddr,
pub external: SocketAddr,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Info {
pub software: String,
pub uptime: u64,
pub port_allocated: u16,
pub port_capacity: u16,
pub interfaces: Vec<Interface>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Session {
pub username: String,
pub password: String,
pub channels: Vec<u16>,
pub port: Option<u16>,
pub expires: u32,
pub permissions: Vec<u16>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Statistics {
pub received_bytes: u64,
pub send_bytes: u64,
pub received_pkts: u64,
pub send_pkts: u64,
pub error_pkts: u64,
}
impl<'a> Display for SessionAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
format!("address={}&interface={}", self.address, self.interface)
)
}
}
#[derive(Debug)]
pub struct Message<T> {
pub realm: String,
pub nonce: String,
pub payload: T,
}
impl<T> Message<T> {
async fn from_res<F: Future<Output = Option<T>>>(
res: Response,
handler: impl FnOnce(Response) -> F,
) -> Option<Self> {
let (realm, nonce) = get_realm_and_nonce(res.headers())?;
Some(Self {
realm: realm.to_string(),
nonce: nonce.to_string(),
payload: handler(res).await?,
})
}
}
pub struct Controller {
client: Client,
server: String,
}
impl Controller {
pub fn new(server: &str) -> Result<Self, reqwest::Error> {
Ok(Self {
server: server.to_string(),
client: ClientBuilder::new()
.timeout(Duration::from_secs(5))
.build()?,
})
}
pub async fn get_info(&self) -> Option<Message<Info>> {
Message::from_res(
self.client
.get(format!("{}/info", self.server))
.send()
.await
.ok()?,
|res| async { res.json().await.ok() },
)
.await
}
pub async fn get_session(&self, query: &SessionAddr) -> Option<Message<Session>> {
Message::from_res(
self.client
.get(format!("{}/session?{}", self.server, query))
.send()
.await
.ok()?,
|res| async { res.json().await.ok() },
)
.await
}
pub async fn get_session_statistics(&self, query: &SessionAddr) -> Option<Message<Statistics>> {
Message::from_res(
self.client
.get(format!("{}/session/statistics?{}", self.server, query))
.send()
.await
.ok()?,
|res| async { res.json().await.ok() },
)
.await
}
pub async fn remove_session(&self, query: &SessionAddr) -> Option<Message<bool>> {
Message::from_res(
self.client
.delete(format!("{}/session?{}", self.server, query))
.send()
.await
.ok()?,
|res| async move { Some(res.status() == StatusCode::OK) },
)
.await
}
}
#[derive(Debug, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Events {
Allocated {
session: SessionAddr,
username: String,
port: u16,
},
ChannelBind {
session: SessionAddr,
username: String,
channel: u16,
},
CreatePermission {
session: SessionAddr,
username: String,
ports: Vec<u16>,
},
Refresh {
session: SessionAddr,
username: String,
lifetime: u32,
},
Closed {
session: SessionAddr,
username: String,
},
}
#[async_trait]
pub trait Hooks {
#[allow(unused_variables)]
async fn auth(
&self,
session: &SessionAddr,
username: &str,
realm: &str,
nonce: &str,
) -> Option<&str> {
None
}
#[allow(unused_variables)]
async fn on(&self, event: &Events, realm: &str, nonce: &str) {}
}
#[derive(Deserialize)]
struct GetPasswordQuery {
address: SocketAddr,
interface: SocketAddr,
username: String,
}
pub async fn start_hooks_server<T>(bind: SocketAddr, hooks: T) -> Result<(), std::io::Error>
where
T: Hooks + Send + Sync + 'static,
{
let app = Router::new()
.route(
"/password",
get(
|headers: HeaderMap,
State(state): State<Arc<T>>,
Query(query): Query<GetPasswordQuery>| async move {
if let Some((realm, nonce)) = get_realm_and_nonce(&headers) {
if let Some(password) =
state.auth(&SessionAddr {
address: query.address,
interface: query.interface,
}, &query.username, realm, nonce).await
{
return password.to_string().into_response();
}
}
StatusCode::NOT_FOUND.into_response()
},
),
)
.route(
"/events",
post(
|headers: HeaderMap, State(state): State<Arc<T>>, Body(event): Body<Events>| async move {
if let Some((realm, nonce)) = get_realm_and_nonce(&headers) {
state.on(&event, realm, nonce).await;
}
StatusCode::OK
},
),
)
.with_state(Arc::new(hooks));
axum::serve(TcpListener::bind(bind).await?, app).await?;
Ok(())
}
fn get_realm_and_nonce(headers: &HeaderMap) -> Option<(&str, &str)> {
if let (Some(Ok(realm)), Some(Ok(nonce))) = (
headers.get("realm").map(|it| it.to_str()),
headers.get("nonce").map(|it| it.to_str()),
) {
Some((realm, nonce))
} else {
None
}
}