use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch, Mutex};
use tracing::{debug, error, warn};
use rust_tg_bot_raw::error::TelegramError;
use crate::utils::network_loop::{network_retry_loop, NetworkLoopConfig};
#[cfg(feature = "webhooks")]
use tokio::sync::Notify;
#[cfg(feature = "webhooks")]
use crate::utils::webhook_handler::WebhookServer;
#[cfg(feature = "webhooks")]
use rust_tg_bot_raw::types::update::Update;
pub type GetUpdatesFn = Arc<
dyn Fn(
i64,
Duration,
Option<Vec<String>>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<Vec<serde_json::Value>, TelegramError>>
+ Send,
>,
> + Send
+ Sync,
>;
pub type DeleteWebhookFn = Arc<
dyn Fn(
bool,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), TelegramError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct PollingConfig {
pub poll_interval: Duration,
pub timeout: Duration,
pub bootstrap_retries: i32,
pub allowed_updates: Option<Vec<String>>,
pub drop_pending_updates: bool,
pub get_updates: GetUpdatesFn,
pub delete_webhook: DeleteWebhookFn,
}
#[cfg(feature = "webhooks")]
#[derive(Clone)]
pub struct WebhookConfig {
pub listen: String,
pub port: u16,
pub url_path: String,
pub webhook_url: Option<String>,
pub secret_token: Option<String>,
pub bootstrap_retries: i32,
pub drop_pending_updates: bool,
pub allowed_updates: Option<Vec<String>>,
pub max_connections: u32,
pub cert_path: Option<String>,
pub key_path: Option<String>,
}
#[cfg(feature = "webhooks")]
impl Default for WebhookConfig {
fn default() -> Self {
Self {
listen: "127.0.0.1".into(),
port: 80,
url_path: String::new(),
webhook_url: None,
secret_token: None,
bootstrap_retries: 0,
drop_pending_updates: false,
allowed_updates: None,
max_connections: 40,
cert_path: None,
key_path: None,
}
}
}
#[cfg(feature = "webhooks")]
impl WebhookConfig {
pub fn new(url: impl Into<String>) -> Self {
let url = url.into();
Self {
webhook_url: Some(url),
..Default::default()
}
}
pub fn listen(mut self, addr: impl Into<String>) -> Self {
self.listen = addr.into();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn url_path(mut self, path: impl Into<String>) -> Self {
self.url_path = path.into();
self
}
pub fn secret_token(mut self, token: impl Into<String>) -> Self {
self.secret_token = Some(token.into());
self
}
pub fn bootstrap_retries(mut self, n: i32) -> Self {
self.bootstrap_retries = n;
self
}
pub fn drop_pending_updates(mut self, drop: bool) -> Self {
self.drop_pending_updates = drop;
self
}
pub fn allowed_updates(mut self, types: Vec<String>) -> Self {
self.allowed_updates = Some(types);
self
}
pub fn max_connections(mut self, n: u32) -> Self {
self.max_connections = n;
self
}
pub fn tls(mut self, cert: impl Into<String>, key: impl Into<String>) -> Self {
self.cert_path = Some(cert.into());
self.key_path = Some(key.into());
self
}
pub fn has_tls(&self) -> bool {
self.cert_path.is_some() && self.key_path.is_some()
}
}
pub struct Updater {
update_tx: mpsc::Sender<serde_json::Value>,
update_rx: Mutex<Option<mpsc::Receiver<serde_json::Value>>>,
running: std::sync::atomic::AtomicBool,
initialized: std::sync::atomic::AtomicBool,
last_update_id: Mutex<i64>,
stop_tx: watch::Sender<bool>,
#[cfg(feature = "webhooks")]
httpd: Mutex<Option<Arc<WebhookServer>>>,
}
impl std::fmt::Debug for Updater {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Updater")
.field("running", &self.is_running())
.field(
"initialized",
&self.initialized.load(std::sync::atomic::Ordering::Relaxed),
)
.finish()
}
}
impl Updater {
pub fn new(channel_size: usize) -> Self {
let (update_tx, update_rx) = mpsc::channel(channel_size);
let (stop_tx, _stop_rx) = watch::channel(false);
Self {
update_tx,
update_rx: Mutex::new(Some(update_rx)),
running: false.into(),
initialized: false.into(),
last_update_id: Mutex::new(0),
stop_tx,
#[cfg(feature = "webhooks")]
httpd: Mutex::new(None),
}
}
pub async fn take_update_rx(&self) -> Option<mpsc::Receiver<serde_json::Value>> {
self.update_rx.lock().await.take()
}
pub fn is_running(&self) -> bool {
self.running.load(std::sync::atomic::Ordering::Relaxed)
}
pub async fn initialize(&self) {
if self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
debug!("Updater already initialized");
return;
}
self.initialized
.store(true, std::sync::atomic::Ordering::Relaxed);
debug!("Updater initialized");
}
pub async fn shutdown(&self) -> Result<(), UpdaterError> {
if self.is_running() {
return Err(UpdaterError::StillRunning);
}
if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
debug!("Updater already shut down");
return Ok(());
}
self.initialized
.store(false, std::sync::atomic::Ordering::Relaxed);
debug!("Updater shut down");
Ok(())
}
pub async fn start_polling(
self: &Arc<Self>,
config: PollingConfig,
) -> Result<(), UpdaterError> {
if self.is_running() {
return Err(UpdaterError::AlreadyRunning);
}
if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
return Err(UpdaterError::NotInitialized);
}
self.running
.store(true, std::sync::atomic::Ordering::Relaxed);
let _ = self.stop_tx.send(false);
let delete_fn = config.delete_webhook.clone();
let drop_pending = config.drop_pending_updates;
let bootstrap_retries = config.bootstrap_retries;
if let Err(e) = self
.bootstrap_delete_webhook(delete_fn, drop_pending, bootstrap_retries)
.await
{
self.running
.store(false, std::sync::atomic::Ordering::Relaxed);
return Err(UpdaterError::Bootstrap(e.to_string()));
}
debug!("Bootstrap complete, starting polling loop");
let updater = Arc::clone(self);
let stop_rx = self.stop_tx.subscribe();
tokio::spawn(async move {
let tx = updater.update_tx.clone();
let timeout = config.timeout;
let poll_interval = config.poll_interval;
let allowed = config.allowed_updates.clone();
let get_updates_fn = config.get_updates.clone();
let result = network_retry_loop(NetworkLoopConfig {
action_cb: || {
let tx = tx.clone();
let updater_inner = updater.clone();
let allowed_inner = allowed.clone();
let get_fn = get_updates_fn.clone();
async move {
let last_id = { *updater_inner.last_update_id.lock().await };
let updates: Vec<serde_json::Value> =
get_fn(last_id, timeout, allowed_inner).await?;
if !updates.is_empty() {
if !updater_inner.is_running() {
warn!(
"Updater stopped unexpectedly. Pulled updates will be \
ignored and pulled again on restart."
);
return Ok(());
}
for update in &updates {
if let Err(e) = tx.send(update.clone()).await {
error!("Failed to enqueue update: {e}");
}
}
if let Some(last) = updates.last() {
if let Some(uid) = last.get("update_id").and_then(|v| v.as_i64()) {
*updater_inner.last_update_id.lock().await = uid + 1;
}
}
}
Ok(())
}
},
on_err_cb: Some(|e: &TelegramError| {
error!("Error while polling for updates: {e}");
}),
description: "Polling Updates",
interval: poll_interval.as_secs_f64(),
stop_rx: Some(stop_rx),
is_running: Some(Box::new({
let u = updater.clone();
move || u.is_running()
})),
max_retries: -1,
repeat_on_success: true,
})
.await;
if let Err(e) = result {
error!("Polling loop exited with error: {e}");
}
});
Ok(())
}
#[cfg(feature = "webhooks")]
pub async fn start_webhook(
self: &Arc<Self>,
config: WebhookConfig,
) -> Result<(), UpdaterError> {
if self.is_running() {
return Err(UpdaterError::AlreadyRunning);
}
if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
return Err(UpdaterError::NotInitialized);
}
self.running
.store(true, std::sync::atomic::Ordering::Relaxed);
let _ = self.stop_tx.send(false);
let (typed_tx, mut typed_rx) = mpsc::channel::<Update>(256);
let value_tx = self.update_tx.clone();
tokio::spawn(async move {
while let Some(update) = typed_rx.recv().await {
match serde_json::to_value(&update) {
Ok(v) => {
let _ = value_tx.send(v).await;
}
Err(e) => {
error!("Failed to serialize Update to Value: {e}");
}
}
}
});
#[cfg(feature = "webhooks-tls")]
let tls_config = if config.has_tls() {
let cert_path = config
.cert_path
.as_deref()
.expect("cert_path checked by has_tls");
let key_path = config
.key_path
.as_deref()
.expect("key_path checked by has_tls");
match crate::utils::webhook_handler::TlsConfig::from_pem_files(cert_path, key_path)
.await
{
Ok(tls) => Some(tls),
Err(e) => {
self.running
.store(false, std::sync::atomic::Ordering::Relaxed);
return Err(UpdaterError::Bootstrap(format!(
"TLS configuration failed: {e}"
)));
}
}
} else {
None
};
#[cfg(not(feature = "webhooks-tls"))]
if config.has_tls() {
warn!(
"TLS cert_path/key_path are set but the `webhooks-tls` feature is not enabled. \
The server will start without TLS. Enable the `webhooks-tls` feature to use HTTPS."
);
}
let server = Arc::new(WebhookServer::new(
&config.listen,
config.port,
&config.url_path,
typed_tx,
config.secret_token,
#[cfg(feature = "webhooks-tls")]
tls_config,
));
let ready = Arc::new(Notify::new());
let ready_clone = ready.clone();
let srv = server.clone();
tokio::spawn(async move {
if let Err(e) = srv.serve_forever(Some(ready_clone)).await {
error!("Webhook server error: {e}");
}
});
ready.notified().await;
debug!(
"Webhook server started on {}:{}",
config.listen, config.port
);
*self.httpd.lock().await = Some(server);
Ok(())
}
pub async fn stop(&self) -> Result<(), UpdaterError> {
if !self.is_running() {
return Err(UpdaterError::NotRunning);
}
debug!("Stopping updater");
self.running
.store(false, std::sync::atomic::Ordering::Relaxed);
let _ = self.stop_tx.send(true);
#[cfg(feature = "webhooks")]
{
let httpd = self.httpd.lock().await;
if let Some(ref server) = *httpd {
server.shutdown();
}
}
debug!("Updater stopped");
Ok(())
}
async fn bootstrap_delete_webhook(
&self,
delete_fn: DeleteWebhookFn,
drop_pending: bool,
max_retries: i32,
) -> Result<(), TelegramError> {
debug!("Deleting webhook (bootstrap)");
network_retry_loop(NetworkLoopConfig {
action_cb: || {
let f = delete_fn.clone();
async move { f(drop_pending).await }
},
on_err_cb: None::<fn(&TelegramError)>,
description: "Bootstrap delete webhook",
interval: 1.0,
stop_rx: None,
is_running: None,
max_retries,
repeat_on_success: false,
})
.await
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum UpdaterError {
#[error("this Updater is already running")]
AlreadyRunning,
#[error("this Updater is not running")]
NotRunning,
#[error("this Updater was not initialized")]
NotInitialized,
#[error("this Updater is still running")]
StillRunning,
#[error("bootstrap failed: {0}")]
Bootstrap(String),
}
#[cfg(test)]
mod tests {
use super::*;
fn noop_get_updates() -> GetUpdatesFn {
Arc::new(|_offset, _timeout, _allowed| Box::pin(async { Ok(Vec::new()) }))
}
fn noop_delete_webhook() -> DeleteWebhookFn {
Arc::new(|_drop_pending| Box::pin(async { Ok(()) }))
}
fn default_config() -> PollingConfig {
PollingConfig {
poll_interval: Duration::ZERO,
timeout: Duration::from_secs(1),
bootstrap_retries: 0,
allowed_updates: None,
drop_pending_updates: false,
get_updates: noop_get_updates(),
delete_webhook: noop_delete_webhook(),
}
}
#[tokio::test]
async fn lifecycle() {
let updater = Arc::new(Updater::new(16));
assert!(!updater.is_running());
updater.initialize().await;
assert!(updater.stop().await.is_err());
updater.shutdown().await.unwrap();
}
#[tokio::test]
async fn start_polling_requires_init() {
let updater = Arc::new(Updater::new(16));
let result = updater.start_polling(default_config()).await;
assert!(matches!(result, Err(UpdaterError::NotInitialized)));
}
#[tokio::test]
async fn start_and_stop_polling() {
let updater = Arc::new(Updater::new(16));
updater.initialize().await;
updater.start_polling(default_config()).await.unwrap();
assert!(updater.is_running());
let result = updater.start_polling(default_config()).await;
assert!(matches!(result, Err(UpdaterError::AlreadyRunning)));
updater.stop().await.unwrap();
assert!(!updater.is_running());
}
#[tokio::test]
async fn take_update_rx_once() {
let updater = Arc::new(Updater::new(16));
let rx = updater.take_update_rx().await;
assert!(rx.is_some());
let rx2 = updater.take_update_rx().await;
assert!(rx2.is_none());
}
#[tokio::test]
async fn polling_delivers_updates() {
let updater = Arc::new(Updater::new(16));
updater.initialize().await;
let mut rx = updater.take_update_rx().await.unwrap();
let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let get_fn: GetUpdatesFn = Arc::new(move |_offset, _timeout, _allowed| {
let cc = cc.clone();
Box::pin(async move {
let n = cc.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n == 0 {
Ok(vec![serde_json::json!({"update_id": 100, "message": {}})])
} else {
Ok(Vec::new())
}
})
});
let config = PollingConfig {
poll_interval: Duration::from_millis(10),
timeout: Duration::from_secs(1),
bootstrap_retries: 0,
allowed_updates: None,
drop_pending_updates: false,
get_updates: get_fn,
delete_webhook: noop_delete_webhook(),
};
updater.start_polling(config).await.unwrap();
let update = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for update")
.expect("channel closed");
assert_eq!(update["update_id"], 100);
updater.stop().await.unwrap();
}
}