use std::any::TypeId;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::mpsc;
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use either::Either;
use crate::arc_erase::ArcEraseDyn;
use crate::storage::data::DataKey;
use crate::storage::data::PagableData;
use crate::storage::support::SerializerForPaging;
use crate::storage::traits::PagableStorage;
use crate::traits::SessionContext;
pub struct SledBackedPagableStorage {
sender: mpsc::Sender<Box<dyn ArcEraseDyn>>,
db: sled::Db,
arcs: DashMap<(TypeId, DataKey), Box<dyn ArcEraseDyn>>,
pending: Mutex<SledPendingPageOut>,
session_context: SessionContext,
}
struct SledPendingPageOut {
pending_messages: mpsc::Receiver<Box<dyn ArcEraseDyn>>,
pending: Vec<Box<dyn ArcEraseDyn>>,
}
impl SledBackedPagableStorage {
pub fn try_new(path: &std::path::Path) -> anyhow::Result<Self> {
let db = sled::Config::new()
.cache_capacity(1024 * 1024 * 2) .path(path)
.open()?;
let (sender, receiver) = mpsc::channel();
Ok(Self {
sender,
db,
arcs: DashMap::new(),
pending: Mutex::new(SledPendingPageOut {
pending_messages: receiver,
pending: Vec::new(),
}),
session_context: SessionContext::new(),
})
}
pub fn pending_paging_count(&self) -> usize {
let mut lock = self.pending.lock().expect("lock poisoned");
while let Ok(v) = lock.pending_messages.try_recv() {
lock.pending.push(v);
}
lock.pending.len()
}
fn serialize_arcs(
&self,
roots: Vec<Box<dyn ArcEraseDyn>>,
finished: &mut HashMap<usize, DataKey>,
session_context: &SessionContext,
) -> anyhow::Result<()> {
enum Task {
Start(Box<dyn ArcEraseDyn>),
Finish((Box<dyn ArcEraseDyn>, Vec<u8>, Vec<Box<dyn ArcEraseDyn>>)),
}
let mut tasks: Vec<Task> = roots.into_iter().map(Task::Start).collect();
while let Some(task) = tasks.pop() {
match task {
Task::Start(v) => {
if finished.contains_key(&v.identity()) {
continue;
}
let mut serializer = SerializerForPaging::new(session_context);
v.serialize(&mut serializer)?;
let (data, arcs) = serializer.finish();
let subtasks: Vec<_> = arcs
.iter()
.filter(|arc| !finished.contains_key(&arc.identity()))
.map(|arc| Task::Start(arc.clone_dyn()))
.collect();
tasks.push(Task::Finish((v, data, arcs)));
tasks.extend(subtasks);
}
Task::Finish((arc, data, serialized_arcs)) => {
let arcs = serialized_arcs
.iter()
.map(|arc| {
*finished
.get(&arc.identity())
.expect("nested arc should have been serialized first")
})
.collect();
let key = self.store_data(PagableData { data, arcs })?;
finished.insert(arc.identity(), key);
arc.set_data_key(key);
}
}
}
Ok(())
}
pub fn page_out_pending(&self) -> anyhow::Result<()> {
loop {
let item = {
let mut lock = self.pending.lock().expect("lock poisoned");
while let Ok(v) = lock.pending_messages.try_recv() {
lock.pending.push(v);
}
lock.pending.pop()
};
match item {
Some(v) if v.needs_paging_out() => {
let mut finished: HashMap<usize, DataKey> = HashMap::new();
self.serialize_arcs(vec![v], &mut finished, &self.session_context)?;
}
Some(_) => continue,
None => break,
}
}
Ok(())
}
pub fn write_bytes<T: bytemuck::Pod>(&self, key: &str, data: T) {
self.db
.insert(key.as_bytes(), sled::IVec::from(bytemuck::bytes_of(&data)))
.expect("sled insert failed");
}
pub fn flush(&self) {
self.db.flush().expect("sled flush failed");
}
pub fn fetch_bytes_blocking<T: bytemuck::Pod>(&self, key: &str) -> anyhow::Result<T> {
let bytes = self
.db
.get(key.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("no data for key {:?}", key))?;
Ok(bytemuck::pod_read_unaligned(&bytes))
}
pub fn fetch_data_blocking(&self, key: &DataKey) -> anyhow::Result<Arc<PagableData>> {
let bytes = self
.db
.get(bytemuck::bytes_of(key))?
.ok_or_else(|| anyhow::anyhow!("no data for key {:?}", key))?;
Self::decode_pagable_data(&bytes, key)
}
fn decode_pagable_data(bytes: &[u8], key: &DataKey) -> anyhow::Result<Arc<PagableData>> {
if bytes.len() < 16 {
return Err(anyhow::anyhow!(
"corrupt sled entry for key {:?}: too short for header",
key
));
}
let data_len = u64::from_le_bytes(bytes[..8].try_into()?) as usize;
let arcs_len = u64::from_le_bytes(bytes[8..16].try_into()?) as usize;
let expected_len = arcs_len
.checked_mul(16)
.and_then(|v| v.checked_add(data_len))
.and_then(|v| v.checked_add(16))
.ok_or_else(|| {
anyhow::anyhow!(
"corrupt sled entry for key {:?}: length overflow (data_len={}, arcs_len={})",
key,
data_len,
arcs_len
)
})?;
if bytes.len() < expected_len {
return Err(anyhow::anyhow!(
"corrupt sled entry for key {:?}: expected {} bytes, got {}",
key,
expected_len,
bytes.len()
));
}
let data = bytes[16..16 + data_len].to_vec();
let arcs = (0..arcs_len)
.map(|i| {
let offset = 16 + data_len + i * 16;
bytemuck::pod_read_unaligned(&bytes[offset..offset + 16])
})
.collect();
Ok(Arc::new(PagableData { data, arcs }))
}
}
#[async_trait::async_trait]
impl PagableStorage for SledBackedPagableStorage {
fn fetch_arc_or_data_blocking(
&self,
type_id: &TypeId,
key: &DataKey,
) -> anyhow::Result<Either<Box<dyn ArcEraseDyn>, Arc<PagableData>>> {
if let Some(v) = self.arcs.get(&(*type_id, *key)) {
return Ok(Either::Left(v.clone_dyn()));
}
self.fetch_data_blocking(key).map(Either::Right)
}
#[cfg(any(feature = "tokio", test))]
async fn fetch_data(&self, key: &DataKey) -> anyhow::Result<Arc<PagableData>> {
let db = self.db.clone();
let key = *key;
tokio::task::spawn_blocking(move || {
let bytes = db
.get(bytemuck::bytes_of(&key))?
.ok_or_else(|| anyhow::anyhow!("no data for key {:?}", key))?;
Self::decode_pagable_data(&bytes, &key)
})
.await?
}
#[cfg(not(any(feature = "tokio", test)))]
async fn fetch_data(&self, _key: &DataKey) -> anyhow::Result<Arc<PagableData>> {
Err(anyhow::anyhow!("sled backend requires tokio feature"))
}
fn on_arc_deserialized(
&self,
typeid: TypeId,
key: DataKey,
arc: Box<dyn ArcEraseDyn>,
) -> Option<Box<dyn ArcEraseDyn>> {
match self.arcs.entry((typeid, key)) {
Entry::Occupied(occupied_entry) => Some(occupied_entry.get().clone_dyn()),
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(arc);
None
}
}
}
fn schedule_for_paging(&self, arc: Box<dyn ArcEraseDyn>) {
drop(self.sender.send(arc));
}
fn session_context(&self) -> &SessionContext {
&self.session_context
}
fn store_data(&self, data: PagableData) -> anyhow::Result<DataKey> {
let key = data.compute_key();
let db_key = bytemuck::bytes_of(&key);
if self.db.contains_key(db_key)? {
return Ok(key);
}
let bytes_size = 8 + 8 + data.data.len() + data.arcs.len() * 16;
let mut bytes = Vec::with_capacity(bytes_size);
bytes.extend_from_slice(&(data.data.len() as u64).to_le_bytes());
bytes.extend_from_slice(&(data.arcs.len() as u64).to_le_bytes());
bytes.extend_from_slice(&data.data);
bytes.extend_from_slice(bytemuck::cast_slice(&data.arcs));
assert_eq!(bytes.len(), bytes_size);
self.db.insert(db_key, sled::IVec::from(bytes))?;
Ok(key)
}
}