#![doc(html_logo_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/uniflight/logo.png")]
#![doc(html_favicon_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/uniflight/favicon.ico")]
use std::borrow::Borrow;
use std::fmt::Debug;
use std::hash::Hash;
use std::panic::AssertUnwindSafe;
use std::sync::{Arc, Weak};
use ahash::RandomState;
use async_once_cell::OnceCell;
use dashmap::DashMap;
use dashmap::Entry::{Occupied, Vacant};
use futures_util::FutureExt; use thread_aware::affinity::Affinity;
use thread_aware::storage::Strategy;
use thread_aware::{Arc as TaArc, PerCore, PerNuma, PerProcess, ThreadAware};
pub struct Merger<K, T, S: Strategy = PerProcess> {
inner: TaArc<DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, S>,
}
impl<K, T, S: Strategy> Debug for Merger<K, T, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Merger").field("inner", &format_args!("DashMap<...>")).finish()
}
}
impl<K, T, S: Strategy> Clone for Merger<K, T, S> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<K, T, S> Default for Merger<K, T, S>
where
K: Hash + Eq + Send + Sync + 'static,
T: Send + Sync + 'static,
S: Strategy,
{
fn default() -> Self {
Self {
inner: TaArc::new(|| DashMap::with_hasher(RandomState::new())),
}
}
}
impl<K, T, S> Merger<K, T, S>
where
K: Hash + Eq + Send + Sync + 'static,
T: Send + Sync + 'static,
S: Strategy,
{
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl<K, T> Merger<K, T, PerProcess>
where
K: Hash + Eq + Send + Sync + 'static,
T: Send + Sync + 'static,
{
#[inline]
#[must_use]
#[cfg_attr(test, mutants::skip)] pub fn new_per_process() -> Self {
Self::default()
}
}
impl<K, T> Merger<K, T, PerNuma>
where
K: Hash + Eq + Send + Sync + 'static,
T: Send + Sync + 'static,
{
#[inline]
#[must_use]
#[cfg_attr(test, mutants::skip)] pub fn new_per_numa() -> Self {
Self::default()
}
}
impl<K, T> Merger<K, T, PerCore>
where
K: Hash + Eq + Send + Sync + 'static,
T: Send + Sync + 'static,
{
#[inline]
#[must_use]
#[cfg_attr(test, mutants::skip)] pub fn new_per_core() -> Self {
Self::default()
}
}
impl<K, T, S: Strategy> Merger<K, T, S>
where
K: Hash + Eq,
{
#[cfg(test)]
fn len(&self) -> usize {
self.inner.len()
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<K, T, S> ThreadAware for Merger<K, T, S>
where
K: Send + Sync,
T: Send + Sync,
S: Strategy + Send + Sync,
{
#[cfg_attr(test, mutants::skip)]
fn relocate(&mut self, source: Option<Affinity>, destination: Affinity) {
self.inner.relocate(source, destination);
}
}
impl<K, T, S> Merger<K, T, S>
where
K: Hash + Eq + Send + Sync,
T: Send + Sync,
S: Strategy + Send + Sync,
{
pub fn execute<Q, F, Fut>(&self, key: &Q, func: F) -> impl Future<Output = Result<T, LeaderPanicked>> + Send + use<Q, F, Fut, K, T, S>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
F: FnOnce() -> Fut + Send,
Fut: Future<Output = T> + Send,
T: Clone,
{
let inner = self.inner.clone();
let cell = Self::get_or_create_cell(&inner, key);
let owned_key = key.to_owned();
async move {
let boxed = Box::pin(func());
let result = cell.get_or_init(boxed).await.clone();
drop(cell); inner.remove_if(owned_key.borrow(), |_, weak| weak.upgrade().is_none());
result
}
}
fn get_or_create_cell<Q>(map: &DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, key: &Q) -> Arc<PanicAwareCell<T>>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
{
if let Some(entry) = map.get(key)
&& let Some(cell) = entry.value().upgrade()
{
return cell;
}
Self::insert_or_get_existing(map, key)
}
fn insert_or_get_existing<Q>(map: &DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, key: &Q) -> Arc<PanicAwareCell<T>>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
{
let cell = Arc::new(PanicAwareCell::new());
let weak = Arc::downgrade(&cell);
match map.entry(key.to_owned()) {
Occupied(mut entry) => {
if let Some(existing) = entry.get().upgrade() {
return existing;
}
entry.insert(weak);
}
Vacant(entry) => {
entry.insert(weak);
}
}
cell
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeaderPanicked {
message: Arc<str>,
}
impl LeaderPanicked {
#[must_use]
pub fn message(&self) -> &str {
&self.message
}
}
impl std::fmt::Display for LeaderPanicked {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "leader task panicked: {}", self.message)
}
}
impl std::error::Error for LeaderPanicked {}
fn extract_panic_message(payload: &(dyn std::any::Any + Send)) -> Arc<str> {
if let Some(s) = payload.downcast_ref::<&str>() {
return Arc::from(*s);
}
if let Some(s) = payload.downcast_ref::<String>() {
return Arc::from(s.as_str());
}
Arc::from("unknown panic")
}
struct PanicAwareCell<T> {
inner: OnceCell<Result<T, LeaderPanicked>>,
}
impl<T> PanicAwareCell<T> {
fn new() -> Self {
Self { inner: OnceCell::new() }
}
#[expect(clippy::future_not_send, reason = "Send bounds enforced by Merger::execute")]
async fn get_or_init<F>(&self, f: F) -> &Result<T, LeaderPanicked>
where
F: Future<Output = T>,
{
self.inner
.get_or_init(AssertUnwindSafe(f).catch_unwind().map(|result| {
result.map_err(|payload| LeaderPanicked {
message: extract_panic_message(&*payload),
})
}))
.await
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use thread_aware::affinity::pinned_affinities;
use super::*;
#[test]
fn relocated_delegates_to_inner() {
let affinities = pinned_affinities(&[2]);
let source = Some(affinities[0]);
let destination = affinities[1];
let mut merger: Merger<String, String> = Merger::new();
merger.relocate(source, destination);
assert!(merger.is_empty());
}
#[test]
fn fast_path_returns_existing() {
let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
let existing_cell = Arc::new(PanicAwareCell::new());
map.insert("key".to_string(), Arc::downgrade(&existing_cell));
let result = Merger::<String, String>::get_or_create_cell(&map, "key");
assert!(Arc::ptr_eq(&result, &existing_cell));
}
#[test]
fn replaces_expired_entry() {
let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
let expired_weak = Arc::downgrade(&Arc::new(PanicAwareCell::<String>::new()));
map.insert("key".to_string(), expired_weak);
let result = Merger::<String, String>::get_or_create_cell(&map, "key");
let entry = map.get("key").unwrap();
assert!(Arc::ptr_eq(&result, &entry.value().upgrade().unwrap()));
}
#[test]
fn race_returns_existing() {
let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
let other_cell = Arc::new(PanicAwareCell::new());
map.insert("key".to_string(), Arc::downgrade(&other_cell));
let result = Merger::<String, String>::insert_or_get_existing(&map, "key");
assert!(Arc::ptr_eq(&result, &other_cell));
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn cleanup_after_completion() {
let group: Merger<String, String> = Merger::new();
assert!(group.is_empty());
let result = group.execute("key1", || async { "Result".to_string() }).await;
assert_eq!(result, Ok("Result".to_string()));
assert!(group.is_empty(), "Map should be empty after single call completes");
let futures: Vec<_> = (0..10)
.map(|_| {
group.execute("key2", || async {
tokio::time::sleep(Duration::from_millis(50)).await;
"Result".to_string()
})
})
.collect();
assert_eq!(group.len(), 1);
for fut in futures {
assert_eq!(fut.await, Ok("Result".to_string()));
}
assert!(group.is_empty(), "Map should be empty after all concurrent calls complete");
let fut1 = group.execute("a", || async { "A".to_string() });
let fut2 = group.execute("b", || async { "B".to_string() });
let fut3 = group.execute("c", || async { "C".to_string() });
assert_eq!(group.len(), 3);
let (r1, r2, r3) = tokio::join!(fut1, fut2, fut3);
assert_eq!(r1, Ok("A".to_string()));
assert_eq!(r2, Ok("B".to_string()));
assert_eq!(r3, Ok("C".to_string()));
assert!(group.is_empty(), "Map should be empty after all keys complete");
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn catch_unwind_works() {
let result = AssertUnwindSafe(async {
panic!("test panic");
#[expect(unreachable_code, reason = "Required to satisfy return type after panic")]
42i32
})
.catch_unwind()
.await;
assert!(result.is_err(), "catch_unwind should catch the panic");
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn panic_aware_cell_catches_panic() {
let cell = PanicAwareCell::<String>::new();
let result = cell
.get_or_init(async {
panic!("test panic");
#[expect(unreachable_code, reason = "Required to satisfy return type after panic")]
"never".to_string()
})
.await;
let err = result.as_ref().unwrap_err();
assert_eq!(err.message(), "test panic");
}
#[test]
fn extract_panic_message_from_string() {
let payload: Box<dyn std::any::Any + Send> = Box::new(String::from("owned string panic"));
let message = extract_panic_message(&*payload);
assert_eq!(&*message, "owned string panic");
}
#[test]
fn extract_panic_message_unknown_type() {
let payload: Box<dyn std::any::Any + Send> = Box::new(42i32);
let message = extract_panic_message(&*payload);
assert_eq!(&*message, "unknown panic");
}
}