use std::collections::HashMap;
use log::debug;
use reqwest::Url;
use tokio::sync::RwLock;
use crate::protocol::ServiceInfo;
use crate::services::ServiceType;
use crate::{client::AuthenticatedClient, ErrorKind};
use crate::{EndpointFilters, Error};
#[derive(Debug)]
pub struct EndpointCache {
info: RwLock<HashMap<&'static str, ServiceInfo>>,
pub filters: EndpointFilters,
pub overrides: HashMap<String, Url>,
}
impl Clone for EndpointCache {
fn clone(&self) -> EndpointCache {
EndpointCache {
info: RwLock::new(HashMap::new()),
filters: self.filters.clone(),
overrides: self.overrides.clone(),
}
}
}
impl EndpointCache {
#[inline]
pub fn new() -> Self {
EndpointCache {
info: RwLock::new(HashMap::new()),
filters: EndpointFilters::default(),
overrides: HashMap::new(),
}
}
#[cfg(test)]
pub fn new_with(service_type: &'static str, service_info: ServiceInfo) -> Self {
let mut hm = HashMap::new();
let _ = hm.insert(service_type, service_info);
EndpointCache {
info: RwLock::new(hm),
filters: EndpointFilters::default(),
overrides: HashMap::new(),
}
}
#[inline]
pub fn clear(&mut self) -> &mut Self {
self.info = RwLock::new(HashMap::new());
self
}
pub async fn extract_service_info<Srv, F, T>(
&self,
client: &AuthenticatedClient,
service: Srv,
filter: F,
) -> Result<T, Error>
where
Srv: ServiceType + Send,
F: FnOnce(&ServiceInfo) -> T + Send,
T: Send,
{
let catalog_type = service.catalog_type();
if let Some(info) = self.info.read().await.get(catalog_type) {
return Ok(filter(info));
}
debug!(
"No cached information for service {}, fetching",
catalog_type
);
let mut lock = self.info.write().await;
Ok(if let Some(info) = lock.get(catalog_type) {
filter(info)
} else {
let ep = match self.overrides.get(catalog_type) {
Some(found) => found.clone(),
None => client.get_endpoint(catalog_type, &self.filters).await?,
};
if ep.cannot_be_a_base() || !ep.has_host() {
return Err(Error::new(
ErrorKind::InvalidResponse,
format!("Invalid URL {} received for service {}", ep, catalog_type),
));
}
let info = ServiceInfo::fetch(service, ep, client).await?;
let value = filter(&info);
let _ = lock.insert(catalog_type, info);
value
})
}
}
#[cfg(test)]
mod test {
use reqwest::Url;
use crate::client::AuthenticatedClient;
use crate::protocol::ServiceInfo;
use crate::services::COMPUTE;
use crate::ErrorKind;
use super::EndpointCache;
#[tokio::test]
async fn test_existing() {
let client = AuthenticatedClient::new_noauth("http://localhost").await;
let sinfo = ServiceInfo {
root_url: Url::parse("http://localhost").unwrap(),
major_version: None,
current_version: None,
minimum_version: None,
};
let cache = EndpointCache::new_with("compute", sinfo.clone());
let sinfo2 = cache
.extract_service_info(&client, COMPUTE, |s| s.clone())
.await
.unwrap();
assert_eq!(sinfo, sinfo2);
}
#[tokio::test]
async fn test_invalid_url() {
let client = AuthenticatedClient::new_noauth("unix:/run/foo.socket").await;
let cache = EndpointCache::new();
let err = cache
.extract_service_info(&client, COMPUTE, |s| s.clone())
.await
.err()
.unwrap();
assert_eq!(err.kind(), ErrorKind::InvalidResponse);
}
}