1#[cfg(feature = "grpc")]
2pub mod grpc;
3#[cfg(feature = "longpoll")]
4pub mod longpoll;
5pub mod net;
6#[cfg(feature = "ratelimit")]
7pub mod ratelimit;
8#[cfg(feature = "webhook")]
9pub mod webhook;
10
11use crate::api::types::*;
12#[cfg(feature = "ratelimit")]
13use crate::bot::ratelimit::RateLimiter;
14use crate::error::{BotError, Result};
15use net::ConnectionPool;
16use net::*;
17use once_cell::sync::OnceCell;
18use reqwest::Url;
19use serde::Serialize;
20use std::fmt;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicU32, Ordering};
23#[cfg(feature = "ratelimit")]
24use tokio::sync::Mutex;
25use tracing::debug;
26
27#[derive(Clone)]
28pub struct Bot {
38 pub(crate) connection_pool: OnceCell<ConnectionPool>,
39 pub(crate) token: Arc<str>,
40 pub(crate) base_api_url: Url,
41 pub(crate) base_api_path: Arc<str>,
42 pub(crate) event_id: Arc<AtomicU32>,
43 #[cfg(feature = "ratelimit")]
44 pub(crate) rate_limiter: OnceCell<Arc<Mutex<RateLimiter>>>,
45}
46
47impl fmt::Debug for Bot {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("Bot")
50 .field("connection_pool", &"<pool>")
51 .field("token", &self.token)
52 .field("base_api_url", &self.base_api_url)
53 .field("base_api_path", &self.base_api_path)
54 .field("event_id", &self.event_id)
55 .finish()
56 }
57}
58
59impl Bot {
60 pub fn new(version: APIVersionUrl) -> Self {
80 debug!("Creating new bot with API version: {:?}", version);
81
82 let token = get_env_token().expect("Failed to get token from environment");
83 debug!("Token successfully obtained from environment");
84
85 let base_api_url = get_env_url().expect("Failed to get API URL from environment");
86 debug!("API URL successfully obtained from environment");
87
88 Self::with_params(&version, token.as_str(), base_api_url.as_str())
89 .expect("Failed to create bot")
90 }
91
92 pub fn with_params(version: &APIVersionUrl, token: &str, api_url: &str) -> Result<Self> {
133 debug!("Creating new bot with API version: {:?}", version);
134 debug!("Using provided token and API URL");
135
136 let base_api_url = Url::parse(api_url).map_err(BotError::Url)?;
137
138 match base_api_url.scheme() {
139 "http" | "https" => {
140 debug!("Base API URL scheme is valid: {}", base_api_url.scheme());
141 }
142 _ => {
143 return Err(BotError::Url(url::ParseError::InvalidIpv4Address));
144 }
145 }
146 debug!("API URL successfully parsed");
147
148 let base_api_path = version.to_string();
149 debug!("Set API base path: {}", base_api_path);
150
151 Ok(Self {
152 connection_pool: OnceCell::new(),
153 token: Arc::<str>::from(token),
154 base_api_url,
155 base_api_path: Arc::<str>::from(base_api_path),
156 event_id: Arc::new(AtomicU32::new(0)),
157 #[cfg(feature = "ratelimit")]
158 rate_limiter: OnceCell::new(),
159 })
160 }
161
162 pub fn get_last_event_id(&self) -> EventId {
164 self.event_id.load(Ordering::Acquire)
165 }
166
167 pub fn set_last_event_id(&self, id: EventId) {
171 self.event_id.store(id, Ordering::Release);
172 }
173
174 pub fn set_path(&self, path: &str) -> String {
177 let mut full_path = self.base_api_path.as_ref().to_string();
178 full_path.push_str(path);
179 full_path
180 }
181
182 pub fn get_parsed_url(&self, path: String, query: String) -> Result<Url> {
191 let mut url = self.base_api_url.clone();
192 url.set_path(&path);
193 url.set_query(Some(&query));
194 url.query_pairs_mut().append_pair("token", &self.token);
195 Ok(url)
196 }
197
198 #[tracing::instrument(skip(self, message))]
216 pub async fn send_api_request<Rq>(&self, message: Rq) -> Result<<Rq>::ResponseType>
217 where
218 Rq: BotRequest + Serialize + std::fmt::Debug,
219 {
220 debug!("Starting send_api_request");
221 #[cfg(feature = "ratelimit")]
223 {
224 if let Some(chat_id) = message.get_chat_id() {
225 let mut rate_limiter = self
226 .rate_limiter
227 .get_or_init(|| Arc::new(Mutex::new(RateLimiter::default())))
228 .lock()
229 .await;
230 if !rate_limiter.wait_if_needed(chat_id).await {
231 return Err(BotError::Validation(
232 "Rate limit exceeded for this chat".to_string(),
233 ));
234 }
235 } else {
236 debug!("No chat_id found in message");
237 }
238 }
239
240 let query = serde_url_params::to_string(&message)?;
241 let url = self.get_parsed_url(self.set_path(<Rq>::METHOD), query.to_owned())?;
242
243 debug!("Request URL: {}", url.path());
244
245 let body = match <Rq>::HTTP_METHOD {
246 HTTPMethod::POST => {
247 debug!(
248 "Sending POST request {:?} {:?}",
249 message,
250 message.get_multipart()
251 );
252 let form = file_to_multipart(message.get_multipart()).await?;
253
254 self.connection_pool
255 .get_or_init(ConnectionPool::optimized)
256 .post_file(url, form)
257 .await?
258 }
259 HTTPMethod::GET => {
260 debug!("Sending GET request");
261 self.connection_pool
262 .get_or_init(ConnectionPool::optimized)
263 .get_text(url)
264 .await?
265 }
266 };
267
268 let response: ApiResponseWrapper<<Rq>::ResponseType> = serde_json::from_str(&body)?;
269 response.into()
270 }
271}
272
273impl Default for Bot {
274 fn default() -> Self {
275 Self::new(APIVersionUrl::V1)
276 }
277}
278
279impl Bot {
280 pub fn with_default_version(token: &str, api_url: &str) -> Result<Self> {
310 Self::with_params(&APIVersionUrl::V1, token, api_url)
311 }
312}
313
314fn get_env_token() -> Result<String> {
315 std::env::var(VKTEAMS_BOT_API_TOKEN).map_err(BotError::from)
316}
317
318fn get_env_url() -> Result<String> {
319 std::env::var(VKTEAMS_BOT_API_URL).map_err(BotError::from)
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use reqwest::Url;
326 use std::sync::Arc;
327
328 #[test]
329 fn test_bot_with_params_valid() {
330 let url = Url::parse("https://example.com/api").unwrap();
331 let token: Arc<str> = Arc::from("test_token");
332 let path: Arc<str> = Arc::from("/api");
333 let event_id = Arc::new(AtomicU32::new(0u32));
334 let bot = Bot {
335 connection_pool: OnceCell::new(),
336 token: token.clone(),
337 base_api_url: url.clone(),
338 base_api_path: path.clone(),
339 event_id: event_id.clone(),
340 #[cfg(feature = "ratelimit")]
341 rate_limiter: OnceCell::new(),
342 };
343 assert_eq!(bot.token.as_ref(), "test_token");
344 assert_eq!(bot.base_api_url, url);
345 assert_eq!(bot.base_api_path.as_ref(), "/api");
346 }
347
348 #[test]
349 fn test_bot_with_params_invalid_url() {
350 let url = Url::parse("");
351 assert!(url.is_err());
352 }
353
354 #[test]
355 fn test_bot_with_default_version_valid() {
356 let url = Url::parse("https://example.com/api").unwrap();
357 let token: Arc<str> = Arc::from("test_token");
358 let bot = Bot {
359 connection_pool: OnceCell::new(),
360 token: token.clone(),
361 base_api_url: url.clone(),
362 base_api_path: Arc::from("/api"),
363 event_id: Arc::new(AtomicU32::new(0u32)),
364 #[cfg(feature = "ratelimit")]
365 rate_limiter: OnceCell::new(),
366 };
367 assert_eq!(bot.token.as_ref(), "test_token");
368 }
369
370 #[test]
371 fn test_bot_with_default_version_invalid_url() {
372 let url = Url::parse("not a url");
373 assert!(url.is_err());
374 }
375
376 #[test]
377 fn test_set_and_get_last_event_id() {
378 let url = Url::parse("https://example.com/api").unwrap();
379 let token: Arc<str> = Arc::from("test_token");
380 let bot = Bot {
381 connection_pool: OnceCell::new(),
382 token: token.clone(),
383 base_api_url: url.clone(),
384 base_api_path: Arc::from("/api"),
385 event_id: Arc::new(AtomicU32::new(0u32)),
386 #[cfg(feature = "ratelimit")]
387 rate_limiter: OnceCell::new(),
388 };
389
390 bot.set_last_event_id(42u32);
392 assert_eq!(bot.get_last_event_id(), 42u32);
393 }
394
395 #[tokio::test]
396 async fn test_get_and_set_last_event_id_sync() {
397 let bot =
398 Bot::with_params(&APIVersionUrl::V1, "test_token", "https://example.com").unwrap();
399
400 assert_eq!(bot.get_last_event_id(), 0);
402
403 bot.set_last_event_id(123);
405 assert_eq!(bot.get_last_event_id(), 123);
406
407 bot.set_last_event_id(456);
409 assert_eq!(bot.get_last_event_id(), 456);
410 }
411
412 #[test]
413 fn test_set_path() {
414 let bot =
415 Bot::with_params(&APIVersionUrl::V1, "test_token", "https://example.com").unwrap();
416
417 let path = bot.set_path("messages/sendText");
418 assert_eq!(path, "bot/v1/messages/sendText");
419
420 let path2 = bot.set_path("chats/getInfo");
421 assert_eq!(path2, "bot/v1/chats/getInfo");
422 }
423
424 #[test]
425 fn test_get_parsed_url_basic() {
426 let bot =
427 Bot::with_params(&APIVersionUrl::V1, "test_token", "https://api.example.com").unwrap();
428
429 let path = "/bot/v1/messages/sendText".to_string();
430 let query = "chatId=test@chat.agent&text=hello".to_string();
431
432 let result = bot.get_parsed_url(path, query);
433 assert!(result.is_ok());
434
435 let url = result.unwrap();
436 assert_eq!(url.scheme(), "https");
437 assert_eq!(url.host_str(), Some("api.example.com"));
438 assert_eq!(url.path(), "/bot/v1/messages/sendText");
439 assert!(url.query().unwrap().contains("token=test_token"));
440 assert!(url.query().unwrap().contains("chatId=test@chat.agent"));
441 assert!(url.query().unwrap().contains("text=hello"));
442 }
443
444 #[test]
445 fn test_get_parsed_url_with_special_chars() {
446 let bot = Bot::with_params(
447 &APIVersionUrl::V1,
448 "special_token",
449 "https://api.example.com",
450 )
451 .unwrap();
452
453 let path = "bot/v1/messages/sendText".to_string();
454 let query = "text=hello world&chatId=test+chat".to_string();
455
456 let result = bot.get_parsed_url(path, query);
457 assert!(result.is_ok());
458
459 let url = result.unwrap();
460 assert!(url.query().unwrap().contains("token=special_token"));
461 }
462
463 #[test]
464 fn test_bot_debug_format() {
465 let bot = Bot::with_params(
466 &APIVersionUrl::V1,
467 "debug_token",
468 "https://debug.example.com",
469 )
470 .unwrap();
471 let debug_str = format!("{bot:?}");
472
473 assert!(debug_str.contains("Bot"));
474 assert!(debug_str.contains("debug_token"));
475 assert!(debug_str.contains("debug.example.com"));
476 assert!(debug_str.contains("<pool>"));
477 }
478
479 #[test]
480 fn test_bot_clone() {
481 let bot1 = Bot::with_params(
482 &APIVersionUrl::V1,
483 "clone_token",
484 "https://clone.example.com",
485 )
486 .unwrap();
487 let bot2 = bot1.clone();
488
489 assert_eq!(bot1.token, bot2.token);
490 assert_eq!(bot1.base_api_url, bot2.base_api_url);
491 assert_eq!(bot1.base_api_path, bot2.base_api_path);
492 }
493
494 #[test]
495 fn test_bot_with_default_version() {
496 let result = Bot::with_default_version("default_token", "https://default.example.com");
497 assert!(result.is_ok());
498
499 let bot = result.unwrap();
500 assert_eq!(bot.token.as_ref(), "default_token");
501 assert_eq!(bot.base_api_path.as_ref(), "bot/v1/");
502 assert_eq!(bot.base_api_url.as_str(), "https://default.example.com/");
503 }
504
505 #[test]
506 fn test_bot_with_params_invalid_urls() {
507 let invalid_urls = [
508 "",
509 "not-a-url",
510 "ftp://invalid-scheme.com",
511 "://missing-scheme.com",
512 ];
513
514 for invalid_url in invalid_urls.iter() {
515 let result = Bot::with_params(&APIVersionUrl::V1, "token", invalid_url);
516 assert!(result.is_err(), "Should fail for URL: {invalid_url}");
517
518 match result.unwrap_err() {
519 BotError::Url(_) => {} _ => panic!("Expected URL error for: {invalid_url}"),
521 }
522 }
523 }
524
525 #[test]
526 fn test_bot_with_empty_token() {
527 let result = Bot::with_params(&APIVersionUrl::V1, "", "https://example.com");
528 assert!(result.is_ok()); let bot = result.unwrap();
531 assert_eq!(bot.token.as_ref(), "");
532 }
533
534 #[tokio::test]
535 async fn test_concurrent_event_id_access() {
536 let bot = Bot::with_params(
537 &APIVersionUrl::V1,
538 "concurrent_token",
539 "https://example.com",
540 )
541 .unwrap();
542
543 let bot_clone = bot.clone();
544 let handle1 = tokio::spawn(async move {
545 for i in 0..100 {
546 bot_clone.set_last_event_id(i);
547 tokio::task::yield_now().await;
548 }
549 });
550
551 let bot_clone2 = bot.clone();
552 let handle2 = tokio::spawn(async move {
553 for _ in 0..100 {
554 let _ = bot_clone2.get_last_event_id();
555 tokio::task::yield_now().await;
556 }
557 });
558
559 let _ = tokio::join!(handle1, handle2);
560
561 }
564}