fluxer-rest 0.3.0

HTTP REST client for the Fluxer bot API
Documentation
use std::sync::Arc;
use std::time::Duration;

use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue, USER_AGENT};
use serde::Serialize;
use serde::de::DeserializeOwned;

use crate::error::{FieldError, FluxerApiError, HttpError, RateLimitError, RestError};
use crate::rate_limit::RateLimitManager;

const DEFAULT_API_URL: &str = "https://api.fluxer.app/v1";
const DEFAULT_USER_AGENT: &str = "FluxerBot (Rust, 0.1)";
const DEFAULT_TIMEOUT_SECS: u64 = 15;
const MAX_RETRIES: u32 = 3;

#[derive(Debug, Clone)]
pub struct RestOptions {
    pub api_url: String,
    pub user_agent: String,
    pub timeout: Duration,
    pub max_retries: u32,
}

impl Default for RestOptions {
    fn default() -> Self {
        Self {
            api_url: DEFAULT_API_URL.to_string(),
            user_agent: DEFAULT_USER_AGENT.to_string(),
            timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
            max_retries: MAX_RETRIES,
        }
    }
}

#[derive(Clone)]
pub struct Rest {
    http: reqwest::Client,
    options: RestOptions,
    token: Arc<tokio::sync::RwLock<Option<String>>>,
    rate_limiter: Arc<RateLimitManager>,
}

impl Rest {
    pub fn new(options: RestOptions) -> Self {
        let http = reqwest::Client::builder()
            .timeout(options.timeout)
            .build()
            .expect("TLS backend available");
        Self {
            http,
            options,
            token: Arc::new(tokio::sync::RwLock::new(None)),
            rate_limiter: Arc::new(RateLimitManager::new()),
        }
    }

    pub async fn set_token(&self, token: impl Into<String>) {
        let raw = token.into();
        let normalized = if raw.starts_with("Bot ") || raw.starts_with("Bearer ") {
            raw
        } else {
            format!("Bot {raw}")
        };
        let mut guard = self.token.write().await;
        *guard = Some(normalized);
    }

    pub async fn get<T: DeserializeOwned>(&self, route: &str) -> Result<T, RestError> {
        self.request(reqwest::Method::GET, route, Option::<&()>::None)
            .await
    }

    pub async fn post<T: DeserializeOwned>(
        &self,
        route: &str,
        body: Option<&(impl Serialize + Sync)>,
    ) -> Result<T, RestError> {
        self.request(reqwest::Method::POST, route, body).await
    }

    pub async fn patch<T: DeserializeOwned>(
        &self,
        route: &str,
        body: Option<&(impl Serialize + Sync)>,
    ) -> Result<T, RestError> {
        self.request(reqwest::Method::PATCH, route, body).await
    }

    pub async fn put<T: DeserializeOwned>(
        &self,
        route: &str,
        body: Option<&(impl Serialize + Sync)>,
    ) -> Result<T, RestError> {
        self.request(reqwest::Method::PUT, route, body).await
    }

    pub async fn delete_route(&self, route: &str) -> Result<(), RestError> {
        self.request_empty(reqwest::Method::DELETE, route).await
    }

    pub async fn put_empty(&self, route: &str) -> Result<(), RestError> {
        self.request_empty(reqwest::Method::PUT, route).await
    }

    pub async fn post_multipart<T: DeserializeOwned>(
        &self,
        route: &str,
        form: reqwest::multipart::Form,
    ) -> Result<T, RestError> {
        self.request_multipart(reqwest::Method::POST, route, form)
            .await
    }

    pub async fn patch_multipart<T: DeserializeOwned>(
        &self,
        route: &str,
        form: reqwest::multipart::Form,
    ) -> Result<T, RestError> {
        self.request_multipart(reqwest::Method::PATCH, route, form)
            .await
    }

    async fn request<T: DeserializeOwned>(
        &self,
        method: reqwest::Method,
        route: &str,
        body: Option<&(impl Serialize + Sync)>,
    ) -> Result<T, RestError> {
        let url = format!("{}{}", self.options.api_url, route);
        let mut attempt = 0u32;

        loop {
            self.rate_limiter.wait_if_needed(route).await;

            let mut req = self.http.request(method.clone(), &url);
            req = req.headers(self.build_headers().await);

            if let Some(b) = body {
                req = req.json(b);
            }

            let res = req.send().await?;
            let status = res.status().as_u16();
            self.read_rate_limit_headers_from(route, res.headers());
            let text = res.text().await.unwrap_or_default();

            if status == 429
                && let Ok(rl) = serde_json::from_str::<fluxer_types::RateLimitErrorBody>(&text)
            {
                let global = rl.global.unwrap_or(false);
                if global {
                    self.rate_limiter.set_global(rl.retry_after);
                }
                attempt += 1;
                if attempt < self.options.max_retries {
                    tokio::time::sleep(Duration::from_secs_f64(rl.retry_after)).await;
                    continue;
                }
                return Err(RateLimitError {
                    retry_after: rl.retry_after,
                    global,
                    message: rl.message,
                }
                .into());
            }

            if status >= 400 {
                return Err(self.parse_error(status, &text));
            }

            if text.is_empty() {
                return serde_json::from_str("null").map_err(Into::into);
            }
            return serde_json::from_str(&text).map_err(Into::into);
        }
    }

    async fn request_empty(&self, method: reqwest::Method, route: &str) -> Result<(), RestError> {
        let url = format!("{}{}", self.options.api_url, route);
        self.rate_limiter.wait_if_needed(route).await;

        let req = self
            .http
            .request(method, &url)
            .headers(self.build_headers().await);
        let res = req.send().await?;
        let status = res.status().as_u16();
        self.read_rate_limit_headers_from(route, res.headers());
        let text = res.text().await.unwrap_or_default();

        if status == 429
            && let Ok(rl) = serde_json::from_str::<fluxer_types::RateLimitErrorBody>(&text)
        {
            return Err(RateLimitError {
                retry_after: rl.retry_after,
                global: rl.global.unwrap_or(false),
                message: rl.message,
            }
            .into());
        }

        if status >= 400 {
            return Err(self.parse_error(status, &text));
        }

        Ok(())
    }

    async fn request_multipart<T: DeserializeOwned>(
        &self,
        method: reqwest::Method,
        route: &str,
        form: reqwest::multipart::Form,
    ) -> Result<T, RestError> {
        let url = format!("{}{}", self.options.api_url, route);
        self.rate_limiter.wait_if_needed(route).await;

        let mut headers = self.build_headers().await;
        headers.remove(CONTENT_TYPE);

        let res = self
            .http
            .request(method, &url)
            .headers(headers)
            .multipart(form)
            .send()
            .await?;

        let status = res.status().as_u16();
        self.read_rate_limit_headers_from(route, res.headers());
        let text = res.text().await.unwrap_or_default();

        if status >= 400 {
            return Err(self.parse_error(status, &text));
        }

        serde_json::from_str(&text).map_err(Into::into)
    }

    fn parse_error(&self, status: u16, text: &str) -> RestError {
        if let Ok(api_err) = serde_json::from_str::<fluxer_types::ApiErrorBody>(text) {
            let field_errors: Vec<FieldError> = api_err
                .errors
                .unwrap_or_default()
                .into_iter()
                .map(|e| FieldError {
                    path: e.path,
                    message: e.message,
                })
                .collect();
            if !field_errors.is_empty() {
                let detail = field_errors
                    .iter()
                    .map(|f| format!("  .{}: {}", f.path, f.message))
                    .collect::<Vec<_>>()
                    .join("\n");
                tracing::error!(
                    "API {status} {code}: {msg}\n{detail}",
                    status = status,
                    code = api_err.code,
                    msg = api_err.message,
                    detail = detail,
                );
            }
            FluxerApiError {
                code: api_err.code,
                message: api_err.message,
                status_code: status,
                errors: field_errors,
            }
            .into()
        } else {
            HttpError {
                status_code: status,
                body: text.to_string(),
            }
            .into()
        }
    }

    async fn build_headers(&self) -> HeaderMap {
        let mut headers = HeaderMap::new();
        headers.insert(
            USER_AGENT,
            HeaderValue::from_str(&self.options.user_agent).expect("valid user agent"),
        );
        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
        let token = self.token.read().await;
        if let Some(ref t) = *token
            && let Ok(val) = HeaderValue::from_str(t)
        {
            headers.insert(AUTHORIZATION, val);
        }
        headers
    }

    fn read_rate_limit_headers_from(&self, route: &str, headers: &HeaderMap) {
        let remaining = headers
            .get("x-ratelimit-remaining")
            .and_then(|v| v.to_str().ok())
            .and_then(|v| v.parse::<u32>().ok());
        let reset_after = headers
            .get("x-ratelimit-reset-after")
            .and_then(|v| v.to_str().ok())
            .and_then(|v| v.parse::<f64>().ok());
        let is_global = headers
            .get("x-ratelimit-global")
            .and_then(|v| v.to_str().ok())
            .map(|v| v == "true")
            .unwrap_or(false);

        self.rate_limiter
            .update(route, remaining, reset_after, is_global);
    }
}

impl Default for Rest {
    fn default() -> Self {
        Self::new(RestOptions::default())
    }
}