use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, RwLock};
use tracing::{debug, error, warn};
use rust_tg_bot_raw::types::update::Update;
use super::base::{Handler, HandlerResult, MatchResult};
pub type ConversationKey = Vec<i64>;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ConversationResult<S> {
NextState(S),
End,
Stay,
}
pub type ConversationCallback<S> = Arc<
dyn Fn(
Arc<Update>,
MatchResult,
) -> Pin<Box<dyn Future<Output = (HandlerResult, ConversationResult<S>)> + Send>>
+ Send
+ Sync,
>;
pub struct ConversationStepHandler<S: Hash + Eq + Clone + Send + Sync + 'static> {
pub handler: Box<dyn Handler>,
pub conv_callback: ConversationCallback<S>,
}
pub struct ConversationHandler<S: Hash + Eq + Clone + Send + Sync + 'static> {
entry_points: Vec<ConversationStepHandler<S>>,
states: HashMap<S, Vec<ConversationStepHandler<S>>>,
fallbacks: Vec<ConversationStepHandler<S>>,
conversations: Arc<RwLock<HashMap<ConversationKey, S>>>,
allow_reentry: bool,
per_chat: bool,
per_user: bool,
per_message: bool,
conversation_timeout: Option<Duration>,
map_to_parent: Option<HashMap<S, S>>,
timeout_handlers: Vec<ConversationStepHandler<S>>,
timeout_cancellers: Arc<RwLock<HashMap<ConversationKey, watch::Sender<bool>>>>,
persistent: bool,
name: Option<String>,
pending_callbacks: Arc<RwLock<HashSet<ConversationKey>>>,
}
impl<S: Hash + Eq + Clone + Send + Sync + 'static> ConversationHandler<S> {
pub fn builder() -> ConversationHandlerBuilder<S> {
ConversationHandlerBuilder::default()
}
fn build_key(&self, update: &Update) -> Option<ConversationKey> {
let mut key = Vec::new();
if self.per_chat {
let chat = update.effective_chat()?;
key.push(chat.id);
}
if self.per_user {
let user = update.effective_user()?;
key.push(user.id);
}
if self.per_message {
let cq = update.callback_query()?;
if let Some(ref inline_id) = cq.inline_message_id {
use std::hash::Hasher;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
hasher.write(inline_id.as_bytes());
key.push(hasher.finish() as i64);
} else if let Some(ref msg) = cq.message {
key.push(msg.message_id());
} else {
return None;
}
}
if key.is_empty() {
return None;
}
Some(key)
}
fn find_matching(
handlers: &[ConversationStepHandler<S>],
update: &Update,
) -> Option<(usize, MatchResult)> {
for (idx, step) in handlers.iter().enumerate() {
if let Some(mr) = step.handler.check_update(update) {
return Some((idx, mr));
}
}
None
}
pub async fn get_state(&self, key: &ConversationKey) -> Option<S> {
self.conversations.read().await.get(key).cloned()
}
pub async fn active_conversations(&self) -> HashMap<ConversationKey, S> {
self.conversations.read().await.clone()
}
pub async fn load_conversations(&self, data: HashMap<ConversationKey, S>) {
*self.conversations.write().await = data;
}
pub async fn save_conversations(&self) -> HashMap<ConversationKey, S> {
self.conversations.read().await.clone()
}
pub fn is_persistent(&self) -> bool {
self.persistent
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
async fn apply_state_transition(
conversations: &RwLock<HashMap<ConversationKey, S>>,
pending_callbacks: &RwLock<HashSet<ConversationKey>>,
key: &ConversationKey,
conv_result: ConversationResult<S>,
current_state: &Option<S>,
map_to_parent: &Option<HashMap<S, S>>,
) -> Option<S> {
match conv_result {
ConversationResult::End => {
conversations.write().await.remove(key);
pending_callbacks.write().await.remove(key);
None
}
ConversationResult::Stay => current_state.clone(),
ConversationResult::NextState(s) => {
if let Some(ref mtp) = map_to_parent {
if mtp.contains_key(&s) {
conversations.write().await.remove(key);
pending_callbacks.write().await.remove(key);
debug!(
"ConversationHandler: map_to_parent triggered for key {:?}",
key
);
return None;
}
}
Some(s)
}
}
}
fn spawn_timeout(
conversations: Arc<RwLock<HashMap<ConversationKey, S>>>,
pending_callbacks: Arc<RwLock<HashSet<ConversationKey>>>,
timeout_cancellers: Arc<RwLock<HashMap<ConversationKey, watch::Sender<bool>>>>,
key: ConversationKey,
update: Arc<Update>,
duration: Duration,
timeout_cbs: Vec<ConversationCallback<S>>,
) -> watch::Sender<bool> {
let (cancel_tx, mut cancel_rx) = watch::channel(false);
let key2 = key.clone();
tokio::spawn(async move {
tokio::select! {
_ = tokio::time::sleep(duration) => {
for cb in &timeout_cbs {
let _ = cb(update.clone(), MatchResult::Empty).await;
}
conversations.write().await.remove(&key2);
pending_callbacks.write().await.remove(&key2);
timeout_cancellers.write().await.remove(&key2);
debug!("Conversation {:?} timed out", key2);
}
_ = cancel_rx.changed() => {
debug!("Timeout cancelled for {:?}", key2);
}
}
});
cancel_tx
}
}
impl<S: Hash + Eq + Clone + Send + Sync + 'static> Handler for ConversationHandler<S> {
fn check_update(&self, update: &Update) -> Option<MatchResult> {
if update.channel_post().is_some() || update.edited_channel_post().is_some() {
return None;
}
let key = self.build_key(update)?;
if let Ok(pending) = self.pending_callbacks.try_read() {
if pending.contains(&key) {
debug!(
"ConversationHandler: skipping update for {:?} (pending callback)",
key
);
return None;
}
}
let current_state = match self.conversations.try_read() {
Ok(guard) => guard.get(&key).cloned(),
Err(_) => {
debug!(
"ConversationHandler: conversations lock contended, skipping {:?}",
key
);
return None;
}
};
match current_state {
None => {
if Self::find_matching(&self.entry_points, update).is_some() {
return Some(MatchResult::Empty);
}
}
Some(ref state) => {
if self.allow_reentry && Self::find_matching(&self.entry_points, update).is_some() {
return Some(MatchResult::Empty);
}
if let Some(handlers) = self.states.get(state) {
if Self::find_matching(handlers, update).is_some() {
return Some(MatchResult::Empty);
}
}
if Self::find_matching(&self.fallbacks, update).is_some() {
return Some(MatchResult::Empty);
}
}
}
None
}
fn handle_update(
&self,
update: Arc<Update>,
_match_result: MatchResult,
) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
let conversations = Arc::clone(&self.conversations);
let pending_callbacks = Arc::clone(&self.pending_callbacks);
let allow_reentry = self.allow_reentry;
#[derive(Debug, Clone, Copy)]
enum HandlerSource {
EntryPoint(usize),
State(usize),
Fallback(usize),
}
let key = self.build_key(&update);
let current_state = key.as_ref().and_then(|k| {
self.conversations
.try_read()
.ok()
.and_then(|g| g.get(k).cloned())
});
let mut source = None;
let mut match_result = MatchResult::Empty;
let check_entries = current_state.is_none() || allow_reentry;
if check_entries {
if let Some((idx, mr)) = Self::find_matching(&self.entry_points, &update) {
source = Some(HandlerSource::EntryPoint(idx));
match_result = mr;
}
}
if source.is_none() {
if let Some(ref state) = current_state {
if let Some(handlers) = self.states.get(state) {
if let Some((idx, mr)) = Self::find_matching(handlers, &update) {
source = Some(HandlerSource::State(idx));
match_result = mr;
}
}
}
}
if source.is_none() {
if let Some((idx, mr)) = Self::find_matching(&self.fallbacks, &update) {
source = Some(HandlerSource::Fallback(idx));
match_result = mr;
}
}
let conv_cb = match source {
Some(HandlerSource::EntryPoint(idx)) => {
Arc::clone(&self.entry_points[idx].conv_callback)
}
Some(HandlerSource::State(idx)) => {
let mut cb = None;
if let Some(ref state) = current_state {
if let Some(handlers) = self.states.get(state) {
if idx < handlers.len() {
cb = Some(Arc::clone(&handlers[idx].conv_callback));
}
}
}
cb.unwrap_or_else(|| {
Arc::new(|_u, _m| {
Box::pin(async { (HandlerResult::Continue, ConversationResult::Stay) })
})
})
}
Some(HandlerSource::Fallback(idx)) => Arc::clone(&self.fallbacks[idx].conv_callback),
None => {
return Box::pin(async { HandlerResult::Continue });
}
};
let is_entry = matches!(source, Some(HandlerSource::EntryPoint(_)));
let is_blocking = match source {
Some(HandlerSource::EntryPoint(idx)) => self.entry_points[idx].handler.block(),
Some(HandlerSource::State(idx)) => current_state
.as_ref()
.and_then(|s| self.states.get(s))
.and_then(|handlers| handlers.get(idx))
.map_or(true, |step| step.handler.block()),
Some(HandlerSource::Fallback(idx)) => self.fallbacks[idx].handler.block(),
None => true,
};
let map_to_parent = self.map_to_parent.clone();
let has_timeout = self.conversation_timeout.is_some();
let timeout_cancellers = Arc::clone(&self.timeout_cancellers);
let timeout_duration = self.conversation_timeout;
let timeout_cbs: Vec<_> = self
.timeout_handlers
.iter()
.map(|step| Arc::clone(&step.conv_callback))
.collect();
let is_persistent = self.persistent;
let _handler_name = self.name.clone();
Box::pin(async move {
let key = match key {
Some(k) => k,
None => return HandlerResult::Continue,
};
let current_state = conversations.read().await.get(&key).cloned();
if is_entry && current_state.is_some() && !allow_reentry {
debug!("ConversationHandler: ignoring re-entry for key {:?}", key);
return HandlerResult::Continue;
}
if has_timeout {
if let Some(tx) = timeout_cancellers.write().await.remove(&key) {
let _ = tx.send(true);
}
}
if !is_blocking {
pending_callbacks.write().await.insert(key.clone());
let conversations2 = Arc::clone(&conversations);
let pending2 = Arc::clone(&pending_callbacks);
let map_to_parent2 = map_to_parent.clone();
let key2 = key.clone();
let current_state2 = current_state.clone();
let update2 = update.clone();
let timeout_cancellers2 = Arc::clone(&timeout_cancellers);
let timeout_cbs2 = timeout_cbs;
tokio::spawn(async move {
let result = tokio::spawn(conv_cb(update2.clone(), match_result)).await;
match result {
Ok((_handler_result, conv_result)) => {
let new_state = Self::apply_state_transition(
&conversations2,
&pending2,
&key2,
conv_result,
¤t_state2,
&map_to_parent2,
)
.await;
if let Some(new_s) = new_state {
conversations2.write().await.insert(key2.clone(), new_s);
}
}
Err(join_err) => {
error!(
"ConversationHandler: non-blocking callback failed for {:?}: {}. \
Reverting to previous state.",
key2, join_err
);
if let Some(ref prev) = current_state2 {
conversations2
.write()
.await
.insert(key2.clone(), prev.clone());
} else {
conversations2.write().await.remove(&key2);
}
}
}
pending2.write().await.remove(&key2);
if has_timeout {
if let Some(duration) = timeout_duration {
let cancel_tx = Self::spawn_timeout(
Arc::clone(&conversations2),
Arc::clone(&pending2),
Arc::clone(&timeout_cancellers2),
key2.clone(),
update2,
duration,
timeout_cbs2,
);
timeout_cancellers2.write().await.insert(key2, cancel_tx);
}
}
});
return HandlerResult::Continue;
}
let (handler_result, conv_result) = conv_cb(update.clone(), match_result).await;
let new_state = Self::apply_state_transition(
&conversations,
&pending_callbacks,
&key,
conv_result,
¤t_state,
&map_to_parent,
)
.await;
if new_state.is_none() && !conversations.read().await.contains_key(&key) {
return handler_result;
}
if let Some(new_s) = new_state {
conversations.write().await.insert(key.clone(), new_s);
}
if has_timeout {
if let Some(duration) = timeout_duration {
let cancel_tx = Self::spawn_timeout(
Arc::clone(&conversations),
Arc::clone(&pending_callbacks),
Arc::clone(&timeout_cancellers),
key.clone(),
update,
duration,
timeout_cbs,
);
timeout_cancellers.write().await.insert(key, cancel_tx);
}
}
if is_persistent {
debug!("ConversationHandler: state changed (persistent handler)");
}
handler_result
})
}
fn block(&self) -> bool {
true
}
}
pub struct ConversationHandlerBuilder<S: Hash + Eq + Clone + Send + Sync + 'static> {
entry_points: Vec<ConversationStepHandler<S>>,
states: HashMap<S, Vec<ConversationStepHandler<S>>>,
fallbacks: Vec<ConversationStepHandler<S>>,
allow_reentry: bool,
per_chat: bool,
per_user: bool,
per_message: bool,
conversation_timeout: Option<Duration>,
name: Option<String>,
map_to_parent: Option<HashMap<S, S>>,
timeout_handlers: Vec<ConversationStepHandler<S>>,
persistent: bool,
}
impl<S: Hash + Eq + Clone + Send + Sync + 'static> Default for ConversationHandlerBuilder<S> {
fn default() -> Self {
Self {
entry_points: Vec::new(),
states: HashMap::new(),
fallbacks: Vec::new(),
allow_reentry: false,
per_chat: true,
per_user: true,
per_message: false,
conversation_timeout: None,
name: None,
map_to_parent: None,
timeout_handlers: Vec::new(),
persistent: false,
}
}
}
impl<S: Hash + Eq + Clone + Send + Sync + 'static> ConversationHandlerBuilder<S> {
pub fn entry_point(mut self, handler: ConversationStepHandler<S>) -> Self {
self.entry_points.push(handler);
self
}
pub fn entry_points(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
self.entry_points.extend(handlers);
self
}
pub fn state(mut self, state: S, handlers: Vec<ConversationStepHandler<S>>) -> Self {
self.states.insert(state, handlers);
self
}
pub fn fallback(mut self, handler: ConversationStepHandler<S>) -> Self {
self.fallbacks.push(handler);
self
}
pub fn fallbacks(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
self.fallbacks.extend(handlers);
self
}
pub fn allow_reentry(mut self, allow: bool) -> Self {
self.allow_reentry = allow;
self
}
pub fn per_chat(mut self, enabled: bool) -> Self {
self.per_chat = enabled;
self
}
pub fn per_user(mut self, enabled: bool) -> Self {
self.per_user = enabled;
self
}
pub fn per_message(mut self, enabled: bool) -> Self {
self.per_message = enabled;
self
}
pub fn conversation_timeout(mut self, timeout: Duration) -> Self {
self.conversation_timeout = Some(timeout);
self
}
pub fn name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn map_to_parent(mut self, mapping: HashMap<S, S>) -> Self {
self.map_to_parent = Some(mapping);
self
}
pub fn timeout_handlers(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
self.timeout_handlers = handlers;
self
}
pub fn timeout_handler(mut self, handler: ConversationStepHandler<S>) -> Self {
self.timeout_handlers.push(handler);
self
}
pub fn persistent(mut self, enabled: bool) -> Self {
self.persistent = enabled;
self
}
pub fn build(self) -> ConversationHandler<S> {
assert!(
self.per_chat || self.per_user || self.per_message,
"At least one of per_chat, per_user, per_message must be true"
);
if self.persistent && self.name.is_none() {
panic!("Conversations can't be persistent when handler is unnamed");
}
if self.per_message && !self.per_chat {
warn!(
"ConversationHandler: per_message=true without per_chat=true \
-- message IDs are not globally unique"
);
}
ConversationHandler {
entry_points: self.entry_points,
states: self.states,
fallbacks: self.fallbacks,
conversations: Arc::new(RwLock::new(HashMap::new())),
allow_reentry: self.allow_reentry,
per_chat: self.per_chat,
per_user: self.per_user,
per_message: self.per_message,
conversation_timeout: self.conversation_timeout,
map_to_parent: self.map_to_parent,
timeout_handlers: self.timeout_handlers,
timeout_cancellers: Arc::new(RwLock::new(HashMap::new())),
persistent: self.persistent,
name: self.name,
pending_callbacks: Arc::new(RwLock::new(HashSet::new())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::Arc;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
enum TestState {
AskName,
AskAge,
}
fn always_match_handler() -> Box<dyn Handler> {
struct AlwaysMatch;
impl Handler for AlwaysMatch {
fn check_update(&self, update: &Update) -> Option<MatchResult> {
if update.message().is_some() {
Some(MatchResult::Empty)
} else {
None
}
}
fn handle_update(
&self,
_update: Arc<Update>,
_match_result: MatchResult,
) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
Box::pin(async { HandlerResult::Continue })
}
}
Box::new(AlwaysMatch)
}
fn never_match_handler() -> Box<dyn Handler> {
struct NeverMatch;
impl Handler for NeverMatch {
fn check_update(&self, _update: &Update) -> Option<MatchResult> {
None
}
fn handle_update(
&self,
_update: Arc<Update>,
_match_result: MatchResult,
) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
Box::pin(async { HandlerResult::Continue })
}
}
Box::new(NeverMatch)
}
fn make_step<S: Hash + Eq + Clone + Send + Sync + 'static>(
handler: Box<dyn Handler>,
result: ConversationResult<S>,
) -> ConversationStepHandler<S> {
ConversationStepHandler {
handler,
conv_callback: Arc::new(move |_u, _m| {
let r = result.clone();
Box::pin(async move { (HandlerResult::Continue, r) })
}),
}
}
fn make_update(chat_id: i64, user_id: i64) -> Update {
serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1,
"date": 0,
"chat": {"id": chat_id, "type": "private"},
"from": {"id": user_id, "is_bot": false, "first_name": "Test"}
}
}))
.expect("test update JSON must be valid")
}
fn make_channel_post_update() -> Update {
serde_json::from_value(json!({
"update_id": 1,
"channel_post": {
"message_id": 1,
"date": 0,
"chat": {"id": -100, "type": "channel", "title": "Test"}
}
}))
.expect("test update JSON must be valid")
}
#[tokio::test]
async fn state_transition_entry_to_state1_to_state2_to_end() {
let conv = ConversationHandler::builder()
.entry_point(make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskName),
))
.state(
TestState::AskName,
vec![make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskAge),
)],
)
.state(
TestState::AskAge,
vec![make_step(always_match_handler(), ConversationResult::End)],
)
.build();
let key = vec![100i64, 42i64];
let update = Arc::new(make_update(100, 42));
assert!(conv.check_update(&update).is_some());
conv.handle_update(update.clone(), MatchResult::Empty).await;
assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
assert!(conv.check_update(&update).is_some());
conv.handle_update(update.clone(), MatchResult::Empty).await;
assert_eq!(conv.get_state(&key).await, Some(TestState::AskAge));
assert!(conv.check_update(&update).is_some());
conv.handle_update(update.clone(), MatchResult::Empty).await;
assert_eq!(conv.get_state(&key).await, None);
}
#[tokio::test]
async fn timeout_removes_conversation() {
let conv = ConversationHandler::builder()
.entry_point(make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskName),
))
.state(
TestState::AskName,
vec![make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskAge),
)],
)
.conversation_timeout(Duration::from_millis(50))
.build();
let key = vec![100i64, 42i64];
let update = Arc::new(make_update(100, 42));
conv.handle_update(update.clone(), MatchResult::Empty).await;
assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
tokio::time::sleep(Duration::from_millis(120)).await;
assert_eq!(conv.get_state(&key).await, None);
}
#[tokio::test]
async fn fallback_triggers_on_unmatched_input() {
let conv = ConversationHandler::builder()
.entry_point(make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskName),
))
.state(
TestState::AskName,
vec![make_step(
never_match_handler(), ConversationResult::NextState(TestState::AskAge),
)],
)
.fallback(make_step(always_match_handler(), ConversationResult::End))
.build();
let key = vec![100i64, 42i64];
let update = Arc::new(make_update(100, 42));
conv.handle_update(update.clone(), MatchResult::Empty).await;
assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
assert!(conv.check_update(&update).is_some());
conv.handle_update(update.clone(), MatchResult::Empty).await;
assert_eq!(conv.get_state(&key).await, None);
}
#[test]
fn channel_post_returns_none() {
let conv = ConversationHandler::<TestState>::builder()
.entry_point(make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskName),
))
.build();
let channel_update = make_channel_post_update();
assert!(
conv.check_update(&channel_update).is_none(),
"Channel posts must be rejected by ConversationHandler"
);
}
#[tokio::test]
async fn persistence_load_save_roundtrip() {
let conv = ConversationHandler::<TestState>::builder()
.entry_point(make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskName),
))
.state(
TestState::AskName,
vec![make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskAge),
)],
)
.name("test_conv".to_string())
.persistent(true)
.build();
let mut data = HashMap::new();
data.insert(vec![1i64, 2i64], TestState::AskAge);
data.insert(vec![3i64, 4i64], TestState::AskName);
conv.load_conversations(data).await;
assert_eq!(
conv.get_state(&vec![1i64, 2i64]).await,
Some(TestState::AskAge)
);
assert_eq!(
conv.get_state(&vec![3i64, 4i64]).await,
Some(TestState::AskName)
);
let saved = conv.save_conversations().await;
assert_eq!(saved.len(), 2);
assert_eq!(saved.get(&vec![1i64, 2i64]), Some(&TestState::AskAge));
}
#[test]
fn builder_name_and_persistence() {
let conv = ConversationHandler::<TestState>::builder()
.entry_point(make_step(
always_match_handler(),
ConversationResult::NextState(TestState::AskName),
))
.name("my_conv".to_string())
.persistent(true)
.build();
assert!(conv.is_persistent());
assert_eq!(conv.name(), Some("my_conv"));
}
#[test]
#[should_panic(expected = "At least one of per_chat, per_user, per_message must be true")]
fn builder_panics_without_key_components() {
ConversationHandler::<TestState>::builder()
.per_chat(false)
.per_user(false)
.per_message(false)
.build();
}
#[test]
#[should_panic(expected = "Conversations can't be persistent when handler is unnamed")]
fn builder_panics_persistent_without_name() {
ConversationHandler::<TestState>::builder()
.persistent(true)
.build();
}
}