use std::sync::Arc;
use tokio::sync::RwLock;
use rust_tg_bot_raw::bot::Bot;
use rust_tg_bot_raw::request::base::BaseRequest;
use crate::callback_data_cache::CallbackDataCache;
use crate::defaults::Defaults;
#[cfg(feature = "rate-limiter")]
use crate::rate_limiter::{DynRateLimiter, RateLimitedRequest};
pub struct ExtBot {
bot: Bot,
defaults: Option<Defaults>,
callback_data_cache: Option<Arc<RwLock<CallbackDataCache>>>,
#[cfg(feature = "rate-limiter")]
rate_limiter: Option<Arc<dyn DynRateLimiter>>,
#[cfg(not(feature = "rate-limiter"))]
rate_limiter: Option<()>,
}
impl std::ops::Deref for ExtBot {
type Target = Bot;
fn deref(&self) -> &Bot {
&self.bot
}
}
impl std::fmt::Debug for ExtBot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtBot")
.field("token", &self.bot.token())
.field("defaults", &self.defaults)
.field(
"has_callback_data_cache",
&self.callback_data_cache.is_some(),
)
.field("has_rate_limiter", &self.rate_limiter.is_some())
.finish()
}
}
impl ExtBot {
#[cfg(feature = "rate-limiter")]
#[must_use]
pub(crate) fn new(
bot: Bot,
defaults: Option<Defaults>,
arbitrary_callback_data: Option<usize>,
rate_limiter: Option<Arc<dyn DynRateLimiter>>,
) -> Self {
let callback_data_cache = arbitrary_callback_data.map(|maxsize| {
let effective = if maxsize == 0 { 1024 } else { maxsize };
Arc::new(RwLock::new(CallbackDataCache::new(effective)))
});
Self {
bot,
defaults,
callback_data_cache,
rate_limiter,
}
}
#[cfg(not(feature = "rate-limiter"))]
#[must_use]
pub(crate) fn new(
bot: Bot,
defaults: Option<Defaults>,
arbitrary_callback_data: Option<usize>,
rate_limiter: Option<()>,
) -> Self {
let callback_data_cache = arbitrary_callback_data.map(|maxsize| {
let effective = if maxsize == 0 { 1024 } else { maxsize };
Arc::new(RwLock::new(CallbackDataCache::new(effective)))
});
Self {
bot,
defaults,
callback_data_cache,
rate_limiter,
}
}
#[must_use]
pub fn from_bot(bot: Bot) -> Self {
Self::new(bot, None, None, None)
}
#[must_use]
pub fn inner(&self) -> &Bot {
&self.bot
}
#[must_use]
pub fn token(&self) -> &str {
self.bot.token()
}
#[must_use]
pub fn defaults(&self) -> Option<&Defaults> {
self.defaults.as_ref()
}
#[must_use]
pub fn callback_data_cache(&self) -> Option<&Arc<RwLock<CallbackDataCache>>> {
self.callback_data_cache.as_ref()
}
#[must_use]
pub fn has_callback_data_cache(&self) -> bool {
self.callback_data_cache.is_some()
}
#[must_use]
pub fn has_rate_limiter(&self) -> bool {
self.rate_limiter.is_some()
}
#[cfg(feature = "rate-limiter")]
#[must_use]
pub fn rate_limiter(&self) -> Option<&Arc<dyn DynRateLimiter>> {
self.rate_limiter.as_ref()
}
#[cfg(not(feature = "rate-limiter"))]
#[must_use]
pub fn rate_limiter(&self) -> Option<()> {
self.rate_limiter
}
#[must_use]
pub fn builder(token: impl Into<String>, request: Arc<dyn BaseRequest>) -> ExtBotBuilder {
ExtBotBuilder::new(token, request)
}
pub async fn initialize(&self) -> rust_tg_bot_raw::error::Result<()> {
#[cfg(feature = "rate-limiter")]
if let Some(ref rl) = self.rate_limiter {
rl.initialize().await;
}
Ok(())
}
pub async fn shutdown(&self) -> rust_tg_bot_raw::error::Result<()> {
#[cfg(feature = "rate-limiter")]
if let Some(ref rl) = self.rate_limiter {
rl.shutdown().await;
}
Ok(())
}
}
pub struct ExtBotBuilder {
token: String,
request: Arc<dyn BaseRequest>,
base_url: Option<String>,
base_file_url: Option<String>,
defaults: Option<Defaults>,
arbitrary_callback_data: Option<usize>,
#[cfg(feature = "rate-limiter")]
rate_limiter: Option<Arc<dyn DynRateLimiter>>,
#[cfg(not(feature = "rate-limiter"))]
rate_limiter: Option<()>,
}
impl ExtBotBuilder {
#[must_use]
pub fn new(token: impl Into<String>, request: Arc<dyn BaseRequest>) -> Self {
Self {
token: token.into(),
request,
base_url: None,
base_file_url: None,
defaults: None,
arbitrary_callback_data: None,
rate_limiter: None,
}
}
#[must_use]
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
#[must_use]
pub fn base_file_url(mut self, url: impl Into<String>) -> Self {
self.base_file_url = Some(url.into());
self
}
#[must_use]
pub fn defaults(mut self, defaults: Defaults) -> Self {
self.defaults = Some(defaults);
self
}
#[must_use]
pub fn arbitrary_callback_data(mut self, maxsize: usize) -> Self {
self.arbitrary_callback_data = Some(maxsize);
self
}
#[cfg(feature = "rate-limiter")]
#[must_use]
pub fn rate_limiter(mut self, rl: Arc<dyn DynRateLimiter>) -> Self {
self.rate_limiter = Some(rl);
self
}
#[cfg(not(feature = "rate-limiter"))]
#[must_use]
pub fn rate_limiter(mut self, _rl: ()) -> Self {
self.rate_limiter = Some(());
self
}
#[must_use]
pub fn build(self) -> ExtBot {
#[cfg(feature = "rate-limiter")]
let (request, rate_limiter) = if let Some(ref rl) = self.rate_limiter {
let wrapped: Arc<dyn BaseRequest> =
Arc::new(RateLimitedRequest::new(self.request.clone(), rl.clone()));
(wrapped, self.rate_limiter)
} else {
(self.request, None)
};
#[cfg(not(feature = "rate-limiter"))]
let (request, rate_limiter) = (self.request, self.rate_limiter);
let bot = Bot::new(&self.token, request);
ExtBot::new(
bot,
self.defaults,
self.arbitrary_callback_data,
rate_limiter,
)
}
}
impl std::fmt::Debug for ExtBotBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtBotBuilder")
.field("token", &"[REDACTED]")
.field("has_rate_limiter", &self.rate_limiter.is_some())
.finish()
}
}
#[cfg(test)]
pub(crate) mod test_support {
use std::time::Duration;
use rust_tg_bot_raw::request::base::{HttpMethod, TimeoutOverride};
use rust_tg_bot_raw::request::request_data::RequestData;
use super::*;
#[derive(Debug)]
pub struct MockRequest;
#[async_trait::async_trait]
impl BaseRequest for MockRequest {
async fn initialize(&self) -> rust_tg_bot_raw::error::Result<()> {
Ok(())
}
async fn shutdown(&self) -> rust_tg_bot_raw::error::Result<()> {
Ok(())
}
fn default_read_timeout(&self) -> Option<Duration> {
Some(Duration::from_secs(5))
}
async fn do_request(
&self,
_url: &str,
_method: HttpMethod,
_request_data: Option<&RequestData>,
_timeouts: TimeoutOverride,
) -> rust_tg_bot_raw::error::Result<(u16, bytes::Bytes)> {
let body = br#"{"ok":true,"result":[]}"#;
Ok((200, bytes::Bytes::from_static(body)))
}
async fn do_request_json_bytes(
&self,
_url: &str,
_body: &[u8],
_timeouts: TimeoutOverride,
) -> rust_tg_bot_raw::error::Result<(u16, bytes::Bytes)> {
let body = br#"{"ok":true,"result":[]}"#;
Ok((200, bytes::Bytes::from_static(body)))
}
}
pub fn mock_request() -> Arc<dyn BaseRequest> {
Arc::new(MockRequest)
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_support::*;
#[test]
fn ext_bot_creation() {
let bot = Bot::new("test_token", mock_request());
let ext = ExtBot::from_bot(bot);
assert_eq!(ext.token(), "test_token");
assert!(ext.defaults().is_none());
assert!(!ext.has_callback_data_cache());
assert!(!ext.has_rate_limiter());
}
#[test]
fn ext_bot_with_callback_cache() {
let bot = Bot::new("token", mock_request());
let ext = ExtBot::new(bot, None, Some(512), None);
assert!(ext.has_callback_data_cache());
}
#[test]
fn ext_bot_with_defaults() {
let defaults = Defaults::builder().parse_mode("HTML").build();
let bot = Bot::new("token", mock_request());
let ext = ExtBot::new(bot, Some(defaults), None, None);
assert_eq!(ext.defaults().unwrap().parse_mode(), Some("HTML"));
}
#[test]
fn ext_bot_builder() {
let ext = ExtBot::builder("my_token", mock_request())
.arbitrary_callback_data(256)
.build();
assert_eq!(ext.token(), "my_token");
assert!(ext.has_callback_data_cache());
}
#[tokio::test]
async fn ext_bot_lifecycle() {
let bot = Bot::new("token", mock_request());
let ext = ExtBot::from_bot(bot);
assert!(ext.initialize().await.is_ok());
assert!(ext.shutdown().await.is_ok());
}
#[test]
fn ext_bot_debug() {
let bot = Bot::new("token", mock_request());
let ext = ExtBot::from_bot(bot);
let s = format!("{ext:?}");
assert!(s.contains("ExtBot"));
assert!(s.contains("token"));
}
#[test]
fn ext_bot_from_bot_convenience() {
let bot = Bot::new("tk", mock_request());
let ext = ExtBot::from_bot(bot);
assert_eq!(ext.token(), "tk");
assert!(ext.defaults().is_none());
assert!(!ext.has_callback_data_cache());
assert!(!ext.has_rate_limiter());
}
#[test]
fn ext_bot_deref_provides_bot_methods() {
let bot = Bot::new("deref_token", mock_request());
let ext = ExtBot::from_bot(bot);
let deref_token: &str = (*ext).token();
assert_eq!(deref_token, "deref_token");
assert_eq!(ext.token(), deref_token);
}
#[cfg(feature = "rate-limiter")]
#[test]
fn ext_bot_builder_with_rate_limiter() {
use crate::rate_limiter::NoRateLimiter;
let limiter: Arc<dyn DynRateLimiter> = Arc::new(NoRateLimiter);
let ext = ExtBot::builder("rl_token", mock_request())
.rate_limiter(limiter)
.build();
assert_eq!(ext.token(), "rl_token");
assert!(ext.has_rate_limiter());
assert!(ext.rate_limiter().is_some());
}
#[cfg(feature = "rate-limiter")]
#[test]
fn ext_bot_builder_without_rate_limiter() {
let ext = ExtBot::builder("no_rl", mock_request()).build();
assert!(!ext.has_rate_limiter());
assert!(ext.rate_limiter().is_none());
}
#[cfg(feature = "rate-limiter")]
#[tokio::test]
async fn ext_bot_lifecycle_with_rate_limiter() {
use crate::rate_limiter::NoRateLimiter;
let limiter: Arc<dyn DynRateLimiter> = Arc::new(NoRateLimiter);
let ext = ExtBot::builder("rl_lc", mock_request())
.rate_limiter(limiter)
.build();
assert!(ext.initialize().await.is_ok());
assert!(ext.shutdown().await.is_ok());
}
}