use crate::polling::UpdateHandler;
use crate::types::Update;
use crate::{Bot, BotError};
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
routing::post,
Json, Router,
};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::{error, info, warn};
struct AppState {
bot: Bot,
handler: Arc<UpdateHandler>,
secret_token: Option<String>,
}
pub struct WebhookServer {
bot: Bot,
handler: UpdateHandler,
port: u16,
path: String,
secret_token: Option<String>,
allowed_updates: Vec<String>,
max_connections: Option<i64>,
drop_pending_updates: bool,
}
impl WebhookServer {
pub fn new(bot: Bot, handler: UpdateHandler) -> Self {
Self {
bot,
handler,
port: 8080,
path: "/webhook".to_string(),
secret_token: None,
allowed_updates: vec![],
max_connections: None,
drop_pending_updates: false,
}
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = path.into();
self
}
pub fn secret_token(mut self, token: impl Into<String>) -> Self {
self.secret_token = Some(token.into());
self
}
pub fn allowed_updates(mut self, updates: Vec<String>) -> Self {
self.allowed_updates = updates;
self
}
pub fn max_connections(mut self, n: i64) -> Self {
self.max_connections = Some(n);
self
}
pub fn drop_pending_updates(mut self) -> Self {
self.drop_pending_updates = true;
self
}
pub async fn start(self, webhook_url: &str) -> Result<(), BotError> {
let full_url = format!("{}{}", webhook_url.trim_end_matches('/'), self.path);
let mut req = self.bot.set_webhook(full_url.clone());
if let Some(ref token) = self.secret_token {
req = req.secret_token(token.clone());
}
if let Some(n) = self.max_connections {
req = req.max_connections(n);
}
if !self.allowed_updates.is_empty() {
req = req.allowed_updates(self.allowed_updates.clone());
}
if self.drop_pending_updates {
req = req.drop_pending_updates(true);
}
req.await?;
info!(url = %full_url, "webhook registered");
let state = Arc::new(AppState {
bot: self.bot,
handler: Arc::new(self.handler),
secret_token: self.secret_token,
});
let app = Router::new()
.route(&self.path, post(handle_update))
.with_state(state);
let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
info!(addr = %addr, "webhook server listening");
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| BotError::Other(format!("Failed to bind port {}: {}", self.port, e)))?;
axum::serve(listener, app)
.await
.map_err(|e| BotError::Other(format!("Webhook server error: {}", e)))?;
Ok(())
}
}
async fn handle_update(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(update): Json<Update>,
) -> StatusCode {
if let Some(ref expected) = state.secret_token {
let provided = headers
.get("x-telegram-bot-api-secret-token")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if provided != expected {
warn!("invalid secret token - webhook request rejected");
return StatusCode::FORBIDDEN;
}
}
let bot = state.bot.clone();
let handler = Arc::clone(&state.handler);
tokio::spawn(async move {
if let Err(join_err) = tokio::spawn(async move { (handler)(bot, update).await }).await {
if join_err.is_panic() {
error!("handler panicked on webhook update - continuing");
}
}
});
StatusCode::OK
}