use std::collections::HashMap;
use async_trait::async_trait;
use chrono::{Duration, Local};
use reqwest::Client as reqClient;
use tokio::sync::RwLock;
use url::Url;
use crate::error::*;
use crate::model::RideTime;
use crate::parser::{FrontPageParser, GenericParkParser, ParkParser};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
pub static BASE_URL: &str = "https://queue-times.com";
#[async_trait]
pub trait QueueTimesClient {
async fn get_park_urls(&self) -> Result<HashMap<String, Url>>;
async fn get_ride_times(&self, park_url: Url) -> Result<Vec<RideTime>>;
}
pub struct Client {
park_parser: GenericParkParser,
front_parser: FrontPageParser,
reqwest_client: reqClient,
}
impl Client {
pub fn new() -> Self {
Client {
park_parser: GenericParkParser::new(),
front_parser: FrontPageParser::new(),
reqwest_client: reqClient::new(),
}
}
}
impl Default for Client {
fn default() -> Self {
Client {
park_parser: GenericParkParser::new(),
front_parser: FrontPageParser::new(),
reqwest_client: reqClient::new(),
}
}
}
#[async_trait]
impl QueueTimesClient for Client {
async fn get_park_urls(&self) -> Result<HashMap<String, Url>> {
let response = self
.reqwest_client
.get(Url::parse(BASE_URL).unwrap().join("/en-US/parks").unwrap())
.send()
.await?;
let html = response.text().await?;
self.front_parser.get_park_urls(&html)
}
async fn get_ride_times(&self, park_url: Url) -> Result<Vec<RideTime>> {
let response = self.reqwest_client.get(park_url).send().await?;
let html = response.text().await?;
self.park_parser.get_ride_times(&html)
}
}
pub struct CachedClient<T>
where
T: QueueTimesClient + Send + Sync + 'static,
{
client: Arc<T>,
ride_cache: Arc<dashmap::DashMap<Url, Vec<RideTime>>>,
parks_cache: RwLock<HashMap<String, Url>>, last_updated: Arc<RwLock<chrono::DateTime<Local>>>,
currently_updating_cache: Arc<AtomicBool>
}
impl<T> CachedClient<T>
where
T: QueueTimesClient + Send + Sync + 'static,
{
pub fn new(client: T) -> Self {
CachedClient {
client: Arc::new(client),
ride_cache: Arc::new(dashmap::DashMap::new()),
parks_cache: RwLock::new(HashMap::new()),
last_updated: Arc::new(RwLock::new(Local::now() - Duration::minutes(6))),
currently_updating_cache: Arc::new(Default::default())
}
}
}
#[async_trait]
impl<T> QueueTimesClient for CachedClient<T>
where
T: QueueTimesClient + Send + Sync + 'static,
{
async fn get_park_urls(&self) -> Result<HashMap<String, Url>> {
if self.parks_cache.read().await.is_empty() {
let parks = self.client.get_park_urls().await?;
let mut lock = self.parks_cache.write().await;
*lock = parks.clone();
Ok(parks)
} else {
let lock = self.parks_cache.read().await;
Ok(lock.clone())
}
}
async fn get_ride_times(&self, park_url: Url) -> Result<Vec<RideTime>> {
{
let time_lock = self.last_updated.read().await;
if (Local::now() - *time_lock) < chrono::Duration::minutes(5) {
let rides = self
.ride_cache
.get(&park_url)
.ok_or_else(|| Error::from(ErrorKind::BadUrl(park_url)))?;
return Ok(rides.value().clone());
}
}
let mut parks = self.get_park_urls().await?;
let client = self.client.clone();
let ride_cache = self.ride_cache.clone();
let last_updated = self.last_updated.clone();
let currently_updating = self.currently_updating_cache.clone();
if !currently_updating.load(Ordering::SeqCst) {
log::debug!("Updating cache");
currently_updating.store(true, Ordering::SeqCst);
tokio::spawn(async move {
let _guard = CompletionGuard{complete: currently_updating};
for (_, park_url) in parks.drain() {
let times = client.get_ride_times(park_url.clone()).await.unwrap();
ride_cache.insert(park_url, times);
}
let mut time_lock = last_updated.write().await;
*time_lock = Local::now();
});
}
let times = self.client.get_ride_times(park_url).await?;
Ok(times)
}
}
impl Default for CachedClient<Client> {
fn default() -> Self {
CachedClient::new(Client::new())
}
}
struct CompletionGuard {
complete: Arc<AtomicBool>
}
impl Drop for CompletionGuard {
fn drop(&mut self) {
self.complete.store(false, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_client() {
let client = Client::new();
let parks = client.get_park_urls().await.unwrap();
println!("CP URL {}", parks.get("Cedar Point").unwrap().to_string());
let cedar_point_waits = client
.get_ride_times(parks.get("Cedar Point").unwrap().to_owned())
.await
.unwrap();
let mille_wait = cedar_point_waits
.iter()
.find(|r| r.name == "Millennium Force")
.unwrap();
println!(
"The current wait for Millennium Force is: {:?}",
mille_wait.status
)
}
#[tokio::test]
async fn test_cache_client() {
let client = Client::new();
let client = CachedClient::new(client);
let parks = client.get_park_urls().await.unwrap();
println!("CP URL {}", parks.get("Cedar Point").unwrap().to_string());
let cedar_point_waits = client
.get_ride_times(parks.get("Cedar Point").unwrap().to_owned())
.await
.unwrap();
let og_mille_wait = cedar_point_waits
.iter()
.find(|r| r.name == "Millennium Force")
.unwrap();
println!(
"The current wait for Millennium Force is: {:?}",
og_mille_wait.status
);
let cedar_point_waits = client
.get_ride_times(parks.get("Cedar Point").unwrap().to_owned())
.await
.unwrap();
let mille_wait = cedar_point_waits
.iter()
.find(|r| r.name == "Millennium Force")
.unwrap();
assert_eq!(mille_wait, og_mille_wait);
let cedar_point_waits = client
.get_ride_times(parks.get("Cedar Point").unwrap().to_owned())
.await
.unwrap();
let mille_wait = cedar_point_waits
.iter()
.find(|r| r.name == "Millennium Force")
.unwrap();
assert_eq!(mille_wait, og_mille_wait);
}
}