use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::channels::ComponentStatus;
use crate::queries::output_state::{FetchError, SnapshotStream};
use crate::reactions::QueryProvider;
#[async_trait]
pub trait SnapshotFetcher: Send + Sync {
async fn fetch_snapshot(&self, query_id: &str) -> Result<SnapshotStream, FetchError>;
}
pub(crate) struct InProcessSnapshotFetcher {
query_provider: Arc<RwLock<Option<Arc<dyn QueryProvider>>>>,
allowed_queries: HashSet<String>,
}
impl InProcessSnapshotFetcher {
pub fn new(
query_provider: Arc<RwLock<Option<Arc<dyn QueryProvider>>>>,
allowed_queries: Vec<String>,
) -> Self {
Self {
query_provider,
allowed_queries: allowed_queries.into_iter().collect(),
}
}
}
#[async_trait]
impl SnapshotFetcher for InProcessSnapshotFetcher {
async fn fetch_snapshot(&self, query_id: &str) -> Result<SnapshotStream, FetchError> {
if !self.allowed_queries.contains(query_id) {
return Err(FetchError::NotRunning {
status: ComponentStatus::Error,
});
}
let provider = self.query_provider.read().await;
let provider = provider.as_ref().ok_or(FetchError::NotRunning {
status: ComponentStatus::Error,
})?;
let query =
provider
.get_query_instance(query_id)
.await
.map_err(|_| FetchError::NotRunning {
status: ComponentStatus::Error,
})?;
let snapshot = query.fetch_snapshot().await?;
Ok(SnapshotStream::from_snapshot(snapshot))
}
}