use std::pin::Pin;
use std::sync::Arc;
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::io::AsyncRead;
use crate::service::MockService;
use crate::vfs::VfsError;
#[derive(Debug)]
pub enum StateViewError {
Serialize(serde_json::Error),
Invalid(String),
Blob(VfsError),
}
impl std::fmt::Display for StateViewError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StateViewError::Serialize(e) => write!(f, "serialization error: {e}"),
StateViewError::Invalid(msg) => write!(f, "invalid state: {msg}"),
StateViewError::Blob(e) => write!(f, "blob error: {e}"),
}
}
}
impl std::error::Error for StateViewError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
StateViewError::Serialize(e) => Some(e),
StateViewError::Invalid(_) => None,
StateViewError::Blob(e) => Some(e),
}
}
}
impl From<VfsError> for StateViewError {
fn from(e: VfsError) -> Self {
StateViewError::Blob(e)
}
}
impl From<serde_json::Error> for StateViewError {
fn from(e: serde_json::Error) -> Self {
StateViewError::Serialize(e)
}
}
pub type StateChangeListener<V> = Arc<dyn Fn(&str, &str, &V) + Send + Sync>;
pub struct StateChangeNotifier<V> {
listeners: std::sync::RwLock<Vec<StateChangeListener<V>>>,
}
impl<V: Send + Sync> StateChangeNotifier<V> {
pub fn new() -> Self {
Self {
listeners: std::sync::RwLock::new(Vec::new()),
}
}
pub fn subscribe(&self, f: impl Fn(&str, &str, &V) + Send + Sync + 'static) {
self.listeners.write().unwrap().push(Arc::new(f));
}
pub fn notify(&self, account_id: &str, region: &str, view: &V) {
let listeners = self.listeners.read().unwrap();
for listener in listeners.iter() {
listener(account_id, region, view);
}
}
}
impl<V: Send + Sync> Default for StateChangeNotifier<V> {
fn default() -> Self {
Self::new()
}
}
#[allow(async_fn_in_trait)]
pub trait StatefulService: MockService {
type StateView: Serialize + DeserializeOwned + Send + Sync;
async fn snapshot(&self, account_id: &str, region: &str) -> Self::StateView;
async fn restore(
&self,
account_id: &str,
region: &str,
view: Self::StateView,
) -> Result<(), StateViewError>;
async fn merge(
&self,
account_id: &str,
region: &str,
view: Self::StateView,
) -> Result<(), StateViewError>;
fn notifier(&self) -> &StateChangeNotifier<Self::StateView>;
async fn notify_state_changed(&self, account_id: &str, region: &str) {
let view = self.snapshot(account_id, region).await;
self.notifier().notify(account_id, region, &view);
}
}
pub const DEFAULT_BLOB_BATCH_SIZE: usize = 64;
pub struct BlobExportEntry {
pub key: String,
pub reader: Box<dyn AsyncRead + Send + Unpin>,
pub size: Option<u64>,
}
pub trait BlobVisitor: Send {
fn visit(
&mut self,
batch: Vec<BlobExportEntry>,
) -> Pin<Box<dyn Future<Output = Result<(), VfsError>> + Send + '_>>;
}
pub trait BlobSource: Send {
fn fetch(
&mut self,
key: String,
) -> Pin<
Box<dyn Future<Output = Result<Box<dyn AsyncRead + Send + Unpin>, VfsError>> + Send + '_>,
>;
}
#[allow(async_fn_in_trait)]
pub trait BlobBackedService: StatefulService {
async fn snapshot_with_blobs(
&self,
account_id: &str,
region: &str,
visitor: &mut dyn BlobVisitor,
) -> Result<Self::StateView, StateViewError>;
async fn restore_with_blobs(
&self,
account_id: &str,
region: &str,
view: Self::StateView,
source: &mut dyn BlobSource,
) -> Result<(), StateViewError>;
}