use futures::{
future::FutureExt,
stream::{FuturesUnordered, StreamExt},
};
#[cfg(feature = "metrics")]
use std::sync::Arc;
use std::{future::Future, time::Duration};
use tracing::{Instrument, trace_span};
use crate::errors::{RequestAttemptError, RequestError};
#[cfg(feature = "metrics")]
use crate::observability::metrics::Metrics;
use crate::response::Coordinator;
#[non_exhaustive]
pub struct Context {
#[cfg(feature = "metrics")]
pub metrics: Arc<Metrics>,
}
pub trait SpeculativeExecutionPolicy: std::fmt::Debug + Send + Sync {
fn max_retry_count(&self, context: &Context) -> usize;
fn retry_interval(&self, context: &Context) -> Duration;
}
#[derive(Debug, Clone)]
pub struct SimpleSpeculativeExecutionPolicy {
pub max_retry_count: usize,
pub retry_interval: Duration,
}
#[cfg(feature = "metrics")]
#[derive(Debug, Clone)]
pub struct PercentileSpeculativeExecutionPolicy {
pub max_retry_count: usize,
pub percentile: f64,
}
impl SpeculativeExecutionPolicy for SimpleSpeculativeExecutionPolicy {
fn max_retry_count(&self, _: &Context) -> usize {
self.max_retry_count
}
fn retry_interval(&self, _: &Context) -> Duration {
self.retry_interval
}
}
#[cfg(feature = "metrics")]
impl SpeculativeExecutionPolicy for PercentileSpeculativeExecutionPolicy {
fn max_retry_count(&self, _: &Context) -> usize {
self.max_retry_count
}
fn retry_interval(&self, context: &Context) -> Duration {
let interval = context.metrics.get_latency_percentile_ms(self.percentile);
let ms = match interval {
Ok(d) => d,
Err(e) => {
tracing::warn!(
"Failed to get latency percentile ({}), defaulting to 100 ms",
e
);
100
}
};
Duration::from_millis(ms)
}
}
fn can_be_ignored<ResT>(result: &Result<ResT, RequestError>) -> bool {
match result {
Ok(_) => false,
#[deny(clippy::wildcard_enum_match_arm)]
Err(e) => match e {
RequestError::EmptyPlan => false,
RequestError::RequestTimeout(_) => false,
RequestError::ConnectionPoolError { .. } => true,
RequestError::LastAttemptError(e) => {
#[deny(clippy::wildcard_enum_match_arm)]
match e {
RequestAttemptError::SerializationError(_)
| RequestAttemptError::CqlRequestSerialization(_)
| RequestAttemptError::BodyExtensionsParseError(_)
| RequestAttemptError::CqlResultParseError(_)
| RequestAttemptError::CqlErrorParseError(_)
| RequestAttemptError::UnexpectedResponse(_)
| RequestAttemptError::RepreparedIdChanged { .. }
| RequestAttemptError::RepreparedIdMissingInBatch
| RequestAttemptError::NonfinishedPagingState => false,
RequestAttemptError::BrokenConnectionError(_)
| RequestAttemptError::UnableToAllocStreamId => true,
RequestAttemptError::DbError(db_error, _) => db_error.can_speculative_retry(),
}
}
},
}
}
const EMPTY_PLAN_ERROR: RequestError = RequestError::EmptyPlan;
pub(crate) async fn execute<QueryFut, ResT>(
policy: &dyn SpeculativeExecutionPolicy,
context: &Context,
mut query_runner_generator: impl FnMut(bool) -> QueryFut,
) -> Result<(ResT, Coordinator), RequestError>
where
QueryFut: Future<Output = Option<Result<(ResT, Coordinator), RequestError>>>,
{
let mut retries_remaining = policy.max_retry_count(context);
let retry_interval = policy.retry_interval(context);
let mut async_tasks = FuturesUnordered::new();
async_tasks.push(
query_runner_generator(false)
.instrument(trace_span!("Speculative execution: original query")),
);
let sleep = tokio::time::sleep(retry_interval).fuse();
tokio::pin!(sleep);
let mut last_error = None;
loop {
futures::select! {
_ = &mut sleep => {
if retries_remaining > 0 {
async_tasks.push(query_runner_generator(true).instrument(trace_span!("Speculative execution", retries_remaining = retries_remaining)));
retries_remaining -= 1;
sleep.set(tokio::time::sleep(retry_interval).fuse());
}
}
res = async_tasks.select_next_some() => {
if let Some(r) = res {
if !can_be_ignored(&r) {
return r;
} else {
last_error = Some(r)
}
} else {
retries_remaining = 0;
}
if async_tasks.is_empty() && retries_remaining == 0 {
return last_error.unwrap_or({
Err(EMPTY_PLAN_ERROR)
});
}
}
}
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "metrics")]
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;
use assert_matches::assert_matches;
use crate::errors::{RequestAttemptError, RequestError};
#[cfg(feature = "metrics")]
use crate::observability::metrics::Metrics;
use crate::policies::speculative_execution::{Context, SimpleSpeculativeExecutionPolicy};
use crate::response::Coordinator;
static EMPTY_CONTEXT: LazyLock<Context> = LazyLock::new(|| Context {
#[cfg(feature = "metrics")]
metrics: Arc::new(Metrics::new()),
});
static IGNORABLE_ERROR: Option<Result<((), Coordinator), RequestError>> = Some(Err(
RequestError::LastAttemptError(RequestAttemptError::UnableToAllocStreamId),
));
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn test_exhausted_plan_with_running_fibers() {
let policy = SimpleSpeculativeExecutionPolicy {
max_retry_count: 5,
retry_interval: Duration::from_secs(1),
};
let generator = {
let mut counter = 0;
move |_first: bool| {
let future = {
let fiber_idx = counter;
async move {
match fiber_idx.cmp(&4) {
std::cmp::Ordering::Less => {
tokio::time::sleep(Duration::from_secs(5)).await;
IGNORABLE_ERROR.clone()
}
std::cmp::Ordering::Equal => None,
std::cmp::Ordering::Greater => {
panic!("Too many speculative executions - expected 4")
}
}
}
};
counter += 1;
future
}
};
let now = tokio::time::Instant::now();
let res = super::execute(&policy, &EMPTY_CONTEXT, generator).await;
assert_matches!(
res,
Err(RequestError::LastAttemptError(
RequestAttemptError::UnableToAllocStreamId
))
);
assert_eq!(
tokio::time::Instant::now(),
now.checked_add(Duration::from_secs(8)).unwrap()
)
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn test_exhausted_plan_last_running_fiber() {
let policy = SimpleSpeculativeExecutionPolicy {
max_retry_count: 5,
retry_interval: Duration::from_secs(6),
};
let generator = {
let mut counter = 0;
move |_first: bool| {
let future = {
let fiber_idx = counter;
async move {
match fiber_idx.cmp(&4) {
std::cmp::Ordering::Less => {
tokio::time::sleep(Duration::from_secs(5)).await;
IGNORABLE_ERROR.clone()
}
std::cmp::Ordering::Equal => None,
std::cmp::Ordering::Greater => {
panic!("Too many speculative executions - expected 4")
}
}
}
};
counter += 1;
future
}
};
let now = tokio::time::Instant::now();
let res = super::execute(&policy, &EMPTY_CONTEXT, generator).await;
assert_matches!(
res,
Err(RequestError::LastAttemptError(
RequestAttemptError::UnableToAllocStreamId
))
);
assert_eq!(
tokio::time::Instant::now(),
now.checked_add(Duration::from_secs(24)).unwrap()
)
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn test_se_panic_on_ignorable_errors() {
let policy = SimpleSpeculativeExecutionPolicy {
max_retry_count: 5,
retry_interval: Duration::from_secs(1),
};
let generator = {
move |_first: bool| async move {
tokio::time::sleep(Duration::from_secs(5)).await;
IGNORABLE_ERROR.clone()
}
};
let now = tokio::time::Instant::now();
let res = super::execute(&policy, &EMPTY_CONTEXT, generator).await;
assert_matches!(
res,
Err(RequestError::LastAttemptError(
RequestAttemptError::UnableToAllocStreamId
))
);
assert_eq!(
tokio::time::Instant::now(),
now.checked_add(Duration::from_secs(10)).unwrap()
)
}
}