use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use progscrape_scrapers::{ScrapeId, StoryDate, TypedScrape};
use serde::{Deserialize, Serialize};
use tracing::{error, info};
use crate::{PersistError, story::StoryScrapeId};
use super::{PersistLocation, db::DB, shard::Shard};
pub const SCRAPE_STORE_VERSION: usize = 1;
pub struct ScrapeStore {
location: PersistLocation,
shards: RwLock<HashMap<Shard, Arc<DB>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ScrapeStoreStats {
#[serde(default)]
pub version: usize,
pub earliest: StoryDate,
pub latest: StoryDate,
pub count: usize,
}
#[derive(Default, Serialize, Deserialize)]
struct ScrapeCacheEntry {
date: StoryDate,
id: String,
json: String,
}
impl ScrapeStore {
pub fn new(location: PersistLocation) -> Result<Self, PersistError> {
tracing::info!("Initialized ScrapeStore at {:?}", location);
Ok(Self {
location,
shards: RwLock::new(HashMap::new()),
})
}
fn open_shard(&self, shard: Shard) -> Result<Arc<DB>, PersistError> {
let mut lock = self.shards.write().expect("Poisoned");
let db = if let Some(db) = lock.get(&shard) {
db
} else {
let db = match self.location.join(shard.to_string()) {
PersistLocation::Memory => DB::open(":memory:")?,
PersistLocation::Path(ref path) => {
std::fs::create_dir_all(path)?;
let path = path.join("scrapes.sqlite3");
tracing::info!("Opening scrape database at {}", path.to_string_lossy());
let db = DB::open(path)?;
db.execute_raw("PRAGMA journal_mode = WAL")?;
db
}
};
lock.entry(shard).or_insert(Arc::new(db))
};
db.create_table::<ScrapeCacheEntry>()?;
db.create_unique_index::<ScrapeCacheEntry>("idx_id", &["id"])?;
Ok(db.clone())
}
pub fn validate_shard(&self, shard: Shard) -> Result<(), PersistError> {
let db = self.open_shard(shard)?;
info!("Validating DB shard {shard:?}");
let res = db.execute_raw("PRAGMA integrity_check");
if let Err(e) = res {
error!("Failed to validate DB shard {shard:?}: {e:?}");
return Err(e);
}
info!("DB shard {shard:?} OK.");
Ok(())
}
pub fn insert_scrape(&self, scrape: &TypedScrape) -> Result<(), PersistError> {
self.insert_scrape_batch([scrape])
}
pub fn insert_scrape_batch<'a, I: IntoIterator<Item = &'a TypedScrape>>(
&self,
iter: I,
) -> Result<(), PersistError> {
let mut per_shard: HashMap<Shard, Vec<&TypedScrape>> = HashMap::new();
for item in iter {
let shard = Shard::from_date_time(item.date);
per_shard.entry(shard).or_default().push(item);
}
for (shard, stories) in per_shard {
let db = self.open_shard(shard)?;
let mut batch = vec![];
for item in stories {
let json = serde_json::to_string(item)?;
batch.push(ScrapeCacheEntry {
date: item.date,
id: item.id.to_string(),
json,
});
}
db.store_batch(batch)?;
}
Ok(())
}
pub fn fetch_scrape(
&self,
shard: Shard,
id: &ScrapeId,
) -> Result<Option<TypedScrape>, PersistError> {
let db = self.open_shard(shard)?;
let scrape = db.load::<ScrapeCacheEntry>(id.to_string())?;
if let Some(scrape) = scrape {
let typed_scrape = serde_json::from_str(&scrape.json)?;
Ok(Some(typed_scrape))
} else {
Ok(None)
}
}
pub fn fetch_scrape_batch<'a, I: IntoIterator<Item = StoryScrapeId>>(
&self,
iter: I,
) -> Result<HashMap<ScrapeId, Option<TypedScrape>>, PersistError> {
let mut map = HashMap::new();
for id in iter {
let db = self.open_shard(id.shard)?;
let scrape = db.load::<ScrapeCacheEntry>(id.id.to_string())?;
if let Some(scrape) = scrape {
let typed_scrape = serde_json::from_str(&scrape.json)?;
map.insert(id.id.clone(), typed_scrape);
} else {
map.insert(id.id.clone(), None);
}
}
Ok(map)
}
pub fn fetch_all<F: FnMut(TypedScrape) -> Result<(), PersistError>, FE: FnMut(PersistError)>(
&self,
shard: Shard,
mut f: F,
mut fe: FE,
) -> Result<(), PersistError> {
let db = self.open_shard(shard)?;
let sql = format!(
"select * from {} order by date, id",
DB::table_for::<ScrapeCacheEntry>()
);
db.query_raw_callback(&sql, |scrape: ScrapeCacheEntry| {
match serde_json::from_str(&scrape.json) {
Ok(typed_scrape) => f(typed_scrape)?,
Err(e) => fe(e.into()),
}
Ok(())
})?;
Ok(())
}
pub fn stats(&self, shard: Shard) -> Result<ScrapeStoreStats, PersistError> {
let db = self.open_shard(shard)?;
let sql = format!(
"select {} version, count(*) count, coalesce(min(date), 0) as earliest, coalesce(max(date), 0) as latest from {}",
SCRAPE_STORE_VERSION,
DB::table_for::<ScrapeCacheEntry>()
);
if let Some(stats) = db.query_raw::<ScrapeStoreStats>(&sql)?.into_iter().next() {
Ok(stats)
} else {
Err(PersistError::UnexpectedError(
"Failed to fetch single row for query".into(),
))
}
}
}
#[cfg(test)]
mod test {
use progscrape_scrapers::ScrapeConfig;
use rstest::rstest;
use crate::test::enable_tracing;
use super::*;
#[rstest]
fn test_insert(_enable_tracing: &bool) -> Result<(), Box<dyn std::error::Error>> {
let store = ScrapeStore::new(PersistLocation::Memory)?;
let samples = progscrape_scrapers::load_sample_scrapes(&ScrapeConfig::default());
let first = &samples[0..100];
let stats = store.stats(Shard::from_date_time(first[0].date))?;
assert_eq!(stats.count, 0);
for scrape in first {
store.insert_scrape(scrape)?;
}
for scrape in first {
let loaded_scrape = store
.fetch_scrape(Shard::from_date_time(scrape.date), &scrape.id)?
.unwrap();
assert_eq!(scrape.id, loaded_scrape.id);
}
let stats = store.stats(Shard::from_date_time(first[0].date))?;
assert!(stats.count >= 1);
Ok(())
}
}