use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectivityStatus {
Online,
Offline,
Degraded {
reason: String,
},
}
#[async_trait]
pub trait ConnectivityProbe: Send + Sync {
async fn check(&self) -> ConnectivityStatus;
}
#[derive(Debug)]
pub enum CallResult<T, E> {
Executed(Result<T, E>),
Queued,
}
impl<T, E> CallResult<T, E> {
pub fn is_queued(&self) -> bool {
matches!(self, CallResult::Queued)
}
}
type BoxedFnOnce = Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
pub struct OfflineCircuitBreaker<P: ConnectivityProbe> {
#[allow(dead_code)]
name: String,
probe: P,
queue: Arc<Mutex<Vec<BoxedFnOnce>>>,
}
impl<P: ConnectivityProbe> OfflineCircuitBreaker<P> {
pub fn new(name: impl Into<String>, probe: P) -> Self {
Self {
name: name.into(),
probe,
queue: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn call<F, Fut, T, E>(&self, f: F) -> CallResult<T, E>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send + 'static,
T: Send + 'static,
E: Send + 'static,
{
match self.probe.check().await {
ConnectivityStatus::Online => {
let result = f().await;
CallResult::Executed(result)
}
_ => {
let wrapper: BoxedFnOnce = Box::new(move || {
Box::pin(async move {
let _ = f().await;
})
});
self.queue.lock().await.push(wrapper);
CallResult::Queued
}
}
}
pub async fn queued_count(&self) -> usize {
self.queue.lock().await.len()
}
pub async fn drain(&self) -> Result<(), String> {
let ops: Vec<BoxedFnOnce> = {
let mut q = self.queue.lock().await;
q.drain(..).collect()
};
for op in ops {
op().await;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PendingOperation {
pub id: String,
}
#[derive(Debug, Clone)]
pub struct ReplayReport {
pub replayed: usize,
pub failed: usize,
}
#[derive(Clone)]
pub struct InMemoryQueue {
ops: Arc<Mutex<Vec<PendingOperation>>>,
}
impl InMemoryQueue {
pub fn new() -> Self {
Self {
ops: Arc::new(Mutex::new(Vec::new())),
}
}
async fn push(&self, op: PendingOperation) {
self.ops.lock().await.push(op);
}
async fn drain_all(&self) -> Vec<PendingOperation> {
let mut q = self.ops.lock().await;
q.drain(..).collect()
}
async fn len(&self) -> usize {
self.ops.lock().await.len()
}
async fn peek_all(&self) -> Vec<PendingOperation> {
self.ops.lock().await.clone()
}
}
impl Default for InMemoryQueue {
fn default() -> Self {
Self::new()
}
}
pub struct StoreAndForward<Q = InMemoryQueue, P: ConnectivityProbe = AlwaysOnlineProbe> {
queue: Q,
#[allow(dead_code)]
probe: P,
}
pub struct AlwaysOnlineProbe;
#[async_trait]
impl ConnectivityProbe for AlwaysOnlineProbe {
async fn check(&self) -> ConnectivityStatus {
ConnectivityStatus::Online
}
}
impl StoreAndForward<InMemoryQueue, AlwaysOnlineProbe> {
pub fn default_new() -> Self {
Self {
queue: InMemoryQueue::new(),
probe: AlwaysOnlineProbe,
}
}
}
impl<P: ConnectivityProbe> StoreAndForward<InMemoryQueue, P> {
pub fn new(queue: InMemoryQueue, probe: P) -> Self {
Self { queue, probe }
}
pub async fn execute<F, Fut>(&self, id: &str, f: F)
where
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<(), String>> + Send,
{
let result = f().await;
if result.is_err() {
self.queue
.push(PendingOperation { id: id.to_string() })
.await;
}
}
pub async fn pending_count(&self) -> usize {
self.queue.len().await
}
pub async fn peek_pending(&self) -> Vec<PendingOperation> {
self.queue.peek_all().await
}
pub async fn replay_all<F, Fut>(&self, handler: F) -> Result<ReplayReport, String>
where
F: Fn(String) -> Fut + Send,
Fut: Future<Output = Result<(), String>> + Send,
{
let ops = self.queue.drain_all().await;
let mut replayed = 0;
let mut failed = 0;
for op in ops {
match handler(op.id).await {
Ok(()) => replayed += 1,
Err(_) => failed += 1,
}
}
Ok(ReplayReport { replayed, failed })
}
}