use std::pin::Pin;
use futures_core::Stream;
use serde_json::Value;
use crate::errors::AudDError;
use crate::http::{BareHttpClient, HttpResponse};
use crate::retry::{retry_async, RetryClass, RetryPolicy};
const LONGPOLL_URL: &str = "https://api.audd.io/longpoll/";
const HTTP_CLIENT_ERROR_FLOOR: u16 = 400;
#[derive(Debug, Clone)]
pub struct LongpollConsumerBuilder {
category: String,
max_attempts: u32,
backoff_factor: f64,
reqwest_client: Option<reqwest::Client>,
base_url: String,
}
impl LongpollConsumerBuilder {
fn new(category: impl Into<String>) -> Self {
Self {
category: category.into(),
max_attempts: 3,
backoff_factor: 0.5,
reqwest_client: None,
base_url: LONGPOLL_URL.to_string(),
}
}
#[must_use]
pub fn max_attempts(mut self, n: u32) -> Self {
self.max_attempts = n;
self
}
#[must_use]
pub fn backoff_factor(mut self, f: f64) -> Self {
self.backoff_factor = f;
self
}
#[must_use]
pub fn reqwest_client(mut self, client: reqwest::Client) -> Self {
self.reqwest_client = Some(client);
self
}
#[must_use]
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn build(self) -> Result<LongpollConsumer, AudDError> {
let http = if let Some(c) = self.reqwest_client {
BareHttpClient::from_client(c)
} else {
BareHttpClient::new()?
};
Ok(LongpollConsumer {
category: self.category,
http,
policy: RetryPolicy::new(RetryClass::Read)
.with_max_attempts(self.max_attempts)
.with_backoff_factor(self.backoff_factor),
url: self.base_url,
})
}
}
#[derive(Debug, Clone)]
pub struct LongpollConsumer {
category: String,
http: BareHttpClient,
policy: RetryPolicy,
url: String,
}
#[derive(Debug, Clone)]
pub struct LongpollIterateOptions {
pub since_time: Option<i64>,
pub timeout: i64,
}
impl Default for LongpollIterateOptions {
fn default() -> Self {
Self {
since_time: None,
timeout: 50,
}
}
}
impl LongpollConsumer {
#[must_use]
pub fn new(category: impl Into<String>) -> Self {
Self::builder(category)
.build()
.expect("default reqwest::Client should build")
}
pub fn builder(category: impl Into<String>) -> LongpollConsumerBuilder {
LongpollConsumerBuilder::new(category)
}
pub fn iterate(
&self,
opts: LongpollIterateOptions,
) -> Pin<Box<dyn Stream<Item = Result<Value, AudDError>> + Send>> {
let http = self.http.clone();
let category = self.category.clone();
let url = self.url.clone();
let policy = self.policy;
let mut cur_since = opts.since_time;
let timeout = opts.timeout;
let stream = async_stream::try_stream! {
loop {
let mut params: Vec<(&str, String)> = vec![
("category", category.clone()),
("timeout", timeout.to_string()),
];
if let Some(t) = cur_since {
params.push(("since_time", t.to_string()));
}
let resp = retry_async(
|| {
let http = http.clone();
let url = url.clone();
let params = params.clone();
async move {
http.get(&url, ¶ms).await
}
},
policy,
)
.await?;
let body = decode(resp)?;
if let Some(ts) = body.get("timestamp").and_then(Value::as_i64) {
cur_since = Some(ts);
}
yield body;
}
};
Box::pin(stream)
}
pub fn close(self) {
drop(self);
}
}
fn decode(resp: HttpResponse) -> Result<Value, AudDError> {
let HttpResponse {
json_body,
http_status,
request_id,
raw_text,
} = resp;
if http_status >= HTTP_CLIENT_ERROR_FLOOR {
return Err(AudDError::Server {
http_status,
message: format!("Longpoll endpoint returned HTTP {http_status}"),
request_id,
raw_response: raw_text,
});
}
let body = json_body.ok_or_else(|| AudDError::Serialization {
message: "Longpoll response was not a JSON object".into(),
raw_text: raw_text.clone(),
})?;
if !body.is_object() {
return Err(AudDError::Serialization {
message: "Longpoll response was not a JSON object".into(),
raw_text,
});
}
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_2xx_garbage_is_serialization() {
let r = HttpResponse {
json_body: None,
http_status: 200,
request_id: None,
raw_text: "boom".into(),
};
let e = decode(r).unwrap_err();
assert!(matches!(e, AudDError::Serialization { .. }));
}
#[test]
fn decode_non_2xx_is_server() {
let r = HttpResponse {
json_body: None,
http_status: 500,
request_id: None,
raw_text: "<html>".into(),
};
let e = decode(r).unwrap_err();
match e {
AudDError::Server { http_status, .. } => assert_eq!(http_status, 500),
other => panic!("not Server: {other:?}"),
}
}
}