use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use serde_json::Value;
use tokio::sync::RwLock;
use rust_tg_bot_raw::bot::MessageOrBool;
use rust_tg_bot_raw::types::files::input_file::InputFile;
use rust_tg_bot_raw::types::update::Update;
use crate::context_types::DefaultData;
use crate::ext_bot::ExtBot;
#[cfg(feature = "job-queue")]
use crate::job_queue::JobQueue;
pub struct DataReadGuard<'a> {
inner: tokio::sync::RwLockReadGuard<'a, DefaultData>,
}
impl<'a> DataReadGuard<'a> {
#[must_use]
pub fn get_str(&self, key: &str) -> Option<&str> {
self.inner.get(key).and_then(|v| v.as_str())
}
#[must_use]
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.inner.get(key).and_then(|v| v.as_i64())
}
#[must_use]
pub fn get_f64(&self, key: &str) -> Option<f64> {
self.inner.get(key).and_then(|v| v.as_f64())
}
#[must_use]
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.inner.get(key).and_then(|v| v.as_bool())
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&Value> {
self.inner.get(key)
}
#[must_use]
pub fn get_id_set(&self, key: &str) -> HashSet<i64> {
self.inner
.get(key)
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
.unwrap_or_default()
}
#[must_use]
pub fn raw(&self) -> &DefaultData {
&self.inner
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
}
impl std::ops::Deref for DataReadGuard<'_> {
type Target = DefaultData;
fn deref(&self) -> &DefaultData {
&self.inner
}
}
pub struct DataWriteGuard<'a> {
inner: tokio::sync::RwLockWriteGuard<'a, DefaultData>,
}
impl<'a> DataWriteGuard<'a> {
#[must_use]
pub fn get_str(&self, key: &str) -> Option<&str> {
self.inner.get(key).and_then(|v| v.as_str())
}
#[must_use]
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.inner.get(key).and_then(|v| v.as_i64())
}
#[must_use]
pub fn get_f64(&self, key: &str) -> Option<f64> {
self.inner.get(key).and_then(|v| v.as_f64())
}
#[must_use]
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.inner.get(key).and_then(|v| v.as_bool())
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&Value> {
self.inner.get(key)
}
#[must_use]
pub fn get_id_set(&self, key: &str) -> HashSet<i64> {
self.inner
.get(key)
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
.unwrap_or_default()
}
pub fn set_str(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.inner.insert(key.into(), Value::String(value.into()));
}
pub fn set_i64(&mut self, key: impl Into<String>, value: i64) {
self.inner.insert(key.into(), Value::Number(value.into()));
}
pub fn set_bool(&mut self, key: impl Into<String>, value: bool) {
self.inner.insert(key.into(), Value::Bool(value));
}
pub fn insert(&mut self, key: String, value: Value) -> Option<Value> {
self.inner.insert(key, value)
}
pub fn add_to_id_set(&mut self, key: &str, id: i64) {
let entry = self
.inner
.entry(key.to_owned())
.or_insert_with(|| Value::Array(vec![]));
if let Some(arr) = entry.as_array_mut() {
let val = Value::Number(id.into());
if !arr.contains(&val) {
arr.push(val);
}
}
}
pub fn remove_from_id_set(&mut self, key: &str, id: i64) {
if let Some(arr) = self.inner.get_mut(key).and_then(|v| v.as_array_mut()) {
arr.retain(|v| v.as_i64() != Some(id));
}
}
#[must_use]
pub fn raw(&self) -> &DefaultData {
&self.inner
}
pub fn raw_mut(&mut self) -> &mut DefaultData {
&mut self.inner
}
pub fn entry(&mut self, key: String) -> std::collections::hash_map::Entry<'_, String, Value> {
self.inner.entry(key)
}
pub fn get_mut(&mut self, key: &str) -> Option<&mut Value> {
self.inner.get_mut(key)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn remove(&mut self, key: &str) -> Option<Value> {
self.inner.remove(key)
}
}
impl std::ops::Deref for DataWriteGuard<'_> {
type Target = DefaultData;
fn deref(&self) -> &DefaultData {
&self.inner
}
}
impl std::ops::DerefMut for DataWriteGuard<'_> {
fn deref_mut(&mut self) -> &mut DefaultData {
&mut self.inner
}
}
#[derive(Debug, Clone)]
pub struct CallbackContext {
bot: Arc<ExtBot>,
chat_id: Option<i64>,
user_id: Option<i64>,
user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
bot_data: Arc<RwLock<DefaultData>>,
pub matches: Option<Vec<String>>,
pub named_matches: Option<HashMap<String, String>>,
pub args: Option<Vec<String>>,
pub error: Option<Arc<dyn std::error::Error + Send + Sync>>,
extra: Option<HashMap<String, Value>>,
#[cfg(feature = "job-queue")]
pub job_queue: Option<Arc<JobQueue>>,
}
impl CallbackContext {
#[must_use]
pub fn new(
bot: Arc<ExtBot>,
chat_id: Option<i64>,
user_id: Option<i64>,
user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
bot_data: Arc<RwLock<DefaultData>>,
) -> Self {
Self {
bot,
chat_id,
user_id,
user_data_store,
chat_data_store,
bot_data,
matches: None,
named_matches: None,
args: None,
error: None,
extra: None,
#[cfg(feature = "job-queue")]
job_queue: None,
}
}
#[must_use]
pub fn from_update(
update: &Update,
bot: Arc<ExtBot>,
user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
bot_data: Arc<RwLock<DefaultData>>,
) -> Self {
let (chat_id, user_id) = extract_ids(update);
Self::new(
bot,
chat_id,
user_id,
user_data_store,
chat_data_store,
bot_data,
)
}
#[must_use]
pub fn from_error(
update: Option<&Update>,
error: Arc<dyn std::error::Error + Send + Sync>,
bot: Arc<ExtBot>,
user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
bot_data: Arc<RwLock<DefaultData>>,
) -> Self {
let (chat_id, user_id) = update.map_or((None, None), extract_ids);
let mut ctx = Self::new(
bot,
chat_id,
user_id,
user_data_store,
chat_data_store,
bot_data,
);
ctx.error = Some(error);
ctx
}
#[must_use]
pub fn bot(&self) -> &Arc<ExtBot> {
&self.bot
}
#[must_use]
pub fn chat_id(&self) -> Option<i64> {
self.chat_id
}
#[must_use]
pub fn user_id(&self) -> Option<i64> {
self.user_id
}
pub async fn bot_data(&self) -> DataReadGuard<'_> {
DataReadGuard {
inner: self.bot_data.read().await,
}
}
pub async fn bot_data_mut(&self) -> DataWriteGuard<'_> {
DataWriteGuard {
inner: self.bot_data.write().await,
}
}
pub async fn user_data(&self) -> Option<DefaultData> {
let uid = self.user_id?;
let store = self.user_data_store.read().await;
store.get(&uid).cloned()
}
pub async fn chat_data(&self) -> Option<DefaultData> {
let cid = self.chat_id?;
let store = self.chat_data_store.read().await;
store.get(&cid).cloned()
}
pub async fn set_user_data(&self, key: String, value: Value) -> bool {
let uid = match self.user_id {
Some(id) => id,
None => return false,
};
let mut store = self.user_data_store.write().await;
store
.entry(uid)
.or_insert_with(HashMap::new)
.insert(key, value);
true
}
pub async fn set_chat_data(&self, key: String, value: Value) -> bool {
let cid = match self.chat_id {
Some(id) => id,
None => return false,
};
let mut store = self.chat_data_store.write().await;
store
.entry(cid)
.or_insert_with(HashMap::new)
.insert(key, value);
true
}
#[must_use]
pub fn match_result(&self) -> Option<&str> {
self.matches
.as_ref()
.and_then(|m| m.first().map(String::as_str))
}
#[must_use]
pub fn extra(&self) -> Option<&HashMap<String, Value>> {
self.extra.as_ref()
}
pub fn extra_mut(&mut self) -> &mut HashMap<String, Value> {
self.extra.get_or_insert_with(HashMap::new)
}
pub fn set_extra(&mut self, key: String, value: Value) {
self.extra
.get_or_insert_with(HashMap::new)
.insert(key, value);
}
#[must_use]
pub fn get_extra(&self, key: &str) -> Option<&Value> {
self.extra.as_ref().and_then(|m| m.get(key))
}
pub async fn drop_callback_data(
&self,
callback_query_id: &str,
) -> Result<(), crate::callback_data_cache::InvalidCallbackData> {
let cache = self.bot.callback_data_cache().ok_or(
crate::callback_data_cache::InvalidCallbackData {
callback_data: None,
},
)?;
let mut guard = cache.write().await;
guard.drop_data(callback_query_id)
}
#[cfg(feature = "job-queue")]
pub fn with_job_queue(mut self, jq: Arc<JobQueue>) -> Self {
self.job_queue = Some(jq);
self
}
pub async fn reply_text(
&self,
update: &Update,
text: &str,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot().send_message(chat_id, text).await
}
pub async fn reply_html(
&self,
update: &Update,
text: &str,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot()
.send_message(chat_id, text)
.parse_mode("HTML")
.await
}
pub async fn reply_markdown_v2(
&self,
update: &Update,
text: &str,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot()
.send_message(chat_id, text)
.parse_mode("MarkdownV2")
.await
}
pub async fn reply_photo(
&self,
update: &Update,
photo: InputFile,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot().send_photo(chat_id, photo).await
}
pub async fn reply_document(
&self,
update: &Update,
document: InputFile,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot().send_document(chat_id, document).await
}
pub async fn reply_sticker(
&self,
update: &Update,
sticker: InputFile,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot().send_sticker(chat_id, sticker).await
}
pub async fn reply_location(
&self,
update: &Update,
latitude: f64,
longitude: f64,
) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
{
let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
})?;
self.bot().send_location(chat_id, latitude, longitude).await
}
pub async fn answer_callback_query(
&self,
update: &Update,
) -> Result<bool, rust_tg_bot_raw::error::TelegramError> {
let cq = update.callback_query().ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No callback query in update".into())
})?;
self.bot().answer_callback_query(&cq.id).await
}
pub async fn edit_callback_message_text(
&self,
update: &Update,
text: &str,
) -> Result<MessageOrBool, rust_tg_bot_raw::error::TelegramError> {
let cq = update.callback_query().ok_or_else(|| {
rust_tg_bot_raw::error::TelegramError::Network("No callback query in update".into())
})?;
if let Some(msg) = cq.message.as_deref() {
self.bot()
.edit_message_text(text)
.chat_id(msg.chat().id)
.message_id(msg.message_id())
.await
} else if let Some(ref iid) = cq.inline_message_id {
self.bot()
.edit_message_text(text)
.inline_message_id(iid)
.await
} else {
Err(rust_tg_bot_raw::error::TelegramError::Network(
"No message in callback query".into(),
))
}
}
}
fn extract_ids(update: &Update) -> (Option<i64>, Option<i64>) {
let chat_id = update.effective_chat().map(|c| c.id);
let user_id = update.effective_user().map(|u| u.id);
(chat_id, user_id)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ext_bot::test_support::mock_request;
use rust_tg_bot_raw::bot::Bot;
fn make_bot() -> Arc<ExtBot> {
let bot = Bot::new("test", mock_request());
Arc::new(ExtBot::from_bot(bot))
}
fn make_stores() -> (
Arc<RwLock<HashMap<i64, DefaultData>>>,
Arc<RwLock<HashMap<i64, DefaultData>>>,
Arc<RwLock<DefaultData>>,
) {
(
Arc::new(RwLock::new(HashMap::new())),
Arc::new(RwLock::new(HashMap::new())),
Arc::new(RwLock::new(HashMap::new())),
)
}
fn make_update(json_val: serde_json::Value) -> Update {
serde_json::from_value(json_val).unwrap()
}
#[test]
fn context_basic_creation() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot.clone(), Some(42), Some(7), ud, cd, bd);
assert_eq!(ctx.chat_id(), Some(42));
assert_eq!(ctx.user_id(), Some(7));
assert!(ctx.error.is_none());
assert!(ctx.args.is_none());
assert!(ctx.matches.is_none());
assert!(ctx.named_matches.is_none());
#[cfg(feature = "job-queue")]
assert!(ctx.job_queue.is_none());
}
#[test]
fn extract_ids_from_message_update() {
let update = make_update(
serde_json::json!({"update_id": 1, "message": {"message_id": 1, "date": 0, "chat": {"id": 100, "type": "private"}, "from": {"id": 200, "is_bot": false, "first_name": "Test"}}}),
);
let (chat_id, user_id) = extract_ids(&update);
assert_eq!(chat_id, Some(100));
assert_eq!(user_id, Some(200));
}
#[test]
fn extract_ids_from_callback_query() {
let update = make_update(
serde_json::json!({"update_id": 2, "callback_query": {"id": "abc", "from": {"id": 300, "is_bot": false, "first_name": "U"}, "chat_instance": "ci", "message": {"message_id": 5, "date": 0, "chat": {"id": 400, "type": "group"}}}}),
);
let (chat_id, user_id) = extract_ids(&update);
assert_eq!(chat_id, Some(400));
assert_eq!(user_id, Some(300));
}
#[test]
fn extract_ids_returns_none_for_empty() {
let update = make_update(serde_json::json!({"update_id": 3}));
let (chat_id, user_id) = extract_ids(&update);
assert!(chat_id.is_none());
assert!(user_id.is_none());
}
#[test]
fn from_update_factory() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let update = make_update(
serde_json::json!({"update_id": 1, "message": {"message_id": 1, "date": 0, "chat": {"id": 10, "type": "private"}, "from": {"id": 20, "is_bot": false, "first_name": "T"}}}),
);
let ctx = CallbackContext::from_update(&update, bot, ud, cd, bd);
assert_eq!(ctx.chat_id(), Some(10));
assert_eq!(ctx.user_id(), Some(20));
}
#[test]
fn from_error_factory() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let err: Arc<dyn std::error::Error + Send + Sync> =
Arc::new(std::io::Error::new(std::io::ErrorKind::Other, "boom"));
let ctx = CallbackContext::from_error(None, err, bot, ud, cd, bd);
assert!(ctx.error.is_some());
assert!(ctx.chat_id().is_none());
}
#[tokio::test]
async fn bot_data_access() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
{
let mut guard = ctx.bot_data_mut().await;
guard.insert("key".into(), Value::String("val".into()));
}
let guard = ctx.bot_data().await;
assert_eq!(guard.get("key"), Some(&Value::String("val".into())));
}
#[tokio::test]
async fn user_data_returns_none_without_user_id() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
assert!(ctx.user_data().await.is_none());
}
#[tokio::test]
async fn chat_data_returns_none_without_chat_id() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
assert!(ctx.chat_data().await.is_none());
}
#[tokio::test]
async fn set_user_data_works() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, Some(42), ud.clone(), cd, bd);
assert!(
ctx.set_user_data("score".into(), Value::Number(100.into()))
.await
);
let store = ud.read().await;
assert_eq!(
store.get(&42).unwrap().get("score"),
Some(&Value::Number(100.into()))
);
}
#[tokio::test]
async fn set_chat_data_works() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, Some(10), None, ud, cd.clone(), bd);
assert!(
ctx.set_chat_data("topic".into(), Value::String("rust".into()))
.await
);
let store = cd.read().await;
assert_eq!(
store.get(&10).unwrap().get("topic"),
Some(&Value::String("rust".into()))
);
}
#[tokio::test]
async fn set_user_data_returns_false_without_user_id() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
assert!(!ctx.set_user_data("k".into(), Value::Null).await);
}
#[test]
fn match_result_shortcut() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let mut ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
assert!(ctx.match_result().is_none());
ctx.matches = Some(vec!["hello".into(), "world".into()]);
assert_eq!(ctx.match_result(), Some("hello"));
}
#[test]
fn extra_is_lazily_initialized() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let mut ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
assert!(ctx.extra().is_none());
assert!(ctx.get_extra("missing").is_none());
ctx.extra_mut()
.insert("count".into(), Value::Number(1.into()));
assert_eq!(ctx.get_extra("count"), Some(&Value::Number(1.into())));
ctx.set_extra("name".into(), Value::String("Alice".into()));
assert_eq!(
ctx.extra().and_then(|extra| extra.get("name")),
Some(&Value::String("Alice".into()))
);
}
#[cfg(feature = "job-queue")]
#[test]
fn with_job_queue() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
let jq = Arc::new(JobQueue::new());
let ctx = ctx.with_job_queue(jq.clone());
assert!(ctx.job_queue.is_some());
}
#[tokio::test]
async fn data_write_guard_typed_setters() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
{
let mut guard = ctx.bot_data_mut().await;
guard.set_str("name", "Alice");
guard.set_i64("score", 42);
guard.set_bool("active", true);
}
let guard = ctx.bot_data().await;
assert_eq!(guard.get_str("name"), Some("Alice"));
assert_eq!(guard.get_i64("score"), Some(42));
assert_eq!(guard.get_bool("active"), Some(true));
}
#[tokio::test]
async fn data_write_guard_id_set_operations() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
{
let mut guard = ctx.bot_data_mut().await;
guard.add_to_id_set("user_ids", 100);
guard.add_to_id_set("user_ids", 200);
guard.add_to_id_set("user_ids", 100); }
let guard = ctx.bot_data().await;
let ids = guard.get_id_set("user_ids");
assert_eq!(ids.len(), 2);
assert!(ids.contains(&100));
assert!(ids.contains(&200));
drop(guard);
{
let mut guard = ctx.bot_data_mut().await;
guard.remove_from_id_set("user_ids", 100);
}
let guard = ctx.bot_data().await;
let ids = guard.get_id_set("user_ids");
assert_eq!(ids.len(), 1);
assert!(ids.contains(&200));
}
#[tokio::test]
async fn data_read_guard_empty_id_set() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
let guard = ctx.bot_data().await;
let ids = guard.get_id_set("nonexistent");
assert!(ids.is_empty());
}
#[tokio::test]
async fn data_guard_deref_to_hashmap() {
let bot = make_bot();
let (ud, cd, bd) = make_stores();
let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
{
let mut guard = ctx.bot_data_mut().await;
guard.set_str("key", "val");
}
let guard = ctx.bot_data().await;
assert!(guard.contains_key("key"));
assert_eq!(guard.len(), 1);
}
}