use std::collections::HashMap;
use http::Extensions;
use reqwest::Url;
use reqwest_middleware::{
reqwest::{Request, Response},
Error, Middleware, Next,
};
use tokio::sync::RwLock;
const EPOCH_KEYS: &[&str] = &["alerts", "spots"];
#[derive(Default)]
pub(super) struct EpochMiddleware {
epochs: RwLock<HashMap<String, String>>,
client: reqwest::Client,
}
impl EpochMiddleware {
#[allow(dead_code)]
pub fn new() -> Self {
Self::default()
}
async fn get_epoch(&self, orig_url: &Url, key: &str) -> Result<String, EpochMiddlewareError> {
self.client
.get(epoch_url(&mut orig_url.clone(), key).unwrap())
.send()
.await?
.text()
.await
.map_err(From::from)
}
}
fn epoch_key(url: &Url) -> Option<String> {
url.path_segments()
.unwrap()
.find(|segment| EPOCH_KEYS.contains(segment))
.map(String::from)
}
fn epoch_url(url: &mut Url, key: &str) -> Option<Url> {
loop {
match url.path_segments().unwrap().next_back() {
Some("") | None => return None,
Some(segment) => {
if segment == key {
url.path_segments_mut().unwrap().push("epoch");
return Some(url.to_owned());
} else {
url.path_segments_mut().unwrap().pop();
}
}
}
}
}
#[async_trait::async_trait]
impl Middleware for EpochMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
if let Some(key) = epoch_key(req.url()) {
let current_epoch = self.get_epoch(req.url(), &key).await?;
if let Some(stored_epoch) = self.epochs.read().await.get(&key) {
if *stored_epoch == current_epoch {
return Err(EpochMiddlewareError::EpochUnchanged.into());
}
}
self.epochs.write().await.insert(key, current_epoch);
}
next.run(req, extensions).await
}
}
#[derive(Debug, thiserror::Error)]
pub enum EpochMiddlewareError {
#[error("Current data still fresh")]
EpochUnchanged,
#[error("Couldn't get epoch: {0}")]
EpochRequestFailure(#[from] reqwest::Error),
}
impl From<EpochMiddlewareError> for Error {
fn from(value: EpochMiddlewareError) -> Self {
Error::middleware(value)
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use httpmock::{Method::GET, MockServer};
use reqwest::Url;
use reqwest_middleware::ClientBuilder;
use super::{epoch_key, epoch_url, EpochMiddleware};
use crate::client::BASE_PATH;
#[test]
fn test_epoch() {
let base = |s: &str| -> Url { format!("{BASE_PATH}/{s}").as_str().try_into().unwrap() };
assert_eq!(
epoch_key(&mut base("spots/1/all/all")),
Some("spots".into())
);
assert_eq!(
epoch_key(&mut base("alerts/12/all/all")),
Some("alerts".into())
);
assert_eq!(epoch_key(&mut base("null")), None);
assert_eq!(
epoch_url(&mut base("spots/1/all/all"), "spots"),
Some(base("spots/epoch"))
);
assert_eq!(epoch_url(&mut base("alerts/12/all/all"), "spots"), None);
assert_eq!(epoch_url(&mut base("null"), "spots"), None);
}
const SPOTS_PATH: &str = "/spots/-1/all/all";
const ALERTS_PATH: &str = "/alerts/12/all/all";
#[tokio::test]
async fn test_middleware() {
let middleware = Arc::new(EpochMiddleware::new());
let client = ClientBuilder::new(reqwest::Client::new())
.with_arc(middleware.clone())
.build();
let server = MockServer::start();
let mock_spots = server.mock(|when, then| {
when.path(SPOTS_PATH);
then.status(200)
.header("content-type", "application/json")
.body(include_str!("test/spot_variants.json"));
});
let mock_alerts = server.mock(|when, then| {
when.path(ALERTS_PATH);
then.status(200)
.header("content-type", "application/json")
.body(include_str!("test/alerts.json"));
});
let mut spots_epoch = "1";
let mut mock_spots_epoch = server.mock(|when, then| {
when.method(GET).path("/spots/epoch");
then.status(200).body(spots_epoch);
});
let mock_alerts_epoch = server.mock(|when, then| {
when.method(GET).path("/alerts/epoch");
then.status(200).body(spots_epoch);
});
let _ = client.get(server.url(SPOTS_PATH)).send().await.unwrap();
test_contents(&middleware, "spots", spots_epoch).await;
mock_spots_epoch.assert_hits_async(1).await;
mock_spots.assert_hits_async(1).await;
let res = client.get(server.url(SPOTS_PATH)).send().await;
assert!(res.is_err());
mock_spots_epoch.assert_hits_async(2).await;
mock_spots.assert_hits_async(1).await;
let _ = client.get(server.url(ALERTS_PATH)).send().await.unwrap();
test_contents(&middleware, "alerts", spots_epoch).await;
mock_alerts_epoch.assert_hits_async(1).await;
mock_alerts.assert_hits_async(1).await;
mock_spots_epoch.delete();
spots_epoch = "2";
mock_spots_epoch = server.mock(|when, then| {
when.method(GET).path("/spots/epoch");
then.status(200).body(spots_epoch);
});
let _ = client.get(server.url(SPOTS_PATH)).send().await.unwrap();
test_contents(&middleware, "spots", spots_epoch).await;
mock_spots_epoch.assert_hits_async(1).await; mock_spots.assert_hits_async(2).await;
}
async fn test_contents(middleware: &Arc<EpochMiddleware>, key: &str, expected: &str) {
assert_eq!(
middleware.epochs.read().await.get(key),
Some(&expected.to_string())
)
}
}