use dashmap::DashMap;
use crate::error::Result;
use super::result::FetchResult;
use super::traits::SchemaFetcher;
pub struct CachingFetcher<F: SchemaFetcher> {
inner: F,
cache: DashMap<String, FetchResult>,
}
impl<F: SchemaFetcher> CachingFetcher<F> {
pub fn new(inner: F) -> Self {
Self {
inner,
cache: DashMap::new(),
}
}
pub fn seed(&self, url: &str, content: Vec<u8>) {
self.cache.insert(
url.to_string(),
FetchResult {
content,
final_url: url.to_string(),
redirected: false,
},
);
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn inner(&self) -> &F {
&self.inner
}
}
impl<F: SchemaFetcher> SchemaFetcher for CachingFetcher<F> {
fn fetch(&self, url: &str) -> Result<FetchResult> {
if let Some(entry) = self.cache.get(url) {
return Ok(entry.value().clone());
}
let result = self.inner.fetch(url)?;
self.cache.insert(url.to_string(), result.clone());
if result.final_url != url {
self.cache.insert(result.final_url.clone(), result.clone());
}
Ok(result)
}
}
#[cfg(feature = "tokio")]
pub struct AsyncCachingFetcher<F: super::traits::AsyncSchemaFetcher> {
inner: F,
cache: DashMap<String, FetchResult>,
}
#[cfg(feature = "tokio")]
impl<F: super::traits::AsyncSchemaFetcher> AsyncCachingFetcher<F> {
pub fn new(inner: F) -> Self {
Self {
inner,
cache: DashMap::new(),
}
}
pub fn seed(&self, url: &str, content: Vec<u8>) {
self.cache.insert(
url.to_string(),
FetchResult {
content,
final_url: url.to_string(),
redirected: false,
},
);
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn inner(&self) -> &F {
&self.inner
}
}
#[cfg(feature = "tokio")]
#[async_trait::async_trait]
impl<F: super::traits::AsyncSchemaFetcher> super::traits::AsyncSchemaFetcher
for AsyncCachingFetcher<F>
{
async fn fetch(&self, url: &str) -> Result<FetchResult> {
if let Some(entry) = self.cache.get(url) {
return Ok(entry.value().clone());
}
let result = self.inner.fetch(url).await?;
self.cache.insert(url.to_string(), result.clone());
if result.final_url != url {
self.cache.insert(result.final_url.clone(), result.clone());
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::fetcher::NoopFetcher;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
struct TrackingFetcher {
responses: HashMap<String, Vec<u8>>,
calls: Arc<Mutex<Vec<String>>>,
}
impl TrackingFetcher {
fn new(responses: HashMap<String, Vec<u8>>) -> Self {
Self {
responses,
calls: Arc::new(Mutex::new(Vec::new())),
}
}
fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
}
impl SchemaFetcher for TrackingFetcher {
fn fetch(&self, url: &str) -> Result<FetchResult> {
self.calls.lock().unwrap().push(url.to_string());
match self.responses.get(url) {
Some(content) => Ok(FetchResult {
content: content.clone(),
final_url: url.to_string(),
redirected: false,
}),
None => Err(crate::schema::fetcher::error::FetchError::RequestFailed {
url: url.to_string(),
message: "Not found".to_string(),
}
.into()),
}
}
}
#[test]
fn test_caching_fetcher_caches_result() {
let mut responses = HashMap::new();
responses.insert(
"http://example.com/a.xsd".to_string(),
b"<schema/>".to_vec(),
);
let inner = TrackingFetcher::new(responses);
let fetcher = CachingFetcher::new(inner);
let r1 = fetcher.fetch("http://example.com/a.xsd").unwrap();
assert_eq!(r1.content, b"<schema/>");
assert_eq!(fetcher.inner().call_count(), 1);
let r2 = fetcher.fetch("http://example.com/a.xsd").unwrap();
assert_eq!(r2.content, b"<schema/>");
assert_eq!(fetcher.inner().call_count(), 1); }
#[test]
fn test_caching_fetcher_seed() {
let fetcher = CachingFetcher::new(NoopFetcher);
fetcher.seed("http://example.com/test.xsd", b"<seeded/>".to_vec());
let result = fetcher.fetch("http://example.com/test.xsd").unwrap();
assert_eq!(result.content, b"<seeded/>");
assert_eq!(fetcher.len(), 1);
}
#[test]
fn test_caching_fetcher_len_is_empty() {
let fetcher = CachingFetcher::new(NoopFetcher);
assert!(fetcher.is_empty());
assert_eq!(fetcher.len(), 0);
fetcher.seed("http://example.com/a.xsd", b"a".to_vec());
assert!(!fetcher.is_empty());
assert_eq!(fetcher.len(), 1);
}
}