#![deny(missing_docs)]
const MAX_MSG_BYTES: i32 = 20_000;
use base64::Engine;
use opentelemetry::global;
use std::collections::HashMap;
use std::io::{Error, Result};
use std::net::{IpAddr, Ipv6Addr};
use std::sync::{Arc, Mutex};
mod config;
pub use config::*;
mod maybe_tls;
pub use maybe_tls::*;
mod ip_deny;
mod ip_rate;
pub use ip_rate::*;
mod cslot;
pub use cslot::*;
mod cmd;
mod metrics;
pub use metrics::*;
pub mod ws {
pub enum Payload {
Vec(Vec<u8>),
BytesMut(bytes::BytesMut),
}
impl std::ops::Deref for Payload {
type Target = [u8];
#[inline(always)]
fn deref(&self) -> &Self::Target {
match self {
Payload::Vec(v) => v.as_slice(),
Payload::BytesMut(b) => b.as_ref(),
}
}
}
impl Payload {
#[inline(always)]
pub fn to_mut(&mut self) -> &mut [u8] {
match self {
Payload::Vec(ref mut owned) => owned,
Payload::BytesMut(b) => b.as_mut(),
}
}
}
use futures::future::BoxFuture;
pub trait SbdWebsocket: Send + Sync + 'static {
fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
fn send(
&self,
payload: Payload,
) -> BoxFuture<'static, std::io::Result<()>>;
fn close(&self) -> BoxFuture<'static, ()>;
}
}
pub use ws::{Payload, SbdWebsocket};
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PubKey(pub Arc<[u8; 32]>);
impl PubKey {
pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
use ed25519_dalek::Verifier;
if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
.is_ok()
} else {
false
}
}
}
pub struct SbdServer {
task_list: Vec<tokio::task::JoinHandle<()>>,
bind_addrs: Vec<std::net::SocketAddr>,
_cslot: CSlot,
}
impl Drop for SbdServer {
fn drop(&mut self) {
for task in self.task_list.iter() {
task.abort();
}
}
}
pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
Arc::new(match ip {
IpAddr::V4(ip) => ip.to_ipv6_mapped(),
IpAddr::V6(ip) => ip,
})
}
pub async fn preflight_ip_check(
config: &Config,
ip_rate: &IpRate,
addr: std::net::SocketAddr,
) -> Option<Arc<Ipv6Addr>> {
let raw_ip = to_canonical_ip(addr.ip());
let use_trusted_ip = config.trusted_ip_header.is_some();
if !use_trusted_ip {
if ip_rate.is_blocked(&raw_ip).await {
return None;
}
if !ip_rate.is_ok(&raw_ip, 1).await {
return None;
}
}
Some(raw_ip)
}
pub async fn handle_upgraded(
config: Arc<Config>,
ip_rate: Arc<IpRate>,
weak_cslot: WeakCSlot,
ws: Arc<impl SbdWebsocket>,
pub_key: PubKey,
calc_ip: Arc<Ipv6Addr>,
maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
) {
let use_trusted_ip = config.trusted_ip_header.is_some();
if &pub_key.0[..28] == cmd::CMD_PREFIX {
return;
}
if use_trusted_ip {
if ip_rate.is_blocked(&calc_ip).await {
return;
}
if !ip_rate.is_ok(&calc_ip, 1).await {
return;
}
}
if let Some(cslot) = weak_cslot.upgrade() {
cslot
.insert(&config, calc_ip, pub_key, ws, maybe_auth)
.await;
}
}
async fn handle_auth(
axum::extract::State(app_state): axum::extract::State<AppState>,
body: bytes::Bytes,
) -> axum::response::Response {
use AuthenticateTokenError::*;
match process_authenticate_token(
&app_state.config,
&app_state.token_tracker,
app_state.auth_failures,
body,
)
.await
{
Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
serde_json::json!({
"authToken": *token,
}),
)),
Err(Unauthorized) => {
tracing::debug!("/authenticate: UNAUTHORIZED");
axum::response::IntoResponse::into_response((
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized",
))
}
Err(HookServerError(err)) => {
tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
axum::response::IntoResponse::into_response((
axum::http::StatusCode::BAD_GATEWAY,
format!("BAD_GATEWAY: {err:?}"),
))
}
Err(OtherError(err)) => {
tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
axum::response::IntoResponse::into_response((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("INTERNAL_SERVER_ERROR: {err:?}"),
))
}
}
}
pub enum AuthenticateTokenError {
Unauthorized,
HookServerError(Error),
OtherError(Error),
}
pub async fn process_authenticate_token(
config: &Config,
token_tracker: &AuthTokenTracker,
auth_failures: opentelemetry::metrics::Counter<u64>,
auth_material: bytes::Bytes,
) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
use AuthenticateTokenError::*;
let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
{
let url = url.clone();
let token = tokio::task::spawn_blocking(move || {
ureq::put(&url)
.header("Content-Type", "application/octet-stream")
.send(&auth_material[..])
.map_err(|err| {
auth_failures.add(1, &[]);
match err {
ureq::Error::StatusCode(401) => Unauthorized,
oth => HookServerError(Error::other(oth)),
}
})?
.into_body()
.read_to_string()
.map_err(Error::other)
.map_err(HookServerError)
})
.await
.map_err(|_| OtherError(Error::other("tokio task died")))??;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct Token {
auth_token: String,
}
let token: Token = serde_json::from_str(&token)
.map_err(|err| OtherError(Error::other(err)))?;
token.auth_token
} else {
use base64::prelude::*;
use rand::Rng;
let mut bytes = [0; 32];
rand::thread_rng().fill(&mut bytes);
BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
}
.into();
token_tracker.register_token(token.clone());
Ok(token)
}
#[derive(Clone)]
struct WebsocketImpl {
write: Arc<
tokio::sync::Mutex<
futures::stream::SplitSink<
axum::extract::ws::WebSocket,
axum::extract::ws::Message,
>,
>,
>,
read: Arc<
tokio::sync::Mutex<
futures::stream::SplitStream<axum::extract::ws::WebSocket>,
>,
>,
attr: Vec<opentelemetry::KeyValue>,
bytes_send: opentelemetry::metrics::Counter<u64>,
bytes_recv: opentelemetry::metrics::Counter<u64>,
}
impl SbdWebsocket for WebsocketImpl {
fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
let this = self.clone();
Box::pin(async move {
let mut read = this.read.lock().await;
use futures::stream::StreamExt;
loop {
match read.next().await {
None => return Err(Error::other("closed")),
Some(r) => {
let msg = r.map_err(Error::other)?;
match msg {
axum::extract::ws::Message::Text(s) => {
this.bytes_recv.add(s.len() as u64, &this.attr);
return Ok(Payload::Vec(s.as_bytes().to_vec()));
}
axum::extract::ws::Message::Binary(v) => {
this.bytes_recv.add(v.len() as u64, &this.attr);
return Ok(Payload::Vec(v[..].to_vec()));
}
axum::extract::ws::Message::Ping(_)
| axum::extract::ws::Message::Pong(_) => (),
axum::extract::ws::Message::Close(_) => {
return Err(Error::other("closed"))
}
}
}
}
}
})
}
fn send(
&self,
payload: Payload,
) -> futures::future::BoxFuture<'static, Result<()>> {
use futures::SinkExt;
let this = self.clone();
Box::pin(async move {
let mut write = this.write.lock().await;
let v = match payload {
Payload::Vec(v) => v,
Payload::BytesMut(b) => b.to_vec(),
};
this.bytes_send.add(v.len() as u64, &this.attr);
write
.send(axum::extract::ws::Message::Binary(
bytes::Bytes::copy_from_slice(&v),
))
.await
.map_err(Error::other)?;
write.flush().await.map_err(Error::other)?;
Ok(())
})
}
fn close(&self) -> futures::future::BoxFuture<'static, ()> {
use futures::SinkExt;
let this = self.clone();
Box::pin(async move {
let _ = this.write.lock().await.close().await;
})
}
}
impl WebsocketImpl {
fn new(
ws: axum::extract::ws::WebSocket,
pk: PubKey,
meter: &opentelemetry::metrics::Meter,
) -> Self {
use futures::StreamExt;
let bytes_send = meter
.u64_counter("sbd.server.bytes_send")
.with_description("Number of bytes sent to client")
.with_unit("bytes")
.build();
let bytes_recv = meter
.u64_counter("sbd.server.bytes_recv")
.with_description("Number of bytes received from client")
.with_unit("bytes")
.build();
let (tx, rx) = ws.split();
Self {
write: Arc::new(tokio::sync::Mutex::new(tx)),
read: Arc::new(tokio::sync::Mutex::new(rx)),
attr: vec![opentelemetry::KeyValue::new(
"pub_key",
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(*pk.0),
)],
bytes_send,
bytes_recv,
}
}
}
async fn handle_ws(
axum::extract::Path(pub_key): axum::extract::Path<String>,
headers: axum::http::HeaderMap,
ws: axum::extract::WebSocketUpgrade,
axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
std::net::SocketAddr,
>,
axum::extract::State(app_state): axum::extract::State<AppState>,
) -> impl axum::response::IntoResponse {
use axum::response::IntoResponse;
use base64::Engine;
let token: Option<Arc<str>> = headers
.get("Authorization")
.and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
if !app_state
.token_tracker
.check_is_token_valid(&app_state.config, token)
{
app_state
.auth_failures
.add(1, &[opentelemetry::KeyValue::new("pub_key", pub_key)]);
return axum::response::IntoResponse::into_response((
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized",
));
}
let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
Ok(pk) if pk.len() == 32 => {
let mut sized_pk = [0; 32];
sized_pk.copy_from_slice(&pk);
PubKey(Arc::new(sized_pk))
}
_ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
};
let mut calc_ip = to_canonical_ip(addr.ip());
if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
if let Some(header) =
headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
{
if let Ok(ip) = header.parse::<IpAddr>() {
calc_ip = to_canonical_ip(ip);
}
}
}
ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
move |socket| async move {
handle_upgraded(
app_state.config.clone(),
app_state.ip_rate.clone(),
app_state.cslot.clone(),
Arc::new(WebsocketImpl::new(
socket,
pk.clone(),
&app_state.meter,
)),
pk,
calc_ip,
maybe_auth,
)
.await;
},
)
}
#[derive(Clone, Default)]
pub struct AuthTokenTracker {
token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
}
impl AuthTokenTracker {
pub fn register_token(&self, token: Arc<str>) {
self.token_map
.lock()
.unwrap()
.insert(token, std::time::Instant::now());
}
pub fn check_is_token_valid(
&self,
config: &Config,
token: Option<Arc<str>>,
) -> bool {
let token: Arc<str> = if let Some(token) = token {
if !token.starts_with("Bearer ") {
return false;
}
token.trim_start_matches("Bearer ").into()
} else if config.authentication_hook_server.is_none() {
return true;
} else {
return false;
};
let mut lock = self.token_map.lock().unwrap();
let idle_dur = config.idle_dur();
lock.retain(|_t, e| e.elapsed() < idle_dur);
if let std::collections::hash_map::Entry::Occupied(mut e) =
lock.entry(token)
{
e.insert(std::time::Instant::now());
true
} else {
false
}
}
}
#[derive(Clone)]
struct AppState {
config: Arc<Config>,
token_tracker: AuthTokenTracker,
ip_rate: Arc<IpRate>,
cslot: WeakCSlot,
auth_failures: opentelemetry::metrics::Counter<u64>,
meter: opentelemetry::metrics::Meter,
}
impl AppState {
pub fn new(
config: Arc<Config>,
ip_rate: Arc<IpRate>,
cslot: WeakCSlot,
meter: opentelemetry::metrics::Meter,
) -> Self {
Self {
config,
token_tracker: AuthTokenTracker::default(),
ip_rate,
cslot,
auth_failures: meter
.u64_counter("sbd.server.auth_failures")
.with_description("Number of failed authentication attempts")
.with_unit("count")
.build(),
meter,
}
}
}
impl SbdServer {
pub async fn new(config: Arc<Config>) -> Result<Self> {
let tls_config = if let (Some(cert), Some(pk)) =
(&config.cert_pem_file, &config.priv_key_pem_file)
{
Some(Arc::new(TlsConfig::new(cert, pk).await?))
} else {
None
};
let sbd_server_meter = global::meter("sbd-server");
let mut task_list = Vec::new();
let mut bind_addrs = Vec::new();
let ip_rate = Arc::new(IpRate::new(config.clone()));
task_list.push(spawn_prune_task(ip_rate.clone()));
let cslot = CSlot::new(
config.clone(),
ip_rate.clone(),
sbd_server_meter.clone(),
);
let weak_cslot = cslot.weak();
let app: axum::Router<()> = axum::Router::new()
.route("/authenticate", axum::routing::put(handle_auth))
.route("/{pub_key}", axum::routing::any(handle_ws))
.layer(axum::extract::DefaultBodyLimit::max(1024))
.with_state(AppState::new(
config.clone(),
ip_rate.clone(),
weak_cslot.clone(),
sbd_server_meter,
));
let app =
app.into_make_service_with_connect_info::<std::net::SocketAddr>();
let mut found_port_zero: Option<u16> = None;
for bind in config.bind.iter() {
let mut a: std::net::SocketAddr =
bind.parse().map_err(Error::other)?;
if let Some(found_port_zero) = &found_port_zero {
if a.port() == 0 {
a.set_port(*found_port_zero);
}
}
let h = axum_server::Handle::new();
if let Some(tls_config) = &tls_config {
let tls_config =
axum_server::tls_rustls::RustlsConfig::from_config(
tls_config.config(),
);
let server = axum_server::bind_rustls(a, tls_config)
.handle(h.clone())
.serve(app.clone());
task_list.push(tokio::task::spawn(async move {
if let Err(err) = server.await {
tracing::error!(?err);
}
}));
} else {
let server =
axum_server::bind(a).handle(h.clone()).serve(app.clone());
task_list.push(tokio::task::spawn(async move {
if let Err(err) = server.await {
tracing::error!(?err);
}
}));
}
if let Some(addr) = h.listening().await {
if found_port_zero.is_none() && a.port() == 0 {
found_port_zero = Some(addr.port());
}
bind_addrs.push(addr);
}
}
Ok(Self {
task_list,
bind_addrs,
_cslot: cslot,
})
}
pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
self.bind_addrs.as_slice()
}
}
#[cfg(test)]
mod test;