use std::{net::SocketAddr, sync::Arc};
use axum::{
Router,
extract::State,
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::post,
};
use bytes::Bytes;
use tracing::{error, info, warn};
use crate::dispatcher::Dispatcher;
pub struct WebhookServer {
dispatcher: Arc<Dispatcher>,
secret: Option<String>,
path: String,
}
impl WebhookServer {
pub fn new(dispatcher: Dispatcher) -> Self {
Self {
dispatcher: Arc::new(dispatcher),
secret: None,
path: "/".into(),
}
}
pub fn secret(mut self, secret: impl Into<String>) -> Self {
self.secret = Some(secret.into());
self
}
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = path.into();
self
}
pub async fn serve(self, addr: impl Into<String>) {
let addr: SocketAddr = addr
.into()
.parse()
.expect("Invalid socket address for webhook server");
let state = Arc::new(WebhookState {
dispatcher: self.dispatcher,
secret: self.secret,
});
let app = Router::new()
.route(&self.path, post(handle_update))
.with_state(state);
info!("Webhook server listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
}
struct WebhookState {
dispatcher: Arc<Dispatcher>,
secret: Option<String>,
}
async fn handle_update(
State(state): State<Arc<WebhookState>>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
if let Some(expected) = &state.secret {
let provided = headers
.get("x-max-bot-api-secret")
.and_then(|v| v.to_str().ok());
match provided {
Some(val) if val == expected => {}
Some(val) => {
warn!("Webhook secret mismatch (got '{val}')");
return StatusCode::UNAUTHORIZED;
}
None => {
warn!("Missing X-Max-Bot-Api-Secret header");
return StatusCode::UNAUTHORIZED;
}
}
}
let update: serde_json::Value = match serde_json::from_slice(&body) {
Ok(u) => u,
Err(e) => {
error!("Failed to parse webhook update: {e}");
return StatusCode::OK;
}
};
state.dispatcher.dispatch_raw(update).await;
StatusCode::OK
}