use std::{error, fmt, sync::Arc, time::Duration};
use reifydb_engine::engine::StandardEngine;
use reifydb_runtime::context::clock::Clock;
use reifydb_type::{
error::{Diagnostic, Error},
params::Params,
value::{frame::frame::Frame, identity::IdentityId},
};
use tokio::{task::spawn_blocking, time};
use tracing::warn;
use crate::interceptor::{Operation, RequestContext, RequestInterceptorChain, ResponseContext};
#[derive(Debug)]
pub enum ExecuteError {
Timeout,
Cancelled,
Disconnected,
Engine {
diagnostic: Arc<Diagnostic>,
statement: String,
},
Rejected {
code: String,
message: String,
},
}
impl fmt::Display for ExecuteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ExecuteError::Timeout => write!(f, "Query execution timed out"),
ExecuteError::Cancelled => write!(f, "Query was cancelled"),
ExecuteError::Disconnected => write!(f, "Query stream disconnected unexpectedly"),
ExecuteError::Engine {
diagnostic,
..
} => write!(f, "Engine error: {}", diagnostic.message),
ExecuteError::Rejected {
code,
message,
} => write!(f, "Rejected [{}]: {}", code, message),
}
}
}
impl error::Error for ExecuteError {}
impl From<Error> for ExecuteError {
fn from(err: Error) -> Self {
ExecuteError::Engine {
diagnostic: Arc::new(err.diagnostic()),
statement: String::new(),
}
}
}
pub type ExecuteResult<T> = Result<T, ExecuteError>;
fn retry_on_conflict<F>(mut f: F) -> Result<Vec<Frame>, Error>
where
F: FnMut() -> Result<Vec<Frame>, Error>,
{
let mut last_err = None;
for attempt in 0..3u32 {
match f() {
Ok(frames) => return Ok(frames),
Err(err) if err.code == "TXN_001" => {
warn!(attempt = attempt + 1, "Transaction conflict detected, retrying");
last_err = Some(err);
}
Err(err) => return Err(err),
}
}
Err(last_err.unwrap())
}
async fn raw_query(
engine: StandardEngine,
query: String,
identity: IdentityId,
params: Params,
timeout: Duration,
clock: Clock,
) -> ExecuteResult<(Vec<Frame>, Duration)> {
let task = spawn_blocking(move || -> (Result<Vec<Frame>, Error>, Duration) {
let t = clock.instant();
let r = engine.query_as(identity, &query, params);
(r, t.elapsed())
});
match time::timeout(timeout, task).await {
Err(_elapsed) => Err(ExecuteError::Timeout),
Ok(Ok((result, compute))) => result.map(|f| (f, compute)).map_err(ExecuteError::from),
Ok(Err(_join_error)) => Err(ExecuteError::Cancelled),
}
}
async fn raw_command(
engine: StandardEngine,
statements: String,
identity: IdentityId,
params: Params,
timeout: Duration,
clock: Clock,
) -> ExecuteResult<(Vec<Frame>, Duration)> {
let task = spawn_blocking(move || -> (Result<Vec<Frame>, Error>, Duration) {
let t = clock.instant();
let r = retry_on_conflict(|| engine.command_as(identity, &statements, params.clone()));
(r, t.elapsed())
});
match time::timeout(timeout, task).await {
Err(_elapsed) => Err(ExecuteError::Timeout),
Ok(Ok((result, compute))) => result.map(|f| (f, compute)).map_err(ExecuteError::from),
Ok(Err(_join_error)) => Err(ExecuteError::Cancelled),
}
}
async fn raw_admin(
engine: StandardEngine,
statements: String,
identity: IdentityId,
params: Params,
timeout: Duration,
clock: Clock,
) -> ExecuteResult<(Vec<Frame>, Duration)> {
let task = spawn_blocking(move || -> (Result<Vec<Frame>, Error>, Duration) {
let t = clock.instant();
let r = retry_on_conflict(|| engine.admin_as(identity, &statements, params.clone()));
(r, t.elapsed())
});
match time::timeout(timeout, task).await {
Err(_elapsed) => Err(ExecuteError::Timeout),
Ok(Ok((result, compute))) => result.map(|f| (f, compute)).map_err(ExecuteError::from),
Ok(Err(_join_error)) => Err(ExecuteError::Cancelled),
}
}
async fn raw_subscription(
engine: StandardEngine,
statement: String,
identity: IdentityId,
params: Params,
timeout: Duration,
clock: Clock,
) -> ExecuteResult<(Vec<Frame>, Duration)> {
let task = spawn_blocking(move || -> (Result<Vec<Frame>, Error>, Duration) {
let t = clock.instant();
let r = retry_on_conflict(|| engine.subscribe_as(identity, &statement, params.clone()));
(r, t.elapsed())
});
match time::timeout(timeout, task).await {
Err(_elapsed) => Err(ExecuteError::Timeout),
Ok(Ok((result, compute))) => result.map(|f| (f, compute)).map_err(ExecuteError::from),
Ok(Err(_join_error)) => Err(ExecuteError::Cancelled),
}
}
pub async fn execute(
chain: &RequestInterceptorChain,
engine: StandardEngine,
mut ctx: RequestContext,
timeout: Duration,
clock: &Clock,
) -> ExecuteResult<(Vec<Frame>, Duration)> {
if !chain.is_empty() {
chain.pre_execute(&mut ctx).await?;
}
let start = clock.instant();
let operation = ctx.operation;
let combined = ctx.statements.join("; ");
let response_parts = if !chain.is_empty() {
Some((ctx.identity, ctx.statements, ctx.params.clone(), ctx.metadata))
} else {
None
};
let result = match operation {
Operation::Query => raw_query(engine, combined, ctx.identity, ctx.params, timeout, clock.clone()).await,
Operation::Command => {
raw_command(engine, combined, ctx.identity, ctx.params, timeout, clock.clone()).await
}
Operation::Admin => raw_admin(engine, combined, ctx.identity, ctx.params, timeout, clock.clone()).await,
Operation::Subscribe => {
raw_subscription(engine, combined, ctx.identity, ctx.params, timeout, clock.clone()).await
}
};
let duration = start.elapsed();
let (result, compute_duration) = match result {
Ok((frames, cd)) => (Ok(frames), cd),
Err(e) => (Err(e), duration),
};
if let Some((identity, statements, params, metadata)) = response_parts {
let response_ctx = ResponseContext {
identity,
operation,
statements,
params,
metadata,
result: match &result {
Ok(frames) => Ok(frames.len()),
Err(e) => Err(e.to_string()),
},
duration,
compute_duration,
};
chain.post_execute(&response_ctx).await;
}
result.map(|frames| (frames, duration))
}