use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use asupersync::Cx;
use crate::error::{McpError, McpErrorCode, McpResult};
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
fn poll_slot<T>(slot: &mut Option<BoxFuture<'_, T>>, cx: &mut Context<'_>) -> Option<T> {
let fut = slot.as_mut()?;
match fut.as_mut().poll(cx) {
Poll::Ready(val) => {
*slot = None; Some(val)
}
Poll::Pending => None,
}
}
pub async fn join_all<T: Send + 'static>(_cx: &Cx, futures: Vec<BoxFuture<'_, T>>) -> Vec<T> {
let len = futures.len();
if len == 0 {
return Vec::new();
}
if len == 1 {
let mut futs = futures;
return vec![futs.remove(0).await];
}
let mut state = JoinAllState {
futures: futures.into_iter().map(Some).collect(),
results: (0..len).map(|_| None).collect(),
remaining: len,
};
std::future::poll_fn(move |cx| state.poll(cx)).await
}
struct JoinAllState<'a, T> {
futures: Vec<Option<BoxFuture<'a, T>>>,
results: Vec<Option<T>>,
remaining: usize,
}
impl<T> JoinAllState<'_, T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Vec<T>> {
for i in 0..self.futures.len() {
if let Some(val) = poll_slot(&mut self.futures[i], cx) {
self.results[i] = Some(val);
self.remaining -= 1;
}
}
if self.remaining == 0 {
let results: Vec<T> = self
.results
.iter_mut()
.map(|slot| slot.take().expect("all futures completed"))
.collect();
Poll::Ready(results)
} else {
Poll::Pending
}
}
}
pub async fn join_all_results<T: Send + 'static>(
cx: &Cx,
futures: Vec<BoxFuture<'_, McpResult<T>>>,
) -> Vec<McpResult<T>> {
join_all(cx, futures).await
}
pub async fn race<T: Send + 'static>(_cx: &Cx, futures: Vec<BoxFuture<'_, T>>) -> McpResult<T> {
if futures.is_empty() {
return Err(McpError::new(
McpErrorCode::InvalidParams,
"race requires at least one future",
));
}
if futures.len() == 1 {
let mut futs = futures;
return Ok(futs.remove(0).await);
}
let mut state = RaceAllState {
futures: futures.into_iter().map(Some).collect(),
};
Ok(std::future::poll_fn(move |cx| state.poll(cx)).await)
}
struct RaceAllState<'a, T> {
futures: Vec<Option<BoxFuture<'a, T>>>,
}
impl<T> RaceAllState<'_, T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
for i in 0..self.futures.len() {
if let Some(val) = poll_slot(&mut self.futures[i], cx) {
self.futures.clear();
return Poll::Ready(val);
}
}
Poll::Pending
}
}
pub async fn race_timeout<T: Send + 'static>(
_cx: &Cx,
timeout: Duration,
futures: Vec<BoxFuture<'_, T>>,
) -> McpResult<T> {
if futures.is_empty() {
return Err(McpError::new(
McpErrorCode::InvalidParams,
"race requires at least one future",
));
}
let deadline = Instant::now() + timeout;
let mut state = RaceTimeoutState {
futures: futures.into_iter().map(Some).collect(),
deadline,
};
std::future::poll_fn(move |cx| state.poll(cx)).await
}
struct RaceTimeoutState<'a, T> {
futures: Vec<Option<BoxFuture<'a, T>>>,
deadline: Instant,
}
impl<T> RaceTimeoutState<'_, T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<McpResult<T>> {
if Instant::now() >= self.deadline {
self.futures.clear();
return Poll::Ready(Err(McpError::new(
McpErrorCode::RequestCancelled,
"operation timed out",
)));
}
for i in 0..self.futures.len() {
if let Some(val) = poll_slot(&mut self.futures[i], cx) {
self.futures.clear();
return Poll::Ready(Ok(val));
}
}
cx.waker().wake_by_ref();
Poll::Pending
}
}
#[derive(Debug)]
pub struct QuorumResult<T> {
pub successes: Vec<T>,
pub quorum_met: bool,
pub failure_count: usize,
}
impl<T> QuorumResult<T> {
#[must_use]
pub fn is_success(&self) -> bool {
self.quorum_met
}
#[must_use]
pub fn into_results(self) -> Option<Vec<T>> {
if self.quorum_met {
Some(self.successes)
} else {
None
}
}
}
pub async fn quorum<T: Send + 'static>(
_cx: &Cx,
required: usize,
futures: Vec<BoxFuture<'_, McpResult<T>>>,
) -> McpResult<QuorumResult<T>> {
let total = futures.len();
if required > total {
return Err(McpError::new(
McpErrorCode::InvalidParams,
format!("quorum requires {required} successes but only {total} futures provided"),
));
}
if required == 0 {
return Ok(QuorumResult {
successes: Vec::new(),
quorum_met: true,
failure_count: 0,
});
}
let mut state = QuorumState {
futures: futures.into_iter().map(Some).collect(),
successes: Vec::with_capacity(required),
failures: 0,
required,
total,
};
std::future::poll_fn(move |cx| state.poll(cx)).await
}
struct QuorumState<'a, T> {
futures: Vec<Option<BoxFuture<'a, McpResult<T>>>>,
successes: Vec<T>,
failures: usize,
required: usize,
total: usize,
}
impl<T> QuorumState<'_, T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<McpResult<QuorumResult<T>>> {
for i in 0..self.futures.len() {
if let Some(result) = poll_slot(&mut self.futures[i], cx) {
match result {
Ok(val) => self.successes.push(val),
Err(_) => self.failures += 1,
}
}
}
let max_allowed_failures = self.total - self.required;
if self.successes.len() >= self.required {
self.futures.clear();
let successes = std::mem::take(&mut self.successes);
return Poll::Ready(Ok(QuorumResult {
successes,
quorum_met: true,
failure_count: self.failures,
}));
}
if self.failures > max_allowed_failures {
self.futures.clear();
let successes = std::mem::take(&mut self.successes);
return Poll::Ready(Ok(QuorumResult {
successes,
quorum_met: false,
failure_count: self.failures,
}));
}
let still_pending = self.futures.iter().any(Option::is_some);
if !still_pending {
let successes = std::mem::take(&mut self.successes);
let quorum_met = successes.len() >= self.required;
return Poll::Ready(Ok(QuorumResult {
successes,
quorum_met,
failure_count: self.failures,
}));
}
Poll::Pending
}
}
pub async fn quorum_timeout<T: Send + 'static>(
_cx: &Cx,
required: usize,
timeout: Duration,
futures: Vec<BoxFuture<'_, McpResult<T>>>,
) -> McpResult<QuorumResult<T>> {
let total = futures.len();
if required > total {
return Err(McpError::new(
McpErrorCode::InvalidParams,
format!("quorum requires {required} successes but only {total} futures provided"),
));
}
if required == 0 {
return Ok(QuorumResult {
successes: Vec::new(),
quorum_met: true,
failure_count: 0,
});
}
let deadline = Instant::now() + timeout;
let mut state = QuorumTimeoutState {
futures: futures.into_iter().map(Some).collect(),
successes: Vec::with_capacity(required),
failures: 0,
required,
total,
deadline,
};
std::future::poll_fn(move |cx| state.poll(cx)).await
}
struct QuorumTimeoutState<'a, T> {
futures: Vec<Option<BoxFuture<'a, McpResult<T>>>>,
successes: Vec<T>,
failures: usize,
required: usize,
total: usize,
deadline: Instant,
}
impl<T> QuorumTimeoutState<'_, T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<McpResult<QuorumResult<T>>> {
if Instant::now() >= self.deadline {
self.futures.clear();
let successes = std::mem::take(&mut self.successes);
let quorum_met = successes.len() >= self.required;
return Poll::Ready(Ok(QuorumResult {
successes,
quorum_met,
failure_count: self.failures,
}));
}
for i in 0..self.futures.len() {
if let Some(result) = poll_slot(&mut self.futures[i], cx) {
match result {
Ok(val) => self.successes.push(val),
Err(_) => self.failures += 1,
}
}
}
let max_allowed_failures = self.total - self.required;
if self.successes.len() >= self.required {
self.futures.clear();
let successes = std::mem::take(&mut self.successes);
return Poll::Ready(Ok(QuorumResult {
successes,
quorum_met: true,
failure_count: self.failures,
}));
}
if self.failures > max_allowed_failures {
self.futures.clear();
let successes = std::mem::take(&mut self.successes);
return Poll::Ready(Ok(QuorumResult {
successes,
quorum_met: false,
failure_count: self.failures,
}));
}
let still_pending = self.futures.iter().any(Option::is_some);
if !still_pending {
let successes = std::mem::take(&mut self.successes);
return Poll::Ready(Ok(QuorumResult {
quorum_met: successes.len() >= self.required,
successes,
failure_count: self.failures,
}));
}
cx.waker().wake_by_ref();
Poll::Pending
}
}
pub async fn first_ok<T: Send + 'static>(
_cx: &Cx,
futures: Vec<BoxFuture<'_, McpResult<T>>>,
) -> McpResult<T> {
if futures.is_empty() {
return Err(McpError::new(
McpErrorCode::InvalidParams,
"first_ok requires at least one future",
));
}
let mut state = FirstOkState {
futures: futures.into_iter().map(Some).collect(),
last_error: None,
};
std::future::poll_fn(move |cx| state.poll(cx)).await
}
struct FirstOkState<'a, T> {
futures: Vec<Option<BoxFuture<'a, McpResult<T>>>>,
last_error: Option<McpError>,
}
impl<T> FirstOkState<'_, T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<McpResult<T>> {
for i in 0..self.futures.len() {
if let Some(result) = poll_slot(&mut self.futures[i], cx) {
match result {
Ok(val) => {
self.futures.clear();
return Poll::Ready(Ok(val));
}
Err(e) => {
self.last_error = Some(e);
}
}
}
}
let still_pending = self.futures.iter().any(Option::is_some);
if !still_pending {
let err = self.last_error.take().unwrap_or_else(|| {
McpError::new(McpErrorCode::InternalError, "all futures failed")
});
return Poll::Ready(Err(err));
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::block_on;
fn make_cx() -> Cx {
Cx::for_testing()
}
#[test]
fn test_join_all_empty() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![];
let results = block_on(join_all(&cx, futures));
assert!(results.is_empty());
}
#[test]
fn test_join_all_single() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![Box::pin(async { 42 })];
let results = block_on(join_all(&cx, futures));
assert_eq!(results, vec![42]);
}
#[test]
fn test_join_all_multiple() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![
Box::pin(async { 1 }),
Box::pin(async { 2 }),
Box::pin(async { 3 }),
];
let results = block_on(join_all(&cx, futures));
assert_eq!(results, vec![1, 2, 3]);
}
#[test]
fn test_race_empty() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![];
let result = block_on(race(&cx, futures));
assert!(result.is_err());
}
#[test]
fn test_race_single() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![Box::pin(async { 42 })];
let result = block_on(race(&cx, futures));
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_quorum_trivial() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> =
vec![Box::pin(async { Ok(1) }), Box::pin(async { Ok(2) })];
let result = block_on(quorum(&cx, 0, futures));
assert!(result.is_ok());
let qr = result.unwrap();
assert!(qr.quorum_met);
assert!(qr.successes.is_empty());
}
#[test]
fn test_quorum_all() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Ok(1) }),
Box::pin(async { Ok(2) }),
Box::pin(async { Ok(3) }),
];
let result = block_on(quorum(&cx, 3, futures));
assert!(result.is_ok());
let qr = result.unwrap();
assert!(qr.quorum_met);
assert_eq!(qr.successes.len(), 3);
}
#[test]
fn test_quorum_partial() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Ok(1) }),
Box::pin(async { Err(McpError::internal_error("fail")) }),
Box::pin(async { Ok(3) }),
];
let result = block_on(quorum(&cx, 2, futures));
assert!(result.is_ok());
let qr = result.unwrap();
assert!(qr.quorum_met);
assert_eq!(qr.successes.len(), 2);
}
#[test]
fn test_quorum_impossible() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![Box::pin(async { Ok(1) })];
let result = block_on(quorum(&cx, 5, futures));
assert!(result.is_err());
}
#[test]
fn test_quorum_insufficient() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Ok(1) }),
Box::pin(async { Err(McpError::internal_error("fail 1")) }),
Box::pin(async { Err(McpError::internal_error("fail 2")) }),
];
let result = block_on(quorum(&cx, 2, futures));
assert!(result.is_ok());
let qr = result.unwrap();
assert!(!qr.quorum_met);
assert_eq!(qr.successes.len(), 1);
}
#[test]
fn test_first_ok_empty() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![];
let result = block_on(first_ok(&cx, futures));
assert!(result.is_err());
}
#[test]
fn test_first_ok_first_succeeds() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> =
vec![Box::pin(async { Ok(1) }), Box::pin(async { Ok(2) })];
let result = block_on(first_ok(&cx, futures));
assert_eq!(result.unwrap(), 1);
}
#[test]
fn test_first_ok_fallback() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Err(McpError::internal_error("fail 1")) }),
Box::pin(async { Ok(2) }),
Box::pin(async { Ok(3) }),
];
let result = block_on(first_ok(&cx, futures));
assert_eq!(result.unwrap(), 2);
}
#[test]
fn test_first_ok_all_fail() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Err(McpError::internal_error("fail 1")) }),
Box::pin(async { Err(McpError::internal_error("fail 2")) }),
];
let result = block_on(first_ok(&cx, futures));
assert!(result.is_err());
}
#[test]
fn join_all_results_collects_ok_and_err() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Ok(1) }),
Box::pin(async { Err(McpError::internal_error("oops")) }),
Box::pin(async { Ok(3) }),
];
let results = block_on(join_all_results(&cx, futures));
assert_eq!(results.len(), 3);
assert_eq!(results[0].as_ref().unwrap(), &1);
assert!(results[1].is_err());
assert_eq!(results[2].as_ref().unwrap(), &3);
}
#[test]
fn race_multiple_returns_first_ready() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![
Box::pin(async { 10 }),
Box::pin(async { 20 }),
Box::pin(async { 30 }),
];
let result = block_on(race(&cx, futures));
assert_eq!(result.unwrap(), 10);
}
#[test]
fn race_timeout_succeeds_within_deadline() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, i32>> = vec![Box::pin(async { 42 })];
let result = block_on(race_timeout(&cx, Duration::from_secs(5), futures));
assert_eq!(result.unwrap(), 42);
let empty: Vec<BoxFuture<'_, i32>> = vec![];
let err = block_on(race_timeout(&cx, Duration::from_secs(5), empty));
assert!(err.is_err());
}
#[test]
fn quorum_timeout_succeeds_within_deadline() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> =
vec![Box::pin(async { Ok(1) }), Box::pin(async { Ok(2) })];
let result = block_on(quorum_timeout(&cx, 2, Duration::from_secs(5), futures));
let qr = result.unwrap();
assert!(qr.quorum_met);
assert_eq!(qr.successes.len(), 2);
}
#[test]
fn quorum_result_is_success_and_into_results() {
let met = QuorumResult {
successes: vec![1, 2],
quorum_met: true,
failure_count: 1,
};
assert!(met.is_success());
let values = met.into_results().unwrap();
assert_eq!(values, vec![1, 2]);
let not_met = QuorumResult {
successes: vec![1],
quorum_met: false,
failure_count: 2,
};
assert!(!not_met.is_success());
assert!(not_met.into_results().is_none());
}
#[test]
fn quorum_result_debug() {
let qr = QuorumResult {
successes: vec![42],
quorum_met: true,
failure_count: 0,
};
let debug = format!("{qr:?}");
assert!(debug.contains("QuorumResult"));
assert!(debug.contains("42"));
assert!(debug.contains("quorum_met: true"));
}
#[test]
fn quorum_all_failures() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Err(McpError::internal_error("fail 1")) }),
Box::pin(async { Err(McpError::internal_error("fail 2")) }),
Box::pin(async { Err(McpError::internal_error("fail 3")) }),
];
let result = block_on(quorum(&cx, 2, futures));
let qr = result.unwrap();
assert!(!qr.quorum_met);
assert!(qr.successes.is_empty());
assert!(qr.failure_count >= 2);
}
#[test]
fn first_ok_all_fail_returns_last_error_message() {
let cx = make_cx();
let futures: Vec<BoxFuture<'_, McpResult<i32>>> = vec![
Box::pin(async { Err(McpError::internal_error("first")) }),
Box::pin(async { Err(McpError::internal_error("last")) }),
];
let err = block_on(first_ok(&cx, futures)).unwrap_err();
assert!(err.message.contains("last"));
}
}