use crate::{entity::Entity, EntityId, Item, Property, RestApi, RestApiError};
use futures::prelude::*;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
const MAX_CONCURRENT_LOAD_DEFAULT: usize = 10;
#[derive(Debug, Clone)]
pub struct EntityContainer {
api: Arc<RestApi>,
items: Arc<RwLock<HashMap<String, Item>>>,
properties: Arc<RwLock<HashMap<String, Property>>>,
max_concurrent_load: usize,
}
impl EntityContainer {
pub fn builder() -> EntityContainerBuilder {
EntityContainerBuilder::default()
}
pub async fn load(&self, entity_ids: &[EntityId]) -> Result<(), RestApiError> {
let item_ids = {
let items = self.items.read().await;
Self::get_items_to_load(&items, entity_ids)
};
let property_ids = {
let properties = self.properties.read().await;
Self::get_properties_to_load(&properties, entity_ids)
};
let (loaded_items, loaded_properties) = tokio::join!(
self.fetch_items(&item_ids),
self.fetch_properties(&property_ids),
);
let loaded_items = loaded_items?;
let loaded_properties = loaded_properties?;
if !loaded_items.is_empty() {
let mut items = self.items.write().await;
for item in loaded_items {
if let Ok(id) = item.id().id() {
items.insert(id.to_owned(), item);
}
}
}
if !loaded_properties.is_empty() {
let mut properties = self.properties.write().await;
for property in loaded_properties {
if let Ok(id) = property.id().id() {
properties.insert(id.to_owned(), property);
}
}
}
Ok(())
}
fn get_items_to_load(items: &HashMap<String, Item>, entity_ids: &[EntityId]) -> Vec<String> {
entity_ids
.iter()
.filter_map(|id| match id {
EntityId::Item(id) => Some(id.as_str()),
_ => None,
})
.filter(|id| !items.contains_key(*id))
.map(|id| id.to_owned())
.collect()
}
async fn fetch_items(&self, item_ids: &[String]) -> Result<Vec<Item>, RestApiError> {
if item_ids.is_empty() {
return Ok(Vec::new());
}
let futures = item_ids
.iter()
.map(|id| Item::get(EntityId::item(id), &self.api));
let results: Vec<_> = futures::stream::iter(futures)
.buffer_unordered(self.max_concurrent_load)
.collect()
.await;
Ok(results.into_iter().flatten().collect())
}
fn get_properties_to_load(
properties: &HashMap<String, Property>,
entity_ids: &[EntityId],
) -> Vec<String> {
entity_ids
.iter()
.filter_map(|id| match id {
EntityId::Property(id) => Some(id.as_str()),
_ => None,
})
.filter(|id| !properties.contains_key(*id))
.map(|id| id.to_owned())
.collect()
}
async fn fetch_properties(
&self,
property_ids: &[String],
) -> Result<Vec<Property>, RestApiError> {
if property_ids.is_empty() {
return Ok(Vec::new());
}
let futures = property_ids
.iter()
.map(|id| Property::get(EntityId::property(id), &self.api));
let results: Vec<_> = futures::stream::iter(futures)
.buffer_unordered(self.max_concurrent_load)
.collect()
.await;
Ok(results.into_iter().flatten().collect())
}
pub fn items(&self) -> Arc<RwLock<HashMap<String, Item>>> {
self.items.clone()
}
pub fn properties(&self) -> Arc<RwLock<HashMap<String, Property>>> {
self.properties.clone()
}
}
#[derive(Debug, Default)]
pub struct EntityContainerBuilder {
api: Option<Arc<RestApi>>,
max_concurrent_load: usize,
}
impl EntityContainerBuilder {
pub fn api(mut self, api: Arc<RestApi>) -> Self {
self.api = Some(api);
self
}
pub const fn max_concurrent(mut self, max_concurrent_load: usize) -> Self {
self.max_concurrent_load = max_concurrent_load;
self
}
pub fn build(self) -> Result<EntityContainer, RestApiError> {
let api = self.api.ok_or(RestApiError::ApiNotSet)?;
let mut max_concurrent_load = self.max_concurrent_load;
if max_concurrent_load == 0 {
max_concurrent_load = MAX_CONCURRENT_LOAD_DEFAULT;
}
Ok(EntityContainer {
api,
items: Arc::new(RwLock::new(HashMap::new())),
properties: Arc::new(RwLock::new(HashMap::new())),
max_concurrent_load,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::RestApi;
use serde_json::Value;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn test_entity_container() {
let q42_str = std::fs::read_to_string("test_data/Q42.json").unwrap();
let q42: Value = serde_json::from_str(&q42_str).unwrap();
let q255_str = std::fs::read_to_string("test_data/Q255.json").unwrap();
let q255: Value = serde_json::from_str(&q255_str).unwrap();
let p214_str = std::fs::read_to_string("test_data/P214.json").unwrap();
let p214: Value = serde_json::from_str(&p214_str).unwrap();
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/w/rest.php/wikibase/v1/entities/items/Q42"))
.respond_with(ResponseTemplate::new(200).set_body_json(&q42))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/w/rest.php/wikibase/v1/entities/items/Q255"))
.respond_with(ResponseTemplate::new(200).set_body_json(&q255))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/w/rest.php/wikibase/v1/entities/properties/P214"))
.respond_with(ResponseTemplate::new(200).set_body_json(&p214))
.mount(&mock_server)
.await;
let api = RestApi::builder(&(mock_server.uri() + "/w/rest.php"))
.unwrap()
.build();
let ec = EntityContainer::builder()
.api(Arc::new(api))
.build()
.unwrap();
ec.load(&[
EntityId::item("Q42"),
EntityId::property("P214"),
EntityId::item("Q255"),
])
.await
.unwrap();
assert!(ec.items().read().await.contains_key("Q42"));
assert!(ec.items().read().await.contains_key("Q255"));
assert!(ec.properties().read().await.contains_key("P214"));
assert!(!ec.properties().read().await.contains_key("Q42"));
assert!(!ec.items().read().await.contains_key("P214"));
}
#[test]
#[cfg_attr(miri, ignore)] fn test_max_concurrent() {
let api = Arc::new(
RestApi::builder("https://test.wikidata.org/w/rest.php")
.unwrap()
.build(),
);
let ec = EntityContainer::builder()
.api(api.clone())
.max_concurrent(5)
.build()
.unwrap();
assert_eq!(ec.max_concurrent_load, 5);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_max_concurrent_default() {
let api = Arc::new(
RestApi::builder("https://test.wikidata.org/w/rest.php")
.unwrap()
.build(),
);
let ec = EntityContainer::builder()
.api(api.clone())
.max_concurrent(0)
.build()
.unwrap();
assert_eq!(ec.max_concurrent_load, MAX_CONCURRENT_LOAD_DEFAULT);
}
}