use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, oneshot};
use tracing;
use crate::bot_api::BotApi;
use crate::error::{ApiError, HandlerError, HandlerResult};
use crate::screen::Screen;
use crate::types::*;
#[derive(Clone, Debug)]
pub enum EditTarget {
Chat {
chat_id: ChatId,
message_id: MessageId,
},
Inline {
inline_message_id: String,
},
}
pub type EditorFn =
Arc<dyn Fn(Screen) -> Pin<Box<dyn Future<Output = Result<(), ApiError>> + Send>> + Send + Sync>;
const DEFAULT_CHAT_INTERVAL: Duration = Duration::from_millis(1500);
const DEFAULT_INLINE_INTERVAL: Duration = Duration::from_millis(3000);
enum ProgressiveCmd {
Update(Screen),
Finalize(Screen, oneshot::Sender<Result<(), HandlerError>>),
}
pub struct ProgressiveHandle {
tx: mpsc::UnboundedSender<ProgressiveCmd>,
_task: tokio::task::JoinHandle<()>,
abort_handle: tokio::task::AbortHandle,
}
impl ProgressiveHandle {
pub fn abort_handle(&self) -> tokio::task::AbortHandle {
self.abort_handle.clone()
}
pub async fn update(&self, screen: Screen) {
let _ = self.tx.send(ProgressiveCmd::Update(screen));
}
pub async fn finalize(self, screen: Screen) -> HandlerResult {
let (done_tx, done_rx) = oneshot::channel();
if self
.tx
.send(ProgressiveCmd::Finalize(screen, done_tx))
.is_err()
{
return Ok(());
}
match done_rx.await {
Ok(result) => result,
Err(_) => Ok(()), }
}
}
async fn progressive_task(
mut rx: mpsc::UnboundedReceiver<ProgressiveCmd>,
editor: EditorFn,
min_interval: Duration,
) {
let mut last_edit = Instant::now() - min_interval; let mut pending: Option<Screen> = None;
loop {
let cmd = if pending.is_some() {
let elapsed = last_edit.elapsed();
if elapsed < min_interval {
let remaining = min_interval - elapsed;
tokio::select! {
cmd = rx.recv() => cmd,
_ = tokio::time::sleep(remaining) => {
if let Some(screen) = pending.take() {
do_edit(&editor, screen, &mut last_edit).await;
}
continue;
}
}
} else {
if let Some(screen) = pending.take() {
do_edit(&editor, screen, &mut last_edit).await;
}
rx.recv().await
}
} else {
rx.recv().await
};
match cmd {
None => {
if let Some(screen) = pending.take() {
do_edit(&editor, screen, &mut last_edit).await;
}
return;
}
Some(ProgressiveCmd::Update(screen)) => {
let elapsed = last_edit.elapsed();
if elapsed >= min_interval {
do_edit(&editor, screen, &mut last_edit).await;
} else {
pending = Some(screen);
}
}
Some(ProgressiveCmd::Finalize(screen, done_tx)) => {
let _ = pending.take();
let elapsed = last_edit.elapsed();
if elapsed < min_interval {
tokio::time::sleep(min_interval - elapsed).await;
}
let result = do_edit_result(&editor, screen).await;
let _ = done_tx.send(result);
return;
}
}
}
}
async fn do_edit(editor: &EditorFn, screen: Screen, last_edit: &mut Instant) {
match editor(screen).await {
Ok(()) => {
*last_edit = Instant::now();
}
Err(ApiError::MessageNotModified) => {
*last_edit = Instant::now();
}
Err(ApiError::TooManyRequests { retry_after }) => {
tracing::warn!("progressive edit rate-limited, waiting {}s", retry_after);
tokio::time::sleep(Duration::from_secs((retry_after as u64).min(30))).await;
*last_edit = Instant::now();
}
Err(e) => {
tracing::error!("progressive edit failed: {}", e);
*last_edit = Instant::now();
}
}
}
async fn do_edit_result(editor: &EditorFn, screen: Screen) -> HandlerResult {
match editor(screen.clone()).await {
Ok(()) => Ok(()),
Err(ApiError::MessageNotModified) => Ok(()),
Err(ApiError::TooManyRequests { retry_after }) => {
tracing::warn!(
"progressive finalize rate-limited, waiting {}s then retrying",
retry_after
);
tokio::time::sleep(Duration::from_secs(retry_after.min(30) as u64)).await;
match editor(screen).await {
Ok(()) => Ok(()),
Err(ApiError::MessageNotModified) => Ok(()),
Err(e) => Err(HandlerError::Api(e)),
}
}
Err(e) => Err(HandlerError::Api(e)),
}
}
pub async fn start_progressive(
bot: Arc<dyn BotApi>,
chat_id: ChatId,
initial: Screen,
) -> Result<ProgressiveHandle, ApiError> {
let first_content = initial
.messages
.into_iter()
.next()
.map(|m| m.content)
.unwrap_or_else(|| MessageContent::Text {
text: "…".to_string(),
parse_mode: ParseMode::Html,
keyboard: None,
link_preview: LinkPreview::Disabled,
});
let sent = bot
.send_message(
chat_id,
first_content,
crate::bot_api::SendOptions::default(),
)
.await?;
let message_id = sent.message_id;
let target = EditTarget::Chat {
chat_id,
message_id,
};
let editor = make_bot_editor(bot, target);
Ok(spawn_progressive(editor, DEFAULT_CHAT_INTERVAL))
}
pub fn start_progressive_inline(editor: EditorFn) -> ProgressiveHandle {
spawn_progressive(editor, DEFAULT_INLINE_INTERVAL)
}
pub fn start_progressive_with_editor(
editor: EditorFn,
min_interval: Duration,
) -> ProgressiveHandle {
spawn_progressive(editor, min_interval)
}
fn make_bot_editor(bot: Arc<dyn BotApi>, target: EditTarget) -> EditorFn {
Arc::new(move |screen: Screen| {
let bot = bot.clone();
let target = target.clone();
Box::pin(async move {
let first = screen
.messages
.into_iter()
.next()
.map(|m| m.content)
.unwrap_or_else(|| MessageContent::Text {
text: "…".to_string(),
parse_mode: ParseMode::Html,
keyboard: None,
link_preview: LinkPreview::Disabled,
});
match target {
EditTarget::Chat {
chat_id,
message_id,
} => edit_chat_message(&*bot, chat_id, message_id, first).await,
EditTarget::Inline {
inline_message_id: _,
} => {
tracing::warn!(
"EditTarget::Inline not yet supported via make_bot_editor; \
use start_progressive_with_editor instead"
);
Ok(())
}
}
})
})
}
async fn edit_chat_message(
bot: &dyn BotApi,
chat_id: ChatId,
message_id: MessageId,
content: MessageContent,
) -> Result<(), ApiError> {
match content {
MessageContent::Text {
text,
parse_mode,
keyboard,
link_preview,
} => {
bot.edit_message_text(
chat_id,
message_id,
text,
parse_mode,
keyboard,
link_preview == LinkPreview::Enabled,
)
.await
}
MessageContent::Photo { ref keyboard, .. }
| MessageContent::Video { ref keyboard, .. }
| MessageContent::Animation { ref keyboard, .. }
| MessageContent::Document { ref keyboard, .. } => {
let kb = keyboard.clone();
bot.edit_message_media(chat_id, message_id, content_with_no_keyboard(content), kb)
.await
}
other => {
tracing::warn!(
"progressive edit: unsupported content type {:?}, skipping",
other.content_type()
);
Ok(())
}
}
}
fn content_with_no_keyboard(content: MessageContent) -> MessageContent {
match content {
MessageContent::Photo {
source,
caption,
parse_mode,
spoiler,
..
} => MessageContent::Photo {
source,
caption,
parse_mode,
keyboard: None,
spoiler,
},
MessageContent::Video {
source,
caption,
parse_mode,
spoiler,
..
} => MessageContent::Video {
source,
caption,
parse_mode,
keyboard: None,
spoiler,
},
MessageContent::Animation {
source,
caption,
parse_mode,
spoiler,
..
} => MessageContent::Animation {
source,
caption,
parse_mode,
keyboard: None,
spoiler,
},
MessageContent::Document {
source,
caption,
parse_mode,
filename,
..
} => MessageContent::Document {
source,
caption,
parse_mode,
keyboard: None,
filename,
},
other => other,
}
}
fn spawn_progressive(editor: EditorFn, min_interval: Duration) -> ProgressiveHandle {
let (tx, rx) = mpsc::unbounded_channel();
let task = tokio::spawn(progressive_task(rx, editor, min_interval));
let abort_handle = task.abort_handle();
ProgressiveHandle {
tx,
_task: task,
abort_handle,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::screen::Screen;
use std::sync::atomic::{AtomicUsize, Ordering};
fn dummy_screen(text: &str) -> Screen {
Screen::text("test", text).build()
}
#[tokio::test]
async fn finalize_always_delivers() {
let call_count = Arc::new(AtomicUsize::new(0));
let last_text = Arc::new(tokio::sync::Mutex::new(String::new()));
let cc = call_count.clone();
let lt = last_text.clone();
let editor: EditorFn = Arc::new(move |screen: Screen| {
let cc = cc.clone();
let lt = lt.clone();
Box::pin(async move {
cc.fetch_add(1, Ordering::SeqCst);
if let Some(msg) = screen.messages.first() {
if let MessageContent::Text { text, .. } = &msg.content {
*lt.lock().await = text.clone();
}
}
Ok(())
})
});
let handle = start_progressive_with_editor(editor, Duration::from_millis(50));
for i in 0..10 {
handle.update(dummy_screen(&format!("update {}", i))).await;
}
let result = handle.finalize(dummy_screen("final")).await;
assert!(result.is_ok());
let text = last_text.lock().await.clone();
assert_eq!(text, "final");
let count = call_count.load(Ordering::SeqCst);
assert!(count >= 1, "at least the finalize should be delivered");
assert!(count <= 11, "should not exceed total updates + finalize");
}
#[tokio::test]
async fn coalesces_rapid_updates() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = call_count.clone();
let editor: EditorFn = Arc::new(move |_screen: Screen| {
let cc = cc.clone();
Box::pin(async move {
cc.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(())
})
});
let handle = start_progressive_with_editor(editor, Duration::from_millis(100));
for i in 0..50 {
handle.update(dummy_screen(&format!("u{}", i))).await;
}
tokio::time::sleep(Duration::from_millis(500)).await;
handle.finalize(dummy_screen("done")).await.unwrap();
let count = call_count.load(Ordering::SeqCst);
assert!(
count < 50,
"expected coalescing to reduce edits, got {}",
count
);
}
#[tokio::test]
async fn handles_message_not_modified() {
let editor: EditorFn =
Arc::new(|_screen: Screen| Box::pin(async move { Err(ApiError::MessageNotModified) }));
let handle = start_progressive_with_editor(editor, Duration::from_millis(10));
handle.update(dummy_screen("same")).await;
tokio::time::sleep(Duration::from_millis(50)).await;
let result = handle.finalize(dummy_screen("same")).await;
assert!(result.is_ok());
}
}