use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::OnceCell;
use super::DependencyError;
type AsyncFactory<T> =
dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, DependencyError>> + Send>> + Send + Sync;
pub struct LazyProvider<T: Clone + Send + Sync + 'static> {
cell: Arc<OnceCell<T>>,
factory: Arc<AsyncFactory<T>>,
}
impl<T: Clone + Send + Sync + 'static> Clone for LazyProvider<T> {
fn clone(&self) -> Self {
Self {
cell: Arc::clone(&self.cell),
factory: Arc::clone(&self.factory),
}
}
}
impl<T: Clone + Send + Sync + 'static> LazyProvider<T> {
pub fn new<F, Fut>(factory: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<T, DependencyError>> + Send + 'static,
{
Self {
cell: Arc::new(OnceCell::new()),
factory: Arc::new(move || Box::pin(factory())),
}
}
pub async fn get(&self) -> Result<T, DependencyError> {
let factory = &self.factory;
self.cell
.get_or_try_init(|| factory())
.await
.cloned()
}
}
#[async_trait::async_trait]
trait LazyInit: Send + Sync {
async fn init(&self) -> Result<(), DependencyError>;
}
struct LazyEntry<T: Clone + Send + Sync + 'static> {
#[allow(dead_code)]
name: String,
provider: LazyProvider<T>,
}
#[async_trait::async_trait]
impl<T: Clone + Send + Sync + 'static> LazyInit for LazyEntry<T> {
async fn init(&self) -> Result<(), DependencyError> {
self.provider.get().await?;
Ok(())
}
}
pub struct LazyContainer {
entries: Vec<Arc<dyn LazyInit>>,
}
impl Default for LazyContainer {
fn default() -> Self {
Self::new()
}
}
impl LazyContainer {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn register_lazy<T, F, Fut>(&mut self, name: &str, factory: F)
where
T: Clone + Send + Sync + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<T, DependencyError>> + Send + 'static,
{
let provider = LazyProvider::new(factory);
self.entries.push(Arc::new(LazyEntry {
name: name.to_string(),
provider,
}));
}
pub async fn warm_up(&self) -> Result<(), DependencyError> {
let mut handles = Vec::with_capacity(self.entries.len());
for entry in &self.entries {
let entry = Arc::clone(entry);
handles.push(tokio::spawn(async move { entry.init().await }));
}
for handle in handles {
handle
.await
.map_err(|e| DependencyError::InitializationFailed {
name: "warm_up".to_string(),
source: Box::new(e),
})??;
}
Ok(())
}
}