use std::{collections::BTreeMap, sync::Arc};
use futures_core::Future;
use matrix_sdk_common::SendOutsideWasm;
use tokio::sync::Mutex;
use crate::{Error, Result};
enum QueryState {
Cancelled,
Success,
Failure,
}
type DeduplicatedRequestMap<Key> = Mutex<BTreeMap<Key, Arc<Mutex<QueryState>>>>;
pub(crate) struct DeduplicatingHandler<Key> {
inflight: DeduplicatedRequestMap<Key>,
}
impl<Key> Default for DeduplicatingHandler<Key> {
fn default() -> Self {
Self { inflight: Default::default() }
}
}
impl<Key: Clone + Ord + std::hash::Hash> DeduplicatingHandler<Key> {
pub async fn run<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
&self,
key: Key,
code: F,
) -> Result<()> {
let mut map = self.inflight.lock().await;
if let Some(request_mutex) = map.get(&key).cloned() {
drop(map);
let mut request_guard = request_mutex.lock().await;
return match *request_guard {
QueryState::Success => {
Ok(())
}
QueryState::Failure => {
Err(Error::ConcurrentRequestFailed)
}
QueryState::Cancelled => {
self.run_code(key, code, &mut request_guard).await
}
};
}
let request_mutex = Arc::new(Mutex::new(QueryState::Cancelled));
map.insert(key.clone(), request_mutex.clone());
let mut request_guard = request_mutex.lock().await;
drop(map);
self.run_code(key, code, &mut request_guard).await
}
async fn run_code<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
&self,
key: Key,
code: F,
result: &mut QueryState,
) -> Result<()> {
match code.await {
Ok(()) => {
*result = QueryState::Success;
self.inflight.lock().await.remove(&key);
Ok(())
}
Err(err) => {
*result = QueryState::Failure;
self.inflight.lock().await.remove(&key);
Err(err)
}
}
}
}
#[cfg(all(test, not(target_family = "wasm")))]
mod tests {
use std::sync::Arc;
use matrix_sdk_test::async_test;
use tokio::{join, spawn, sync::Mutex, task::yield_now};
use crate::deduplicating_handler::DeduplicatingHandler;
#[async_test]
async fn test_deduplicating_handler_same_key() -> anyhow::Result<()> {
let num_calls = Arc::new(Mutex::new(0));
let inner = || {
let num_calls_cloned = num_calls.clone();
async move {
yield_now().await;
*num_calls_cloned.lock().await += 1;
yield_now().await;
Ok(())
}
};
let handler = DeduplicatingHandler::default();
let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
assert!(first.is_ok());
assert!(second.is_ok());
assert_eq!(*num_calls.lock().await, 1);
Ok(())
}
#[async_test]
async fn test_deduplicating_handler_different_keys() -> anyhow::Result<()> {
let num_calls = Arc::new(Mutex::new(0));
let inner = || {
let num_calls_cloned = num_calls.clone();
async move {
yield_now().await;
*num_calls_cloned.lock().await += 1;
yield_now().await;
Ok(())
}
};
let handler = DeduplicatingHandler::default();
let (first, second) = join!(handler.run(0, inner()), handler.run(1, inner()));
assert!(first.is_ok());
assert!(second.is_ok());
assert_eq!(*num_calls.lock().await, 2);
Ok(())
}
#[async_test]
async fn test_deduplicating_handler_failure() -> anyhow::Result<()> {
let num_calls = Arc::new(Mutex::new(0));
let inner = || {
let num_calls_cloned = num_calls.clone();
async move {
yield_now().await;
*num_calls_cloned.lock().await += 1;
yield_now().await;
Err(crate::Error::AuthenticationRequired)
}
};
let handler = DeduplicatingHandler::default();
let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
assert!(first.is_err());
assert!(second.is_err());
assert_eq!(*num_calls.lock().await, 1);
let inner = || {
let num_calls_cloned = num_calls.clone();
async move {
*num_calls_cloned.lock().await += 1;
Ok(())
}
};
*num_calls.lock().await = 0;
handler.run(0, inner()).await?;
assert_eq!(*num_calls.lock().await, 1);
Ok(())
}
#[async_test]
async fn test_cancelling_deduplicated_query() -> anyhow::Result<()> {
let allow_progress = Arc::new(Mutex::new(()));
let num_before = Arc::new(Mutex::new(0));
let num_after = Arc::new(Mutex::new(0));
let inner = || {
let num_before = num_before.clone();
let num_after = num_after.clone();
let allow_progress = allow_progress.clone();
async move {
*num_before.lock().await += 1;
let _ = allow_progress.lock().await;
*num_after.lock().await += 1;
Ok(())
}
};
let handler = Arc::new(DeduplicatingHandler::default());
let progress_guard = allow_progress.lock().await;
let first = spawn({
let handler = handler.clone();
let query = inner();
async move { handler.run(0, query).await }
});
let second = spawn({
let handler = handler.clone();
let query = inner();
async move { handler.run(0, query).await }
});
yield_now().await;
assert_eq!(*num_before.lock().await, 1);
assert_eq!(*num_after.lock().await, 0);
first.abort();
assert!(first.await.unwrap_err().is_cancelled());
yield_now().await;
assert_eq!(*num_before.lock().await, 2);
assert_eq!(*num_after.lock().await, 0);
drop(progress_guard);
assert!(second.await.unwrap().is_ok());
assert_eq!(*num_before.lock().await, 2);
assert_eq!(*num_after.lock().await, 1);
Ok(())
}
}