use std::fmt::Debug;
use std::marker::PhantomData;
use std::slice;
use futures::future::FutureExt;
use tokio::sync::mpsc;
use tracing::{span, Level};
use crate::{
batch_function::BatchFunction,
cache::Cache,
loader_op::{LoadRequest, LoaderOp},
};
pub struct LoaderWorker<K, V, F, CacheT, ContextT>
where
K: 'static + Eq + Debug + Ord + Copy + Send + Sync,
V: 'static + Send + Debug + Clone,
F: 'static + BatchFunction<K, V, Context = ContextT> + Send,
CacheT: Cache,
ContextT: Send + Sync + 'static,
{
cache: CacheT,
request_rx: mpsc::UnboundedReceiver<LoaderOp<K, V>>,
keys_to_load: Vec<K>,
pending_request: Vec<LoadRequest<K, V>>,
context: ContextT,
phantom_batch_function: PhantomData<F>,
debug_name: &'static str,
}
impl<K, V, F, CacheT, ContextT> LoaderWorker<K, V, F, CacheT, ContextT>
where
K: 'static + Eq + Debug + Copy + Ord + Send + Sync,
V: 'static + Send + Debug + Clone,
F: 'static + BatchFunction<K, V, Context = ContextT> + Send,
CacheT: Cache<K = K, V = V>,
ContextT: Send + Sync + 'static,
{
pub fn new(
cache: CacheT,
request_rx: mpsc::UnboundedReceiver<LoaderOp<K, V>>,
context: ContextT,
) -> Self {
Self {
cache,
request_rx,
keys_to_load: Vec::new(),
pending_request: Vec::new(),
context,
phantom_batch_function: PhantomData,
debug_name: std::any::type_name::<(K, V)>(),
}
}
pub async fn start(mut self) {
let span = span!(Level::TRACE, "LoaderWorker", kv = self.debug_name,);
let _enter = span.enter();
loop {
match self.request_rx.recv().await {
None => {
tracing::info!("Tx channel closed. Terminating LoaderWorker.");
return;
}
Some(op) => self.mux_op(op),
}
while let Some(Some(op)) = self.request_rx.recv().now_or_never() {
self.mux_op(op);
}
if !self.pending_request.is_empty() {
self.execute_load().await;
}
}
}
#[tracing::instrument(skip(self))]
fn mux_op(&mut self, op: LoaderOp<K, V>) {
match op {
LoaderOp::Load(request) => {
let cached = self.cache.get_key_vals(request.keys());
let keys_to_load = cached
.iter()
.filter_map(|(k, v)| if v.is_none() { Some(**k) } else { None })
.collect::<Vec<_>>();
tracing::debug!(requested_keys = ?request.keys(), ?keys_to_load);
if keys_to_load.is_empty() {
let values = cached.into_iter().map(|(_k, v)| v).collect::<Vec<_>>();
request.send_response(values);
} else {
self.keys_to_load.extend(&keys_to_load);
self.pending_request.push(request);
}
}
LoaderOp::Prime(key, value) => self.cache.insert(key, value),
LoaderOp::PrimeMany(key_vals) => self.cache.insert_many(key_vals),
LoaderOp::Clear(key) => self.cache.remove(slice::from_ref(&key)),
LoaderOp::ClearMany(keys) => self.cache.remove(&keys),
}
}
#[tracing::instrument(skip(self))]
async fn execute_load(&mut self) {
self.keys_to_load.sort();
self.keys_to_load.dedup();
let loaded_keyvals = F::load(&self.keys_to_load, &self.context).await;
tracing::debug!(?loaded_keyvals);
self.cache.insert_many(loaded_keyvals);
for request in self.pending_request.drain(..) {
let values = self.cache.get(request.keys());
request.send_response(values);
}
}
}