use crate::channel::{Channel, ChannelType, EditMessage, OutboundMessage};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::time::Instant;
#[derive(Debug, Clone)]
pub struct MessageRef {
pub channel_type: ChannelType,
pub account_id: String,
pub recipient_id: String,
pub message_id: Option<String>,
pub reply_to: Option<String>,
}
#[async_trait]
pub trait DeliveryStrategy: Send + Sync {
async fn on_partial(&self, text: &str, msg_ref: &MessageRef) -> anyhow::Result<()>;
async fn on_complete(&self, text: &str, msg_ref: &MessageRef) -> anyhow::Result<()>;
}
struct StreamingState {
sent_message_id: Option<String>,
last_edit: Option<Instant>,
latest_text: String,
}
pub struct StreamingDelivery {
channel: Arc<dyn Channel>,
state: Mutex<StreamingState>,
edit_interval: std::time::Duration,
}
impl StreamingDelivery {
pub fn new(channel: Arc<dyn Channel>) -> Self {
Self {
channel,
state: Mutex::new(StreamingState {
sent_message_id: None,
last_edit: None,
latest_text: String::new(),
}),
edit_interval: std::time::Duration::from_secs(1),
}
}
async fn try_edit(
&self,
msg_ref: &MessageRef,
text: &str,
message_id: &str,
) -> anyhow::Result<()> {
let edit_msg = EditMessage {
channel_type: msg_ref.channel_type,
account_id: msg_ref.account_id.clone(),
message_id: message_id.to_string(),
recipient_id: msg_ref.recipient_id.clone(),
text: text.to_string(),
};
match self.channel.edit(edit_msg).await {
Ok(()) => Ok(()),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("message is not modified") {
tracing::debug!("edit skipped: message content unchanged");
return Ok(());
}
tracing::warn!(error = %e, "edit failed, retrying after rate-limit window");
tokio::time::sleep(self.edit_interval).await;
let state = self.state.lock().await;
let retry_text = if state.latest_text.is_empty() {
text.to_string()
} else {
state.latest_text.clone()
};
drop(state);
let retry_msg = EditMessage {
channel_type: msg_ref.channel_type,
account_id: msg_ref.account_id.clone(),
message_id: message_id.to_string(),
recipient_id: msg_ref.recipient_id.clone(),
text: retry_text,
};
self.channel.edit(retry_msg).await
}
}
}
}
#[async_trait]
impl DeliveryStrategy for StreamingDelivery {
async fn on_partial(&self, text: &str, msg_ref: &MessageRef) -> anyhow::Result<()> {
let mut state = self.state.lock().await;
state.latest_text = text.to_string();
if state.sent_message_id.is_none() {
let outbound = OutboundMessage {
channel_type: msg_ref.channel_type,
account_id: msg_ref.account_id.clone(),
recipient_id: msg_ref.recipient_id.clone(),
text: text.to_string(),
reply_to: msg_ref.reply_to.clone(),
is_partial: true,
};
let sent_id = self.channel.send(outbound).await?;
state.sent_message_id = sent_id.or_else(|| msg_ref.message_id.clone());
state.last_edit = Some(Instant::now());
return Ok(());
}
if let Some(last) = state.last_edit {
if last.elapsed() < self.edit_interval {
return Ok(());
}
}
let message_id = state.sent_message_id.clone().unwrap();
state.last_edit = Some(Instant::now());
drop(state);
self.try_edit(msg_ref, text, &message_id).await
}
async fn on_complete(&self, text: &str, msg_ref: &MessageRef) -> anyhow::Result<()> {
let mut state = self.state.lock().await;
state.latest_text = text.to_string();
if let Some(ref message_id) = state.sent_message_id {
let message_id = message_id.clone();
drop(state);
self.try_edit(msg_ref, text, &message_id).await
} else {
drop(state);
let outbound = OutboundMessage {
channel_type: msg_ref.channel_type,
account_id: msg_ref.account_id.clone(),
recipient_id: msg_ref.recipient_id.clone(),
text: text.to_string(),
reply_to: msg_ref.reply_to.clone(),
is_partial: false,
};
self.channel.send(outbound).await?;
Ok(())
}
}
}
pub struct BatchDelivery {
channel: Arc<dyn Channel>,
}
impl BatchDelivery {
pub fn new(channel: Arc<dyn Channel>) -> Self {
Self { channel }
}
}
#[async_trait]
impl DeliveryStrategy for BatchDelivery {
async fn on_partial(&self, _text: &str, _msg_ref: &MessageRef) -> anyhow::Result<()> {
Ok(())
}
async fn on_complete(&self, text: &str, msg_ref: &MessageRef) -> anyhow::Result<()> {
let outbound = OutboundMessage {
channel_type: msg_ref.channel_type,
account_id: msg_ref.account_id.clone(),
recipient_id: msg_ref.recipient_id.clone(),
text: text.to_string(),
reply_to: msg_ref.reply_to.clone(),
is_partial: false,
};
self.channel.send(outbound).await?;
Ok(())
}
}
pub fn select_strategy(
channel: Arc<dyn Channel>,
stream_mode: Option<&str>,
) -> Arc<dyn DeliveryStrategy> {
let use_streaming =
channel.supports_editing() && stream_mode.map(|m| m != "complete").unwrap_or(true);
if use_streaming {
tracing::debug!(
channel = %channel.channel_type(),
"using streaming delivery (edit-in-place)"
);
Arc::new(StreamingDelivery::new(channel))
} else {
tracing::debug!(
channel = %channel.channel_type(),
"using batch delivery (single message)"
);
Arc::new(BatchDelivery::new(channel))
}
}
pub fn split_message(text: &str, max_len: usize) -> Vec<String> {
if max_len == 0 {
return vec![text.to_string()];
}
if text.len() <= max_len {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut remaining = text;
while !remaining.is_empty() {
if remaining.len() <= max_len {
chunks.push(remaining.to_string());
break;
}
let slice = &remaining[..max_len];
let at = slice
.rfind('\n')
.or_else(|| slice.rfind(' '))
.unwrap_or(max_len);
let at = if at == 0 { max_len } else { at };
chunks.push(remaining[..at].to_string());
remaining = remaining[at..].trim_start();
}
chunks
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channel::{Channel, ChannelType, EditMessage, InboundMessage, OutboundMessage};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::mpsc;
struct MockChannel {
editing_supported: bool,
send_count: AtomicU32,
edit_count: AtomicU32,
fail_first_n_edits: AtomicU32,
}
impl MockChannel {
fn new(editing_supported: bool) -> Self {
Self {
editing_supported,
send_count: AtomicU32::new(0),
edit_count: AtomicU32::new(0),
fail_first_n_edits: AtomicU32::new(0),
}
}
fn with_failing_edits(mut self, n: u32) -> Self {
self.fail_first_n_edits = AtomicU32::new(n);
self
}
fn sends(&self) -> u32 {
self.send_count.load(Ordering::SeqCst)
}
fn edits(&self) -> u32 {
self.edit_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Channel for MockChannel {
fn channel_type(&self) -> ChannelType {
ChannelType::Telegram
}
async fn start(&self, _tx: mpsc::Sender<InboundMessage>) -> anyhow::Result<()> {
Ok(())
}
async fn send(&self, _msg: OutboundMessage) -> anyhow::Result<Option<String>> {
self.send_count.fetch_add(1, Ordering::SeqCst);
Ok(Some("mock-msg-123".to_string()))
}
async fn edit(&self, _msg: EditMessage) -> anyhow::Result<()> {
let remaining = self.fail_first_n_edits.load(Ordering::SeqCst);
if remaining > 0 {
self.fail_first_n_edits.fetch_sub(1, Ordering::SeqCst);
return Err(anyhow::anyhow!("rate limited"));
}
self.edit_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn supports_editing(&self) -> bool {
self.editing_supported
}
async fn shutdown(&self) -> anyhow::Result<()> {
Ok(())
}
}
fn test_msg_ref() -> MessageRef {
MessageRef {
channel_type: ChannelType::Telegram,
account_id: "default".to_string(),
recipient_id: "user123".to_string(),
message_id: Some("msg_1".to_string()),
reply_to: Some("orig_msg".to_string()),
}
}
#[tokio::test]
async fn test_batch_ignores_partials() {
let ch = Arc::new(MockChannel::new(false));
let strategy = BatchDelivery::new(ch.clone());
let msg_ref = test_msg_ref();
strategy.on_partial("hello", &msg_ref).await.unwrap();
strategy.on_partial("hello world", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 0);
assert_eq!(ch.edits(), 0);
}
#[tokio::test]
async fn test_batch_sends_on_complete() {
let ch = Arc::new(MockChannel::new(false));
let strategy = BatchDelivery::new(ch.clone());
let msg_ref = test_msg_ref();
strategy.on_complete("final text", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 1);
assert_eq!(ch.edits(), 0);
}
#[tokio::test]
async fn test_streaming_sends_placeholder_on_first_partial() {
let ch = Arc::new(MockChannel::new(true));
let strategy = StreamingDelivery::new(ch.clone());
let msg_ref = test_msg_ref();
strategy.on_partial("hel", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 1, "should send placeholder on first partial");
assert_eq!(ch.edits(), 0, "no edits yet on first partial");
}
#[tokio::test]
async fn test_streaming_rate_limits_edits() {
let ch = Arc::new(MockChannel::new(true));
let strategy = StreamingDelivery::new(ch.clone());
let msg_ref = test_msg_ref();
strategy.on_partial("a", &msg_ref).await.unwrap();
strategy.on_partial("ab", &msg_ref).await.unwrap();
strategy.on_partial("abc", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 1, "only one send (placeholder)");
assert_eq!(ch.edits(), 0, "edits throttled within 1 sec");
}
#[tokio::test]
async fn test_streaming_complete_performs_final_edit() {
let ch = Arc::new(MockChannel::new(true));
let strategy = StreamingDelivery::new(ch.clone());
let msg_ref = test_msg_ref();
strategy.on_partial("partial", &msg_ref).await.unwrap();
strategy.on_complete("final", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 1, "placeholder send");
assert_eq!(ch.edits(), 1, "final edit");
}
#[tokio::test]
async fn test_streaming_complete_without_partial_sends_message() {
let ch = Arc::new(MockChannel::new(true));
let strategy = StreamingDelivery::new(ch.clone());
let msg_ref = test_msg_ref();
strategy
.on_complete("instant response", &msg_ref)
.await
.unwrap();
assert_eq!(ch.sends(), 1, "should send as regular message");
assert_eq!(ch.edits(), 0, "no edits needed");
}
#[tokio::test]
async fn test_select_strategy_streaming_when_editing_supported() {
let ch = Arc::new(MockChannel::new(true));
let strategy = select_strategy(ch, None);
let msg_ref = test_msg_ref();
strategy.on_partial("test", &msg_ref).await.unwrap();
}
#[tokio::test]
async fn test_select_strategy_batch_when_complete_mode() {
let ch = Arc::new(MockChannel::new(true));
let strategy = select_strategy(ch.clone(), Some("complete"));
let msg_ref = test_msg_ref();
strategy.on_partial("test", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 0, "batch ignores partials");
}
#[tokio::test]
async fn test_select_strategy_batch_when_no_editing() {
let ch = Arc::new(MockChannel::new(false));
let strategy = select_strategy(ch.clone(), Some("partial"));
let msg_ref = test_msg_ref();
strategy.on_partial("test", &msg_ref).await.unwrap();
assert_eq!(
ch.sends(),
0,
"batch ignores partials even with partial mode"
);
}
#[tokio::test]
async fn test_streaming_edit_failure_retries() {
let ch = Arc::new(MockChannel::new(true).with_failing_edits(1));
let strategy = StreamingDelivery {
channel: ch.clone(),
state: Mutex::new(StreamingState {
sent_message_id: None,
last_edit: None,
latest_text: String::new(),
}),
edit_interval: std::time::Duration::from_millis(10),
};
let msg_ref = test_msg_ref();
strategy.on_partial("a", &msg_ref).await.unwrap();
assert_eq!(ch.sends(), 1);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
strategy.on_partial("ab", &msg_ref).await.unwrap();
assert_eq!(ch.edits(), 1, "retry should succeed");
}
}