#![cfg(feature = "webhooks")]
use std::net::SocketAddr;
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::serve::ListenerExt;
use axum::Router;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Notify};
use tracing::{debug, error, info, warn};
use rust_tg_bot_raw::types::update::Update;
#[inline]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let diff = a
.iter()
.zip(b.iter())
.fold(0u8, |acc, (&x, &y)| acc | (x ^ y));
diff == 0
}
#[inline(always)]
fn ok_response() -> Response {
StatusCode::OK.into_response()
}
#[inline(always)]
fn forbidden_response() -> Response {
StatusCode::FORBIDDEN.into_response()
}
#[inline(always)]
fn bad_request_response() -> Response {
StatusCode::BAD_REQUEST.into_response()
}
#[inline(always)]
fn service_unavailable_response() -> Response {
StatusCode::SERVICE_UNAVAILABLE.into_response()
}
#[inline(always)]
fn internal_error_response() -> Response {
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
#[derive(Debug, Clone)]
struct WebhookState {
update_tx: mpsc::Sender<Update>,
secret_token: Option<Arc<[u8]>>,
}
#[cfg(feature = "webhooks-tls")]
#[derive(Clone)]
pub struct TlsConfig {
acceptor: tokio_rustls::TlsAcceptor,
}
#[cfg(feature = "webhooks-tls")]
impl std::fmt::Debug for TlsConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsConfig")
.field("acceptor", &"TlsAcceptor { .. }")
.finish()
}
}
#[cfg(feature = "webhooks-tls")]
impl TlsConfig {
pub async fn from_pem_files(cert_path: &str, key_path: &str) -> Result<Self, std::io::Error> {
use rustls_pemfile::{certs, private_key};
use std::io::{self, BufReader};
use tokio_rustls::rustls::ServerConfig;
let cert_data = tokio::fs::read(cert_path).await.map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to read cert file '{cert_path}': {e}"),
)
})?;
let key_data = tokio::fs::read(key_path).await.map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to read key file '{key_path}': {e}"),
)
})?;
let certs: Vec<_> = certs(&mut BufReader::new(cert_data.as_slice()))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("invalid cert PEM: {e}"))
})?;
if certs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("no certificates found in '{cert_path}'"),
));
}
let key = private_key(&mut BufReader::new(key_data.as_slice()))
.map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("invalid key PEM: {e}"))
})?
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("no private key found in '{key_path}'"),
)
})?;
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("TLS config error: {e}"))
})?;
Ok(Self {
acceptor: tokio_rustls::TlsAcceptor::from(Arc::new(server_config)),
})
}
}
#[derive(Debug, Clone)]
pub struct WebhookHandler {
state: WebhookState,
}
impl WebhookHandler {
pub fn new(update_tx: mpsc::Sender<Update>, secret_token: Option<String>) -> Self {
let secret_token = secret_token.map(|s| Arc::from(s.into_bytes().into_boxed_slice()));
Self {
state: WebhookState {
update_tx,
secret_token,
},
}
}
pub fn into_router(self, url_path: &str) -> Router {
let path = if url_path.starts_with('/') {
url_path.to_owned()
} else {
format!("/{url_path}")
};
Router::new()
.route(&path, post(handle_webhook))
.route("/healthcheck", get(handle_healthcheck))
.with_state(self.state)
}
}
#[derive(Debug)]
pub struct WebhookServer {
listen: String,
port: u16,
router: Router,
shutdown_notify: Arc<Notify>,
running: std::sync::atomic::AtomicBool,
#[cfg(feature = "webhooks-tls")]
tls: Option<TlsConfig>,
}
impl WebhookServer {
pub fn new(
listen: impl Into<String>,
port: u16,
url_path: &str,
update_tx: mpsc::Sender<Update>,
secret_token: Option<String>,
#[cfg(feature = "webhooks-tls")] tls: Option<TlsConfig>,
) -> Self {
let handler = WebhookHandler::new(update_tx, secret_token);
let router = handler.into_router(url_path);
Self {
listen: listen.into(),
port,
router,
shutdown_notify: Arc::new(Notify::new()),
running: std::sync::atomic::AtomicBool::new(false),
#[cfg(feature = "webhooks-tls")]
tls,
}
}
pub fn is_running(&self) -> bool {
self.running.load(std::sync::atomic::Ordering::Relaxed)
}
pub async fn serve_forever(&self, ready: Option<Arc<Notify>>) -> Result<(), std::io::Error> {
let addr: SocketAddr = format!("{}:{}", self.listen, self.port)
.parse()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
let listener = TcpListener::bind(addr).await?;
self.running
.store(true, std::sync::atomic::Ordering::Relaxed);
#[cfg(feature = "webhooks-tls")]
if let Some(ref tls) = self.tls {
info!("Webhook server (HTTPS) started on {addr}");
if let Some(n) = ready {
n.notify_one();
}
return self.serve_tls(listener, tls.clone(), addr).await;
}
let listener = listener.tap_io(|tcp_stream| {
if let Err(e) = tcp_stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY: {e}");
}
});
info!("Webhook server started on {addr}");
if let Some(n) = ready {
n.notify_one();
}
let shutdown_notify = self.shutdown_notify.clone();
axum::serve(listener, self.router.clone())
.with_graceful_shutdown(async move {
shutdown_notify.notified().await;
})
.await?;
self.running
.store(false, std::sync::atomic::Ordering::Relaxed);
info!("Webhook server stopped");
Ok(())
}
#[cfg(feature = "webhooks-tls")]
async fn serve_tls(
&self,
listener: TcpListener,
tls: TlsConfig,
addr: SocketAddr,
) -> Result<(), std::io::Error> {
use hyper_util::service::TowerToHyperService;
let shutdown_notify = self.shutdown_notify.clone();
let router = self.router.clone();
let graceful = tokio_util::sync::CancellationToken::new();
let graceful_for_shutdown = graceful.clone();
tokio::spawn(async move {
shutdown_notify.notified().await;
graceful_for_shutdown.cancel();
});
let mut connection_handles = tokio::task::JoinSet::new();
loop {
tokio::select! {
_ = graceful.cancelled() => {
debug!("TLS server shutting down, waiting for in-flight connections");
break;
}
accepted = listener.accept() => {
let (tcp_stream, remote_addr) = match accepted {
Ok(conn) => conn,
Err(e) => {
error!("Failed to accept TCP connection: {e}");
continue;
}
};
if let Err(e) = tcp_stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY: {e}");
}
let acceptor = tls.acceptor.clone();
let token = graceful.clone();
let svc = TowerToHyperService::new(router.clone());
connection_handles.spawn(async move {
let tls_stream = match acceptor.accept(tcp_stream).await {
Ok(s) => s,
Err(e) => {
debug!("TLS handshake failed from {remote_addr}: {e}");
return;
}
};
let io = hyper_util::rt::TokioIo::new(tls_stream);
let builder = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
);
let conn = builder.serve_connection(io, svc);
let mut conn = std::pin::pin!(conn);
tokio::select! {
result = conn.as_mut() => {
if let Err(e) = result {
debug!("Connection error from {remote_addr}: {e}");
}
}
_ = token.cancelled() => {
conn.as_mut().graceful_shutdown();
if let Err(e) = conn.await {
debug!("Connection error during shutdown from {remote_addr}: {e}");
}
}
}
});
}
}
}
while connection_handles.join_next().await.is_some() {}
self.running
.store(false, std::sync::atomic::Ordering::Relaxed);
info!("Webhook server (HTTPS) stopped on {addr}");
Ok(())
}
pub fn shutdown(&self) {
if self.is_running() {
debug!("Shutting down webhook server");
self.shutdown_notify.notify_one();
}
}
}
async fn handle_webhook(
State(state): State<WebhookState>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let ct = headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !ct.starts_with("application/json") {
debug!("Rejected request with Content-Type: {ct}");
return forbidden_response();
}
if let Some(ref expected) = state.secret_token {
let provided = headers
.get("x-telegram-bot-api-secret-token")
.map(|v| v.as_bytes());
match provided {
None => {
debug!("Request missing secret token header");
return forbidden_response();
}
Some(tok) if !constant_time_eq(tok, expected) => {
debug!("Request had invalid secret token");
return forbidden_response();
}
Some(_) => {}
}
}
let update: Update = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => {
error!("Failed to parse update JSON: {e}");
return bad_request_response();
}
};
debug!(update_id = update.update_id, "Webhook received update");
match state.update_tx.try_send(update) {
Ok(()) => ok_response(),
Err(mpsc::error::TrySendError::Full(_)) => {
warn!("Update channel full -- applying backpressure (503)");
service_unavailable_response()
}
Err(mpsc::error::TrySendError::Closed(_)) => {
error!("Update channel closed");
internal_error_response()
}
}
}
async fn handle_healthcheck() -> StatusCode {
StatusCode::OK
}
pub struct WebhookApp;
impl WebhookApp {
pub fn new(
listen: impl Into<String>,
port: u16,
url_path: &str,
update_tx: mpsc::Sender<Update>,
secret_token: Option<String>,
) -> WebhookServer {
WebhookServer::new(
listen,
port,
url_path,
update_tx,
secret_token,
#[cfg(feature = "webhooks-tls")]
None,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ct_eq_equal_slices() {
assert!(constant_time_eq(b"hello", b"hello"));
}
#[test]
fn ct_eq_different_slices() {
assert!(!constant_time_eq(b"hello", b"world"));
}
#[test]
fn ct_eq_different_lengths() {
assert!(!constant_time_eq(b"short", b"longer"));
}
#[test]
fn ct_eq_empty_slices() {
assert!(constant_time_eq(b"", b""));
}
#[test]
fn ct_eq_single_bit_diff() {
assert!(!constant_time_eq(b"A", b"B"));
}
#[tokio::test]
async fn rejects_wrong_content_type() {
let (tx, _rx) = mpsc::channel(1);
let state = WebhookState {
update_tx: tx,
secret_token: None,
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "text/plain".parse().unwrap());
let resp = handle_webhook(State(state), headers, Bytes::from_static(b"{}")).await;
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn rejects_missing_secret_token() {
let (tx, _rx) = mpsc::channel(1);
let state = WebhookState {
update_tx: tx,
secret_token: Some(Arc::from(b"my-secret".to_vec().into_boxed_slice())),
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
let resp = handle_webhook(State(state), headers, Bytes::from_static(b"{}")).await;
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn rejects_wrong_secret_token() {
let (tx, _rx) = mpsc::channel(1);
let state = WebhookState {
update_tx: tx,
secret_token: Some(Arc::from(b"correct".to_vec().into_boxed_slice())),
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert("x-telegram-bot-api-secret-token", "wrong".parse().unwrap());
let resp = handle_webhook(State(state), headers, Bytes::from_static(b"{}")).await;
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn accepts_valid_request() {
let (tx, mut rx) = mpsc::channel(1);
let state = WebhookState {
update_tx: tx,
secret_token: None,
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
let resp = handle_webhook(
State(state),
headers,
Bytes::from_static(b"{\"update_id\": 1}"),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let update = rx.recv().await.unwrap();
assert_eq!(update.update_id, 1);
}
#[tokio::test]
async fn accepts_valid_request_with_secret() {
let (tx, mut rx) = mpsc::channel(1);
let state = WebhookState {
update_tx: tx,
secret_token: Some(Arc::from(b"mysecret".to_vec().into_boxed_slice())),
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert(
"x-telegram-bot-api-secret-token",
"mysecret".parse().unwrap(),
);
let resp = handle_webhook(
State(state),
headers,
Bytes::from_static(b"{\"update_id\": 42}"),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let update = rx.recv().await.unwrap();
assert_eq!(update.update_id, 42);
}
#[tokio::test]
async fn returns_503_when_channel_full() {
let (tx, _rx) = mpsc::channel(1);
let prefill: Update = serde_json::from_str("{\"update_id\": 0}").unwrap();
tx.try_send(prefill).unwrap();
let state = WebhookState {
update_tx: tx,
secret_token: None,
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
let resp = handle_webhook(
State(state),
headers,
Bytes::from_static(b"{\"update_id\": 99}"),
)
.await;
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn returns_500_when_channel_closed() {
let (tx, rx) = mpsc::channel(1);
drop(rx);
let state = WebhookState {
update_tx: tx,
secret_token: None,
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
let resp = handle_webhook(
State(state),
headers,
Bytes::from_static(b"{\"update_id\": 1}"),
)
.await;
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn returns_400_on_malformed_json() {
let (tx, _rx) = mpsc::channel(1);
let state = WebhookState {
update_tx: tx,
secret_token: None,
};
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
let resp = handle_webhook(
State(state),
headers,
Bytes::from_static(b"this is not json"),
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn healthcheck_returns_200() {
let status = handle_healthcheck().await;
assert_eq!(status, StatusCode::OK);
}
#[test]
fn webhook_handler_creates_router() {
let (tx, _rx) = mpsc::channel(1);
let handler = WebhookHandler::new(tx, Some("secret".into()));
let _router = handler.into_router("/webhook");
}
#[test]
fn webhook_handler_normalizes_path_without_slash() {
let (tx, _rx) = mpsc::channel(1);
let handler = WebhookHandler::new(tx, None);
let _router = handler.into_router("webhook");
}
}