sota 0.9.1

API crate for Summits on the Air
Documentation
use std::collections::HashMap;

use http::Extensions;
use reqwest::Url;
use reqwest_middleware::{
    reqwest::{Request, Response},
    Error, Middleware, Next,
};
use tokio::sync::RwLock;

const EPOCH_KEYS: &[&str] = &["alerts", "spots"];

/// This [`Middleware`] adds support for querying SOTA API epochs.
#[derive(Default)]
pub(super) struct EpochMiddleware {
    epochs: RwLock<HashMap<String, String>>,
    /// For making epoch requests.
    client: reqwest::Client,
}

impl EpochMiddleware {
    #[allow(dead_code)]
    pub fn new() -> Self {
        Self::default()
    }

    async fn get_epoch(&self, orig_url: &Url, key: &str) -> Result<String, EpochMiddlewareError> {
        self.client
            .get(epoch_url(&mut orig_url.clone(), key).unwrap())
            .send()
            .await?
            .text()
            .await
            .map_err(From::from)
    }
}

/// Return the name of the queried resource if it's governed by [epochal caching](EPOCH_KEYS).
fn epoch_key(url: &Url) -> Option<String> {
    url.path_segments()
        .unwrap()
        .find(|segment| EPOCH_KEYS.contains(segment))
        .map(String::from)
}

/// Replace path segments after queried resource's name with `/epoch`.
///
/// For example, `http://localhost/spots/-1/all/all` becomes
/// `http://localhost/spots/epoch`.
fn epoch_url(url: &mut Url, key: &str) -> Option<Url> {
    // All the unwraps are safe: cannot-be-a-base URLs are nonexistent here.
    loop {
        match url.path_segments().unwrap().next_back() {
            Some("") | None => return None,
            Some(segment) => {
                if segment == key {
                    url.path_segments_mut().unwrap().push("epoch");
                    return Some(url.to_owned());
                } else {
                    url.path_segments_mut().unwrap().pop();
                }
            }
        }
    }
}

#[async_trait::async_trait]
impl Middleware for EpochMiddleware {
    async fn handle(
        &self,
        req: Request,
        extensions: &mut Extensions,
        next: Next<'_>,
    ) -> reqwest_middleware::Result<Response> {
        if let Some(key) = epoch_key(req.url()) {
            let current_epoch = self.get_epoch(req.url(), &key).await?;

            if let Some(stored_epoch) = self.epochs.read().await.get(&key) {
                if *stored_epoch == current_epoch {
                    return Err(EpochMiddlewareError::EpochUnchanged.into());
                }
            }

            self.epochs.write().await.insert(key, current_epoch);
        }

        next.run(req, extensions).await
    }
}

#[derive(Debug, thiserror::Error)]
pub enum EpochMiddlewareError {
    #[error("Current data still fresh")]
    EpochUnchanged,
    #[error("Couldn't get epoch: {0}")]
    EpochRequestFailure(#[from] reqwest::Error),
}

impl From<EpochMiddlewareError> for Error {
    fn from(value: EpochMiddlewareError) -> Self {
        Error::middleware(value)
    }
}

#[cfg(test)]
mod test {
    use std::sync::Arc;

    use httpmock::{Method::GET, MockServer};
    use reqwest::Url;
    use reqwest_middleware::ClientBuilder;

    use super::{epoch_key, epoch_url, EpochMiddleware};
    use crate::client::BASE_PATH;

    #[test]
    fn test_epoch() {
        let base = |s: &str| -> Url { format!("{BASE_PATH}/{s}").as_str().try_into().unwrap() };

        assert_eq!(
            epoch_key(&mut base("spots/1/all/all")),
            Some("spots".into())
        );
        assert_eq!(
            epoch_key(&mut base("alerts/12/all/all")),
            Some("alerts".into())
        );
        assert_eq!(epoch_key(&mut base("null")), None);

        assert_eq!(
            epoch_url(&mut base("spots/1/all/all"), "spots"),
            Some(base("spots/epoch"))
        );
        assert_eq!(epoch_url(&mut base("alerts/12/all/all"), "spots"), None);
        assert_eq!(epoch_url(&mut base("null"), "spots"), None);
    }

    const SPOTS_PATH: &str = "/spots/-1/all/all";
    const ALERTS_PATH: &str = "/alerts/12/all/all";

    #[tokio::test]
    async fn test_middleware() {
        let middleware = Arc::new(EpochMiddleware::new());

        let client = ClientBuilder::new(reqwest::Client::new())
            .with_arc(middleware.clone())
            .build();

        let server = MockServer::start();
        let mock_spots = server.mock(|when, then| {
            when.path(SPOTS_PATH);
            then.status(200)
                .header("content-type", "application/json")
                .body(include_str!("test/spot_variants.json"));
        });
        let mock_alerts = server.mock(|when, then| {
            when.path(ALERTS_PATH);
            then.status(200)
                .header("content-type", "application/json")
                .body(include_str!("test/alerts.json"));
        });

        let mut spots_epoch = "1";
        let mut mock_spots_epoch = server.mock(|when, then| {
            when.method(GET).path("/spots/epoch");
            then.status(200).body(spots_epoch);
        });
        let mock_alerts_epoch = server.mock(|when, then| {
            when.method(GET).path("/alerts/epoch");
            then.status(200).body(spots_epoch);
        });

        // Get spots for first time; store epoch.
        let _ = client.get(server.url(SPOTS_PATH)).send().await.unwrap();
        test_contents(&middleware, "spots", spots_epoch).await;
        mock_spots_epoch.assert_hits_async(1).await;
        mock_spots.assert_hits_async(1).await;

        // Get spots again, but server info is fresh; fail request.
        let res = client.get(server.url(SPOTS_PATH)).send().await;
        assert!(res.is_err());
        mock_spots_epoch.assert_hits_async(2).await;
        mock_spots.assert_hits_async(1).await;

        // Meanwhile, get alerts, which follow a separate timeline.
        let _ = client.get(server.url(ALERTS_PATH)).send().await.unwrap();
        test_contents(&middleware, "alerts", spots_epoch).await;
        mock_alerts_epoch.assert_hits_async(1).await;
        mock_alerts.assert_hits_async(1).await;

        // Advance spots epoch.
        mock_spots_epoch.delete();
        spots_epoch = "2";
        mock_spots_epoch = server.mock(|when, then| {
            when.method(GET).path("/spots/epoch");
            then.status(200).body(spots_epoch);
        });

        // Next attempt to get spots succeeds.
        let _ = client.get(server.url(SPOTS_PATH)).send().await.unwrap();
        test_contents(&middleware, "spots", spots_epoch).await;
        mock_spots_epoch.assert_hits_async(1).await; // not 3; `.delete()` wiped first two
        mock_spots.assert_hits_async(2).await;
    }

    async fn test_contents(middleware: &Arc<EpochMiddleware>, key: &str, expected: &str) {
        assert_eq!(
            middleware.epochs.read().await.get(key),
            Some(&expected.to_string())
        )
    }
}