use std::time::Duration;
use serde_json::Value;
use tracing::debug;
use crate::error::{Result, TelegramError};
use super::request_data::RequestData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HttpMethod {
Post,
Get,
}
impl HttpMethod {
pub fn as_str(self) -> &'static str {
match self {
Self::Post => "POST",
Self::Get => "GET",
}
}
}
impl std::fmt::Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TimeoutOverride {
pub connect: Option<Option<Duration>>,
pub read: Option<Option<Duration>>,
pub write: Option<Option<Duration>>,
pub pool: Option<Option<Duration>>,
}
impl TimeoutOverride {
pub const fn default_none() -> Self {
Self {
connect: None,
read: None,
write: None,
pool: None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ResolvedTimeouts {
pub connect: Option<Duration>,
pub read: Option<Duration>,
pub write: Option<Duration>,
pub pool: Option<Duration>,
}
#[async_trait::async_trait]
pub trait BaseRequest: Send + Sync {
async fn initialize(&self) -> Result<()>;
async fn shutdown(&self) -> Result<()>;
fn default_read_timeout(&self) -> Option<Duration>;
async fn do_request(
&self,
url: &str,
method: HttpMethod,
request_data: Option<&RequestData>,
timeouts: TimeoutOverride,
) -> Result<(u16, bytes::Bytes)>;
async fn do_request_json_bytes(
&self,
url: &str,
body: &[u8],
timeouts: TimeoutOverride,
) -> Result<(u16, bytes::Bytes)>;
async fn post(
&self,
url: &str,
request_data: Option<&RequestData>,
timeouts: TimeoutOverride,
) -> Result<Value> {
let raw = self
.request_wrapper(url, HttpMethod::Post, request_data, timeouts)
.await?;
let json_data = parse_json_payload_impl(&raw)?;
json_data
.get("result")
.cloned()
.ok_or_else(|| TelegramError::Network("Missing 'result' field in API response".into()))
}
async fn post_json(&self, url: &str, body: &[u8], timeouts: TimeoutOverride) -> Result<Value> {
let (code, payload) = self.do_request_json_bytes(url, body, timeouts).await?;
if (200..=299).contains(&code) {
let json_data = parse_json_payload_impl(&payload)?;
return json_data.get("result").cloned().ok_or_else(|| {
TelegramError::Network("Missing 'result' field in API response".into())
});
}
let (message, migrate_chat_id, retry_after, extra_params) =
parse_error_body(&payload, code);
if let Some(new_chat_id) = migrate_chat_id {
return Err(TelegramError::ChatMigrated { new_chat_id });
}
if let Some(secs) = retry_after {
return Err(TelegramError::RetryAfter {
retry_after: Duration::from_secs(secs),
});
}
let full_message = if let Some(params) = extra_params {
format!("{message}. The server response contained unknown parameters: {params}")
} else {
message
};
let err = match code {
403 => TelegramError::Forbidden(full_message),
401 | 404 => TelegramError::InvalidToken(full_message),
400 => TelegramError::BadRequest(full_message),
409 => TelegramError::Conflict(full_message),
_ => TelegramError::Network(full_message),
};
Err(err)
}
async fn retrieve(&self, url: &str, timeouts: TimeoutOverride) -> Result<bytes::Bytes> {
self.request_wrapper(url, HttpMethod::Get, None, timeouts)
.await
}
async fn request_wrapper(
&self,
url: &str,
method: HttpMethod,
request_data: Option<&RequestData>,
timeouts: TimeoutOverride,
) -> Result<bytes::Bytes> {
let (code, payload) = match self.do_request(url, method, request_data, timeouts).await {
Ok(pair) => pair,
Err(e) => return Err(e),
};
if (200..=299).contains(&code) {
return Ok(payload);
}
let (message, migrate_chat_id, retry_after, extra_params) =
parse_error_body(&payload, code);
if let Some(new_chat_id) = migrate_chat_id {
return Err(TelegramError::ChatMigrated { new_chat_id });
}
if let Some(secs) = retry_after {
return Err(TelegramError::RetryAfter {
retry_after: Duration::from_secs(secs),
});
}
let full_message = if let Some(params) = extra_params {
format!("{message}. The server response contained unknown parameters: {params}")
} else {
message
};
let err = match code {
403 => TelegramError::Forbidden(full_message),
401 | 404 => TelegramError::InvalidToken(full_message),
400 => TelegramError::BadRequest(full_message),
409 => TelegramError::Conflict(full_message),
_ => TelegramError::Network(full_message),
};
Err(err)
}
fn parse_json_payload(&self, payload: &[u8]) -> Result<Value> {
parse_json_payload_impl(payload)
}
}
pub fn parse_json_payload_impl(payload: &[u8]) -> Result<Value> {
let text = String::from_utf8_lossy(payload);
serde_json::from_str(&text).map_err(|e| {
debug!("Cannot parse server response as JSON: {e} payload={text:?}");
TelegramError::Network(format!("Invalid server response: {e}"))
})
}
fn parse_error_body(
payload: &[u8],
code: u16,
) -> (String, Option<i64>, Option<u64>, Option<String>) {
let fallback_message = http_status_phrase(code);
match parse_json_payload_impl(payload) {
Err(_) => {
let raw = String::from_utf8_lossy(payload);
let msg = format!("{fallback_message}. Parsing the server response {raw:?} failed");
(msg, None, None, None)
}
Ok(body) => {
let description = body
.get("description")
.and_then(Value::as_str)
.map(str::to_owned)
.unwrap_or(fallback_message);
let parameters = body.get("parameters");
let migrate_to_chat_id = parameters
.and_then(|p| p.get("migrate_to_chat_id"))
.and_then(Value::as_i64);
let retry_after = parameters
.and_then(|p| p.get("retry_after"))
.and_then(Value::as_u64);
let extra = parameters.and_then(|p| {
if let Value::Object(map) = p {
let unknown: serde_json::Map<String, Value> = map
.iter()
.filter(|(k, _)| {
k.as_str() != "migrate_to_chat_id" && k.as_str() != "retry_after"
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if unknown.is_empty() {
None
} else {
Some(Value::Object(unknown).to_string())
}
} else {
None
}
});
(description, migrate_to_chat_id, retry_after, extra)
}
}
}
fn http_status_phrase(code: u16) -> String {
let phrase = match code {
200 => "OK",
201 => "Created",
204 => "No Content",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
409 => "Conflict",
420 => "Enhance Your Calm",
429 => "Too Many Requests",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
504 => "Gateway Timeout",
_ => "Unknown HTTP Error",
};
format!("{phrase} ({code})")
}
pub use async_trait::async_trait;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn http_method_as_str() {
assert_eq!(HttpMethod::Post.as_str(), "POST");
assert_eq!(HttpMethod::Get.as_str(), "GET");
}
#[test]
fn http_method_display() {
assert_eq!(HttpMethod::Post.to_string(), "POST");
}
#[test]
fn parse_valid_json() {
let raw = br#"{"ok":true,"result":{"id":1}}"#;
let v = parse_json_payload_impl(raw).unwrap();
assert_eq!(v["ok"], true);
assert_eq!(v["result"]["id"], 1);
}
#[test]
fn parse_invalid_json_returns_network_error() {
let raw = b"not json {{";
let err = parse_json_payload_impl(raw).unwrap_err();
assert!(
matches!(err, TelegramError::Network(_)),
"expected Network, got {err:?}"
);
}
#[test]
fn parse_invalid_utf8_with_replacement() {
let raw = b"\xff\xfe{\"ok\":true}";
let _ = parse_json_payload_impl(raw);
}
#[test]
fn parse_error_body_extracts_description() {
let body = br#"{"ok":false,"error_code":400,"description":"Bad Request: chat not found"}"#;
let (msg, migrate, retry, extra) = parse_error_body(body, 400);
assert_eq!(msg, "Bad Request: chat not found");
assert!(migrate.is_none());
assert!(retry.is_none());
assert!(extra.is_none());
}
#[test]
fn parse_error_body_migrate_chat_id() {
let body = br#"{"ok":false,"error_code":400,"description":"...","parameters":{"migrate_to_chat_id":-1001234567}}"#;
let (_, migrate, _, _) = parse_error_body(body, 400);
assert_eq!(migrate, Some(-1_001_234_567_i64));
}
#[test]
fn parse_error_body_retry_after() {
let body = br#"{"ok":false,"error_code":429,"description":"Too Many Requests","parameters":{"retry_after":30}}"#;
let (_, _, retry, _) = parse_error_body(body, 429);
assert_eq!(retry, Some(30));
}
#[test]
fn parse_error_body_invalid_json() {
let body = b"<html>502 Bad Gateway</html>";
let (msg, _, _, _) = parse_error_body(body, 502);
assert!(msg.contains("Parsing the server response"), "got: {msg}");
}
#[test]
fn parse_error_body_unknown_parameters() {
let body = br#"{"ok":false,"description":"err","parameters":{"some_future_field":1}}"#;
let (msg, _, _, extra) = parse_error_body(body, 400);
assert_eq!(msg, "err");
assert!(extra.is_some(), "expected extra params, got none");
}
#[test]
fn known_status_codes() {
assert!(http_status_phrase(400).contains("Bad Request"));
assert!(http_status_phrase(403).contains("Forbidden"));
assert!(http_status_phrase(409).contains("Conflict"));
assert!(http_status_phrase(502).contains("Bad Gateway"));
}
#[test]
fn unknown_status_code() {
assert!(http_status_phrase(418).contains("Unknown HTTP Error"));
}
}