use crate::postgres_dao::ddl::ADMIN_DB_NAME;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use concepts::{
ComponentId, ComponentRetryConfig, ComponentType, ExecutionId, FunctionFqn, JoinSetId,
StrVariant, SupportedFunctionReturnValue,
component_id::{ComponentDigest, Digest},
prefixed_ulid::{DelayId, DeploymentId, ExecutionIdDerived, ExecutorId, RunId},
storage::{
AppendBatchResponse, AppendDelayResponseOutcome, AppendEventsToExecution, AppendRequest,
AppendResponse, AppendResponseToExecution, BacktraceFilter, BacktraceInfo,
ComponentMetadataRecord, ComponentUpgradeOutcome, ComponentUpgradeReason, CreateRequest,
DUMMY_CREATED, DUMMY_HISTORY_EVENT, DbConnection, DbErrorGeneric, DbErrorRead,
DbErrorReadWithTimeout, DbErrorStubResponse, DbErrorWrite, DbErrorWriteNonRetriable,
DbExecutor, DbExternalApi, DbPool, DbPoolCloseable, DeploymentComponentDetail,
DeploymentComponentRecord, DeploymentRecord, DeploymentState, DeploymentStatus,
ExecutionEvent, ExecutionListPagination, ExecutionRequest, ExecutionWithState,
ExecutionWithStateRequestsResponses, ExpiredDelay, ExpiredLock, ExpiredTimer,
HISTORY_EVENT_TYPE_JOIN_NEXT, HistoryEvent, JoinSetRequest, JoinSetResponse,
JoinSetResponseEvent, JoinSetResponseEventOuter, ListExecutionEventsResponse,
ListExecutionsFilter, ListLogsResponse, ListResponsesResponse, LockPendingResponse, Locked,
LockedBy, LockedExecution, LogEntry, LogEntryRow, LogFilter, LogInfoAppendRow, LogLevel,
LogStreamType, Pagination, PendingState, PendingStateBlockedByJoinSet,
PendingStateFinishedResultKind, PendingStateMergedPause, ResponseCursor,
ResponseWithCursor, STATE_BLOCKED_BY_JOIN_SET, STATE_FINISHED, STATE_LOCKED,
STATE_PENDING_AT, TimeoutOutcome, Unlocked, Version, VersionType, WasmBacktrace,
},
};
use db_common::{
AppendNotifier, CombinedState, CombinedStateDTO, NotifierExecutionFinished, NotifierPendingAt,
PendingFfqnSubscribersHolder,
};
use deadpool_postgres::{Client, ManagerConfig, Pool, RecyclingMethod};
use hashbrown::HashMap;
use secrecy::{ExposeSecret as _, SecretString};
use sha2::{Digest as _, Sha256};
use std::{collections::VecDeque, pin::Pin, str::FromStr as _, sync::Arc, time::Duration};
use std::{fmt::Write as _, panic::Location};
use strum::IntoEnumIterator as _;
use tokio::sync::{mpsc, oneshot};
use tokio_postgres::{
NoTls, Row, Transaction,
row::RowIndex,
types::{FromSql, Json, ToSql},
};
use tracing::{Level, debug, error, info, instrument, trace, warn};
use tracing_error::SpanTrace;
#[track_caller]
fn get<'a, T: FromSql<'a>, I: RowIndex + std::fmt::Display + Copy>(
row: &'a Row,
name: I,
) -> Result<T, DbErrorGeneric> {
match row.try_get(name) {
Ok(ok) => Ok(ok),
Err(err) => {
Err(consistency_db_err(format!(
"Failed to retrieve column '{name}': {err:?}"
)))
}
}
}
mod ddl {
pub const ADMIN_DB_NAME: &str = "postgres";
}
mod embedded {
refinery::embed_migrations!("migrations");
}
#[derive(Debug, Clone)]
pub struct PostgresConfig {
pub host: String,
pub user: String,
pub password: SecretString,
pub db_name: String,
}
#[derive(Debug, thiserror::Error)]
#[error("initialization error")]
pub struct InitializationError;
async fn create_database(
config: &PostgresConfig,
provision_policy: ProvisionPolicy,
) -> Result<DbInitialzationOutcome, InitializationError> {
let mut admin_cfg = deadpool_postgres::Config::new();
admin_cfg.host = Some(config.host.clone());
admin_cfg.user = Some(config.user.clone());
admin_cfg.password = Some(config.password.expose_secret().to_string());
admin_cfg.dbname = Some(ADMIN_DB_NAME.into());
admin_cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
let admin_pool = admin_cfg.create_pool(None, NoTls).map_err(|err| {
error!("Cannot create the default pool - {err:?}");
InitializationError
})?;
let client = admin_pool.get().await.map_err(|err| {
error!("Cannot get a connection from the default pool - {err:?}");
InitializationError
})?;
let row = client
.query_opt(
&format!(
"SELECT 1 FROM pg_database WHERE datname = '{}'",
config.db_name
),
&[],
)
.await
.map_err(|err| {
error!("Cannot select from the default database - {err:?}");
InitializationError
})?;
match (row, provision_policy) {
(None, ProvisionPolicy::MustCreate | ProvisionPolicy::Auto) => {
client
.execute(&format!("CREATE DATABASE {}", config.db_name), &[])
.await
.map_err(|err| {
error!("Cannot create the database - {err:?}");
InitializationError
})?;
info!("Database '{}' created.", config.db_name);
Ok(DbInitialzationOutcome::Created)
}
(Some(_), ProvisionPolicy::Auto) => {
info!("Database '{}' exists.", config.db_name);
Ok(DbInitialzationOutcome::Existing)
}
(Some(_), ProvisionPolicy::MustCreate) => {
warn!("Database '{}' already exists.", config.db_name);
Err(InitializationError)
}
(_, ProvisionPolicy::NeverCreate) => unreachable!("checked by the caller"),
}
}
type ResponseSubscribers =
Arc<std::sync::Mutex<HashMap<ExecutionId, (oneshot::Sender<ResponseWithCursor>, u64)>>>;
type PendingSubscribers = Arc<std::sync::Mutex<PendingFfqnSubscribersHolder>>;
type ExecutionFinishedSubscribers = std::sync::Mutex<
HashMap<ExecutionId, HashMap<u64, oneshot::Sender<SupportedFunctionReturnValue>>>,
>;
pub struct PostgresPool {
pool: Pool,
response_subscribers: ResponseSubscribers,
pending_subscribers: PendingSubscribers,
execution_finished_subscribers: Arc<ExecutionFinishedSubscribers>,
pub config: PostgresConfig,
}
#[async_trait]
impl DbPool for PostgresPool {
async fn db_exec_conn(&self) -> Result<Box<dyn DbExecutor>, DbErrorGeneric> {
let client = self.pool.get().await?;
Ok(Box::new(PostgresConnection {
client: tokio::sync::Mutex::new(client),
response_subscribers: self.response_subscribers.clone(),
pending_subscribers: self.pending_subscribers.clone(),
execution_finished_subscribers: self.execution_finished_subscribers.clone(),
}))
}
async fn connection(&self) -> Result<Box<dyn DbConnection>, DbErrorGeneric> {
let client = self.pool.get().await?;
Ok(Box::new(PostgresConnection {
client: tokio::sync::Mutex::new(client),
response_subscribers: self.response_subscribers.clone(),
pending_subscribers: self.pending_subscribers.clone(),
execution_finished_subscribers: self.execution_finished_subscribers.clone(),
}))
}
async fn external_api_conn(&self) -> Result<Box<dyn DbExternalApi>, DbErrorGeneric> {
let client = self.pool.get().await?;
Ok(Box::new(PostgresConnection {
client: tokio::sync::Mutex::new(client),
response_subscribers: self.response_subscribers.clone(),
pending_subscribers: self.pending_subscribers.clone(),
execution_finished_subscribers: self.execution_finished_subscribers.clone(),
}))
}
#[cfg(feature = "test")]
async fn connection_test(
&self,
) -> Result<Box<dyn concepts::storage::DbConnectionTest>, DbErrorGeneric> {
let client = self.pool.get().await?;
Ok(Box::new(PostgresConnection {
client: tokio::sync::Mutex::new(client),
response_subscribers: self.response_subscribers.clone(),
pending_subscribers: self.pending_subscribers.clone(),
execution_finished_subscribers: self.execution_finished_subscribers.clone(),
}))
}
}
pub struct PostgresConnection {
client: tokio::sync::Mutex<Client>, response_subscribers: ResponseSubscribers,
pending_subscribers: PendingSubscribers,
execution_finished_subscribers: Arc<ExecutionFinishedSubscribers>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProvisionPolicy {
NeverCreate,
Auto,
MustCreate,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DbInitialzationOutcome {
Created,
Existing,
}
impl PostgresPool {
#[instrument(skip_all, name = "postgres_new")]
pub async fn new(
config: PostgresConfig,
provision_policy: ProvisionPolicy,
) -> Result<PostgresPool, InitializationError> {
Self::new_with_outcome(config, provision_policy)
.await
.map(|(db, _)| db)
}
pub async fn new_with_outcome(
config: PostgresConfig,
provision_policy: ProvisionPolicy,
) -> Result<(PostgresPool, DbInitialzationOutcome), InitializationError> {
let outcome = if matches!(
provision_policy,
ProvisionPolicy::Auto | ProvisionPolicy::MustCreate
) {
create_database(&config, provision_policy).await?
} else {
DbInitialzationOutcome::Existing
};
let mut cfg = deadpool_postgres::Config::new();
cfg.host = Some(config.host.clone());
cfg.user = Some(config.user.clone());
cfg.password = Some(config.password.expose_secret().to_string());
cfg.dbname = Some(config.db_name.clone());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
let pool = cfg.create_pool(None, NoTls).map_err(|err| {
error!("Cannot create the database pool - {err:?}");
InitializationError
})?;
let mut client = pool.get().await.map_err(|err| {
error!("Cannot get a connection from the database pool - {err:?}");
InitializationError
})?;
embedded::migrations::runner()
.run_async(&mut **client)
.await
.map_err(|err| {
error!("Cannot run migrations - {err:?}");
InitializationError
})?;
debug!("Database schema initialized.");
Ok((
PostgresPool {
pool,
execution_finished_subscribers: Arc::default(),
pending_subscribers: Arc::default(),
response_subscribers: Arc::default(),
config,
},
outcome,
))
}
}
fn deployment_record_from_pg_row(row: &Row) -> Result<DeploymentRecord, DbErrorRead> {
let deployment_id_str: String = get(row, "deployment_id")?;
let deployment_id = deployment_id_str.parse::<DeploymentId>().map_err(|e| {
DbErrorRead::Generic(consistency_db_err(format!("invalid deployment_id: {e}")))
})?;
let status_str: String = get(row, "status")?;
let status = status_str
.parse::<DeploymentStatus>()
.map_err(|e| DbErrorRead::Generic(consistency_db_err(format!("invalid status: {e}"))))?;
Ok(DeploymentRecord {
deployment_id,
created_at: get(row, "created_at")?,
last_active_at: get(row, "last_active_at")?,
status,
config_json: get(row, "config_json")?,
obelisk_version: get(row, "obelisk_version")?,
created_by: get(row, "created_by")?,
})
}
fn deployment_component_detail_from_pg_row(
row: &Row,
) -> Result<DeploymentComponentDetail, DbErrorRead> {
let component_name: String = get(row, "component_name")?;
let component_type: ComponentType =
get::<String, _>(row, "component_type")?
.parse()
.map_err(|err| {
DbErrorRead::Generic(consistency_db_err(format!("invalid component_type: {err}")))
})?;
let component_digest = ComponentDigest(Digest(
get::<Vec<u8>, _>(row, "component_digest")?
.try_into()
.map_err(|_| {
DbErrorRead::Generic(consistency_db_err("invalid component_digest length"))
})?,
));
let component_id = ComponentId::new(
component_type,
StrVariant::from(component_name),
component_digest,
)
.map_err(|err| DbErrorRead::Generic(consistency_db_err(err.to_string())))?;
let imports =
get::<Json<Vec<concepts::storage::PersistedFunctionMetadata>>, _>(row, "imports_json")?.0;
let exports =
get::<Json<Vec<concepts::storage::PersistedFunctionMetadata>>, _>(row, "exports_json")?.0;
let wit: String = get(row, "wit")?;
Ok(DeploymentComponentDetail {
component_id,
imports,
exports,
wit,
})
}
#[track_caller]
fn consistency_db_err(reason: impl Into<StrVariant>) -> DbErrorGeneric {
DbErrorGeneric::Uncategorized {
reason: reason.into(),
context: SpanTrace::capture(),
source: None,
loc: Location::caller(),
}
}
#[track_caller]
fn consistency_db_err_src(
reason: impl Into<StrVariant>,
source: Arc<dyn std::error::Error + Send + Sync>,
) -> DbErrorGeneric {
DbErrorGeneric::Uncategorized {
reason: reason.into(),
context: SpanTrace::capture(),
source: Some(source),
loc: Location::caller(),
}
}
#[derive(Debug, Clone)]
struct DelayReq {
join_set_id: JoinSetId,
delay_id: DelayId,
expires_at: DateTime<Utc>,
paused: bool,
}
async fn fetch_created_event(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
) -> Result<CreateRequest, DbErrorRead> {
let stmt = "SELECT created_at, json_value FROM t_execution_log WHERE \
execution_id = $1 AND version = 0";
let row = tx.query_one(stmt, &[&execution_id.to_string()]).await?;
let created_at = get(&row, "created_at")?;
let event: Json<ExecutionRequest> = get(&row, "json_value")?;
let event = event.0;
if let ExecutionRequest::Created {
ffqn,
params,
parent,
scheduled_at,
component_id,
deployment_id,
metadata,
scheduled_by,
} = event
{
Ok(CreateRequest {
created_at,
execution_id: execution_id.clone(),
ffqn,
params,
parent,
scheduled_at,
component_id,
deployment_id,
metadata,
scheduled_by,
paused: false,
})
} else {
error!("Row with version=0 must be a `Created` event - {event:?}");
Err(consistency_db_err("expected `Created` event").into())
}
}
fn check_expected_next_and_appending_version(
expected_version: &Version,
appending_version: &Version,
) -> Result<(), DbErrorWrite> {
if *expected_version != *appending_version {
debug!(
"Version conflict - expected: {expected_version:?}, appending: {appending_version:?}"
);
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::VersionConflict {
expected: expected_version.clone(),
requested: appending_version.clone(),
},
));
}
Ok(())
}
#[instrument(level = Level::DEBUG, skip_all, fields(execution_id = %req.execution_id))]
async fn create_inner(
tx: &Transaction<'_>,
req: CreateRequest,
) -> Result<(AppendResponse, AppendNotifier), DbErrorWrite> {
trace!("create_inner");
let version = Version::default();
let execution_id = req.execution_id.clone();
let execution_id_str = execution_id.to_string();
let ffqn = req.ffqn.clone();
let created_at = req.created_at;
let scheduled_at = req.scheduled_at;
let component_id = req.component_id.clone();
let deployment_id = req.deployment_id;
let paused = req.paused;
let event = ExecutionRequest::from(req);
let event = Json(event);
tx.execute(
"INSERT INTO t_execution_log (
execution_id, created_at, version, json_value, variant, join_set_id
) VALUES ($1, $2, $3, $4, $5, $6)",
&[
&execution_id_str,
&created_at,
&i64::from(version.0), &event,
&event.0.variant(),
&event.0.join_set_id().map(std::string::ToString::to_string),
],
)
.await?;
let pending_at = {
debug!("Creating with `Pending(`{scheduled_at:?}`)");
tx.execute(
r"
INSERT INTO t_state (
execution_id,
is_top_level,
corresponding_version,
pending_expires_finished,
ffqn,
state,
created_at,
component_id_input_digest,
component_type,
deployment_id,
first_scheduled_at,
updated_at,
intermittent_event_count,
is_paused
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, CURRENT_TIMESTAMP, 0, $12
)",
&[
&execution_id_str,
&execution_id.is_top_level(),
&i64::from(version.0),
&scheduled_at,
&ffqn.to_string(),
&STATE_PENDING_AT,
&created_at,
&component_id.component_digest.as_slice(),
&component_id.component_type.to_string(),
&deployment_id.to_string(),
&scheduled_at,
&false,
], )
.await?;
AppendNotifier {
pending_at: if paused {
None
} else {
Some(NotifierPendingAt {
scheduled_at,
ffqn: ffqn.clone(),
component_input_digest: component_id.component_digest,
})
},
execution_finished: None,
response: None,
}
};
let mut next_version = Version::new(version.0 + 1);
if paused {
let (v, _) = append(
tx,
&execution_id,
AppendRequest {
created_at,
event: ExecutionRequest::Paused,
},
next_version,
)
.await?;
next_version = v;
}
Ok((next_version, pending_at))
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %scheduled_at))]
async fn update_state_pending_after_response_appended(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
scheduled_at: DateTime<Utc>, component_input_digest: ComponentDigest,
) -> Result<AppendNotifier, DbErrorWrite> {
debug!("Setting t_state to Pending(`{scheduled_at:?}`) after response appended");
let execution_id_str = execution_id.to_string();
let updated = tx
.execute(
r"
UPDATE t_state
SET
pending_expires_finished = $1,
state = $2,
updated_at = CURRENT_TIMESTAMP,
max_retries = NULL,
retry_exp_backoff_millis = NULL,
last_lock_version = NULL,
join_set_id = NULL,
join_set_closing = NULL,
result_kind = NULL
WHERE execution_id = $3
",
&[
&scheduled_at, &STATE_PENDING_AT, &execution_id_str, ],
)
.await?;
if updated == 0 {
return Err(DbErrorWrite::NotFound);
}
Ok(AppendNotifier {
pending_at: Some(NotifierPendingAt {
scheduled_at,
ffqn: fetch_created_event(tx, execution_id).await?.ffqn,
component_input_digest,
}),
execution_finished: None,
response: None,
})
}
struct PendingAfterEventUpdate {
scheduled_at: DateTime<Utc>,
intermittent_failure: bool,
component_input_digest: ComponentDigest,
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, scheduled_at = %update.scheduled_at, %appending_version))]
async fn update_state_pending_after_event_appended(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
appending_version: &Version,
update: PendingAfterEventUpdate,
) -> Result<(AppendResponse, AppendNotifier), DbErrorWrite> {
let scheduled_at = update.scheduled_at;
debug!("Setting t_state to Pending(`{scheduled_at:?}`) after event appended");
let intermittent_delta = i64::from(update.intermittent_failure); let sql = r"
UPDATE t_state
SET
corresponding_version = $1,
pending_expires_finished = $2,
state = $3,
updated_at = CURRENT_TIMESTAMP,
intermittent_event_count = intermittent_event_count + $4,
max_retries = NULL,
retry_exp_backoff_millis = NULL,
last_lock_version = NULL,
join_set_id = NULL,
join_set_closing = NULL,
result_kind = NULL
WHERE execution_id = $5;
";
let updated = tx
.execute(
sql,
&[
&i64::from(appending_version.0),
&scheduled_at,
&STATE_PENDING_AT,
&intermittent_delta,
&execution_id.to_string(),
],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
Ok((
appending_version.increment(),
AppendNotifier {
pending_at: Some(NotifierPendingAt {
scheduled_at,
ffqn: fetch_created_event(tx, execution_id).await?.ffqn,
component_input_digest: update.component_input_digest,
}),
execution_finished: None,
response: None,
},
))
}
async fn update_state_component_upgrade_finished_success(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
component_digest: &ComponentDigest,
deployment_id: DeploymentId,
appending_version: &Version,
) -> Result<AppendResponse, DbErrorWrite> {
debug!("Updating t_state to component {component_digest}");
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
updated_at = CURRENT_TIMESTAMP,
component_id_input_digest = $2,
deployment_id = $3,
incompatible_digest = NULL
WHERE execution_id = $4;
",
&[
&i64::from(appending_version.0),
&component_digest.as_slice(),
&deployment_id.to_string(),
&execution_id.to_string(),
],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
Ok(appending_version.increment())
}
async fn update_state_unlocked_from_locked(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
component_digest: ComponentDigest,
scheduled_at: DateTime<Utc>,
appending_version: &Version,
) -> Result<(AppendResponse, AppendNotifier), DbErrorWrite> {
update_state_pending_after_event_appended(
tx,
execution_id,
appending_version,
PendingAfterEventUpdate {
scheduled_at,
intermittent_failure: false,
component_input_digest: component_digest,
},
)
.await
}
async fn update_state_component_upgrade_finished_failed(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
target_digest: &ComponentDigest,
appending_version: &Version,
) -> Result<AppendResponse, DbErrorWrite> {
debug!("Marking component {target_digest} incompatible after upgrade failure");
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
updated_at = CURRENT_TIMESTAMP,
incompatible_digest = $2
WHERE execution_id = $3;
",
&[
&i64::from(appending_version.0),
&target_digest.as_slice(),
&execution_id.to_string(),
],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
Ok(appending_version.increment())
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %appending_version))]
#[expect(clippy::too_many_arguments)]
async fn update_state_locked_get_intermittent_event_count(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
deployment_id: DeploymentId,
component_digest: Option<&ComponentDigest>,
executor_id: ExecutorId,
run_id: RunId,
lock_expires_at: DateTime<Utc>,
appending_version: &Version,
retry_config: ComponentRetryConfig,
) -> Result<u32, DbErrorWrite> {
debug!("Setting t_state to Locked(`{lock_expires_at:?}`)");
let backoff_millis =
i64::try_from(retry_config.retry_exp_backoff.as_millis()).map_err(|err| {
DbErrorGeneric::Uncategorized {
reason: "backoff too big".into(),
context: SpanTrace::capture(),
source: Some(Arc::new(err)),
loc: Location::caller(),
}
})?;
let execution_id_str = execution_id.to_string();
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
pending_expires_finished = $2,
state = $3,
updated_at = CURRENT_TIMESTAMP,
deployment_id = $4,
component_id_input_digest = COALESCE($5, component_id_input_digest),
max_retries = $6,
retry_exp_backoff_millis = $7,
last_lock_version = $1,
executor_id = $8,
run_id = $9,
join_set_id = NULL,
join_set_closing = NULL,
result_kind = NULL
WHERE execution_id = $10
AND is_paused = false
",
&[
&i64::from(appending_version.0),
&lock_expires_at,
&STATE_LOCKED,
&deployment_id.to_string(),
&component_digest.map(|component_digest| component_digest.as_slice().to_vec()), &retry_config.max_retries.map(i64::from),
&backoff_millis,
&executor_id.to_string(),
&run_id.to_string(),
&execution_id_str,
],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
let row = tx
.query_one(
"SELECT intermittent_event_count FROM t_state WHERE execution_id = $1",
&[&execution_id_str],
)
.await
.map_err(DbErrorGeneric::from)?;
let count: i64 = get(&row, "intermittent_event_count")?; let count = u32::try_from(count)
.map_err(|_| consistency_db_err("`intermittent_event_count` must not be negative"))?;
Ok(count)
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %appending_version))]
async fn update_state_blocked(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
appending_version: &Version,
join_set_id: &JoinSetId,
lock_expires_at: DateTime<Utc>,
join_set_closing: bool,
) -> Result<AppendResponse, DbErrorWrite> {
debug!("Setting t_state to BlockedByJoinSet(`{join_set_id}`)");
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
pending_expires_finished = $2,
state = $3,
updated_at = CURRENT_TIMESTAMP,
max_retries = NULL,
retry_exp_backoff_millis = NULL,
last_lock_version = NULL,
join_set_id = $4,
join_set_closing = $5,
result_kind = NULL
WHERE execution_id = $6
",
&[
&i64::from(appending_version.0), &lock_expires_at, &STATE_BLOCKED_BY_JOIN_SET, &join_set_id.to_string(), &join_set_closing, &execution_id.to_string(), ],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
Ok(appending_version.increment())
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %appending_version))]
async fn update_state_finished(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
appending_version: &Version,
finished_at: DateTime<Utc>,
result_kind: PendingStateFinishedResultKind,
) -> Result<(), DbErrorWrite> {
debug!("Setting t_state to Finished");
let result_kind_json = Json(result_kind);
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
pending_expires_finished = $2,
state = $3,
updated_at = CURRENT_TIMESTAMP,
max_retries = NULL,
retry_exp_backoff_millis = NULL,
last_lock_version = NULL,
executor_id = NULL,
run_id = NULL,
join_set_id = NULL,
join_set_closing = NULL,
is_paused = false,
result_kind = $4
WHERE execution_id = $5
",
&[
&i64::from(appending_version.0), &finished_at, &STATE_FINISHED, &result_kind_json, &execution_id.to_string(), ],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
Ok(())
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %appending_version, %is_paused))]
async fn update_state_paused(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
appending_version: &Version,
is_paused: bool,
) -> Result<AppendResponse, DbErrorWrite> {
debug!(
"Setting t_state to {}",
if is_paused { "paused" } else { "unpaused" }
);
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
is_paused = $2,
updated_at = CURRENT_TIMESTAMP
WHERE execution_id = $3
",
&[
&i64::from(appending_version.0), &is_paused, &execution_id.to_string(), ],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
Ok(appending_version.increment())
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %appending_version))]
async fn bump_state_next_version(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
appending_version: &Version,
delay_req: Option<DelayReq>,
) -> Result<AppendResponse, DbErrorWrite> {
debug!("update_index_version");
let execution_id_str = execution_id.to_string();
let updated = tx
.execute(
r"
UPDATE t_state
SET
corresponding_version = $1,
updated_at = CURRENT_TIMESTAMP
WHERE execution_id = $2
",
&[
&i64::from(appending_version.0), &execution_id_str, ],
)
.await?;
if updated != 1 {
return Err(DbErrorWrite::NotFound);
}
if let Some(DelayReq {
join_set_id,
delay_id,
expires_at,
paused,
}) = delay_req
{
debug!("Inserting delay to `t_delay`");
tx.execute(
"INSERT INTO t_delay (execution_id, join_set_id, delay_id, expires_at, is_paused) VALUES ($1, $2, $3, $4, $5)",
&[
&execution_id_str,
&join_set_id.to_string(),
&delay_id.to_string(),
&expires_at,
&paused,
],
)
.await?;
}
Ok(appending_version.increment())
}
async fn get_combined_state(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
) -> Result<CombinedState, DbErrorRead> {
let row = tx
.query_one(
r"
SELECT
created_at, first_scheduled_at,
state, ffqn, component_id_input_digest, component_type, deployment_id, corresponding_version, pending_expires_finished,
last_lock_version, executor_id, run_id,
join_set_id, join_set_closing,
result_kind, is_paused
FROM t_state
WHERE execution_id = $1
",
&[&execution_id.to_string()],
)
.await
.map_err(DbErrorRead::from)?;
let created_at: DateTime<Utc> = get(&row, "created_at")?;
let first_scheduled_at: DateTime<Utc> = get(&row, "first_scheduled_at")?;
let digest_bytes: Vec<u8> = get(&row, "component_id_input_digest")?;
let digest = Digest::try_from(digest_bytes.as_slice()).map_err(|err| {
consistency_db_err_src("cannot parse `component_id_input_digest`", Arc::from(err))
})?;
let component_digest = ComponentDigest(digest);
let component_type: String = get(&row, "component_type")?;
let component_type = ComponentType::from_str(&component_type)
.map_err(|err| consistency_db_err_src("cannot parse `component_type`", Arc::from(err)))?;
let deployment_id: String = get(&row, "deployment_id")?;
let deployment_id = DeploymentId::from_str(&deployment_id).map_err(DbErrorGeneric::from)?;
let state: String = get(&row, "state")?;
let ffqn: String = get(&row, "ffqn")?;
let ffqn = FunctionFqn::from_str(&ffqn).map_err(|parse_err| {
consistency_db_err(format!("invalid ffqn value in `t_state` - {parse_err}"))
})?;
let pending_expires_finished: DateTime<Utc> = get(&row, "pending_expires_finished")?;
let last_lock_version_raw: Option<i64> = get(&row, "last_lock_version")?;
let last_lock_version = last_lock_version_raw
.map(Version::try_from)
.transpose()
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let executor_id_raw: Option<String> = get(&row, "executor_id")?;
let executor_id = executor_id_raw
.map(|id| ExecutorId::from_str(&id))
.transpose()
.map_err(DbErrorGeneric::from)?;
let run_id_raw: Option<String> = get(&row, "run_id")?;
let run_id = run_id_raw
.map(|id| RunId::from_str(&id))
.transpose()
.map_err(DbErrorGeneric::from)?;
let join_set_id_raw: Option<String> = get(&row, "join_set_id")?;
let join_set_id = join_set_id_raw
.map(|id| JoinSetId::from_str(&id))
.transpose()
.map_err(DbErrorGeneric::from)?;
let join_set_closing: Option<bool> = get(&row, "join_set_closing")?;
let result_kind: Option<Json<PendingStateFinishedResultKind>> = get(&row, "result_kind")?;
let result_kind = result_kind.map(|it| it.0);
let is_paused: bool = get(&row, "is_paused")?;
let corresponding_version: i64 = get(&row, "corresponding_version")?;
let corresponding_version = Version::new(
VersionType::try_from(corresponding_version)
.map_err(|_| consistency_db_err("version must be non-negative"))?,
);
let dto = CombinedStateDTO {
execution_id: execution_id.clone(),
created_at,
first_scheduled_at,
state,
ffqn,
component_digest,
component_type,
deployment_id,
pending_expires_finished,
last_lock_version,
executor_id,
run_id,
join_set_id,
join_set_closing,
result_kind,
is_paused,
};
CombinedState::new(dto, corresponding_version).map_err(DbErrorRead::from)
}
async fn list_executions(
read_tx: &Transaction<'_>,
filter: ListExecutionsFilter,
pagination: &ExecutionListPagination,
) -> Result<Vec<ExecutionWithState>, DbErrorGeneric> {
struct QueryBuilder {
where_clauses: Vec<String>,
params: Vec<Box<dyn ToSql + Send + Sync>>,
}
impl QueryBuilder {
fn new() -> Self {
Self {
where_clauses: Vec::new(),
params: Vec::new(),
}
}
fn add_param<T>(&mut self, param: T) -> String
where
T: ToSql + Sync + Send + 'static,
{
self.params.push(Box::new(param));
format!("${}", self.params.len())
}
fn add_where(&mut self, clause: String) {
self.where_clauses.push(clause);
}
}
let mut qb = QueryBuilder::new();
let (limit, limit_desc) = match pagination {
ExecutionListPagination::CreatedBy(p) => {
let limit = p.length();
let is_desc = p.is_desc();
if let Some(cursor) = p.cursor() {
let placeholder = qb.add_param(*cursor);
qb.add_where(format!("created_at {} {placeholder}", p.rel()));
}
(limit, is_desc)
}
ExecutionListPagination::ExecutionId(p) => {
let limit = p.length();
let is_desc = p.is_desc();
if let Some(cursor) = p.cursor() {
let placeholder = qb.add_param(cursor.to_string());
qb.add_where(format!("execution_id {} {placeholder}", p.rel()));
}
(limit, is_desc)
}
};
if !filter.show_derived {
qb.add_where("is_top_level = true".to_string());
}
if let Some(function_name_filter) = filter.function_name_filter {
let placeholder = qb.add_param(function_name_filter.like_pattern());
qb.add_where(format!("ffqn LIKE {placeholder}"));
}
let like = |value: String| format!("{value}%");
if filter.hide_finished {
qb.add_where(format!("state != '{STATE_FINISHED}'"));
}
if let Some(prefix) = filter.execution_id_prefix {
let placeholder = qb.add_param(like(prefix));
qb.add_where(format!("execution_id LIKE {placeholder}"));
}
if let Some(component_digest) = filter.component_digest {
let placeholder = qb.add_param(component_digest);
qb.add_where(format!("component_id_input_digest = {placeholder}"));
}
if let Some(deployment_id) = filter.deployment_id {
let placeholder = qb.add_param(deployment_id.to_string());
qb.add_where(format!("deployment_id = {placeholder}"));
}
let where_str = if qb.where_clauses.is_empty() {
String::new()
} else {
format!("WHERE {}", qb.where_clauses.join(" AND "))
};
let order_col = match pagination {
ExecutionListPagination::CreatedBy(_) => "created_at",
ExecutionListPagination::ExecutionId(_) => "execution_id",
};
let (inner_order, outer_order) = if limit_desc {
("DESC", "")
} else {
("", "DESC")
};
let inner_sql = format!(
r"SELECT created_at, first_scheduled_at, component_id_input_digest, component_type, deployment_id,
state, execution_id, ffqn, corresponding_version, pending_expires_finished,
last_lock_version, executor_id, run_id,
join_set_id, join_set_closing,
result_kind, is_paused
FROM t_state {where_str} ORDER BY {order_col} {inner_order} LIMIT {limit}"
);
let sql = if outer_order.is_empty() {
inner_sql
} else {
format!("SELECT * FROM ({inner_sql}) AS sub ORDER BY {order_col} {outer_order}")
};
let params_refs: Vec<&(dyn ToSql + Sync)> = qb
.params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect();
let rows = read_tx.query(&sql, ¶ms_refs).await?;
let mut vec = Vec::with_capacity(rows.len());
for row in rows {
let unpack = || -> Result<ExecutionWithState, DbErrorGeneric> {
let execution_id_str: String = get(&row, "execution_id")?;
let execution_id = ExecutionId::from_str(&execution_id_str)
.map_err(|err| consistency_db_err(err.to_string()))?;
let digest_bytes: Vec<u8> = get(&row, "component_id_input_digest")?;
let digest = Digest::try_from(digest_bytes.as_slice()).map_err(|err| {
consistency_db_err_src("cannot parse `component_id_input_digest`", Arc::from(err))
})?;
let component_digest = ComponentDigest(digest);
let component_type: String = get(&row, "component_type")?;
let component_type = ComponentType::from_str(&component_type).map_err(|err| {
consistency_db_err_src("cannot parse `component_type`", Arc::from(err))
})?;
let deployment_id: String = get(&row, "deployment_id")?;
let deployment_id =
DeploymentId::from_str(&deployment_id).map_err(DbErrorGeneric::from)?;
let created_at: DateTime<Utc> = get(&row, "created_at")?;
let first_scheduled_at: DateTime<Utc> = get(&row, "first_scheduled_at")?;
let result_kind: Option<Json<PendingStateFinishedResultKind>> =
get(&row, "result_kind")?;
let result_kind = result_kind.map(|it| it.0);
let is_paused: bool = get(&row, "is_paused")?;
let corresponding_version: i64 = get(&row, "corresponding_version")?;
let corresponding_version = Version::try_from(corresponding_version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let executor_id_str: Option<String> = get(&row, "executor_id")?;
let executor_id = executor_id_str
.map(|id| ExecutorId::from_str(&id))
.transpose()?;
let last_lock_version_raw: Option<i64> = get(&row, "last_lock_version")?;
let last_lock_version = last_lock_version_raw
.map(Version::try_from)
.transpose()
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let run_id_str: Option<String> = get(&row, "run_id")?;
let run_id = run_id_str.map(|id| RunId::from_str(&id)).transpose()?;
let join_set_id_str: Option<String> = get(&row, "join_set_id")?;
let join_set_id = join_set_id_str
.map(|id| JoinSetId::from_str(&id))
.transpose()?;
let ffqn: String = get(&row, "ffqn")?;
let ffqn = FunctionFqn::from_str(&ffqn).map_err(|parse_err| {
error!("Error parsing ffqn - {parse_err:?}");
consistency_db_err("invalid ffqn value in `t_state`")
})?;
let combined_state_dto = CombinedStateDTO {
execution_id,
created_at,
first_scheduled_at,
component_digest,
component_type,
deployment_id,
state: get(&row, "state")?,
ffqn,
pending_expires_finished: get(&row, "pending_expires_finished")?,
executor_id,
last_lock_version,
run_id,
join_set_id,
join_set_closing: get(&row, "join_set_closing")?,
result_kind,
is_paused,
};
let combined_state = CombinedState::new(combined_state_dto, corresponding_version)?;
Ok(combined_state.execution_with_state)
};
match unpack() {
Ok(execution) => vec.push(execution),
Err(err) => {
warn!("Skipping corrupted row in t_state: {err:?}");
}
}
}
Ok(vec)
}
async fn list_responses(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
pagination: Option<Pagination<u32>>,
) -> Result<Vec<ResponseWithCursor>, DbErrorRead> {
let mut params: Vec<Box<dyn ToSql + Send + Sync>> = Vec::new();
let mut add_param = |p: Box<dyn ToSql + Send + Sync>| {
params.push(p);
format!("${}", params.len())
};
let p_execution_id = add_param(Box::new(execution_id.to_string()));
let mut sql = format!(
"SELECT \
r.id, r.created_at, r.join_set_id, r.delay_id, r.delay_success, r.child_execution_id, r.finished_version, l.json_value \
FROM t_join_set_response r LEFT OUTER JOIN t_execution_log l ON r.child_execution_id = l.execution_id \
WHERE \
r.execution_id = {p_execution_id} \
AND ( r.finished_version = l.version OR r.child_execution_id IS NULL )"
);
let limit = match &pagination {
Some(p @ (Pagination::NewerThan { cursor, .. } | Pagination::OlderThan { cursor, .. })) => {
let p_cursor = add_param(Box::new(i64::from(*cursor)));
write!(sql, " AND r.id {} {}", p.rel(), p_cursor).unwrap();
Some(p.length())
}
None => None,
};
sql.push_str(" ORDER BY r.id");
let is_desc = pagination.as_ref().is_some_and(Pagination::is_desc);
if is_desc {
sql.push_str(" DESC");
}
if let Some(limit) = limit {
let p_limit = add_param(Box::new(i64::from(limit)));
write!(sql, " LIMIT {p_limit}").unwrap();
}
if is_desc {
sql = format!("SELECT * FROM ({sql}) AS sub ORDER BY id ASC");
}
let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect();
let rows = tx
.query(&sql, ¶ms_refs)
.await
.map_err(DbErrorRead::from)?;
let mut results = Vec::with_capacity(rows.len());
for row in rows {
results.push(parse_response_with_cursor(&row)?);
}
Ok(results)
}
async fn list_logs_tx(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
show_derived: bool,
filter: &LogFilter,
pagination: &Pagination<DateTime<Utc>>,
) -> Result<ListLogsResponse, DbErrorRead> {
let mut param_index = 1;
let exec_id_str = execution_id.to_string();
let exec_id_filter = if show_derived {
format!("execution_id LIKE ${param_index} || '%'")
} else {
format!("execution_id = ${param_index}")
};
let mut query = format!(
"SELECT id, run_id, created_at, level, message, stream_type, payload, execution_id
FROM t_log
WHERE {exec_id_filter}"
);
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![&exec_id_str];
param_index += 1;
let level_filter = if filter.should_show_logs() {
let levels_str = if !filter.levels().is_empty() {
filter
.levels()
.iter()
.map(|lvl| (*lvl as u8).to_string())
.collect::<Vec<_>>()
.join(",")
} else {
LogLevel::iter()
.map(|lvl| (lvl as u8).to_string())
.collect::<Vec<_>>()
.join(",")
};
Some(format!(" level IN ({levels_str})"))
} else {
None
};
let stream_filter = if filter.should_show_streams() {
let streams_str = if !filter.stream_types().is_empty() {
filter
.stream_types()
.iter()
.map(|st| (*st as u8).to_string())
.collect::<Vec<_>>()
.join(",")
} else {
LogStreamType::iter()
.map(|st| (st as u8).to_string())
.collect::<Vec<_>>()
.join(",")
};
Some(format!(" stream_type IN ({streams_str})"))
} else {
None
};
match (level_filter, stream_filter) {
(Some(level_filter), Some(stream_filter)) => {
write!(&mut query, " AND ({level_filter} OR {stream_filter})")
.expect("writing to string");
}
(Some(level_filter), None) => {
write!(&mut query, " AND {level_filter}").expect("writing to string");
}
(None, Some(stream_filter)) => {
write!(&mut query, " AND {stream_filter}").expect("writing to string");
}
(None, None) => unreachable!("guarded by constructor"),
}
write!(
&mut query,
" AND created_at {} ${param_index}",
pagination.rel()
)
.expect("writing to string");
let cursor_val = pagination.cursor();
params.push(cursor_val);
param_index += 1;
let dir = if pagination.is_desc() { "DESC" } else { "ASC" };
write!(
&mut query,
" ORDER BY created_at {dir}, id {dir} LIMIT ${param_index}",
)
.expect("writing to string");
let length_val: i64 = i64::from(pagination.length());
params.push(&length_val);
let rows = tx.query(&query, ¶ms[..]).await?;
let mut items = Vec::with_capacity(rows.len());
for row in rows {
let created_at: chrono::DateTime<chrono::Utc> = get(&row, "created_at")?;
let run_id: String = get(&row, "run_id")?;
let run_id = RunId::from_str(&run_id).map_err(|parse_err| {
consistency_db_err_src(
format!("cannot convert RunId {run_id}"),
Arc::from(parse_err),
)
})?;
let execution_id_str: String = get(&row, "execution_id")?;
let execution_id = ExecutionId::from_str(&execution_id_str).map_err(|parse_err| {
consistency_db_err_src(
format!("cannot convert ExecutionId {execution_id_str}"),
Arc::from(parse_err),
)
})?;
let level: Option<i32> = get(&row, "level")?;
let message: Option<String> = get(&row, "message")?;
let stream_type: Option<i32> = get(&row, "stream_type")?;
let payload: Option<Vec<u8>> = get(&row, "payload")?;
let log_entry = match (level, message, stream_type, payload) {
(Some(lvl), Some(msg), None, None) => {
let map_err =
|err| consistency_db_err_src(format!("cannot convert {lvl} to LogLevel"), err);
LogEntry::Log {
created_at,
level: u8::try_from(lvl)
.map(|lvl| LogLevel::try_from(lvl).map_err(|err| map_err(Arc::from(err))))
.map_err(|err| map_err(Arc::from(err)))??,
message: msg,
}
}
(None, None, Some(stype), Some(pl)) => {
let map_err = |err| {
consistency_db_err_src(format!("cannot convert {stype} to LogStreamType"), err)
};
LogEntry::Stream {
created_at,
stream_type: u8::try_from(stype)
.map(|stype| {
LogStreamType::try_from(stype).map_err(|err| map_err(Arc::from(err)))
})
.map_err(|err| map_err(Arc::from(err)))??,
payload: pl,
}
}
_ => {
return Err(consistency_db_err("invalid t_log row".to_string()).into());
}
};
items.push(LogEntryRow {
cursor: created_at,
run_id,
log_entry,
execution_id,
});
}
Ok(ListLogsResponse {
next_page: items
.last()
.map(|item| Pagination::NewerThan {
length: pagination.length(),
cursor: item.cursor,
including_cursor: false,
})
.unwrap_or(if pagination.is_asc() {
*pagination } else {
Pagination::NewerThan {
length: pagination.length(),
cursor: DateTime::<Utc>::UNIX_EPOCH,
including_cursor: false,
}
}),
prev_page: match items.first() {
Some(item) => Some(Pagination::OlderThan {
length: pagination.length(),
cursor: item.cursor,
including_cursor: false,
}),
None if pagination.is_asc() && pagination.cursor() > &DateTime::<Utc>::UNIX_EPOCH => {
Some(pagination.invert())
}
None => None,
},
items,
})
}
async fn list_deployment_states(
tx: &Transaction<'_>,
current_time: DateTime<Utc>,
pagination: Pagination<Option<DeploymentId>>,
include_config_json: bool,
) -> Result<Vec<DeploymentState>, DbErrorRead> {
let mut params: Vec<Box<dyn ToSql + Send + Sync>> = Vec::new();
let mut add_param = |p: Box<dyn ToSql + Send + Sync>| {
params.push(p);
format!("${}", params.len())
};
let p_now = add_param(Box::new(current_time));
let config_json_col = if include_config_json {
"d.config_json"
} else {
"NULL::TEXT AS config_json"
};
let mut sql = format!(
"
SELECT
d.deployment_id,
COUNT(*) FILTER (WHERE s.state = '{STATE_LOCKED}') AS locked,
COUNT(*) FILTER (
WHERE s.state = '{STATE_PENDING_AT}'
AND s.pending_expires_finished <= {p_now}
) AS pending,
COUNT(*) FILTER (
WHERE s.state = '{STATE_PENDING_AT}'
AND s.pending_expires_finished > {p_now}
) AS scheduled,
COUNT(*) FILTER (WHERE s.state = '{STATE_BLOCKED_BY_JOIN_SET}') AS blocked,
COUNT(*) FILTER (WHERE s.state = '{STATE_FINISHED}') AS finished,
{config_json_col},
d.created_at,
d.last_active_at,
d.status
FROM t_deployment d
LEFT JOIN t_state s ON s.deployment_id = d.deployment_id"
);
if let Some(cursor) = pagination.cursor() {
let p_cursor = add_param(Box::new(cursor.to_string()));
write!(
sql,
" WHERE d.deployment_id {rel} {p_cursor}",
rel = pagination.rel()
)
.expect("writing to string");
}
let (inner_order, outer_order) = if pagination.is_desc() {
("DESC", "")
} else {
("ASC", "DESC")
};
write!(
sql,
" GROUP BY d.deployment_id, d.config_json, d.created_at, d.last_active_at, d.status ORDER BY d.deployment_id {inner_order} LIMIT {}",
pagination.length()
)
.expect("writing to string");
let final_sql = if outer_order.is_empty() {
sql
} else {
format!("SELECT * FROM ({sql}) AS sub ORDER BY deployment_id {outer_order}")
};
let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect();
let rows = tx
.query(&final_sql, ¶ms_refs)
.await
.map_err(DbErrorRead::from)?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let deployment_id: String = get(&row, "deployment_id")?;
let status_str: String = get::<String, _>(&row, "status")?;
let status = status_str
.parse::<DeploymentStatus>()
.map_err(|e| consistency_db_err(format!("unknown deployment status: {e}")))?;
result.push(DeploymentState {
deployment_id: DeploymentId::from_str(&deployment_id).map_err(DbErrorGeneric::from)?,
locked: u32::try_from(get::<i64, _>(&row, "locked")?).expect("count is never negative"),
pending: u32::try_from(get::<i64, _>(&row, "pending")?)
.expect("count is never negative"),
scheduled: u32::try_from(get::<i64, _>(&row, "scheduled")?)
.expect("count is never negative"),
blocked: u32::try_from(get::<i64, _>(&row, "blocked")?)
.expect("count is never negative"),
finished: u32::try_from(get::<i64, _>(&row, "finished")?)
.expect("count is never negative"),
config_json: get::<Option<String>, _>(&row, "config_json")?,
created_at: get::<DateTime<Utc>, _>(&row, "created_at")?,
last_active_at: get::<Option<DateTime<Utc>>, _>(&row, "last_active_at")?,
status,
});
}
Ok(result)
}
fn parse_response_with_cursor(
row: &tokio_postgres::Row,
) -> Result<ResponseWithCursor, DbErrorRead> {
let id = u32::try_from(get::<i64, _>(row, "id")?)
.map_err(|_| consistency_db_err("id must not be negative"))?;
let created_at: DateTime<Utc> = get(row, "created_at")?;
let join_set_id_str: String = get(row, "join_set_id")?;
let join_set_id = JoinSetId::from_str(&join_set_id_str).map_err(DbErrorGeneric::from)?;
let delay_id: Option<String> = get(row, "delay_id")?;
let delay_id = delay_id
.map(|id| DelayId::from_str(&id))
.transpose()
.map_err(DbErrorGeneric::from)?;
let delay_success: Option<bool> = get(row, "delay_success")?;
let child_execution_id: Option<String> = get(row, "child_execution_id")?;
let child_execution_id = child_execution_id
.map(|id| ExecutionIdDerived::from_str(&id))
.transpose()
.map_err(DbErrorGeneric::from)?;
let finished_version = get::<Option<i64>, _>(row, "finished_version")?
.map(Version::try_from)
.transpose()
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let json_value: Option<Json<ExecutionRequest>> = get(row, "json_value")?;
let json_value = json_value.map(|it| it.0);
let event = match (
delay_id,
delay_success,
child_execution_id,
finished_version,
json_value,
) {
(Some(delay_id), Some(delay_success), None, None, None) => JoinSetResponse::DelayFinished {
delay_id,
result: delay_success.then_some(()).ok_or(()),
},
(None, None, Some(child_execution_id), Some(finished_version), Some(json_val)) => {
if let ExecutionRequest::Finished { retval: result, .. } = json_val {
JoinSetResponse::ChildExecutionFinished {
child_execution_id,
finished_version,
result,
}
} else {
error!("Joined log entry must be 'Finished'");
return Err(consistency_db_err("joined log entry must be 'Finished'").into());
}
}
(delay, delay_success, child, finished, result) => {
error!(
"Invalid row in t_join_set_response {id} - {delay:?} {delay_success:?} {child:?} {finished:?} {result:?}",
);
return Err(consistency_db_err("invalid row in t_join_set_response").into());
}
};
Ok(ResponseWithCursor {
cursor: ResponseCursor(id),
event: JoinSetResponseEventOuter {
event: JoinSetResponseEvent { join_set_id, event },
created_at,
},
})
}
#[instrument(level = Level::TRACE, skip_all, fields(%execution_id, %run_id, %executor_id))]
#[expect(clippy::too_many_arguments)]
async fn lock_single_execution(
tx: &Transaction<'_>,
created_at: DateTime<Utc>,
component_id: &ComponentId,
update_component_digest: bool, deployment_id: DeploymentId,
execution_id: &ExecutionId,
run_id: RunId,
appending_version: &Version,
executor_id: ExecutorId,
lock_expires_at: DateTime<Utc>,
retry_config: ComponentRetryConfig,
) -> Result<LockedExecution, DbErrorWrite> {
trace!("lock_single_execution");
let combined_state = get_combined_state(tx, execution_id).await?;
let context_component_digest = if update_component_digest {
component_id.component_digest.clone()
} else {
combined_state.execution_with_state.component_digest.clone()
};
combined_state
.execution_with_state
.pending_state
.can_append_lock(created_at, executor_id, run_id, lock_expires_at)?;
let expected_version = combined_state.get_next_version_assert_not_finished();
check_expected_next_and_appending_version(&expected_version, appending_version)?;
let locked_event = Locked {
component_id: component_id.clone(),
deployment_id,
executor_id,
lock_expires_at,
run_id,
retry_config,
};
let event = ExecutionRequest::Locked(locked_event.clone());
let event = Json(event);
tx.execute(
"INSERT INTO t_execution_log \
(execution_id, created_at, json_value, version, variant) \
VALUES ($1, $2, $3, $4, $5)",
&[
&execution_id.to_string(),
&created_at,
&event,
&i64::from(appending_version.0),
&event.0.variant(),
],
)
.await
.map_err(|err| {
DbErrorWrite::NonRetriable(DbErrorWriteNonRetriable::IllegalState {
reason: "cannot lock".into(),
context: SpanTrace::capture(),
source: Some(Arc::new(err)),
loc: Location::caller(),
})
})?;
let responses = list_responses(tx, execution_id, None).await?;
trace!("Responses: {responses:?}");
let intermittent_event_count = update_state_locked_get_intermittent_event_count(
tx,
execution_id,
deployment_id,
update_component_digest.then_some(&component_id.component_digest),
executor_id,
run_id,
lock_expires_at,
appending_version,
retry_config,
)
.await?;
let rows = tx
.query(
"SELECT json_value, version FROM t_execution_log WHERE \
execution_id = $1 AND (variant = $2 OR variant = $3) \
ORDER BY version",
&[
&execution_id.to_string(),
&DUMMY_CREATED.variant(),
&DUMMY_HISTORY_EVENT.variant(),
],
)
.await
.map_err(DbErrorGeneric::from)?;
let mut events: VecDeque<ExecutionEvent> = VecDeque::new();
for row in rows {
let event: Json<ExecutionRequest> = get(&row, "json_value")?;
let event = event.0;
let version: i64 = get(&row, "version")?;
let version = Version::try_from(version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
events.push_back(ExecutionEvent {
created_at: DateTime::from_timestamp_nanos(0), event,
backtrace_id: None,
version,
});
}
let Some(ExecutionRequest::Created {
ffqn,
params,
parent,
metadata,
..
}) = events.pop_front().map(|outer| outer.event)
else {
error!("Execution log must contain at least `Created` event");
return Err(consistency_db_err("execution log must contain `Created` event").into());
};
let mut event_history = Vec::new();
for ExecutionEvent { event, version, .. } in events {
if let ExecutionRequest::HistoryEvent { event } = event {
event_history.push((event, version));
} else {
error!("Rows can only contain `Created` and `HistoryEvent` event kinds");
return Err(consistency_db_err(
"rows can only contain `Created` and `HistoryEvent` event kinds",
)
.into());
}
}
Ok(LockedExecution {
execution_id: execution_id.clone(),
metadata,
component_digest: context_component_digest,
next_version: appending_version.increment(),
ffqn,
params,
event_history,
responses,
parent,
intermittent_event_count,
locked_event,
})
}
async fn count_join_next(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
join_set_id: &JoinSetId,
) -> Result<u32, DbErrorRead> {
let row = tx
.query_one(
"SELECT COUNT(*) as count FROM t_execution_log WHERE execution_id = $1 AND join_set_id = $2 \
AND history_event_type = $3",
&[
&execution_id.to_string(),
&join_set_id.to_string(),
&HISTORY_EVENT_TYPE_JOIN_NEXT,
],
)
.await
.map_err(DbErrorRead::from)?;
let count = u32::try_from(get::<i64, _>(&row, "count")?).expect("COUNT cannot be negative");
Ok(count)
}
async fn nth_response(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
join_set_id: &JoinSetId,
skip_rows: u32,
) -> Result<Option<ResponseWithCursor>, DbErrorRead> {
let row = tx
.query_opt(
"SELECT r.id, r.created_at, r.join_set_id, \
r.delay_id, r.delay_success, \
r.child_execution_id, r.finished_version, l.json_value \
FROM t_join_set_response r LEFT OUTER JOIN t_execution_log l ON r.child_execution_id = l.execution_id \
WHERE \
r.execution_id = $1 AND r.join_set_id = $2 AND \
( \
r.finished_version = l.version \
OR \
r.child_execution_id IS NULL \
) \
ORDER BY id \
LIMIT 1 OFFSET $3",
&[
&execution_id.to_string(),
&join_set_id.to_string(),
&i64::from(skip_rows),
]
)
.await
.map_err(DbErrorRead::from)?;
match row {
Some(r) => Ok(Some(parse_response_with_cursor(&r)?)),
None => Ok(None),
}
}
#[instrument(level = Level::TRACE, skip_all, fields(%execution_id))]
async fn append(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
req: AppendRequest,
appending_version: Version,
) -> Result<(AppendResponse, AppendNotifier), DbErrorWrite> {
if matches!(req.event, ExecutionRequest::Created { .. }) {
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::ValidationFailed(
"cannot append `Created` event - use `create` instead".into(),
),
));
}
if let AppendRequest {
event:
ExecutionRequest::Locked(Locked {
component_id,
deployment_id,
executor_id,
run_id,
lock_expires_at,
retry_config,
}),
created_at,
} = req
{
return lock_single_execution(
tx,
created_at,
&component_id,
true, deployment_id,
execution_id,
run_id,
&appending_version,
executor_id,
lock_expires_at,
retry_config,
)
.await
.map(|locked_execution| (locked_execution.next_version, AppendNotifier::default()));
}
let combined_state = get_combined_state(tx, execution_id).await?;
if combined_state
.execution_with_state
.pending_state
.is_finished()
{
debug!("Execution is already finished");
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::AlreadyFinished,
));
}
check_expected_next_and_appending_version(
&combined_state.get_next_version_assert_not_finished(),
&appending_version,
)?;
let event = Json(req.event);
tx.execute(
"INSERT INTO t_execution_log (execution_id, created_at, json_value, version, variant, join_set_id) \
VALUES ($1, $2, $3, $4, $5, $6)",
&[
&execution_id.to_string(),
&req.created_at,
&event,
&i64::from(appending_version.0),
&event.0.variant(),
&event.0.join_set_id().map(std::string::ToString::to_string),
],
)
.await?;
match &event.0 {
ExecutionRequest::Created { .. } => {
unreachable!("handled in the caller")
}
ExecutionRequest::Locked { .. } => {
unreachable!("handled above")
}
ExecutionRequest::TemporarilyFailed {
backoff_expires_at, ..
}
| ExecutionRequest::TemporarilyTimedOut {
backoff_expires_at, ..
} => {
let (next_version, notifier) = update_state_pending_after_event_appended(
tx,
execution_id,
&appending_version,
PendingAfterEventUpdate {
scheduled_at: *backoff_expires_at,
intermittent_failure: true,
component_input_digest: combined_state.execution_with_state.component_digest,
},
)
.await?;
return Ok((next_version, notifier));
}
ExecutionRequest::Unlocked(unlocked) => {
match &combined_state.execution_with_state.pending_state {
PendingState::PendingAt(_) => {
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::UnlockedCannotBeAppended("pending"),
));
}
PendingState::Locked(_) => {
let (next_version, notifier) = update_state_unlocked_from_locked(
tx,
execution_id,
combined_state.execution_with_state.component_digest,
unlocked.backoff_expires_at,
&appending_version,
)
.await?;
return Ok((next_version, notifier));
}
PendingState::BlockedByJoinSet(_) => {
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::UnlockedCannotBeAppended("blocked"),
));
}
PendingState::Paused(_) => {
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::UnlockedCannotBeAppended("paused"),
));
}
PendingState::Finished(_) => {
unreachable!("handled above");
}
}
}
ExecutionRequest::ComponentUpgradeFinished {
component_digest,
deployment_id,
outcome,
} => match outcome {
ComponentUpgradeOutcome::Success { .. } => {
let next_version = update_state_component_upgrade_finished_success(
tx,
execution_id,
component_digest,
*deployment_id,
&appending_version,
)
.await?;
return Ok((next_version, AppendNotifier::default()));
}
ComponentUpgradeOutcome::Failed { .. } => {
let next_version = update_state_component_upgrade_finished_failed(
tx,
execution_id,
component_digest,
&appending_version,
)
.await?;
return Ok((next_version, AppendNotifier::default()));
}
},
ExecutionRequest::Paused => {
match &combined_state.execution_with_state.pending_state {
PendingState::Finished { .. } => {
unreachable!("handled above");
}
PendingState::Locked(..) => {
return Err(DbErrorWriteNonRetriable::IllegalState {
reason:
"cannot append Paused event when execution is locked; use pause_execution"
.into(),
context: SpanTrace::capture(),
source: None,
loc: Location::caller(),
}
.into());
}
PendingState::Paused(..) => {
return Err(DbErrorWriteNonRetriable::IllegalState {
reason: "cannot pause, execution is already paused".into(),
context: SpanTrace::capture(),
source: None,
loc: Location::caller(),
}
.into());
}
_ => {}
}
let next_version =
update_state_paused(tx, execution_id, &appending_version, true).await?;
return Ok((next_version, AppendNotifier::default()));
}
ExecutionRequest::Unpaused => {
if !combined_state
.execution_with_state
.pending_state
.is_paused()
{
return Err(DbErrorWriteNonRetriable::IllegalState {
reason: "cannot unpause, execution is not paused".into(),
context: SpanTrace::capture(),
source: None,
loc: Location::caller(),
}
.into());
}
let next_version =
update_state_paused(tx, execution_id, &appending_version, false).await?;
return Ok((next_version, AppendNotifier::default()));
}
ExecutionRequest::Finished { retval, .. } => {
update_state_finished(
tx,
execution_id,
&appending_version,
req.created_at,
PendingStateFinishedResultKind::from(retval),
)
.await?;
return Ok((
appending_version,
AppendNotifier {
pending_at: None,
execution_finished: Some(NotifierExecutionFinished {
execution_id: execution_id.clone(),
retval: retval.clone(),
}),
response: None,
},
));
}
ExecutionRequest::HistoryEvent {
event:
HistoryEvent::JoinSetCreate { .. }
| HistoryEvent::JoinSetRequest {
request: JoinSetRequest::ChildExecutionRequest { .. },
..
}
| HistoryEvent::Persist { .. }
| HistoryEvent::Schedule { .. }
| HistoryEvent::Stub { .. }
| HistoryEvent::JoinNextTooMany { .. }
| HistoryEvent::JoinNextTry { .. },
} => {
return Ok((
bump_state_next_version(tx, execution_id, &appending_version, None).await?,
AppendNotifier::default(),
));
}
ExecutionRequest::HistoryEvent {
event:
HistoryEvent::JoinSetRequest {
join_set_id,
request:
JoinSetRequest::DelayRequest {
delay_id,
expires_at,
paused,
..
},
},
} => {
return Ok((
bump_state_next_version(
tx,
execution_id,
&appending_version,
Some(DelayReq {
join_set_id: join_set_id.clone(),
delay_id: delay_id.clone(),
expires_at: *expires_at,
paused: *paused,
}),
)
.await?,
AppendNotifier::default(),
));
}
ExecutionRequest::HistoryEvent {
event:
HistoryEvent::JoinNext {
join_set_id,
run_expires_at,
closing,
requested_ffqn: _,
},
} => {
let join_next_count = count_join_next(tx, execution_id, join_set_id).await?;
let nth_response =
nth_response(tx, execution_id, join_set_id, join_next_count - 1).await?;
trace!("join_next_count: {join_next_count}, nth_response: {nth_response:?}");
assert!(join_next_count > 0);
if let Some(ResponseWithCursor {
event:
JoinSetResponseEventOuter {
created_at: nth_created_at,
..
},
cursor: _,
}) = nth_response
{
let scheduled_at = std::cmp::max(*run_expires_at, nth_created_at);
let (next_version, notifier) = update_state_pending_after_event_appended(
tx,
execution_id,
&appending_version,
PendingAfterEventUpdate {
scheduled_at,
intermittent_failure: false,
component_input_digest: combined_state
.execution_with_state
.component_digest,
},
)
.await?;
return Ok((next_version, notifier));
}
return Ok((
update_state_blocked(
tx,
execution_id,
&appending_version,
join_set_id,
*run_expires_at,
*closing,
)
.await?,
AppendNotifier::default(),
));
}
}
}
async fn append_response(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
event: JoinSetResponseEventOuter,
) -> Result<AppendNotifier, DbErrorWrite> {
let join_set_id = &event.event.join_set_id;
let (delay_id, delay_success) = match &event.event.event {
JoinSetResponse::DelayFinished { delay_id, result } => {
(Some(delay_id.to_string()), Some(result.is_ok()))
}
JoinSetResponse::ChildExecutionFinished { .. } => (None, None),
};
let (child_execution_id, finished_version) = match &event.event.event {
JoinSetResponse::ChildExecutionFinished {
child_execution_id,
finished_version,
result: _,
} => (
Some(child_execution_id.to_string()),
Some(i64::from(finished_version.0)),
),
JoinSetResponse::DelayFinished { .. } => (None, None),
};
let row = tx.query_one(
"INSERT INTO t_join_set_response (execution_id, created_at, join_set_id, delay_id, delay_success, child_execution_id, finished_version) \
VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id",
&[
&execution_id.to_string(),
&event.created_at,
&join_set_id.to_string(),
&delay_id,
&delay_success,
&child_execution_id,
&finished_version,
]
).await?;
let cursor = ResponseCursor(
u32::try_from(get::<i64, _>(&row, 0)?)
.map_err(|_| consistency_db_err("t_join_set_response.id must not be negative"))?,
);
let combined_state = get_combined_state(tx, execution_id).await?;
debug!("previous_pending_state: {combined_state:?}");
let mut notifier = if let PendingStateMergedPause::BlockedByJoinSet {
state:
PendingStateBlockedByJoinSet {
join_set_id: found_join_set_id,
lock_expires_at, closing: _,
},
paused: _,
} =
PendingStateMergedPause::from(combined_state.execution_with_state.pending_state)
&& *join_set_id == found_join_set_id
{
let scheduled_at = std::cmp::max(lock_expires_at, event.created_at);
update_state_pending_after_response_appended(
tx,
execution_id,
scheduled_at,
combined_state.execution_with_state.component_digest,
)
.await?
} else {
AppendNotifier::default()
};
if let JoinSetResponseEvent {
join_set_id,
event:
JoinSetResponse::DelayFinished {
delay_id,
result: _,
},
} = &event.event
{
debug!(%join_set_id, %delay_id, "Deleting from `t_delay`");
tx.execute(
"DELETE FROM t_delay WHERE execution_id = $1 AND join_set_id = $2 AND delay_id = $3",
&[
&execution_id.to_string(),
&join_set_id.to_string(),
&delay_id.to_string(),
],
)
.await?;
}
notifier.response = Some((execution_id.clone(), ResponseWithCursor { cursor, event }));
Ok(notifier)
}
async fn append_backtrace(
tx: &Transaction<'_>,
backtrace_info: &BacktraceInfo,
) -> Result<(), DbErrorWrite> {
let backtrace_hash = backtrace_info.wasm_backtrace.hash();
tx.execute(
"INSERT INTO t_wasm_backtrace (backtrace_hash, wasm_backtrace) \
VALUES ($1, $2) \
ON CONFLICT (backtrace_hash) DO NOTHING",
&[
&backtrace_hash.as_slice(),
&Json(&backtrace_info.wasm_backtrace),
],
)
.await?;
tx.execute(
"INSERT INTO t_execution_backtrace \
(execution_id, component_id, version_min_including, version_max_excluding, backtrace_hash) \
VALUES ($1, $2, $3, $4, $5)",
&[
&backtrace_info.execution_id.to_string(),
&Json(&backtrace_info.component_id),
&i64::from(backtrace_info.version_min_including.0),
&i64::from(backtrace_info.version_max_excluding.0),
&backtrace_hash.as_slice(),
],
)
.await?;
Ok(())
}
async fn append_log(tx: &Transaction<'_>, row: &LogInfoAppendRow) -> Result<(), DbErrorWrite> {
let (level, message, stream_type, payload, created_at) = match &row.log_entry {
LogEntry::Log {
created_at,
level,
message,
} => (
Some(*level as i32),
Some(message.as_str()),
None::<i32>,
None::<&[u8]>,
created_at,
),
LogEntry::Stream {
created_at,
payload,
stream_type,
} => (
None::<i32>,
None::<&str>,
Some(*stream_type as i32),
Some(payload.as_slice()),
created_at,
),
};
tx.execute(
"INSERT INTO t_log (
execution_id,
run_id,
created_at,
level,
message,
stream_type,
payload
) VALUES ($1, $2, $3, $4, $5, $6, $7)",
&[
&row.execution_id.to_string(),
&row.run_id.to_string(),
&created_at,
&level,
&message,
&stream_type,
&payload,
],
)
.await?;
Ok(())
}
async fn get_execution_log(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
) -> Result<concepts::storage::ExecutionLog, DbErrorRead> {
let rows = tx
.query(
"SELECT created_at, json_value, version FROM t_execution_log WHERE \
execution_id = $1 ORDER BY version",
&[&execution_id.to_string()],
)
.await
.map_err(DbErrorRead::from)?;
if rows.is_empty() {
return Err(DbErrorRead::NotFound);
}
let mut events = Vec::with_capacity(rows.len());
for row in rows {
let created_at: DateTime<Utc> = get(&row, "created_at")?;
let event: Json<ExecutionRequest> = get(&row, "json_value")?;
let event = event.0;
let version: i64 = get(&row, "version")?;
let version = Version::try_from(version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
events.push(ExecutionEvent {
created_at,
event,
backtrace_id: None,
version,
});
}
let combined_state = get_combined_state(tx, execution_id).await?;
let responses = list_responses(tx, execution_id, None).await?;
Ok(concepts::storage::ExecutionLog {
execution_id: execution_id.clone(),
events,
responses,
next_version: combined_state.get_next_version_or_finished(),
pending_state: combined_state.execution_with_state.pending_state,
component_digest: combined_state.execution_with_state.component_digest,
component_type: combined_state.execution_with_state.component_type,
deployment_id: combined_state.execution_with_state.deployment_id,
})
}
async fn get_max_version(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
) -> Result<Version, DbErrorRead> {
let row = tx
.query_one(
"SELECT MAX(version) as version FROM t_execution_log WHERE execution_id = $1",
&[&execution_id.to_string()],
)
.await?;
let max_version: i64 = get(&row, "version")?;
let max_version = Version::try_from(max_version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
Ok(max_version)
}
async fn get_max_response_cursor(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
) -> Result<ResponseCursor, DbErrorRead> {
let row = tx
.query_one(
"SELECT MAX(id) as id FROM t_join_set_response WHERE execution_id = $1",
&[&execution_id.to_string()],
)
.await?;
let max_cursor = get::<Option<i64>, _>(&row, "id")?.unwrap_or_default();
let max_cursor = ResponseCursor(
u32::try_from(max_cursor).map_err(|_| consistency_db_err("id must not be negative"))?,
);
Ok(max_cursor)
}
async fn list_execution_events(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
pagination: Pagination<VersionType>,
include_backtrace_id: bool,
) -> Result<Vec<ExecutionEvent>, DbErrorRead> {
let mut params: Vec<Box<dyn ToSql + Send + Sync>> = Vec::new();
let mut add_param = |p: Box<dyn ToSql + Send + Sync>| {
params.push(p);
format!("${}", params.len())
};
let p_execution_id = add_param(Box::new(execution_id.to_string()));
let (cursor, length, rel, is_desc) = match &pagination {
Pagination::NewerThan {
cursor,
length,
including_cursor,
} => (
*cursor,
*length,
if *including_cursor { ">=" } else { ">" },
false,
),
Pagination::OlderThan {
cursor,
length,
including_cursor,
} => (
*cursor,
*length,
if *including_cursor { "<=" } else { "<" },
true,
),
};
let p_cursor = add_param(Box::new(i64::from(cursor)));
let p_limit = add_param(Box::new(i64::from(length)));
let base_select = if include_backtrace_id {
format!(
"SELECT
log.created_at,
log.json_value,
log.version,
bt.version_min_including AS backtrace_id
FROM
t_execution_log AS log
LEFT OUTER JOIN
t_execution_backtrace AS bt ON log.execution_id = bt.execution_id
AND log.version >= bt.version_min_including
AND log.version < bt.version_max_excluding
WHERE
log.execution_id = {p_execution_id}
AND log.version {rel} {p_cursor}"
)
} else {
format!(
"SELECT
created_at, json_value, NULL::BIGINT as backtrace_id, version
FROM t_execution_log WHERE
execution_id = {p_execution_id} AND version {rel} {p_cursor}"
)
};
let order = if is_desc { "DESC" } else { "ASC" };
let mut sql = format!("{base_select} ORDER BY version {order} LIMIT {p_limit}");
if is_desc {
sql = format!("SELECT * FROM ({sql}) AS sub ORDER BY version ASC");
}
let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect();
let rows = tx
.query(&sql, ¶ms_refs)
.await
.map_err(DbErrorRead::from)?;
let mut events = Vec::with_capacity(rows.len());
for row in rows {
let created_at: DateTime<Utc> = get(&row, "created_at")?;
let backtrace_id = get::<Option<i64>, _>(&row, "backtrace_id")?
.map(Version::try_from)
.transpose()
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let version = get::<i64, _>(&row, "version")?;
let version = Version::new(
VersionType::try_from(version)
.map_err(|_| consistency_db_err("version must be non-negative"))?,
);
let event_req: Json<ExecutionRequest> = get(&row, "json_value")?;
let event_req = event_req.0;
events.push(ExecutionEvent {
created_at,
event: event_req,
backtrace_id,
version,
});
}
Ok(events)
}
async fn get_execution_event(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
version: VersionType,
) -> Result<ExecutionEvent, DbErrorRead> {
let row = tx
.query_one(
"SELECT created_at, json_value, version FROM t_execution_log WHERE \
execution_id = $1 AND version = $2",
&[&execution_id.to_string(), &i64::from(version)],
)
.await?;
let created_at: DateTime<Utc> = get(&row, "created_at")?;
let json_val: Json<ExecutionRequest> = get(&row, "json_value")?;
let version = get::<i64, _>(&row, "version")?;
let version = Version::try_from(version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let event = json_val.0;
Ok(ExecutionEvent {
created_at,
event,
backtrace_id: None,
version,
})
}
async fn get_last_execution_event(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
) -> Result<ExecutionEvent, DbErrorRead> {
let row = tx
.query_one(
"SELECT created_at, json_value, version FROM t_execution_log WHERE \
execution_id = $1 ORDER BY version DESC LIMIT 1",
&[&execution_id.to_string()],
)
.await?;
let created_at: DateTime<Utc> = get(&row, "created_at")?;
let event: Json<ExecutionRequest> = get(&row, "json_value")?;
let event = event.0;
let version: i64 = get(&row, "version")?;
let version = Version::try_from(version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
Ok(ExecutionEvent {
created_at,
event,
backtrace_id: None,
version,
})
}
async fn delay_response(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
delay_id: &DelayId,
) -> Result<Option<bool>, DbErrorRead> {
let row = tx
.query_opt(
"SELECT delay_success \
FROM t_join_set_response \
WHERE \
execution_id = $1 AND delay_id = $2",
&[&execution_id.to_string(), &delay_id.to_string()],
)
.await?;
match row {
Some(r) => Ok(Some(get::<bool, _>(&r, "delay_success")?)),
None => Ok(None),
}
}
#[instrument(level = Level::TRACE, skip_all)]
async fn get_responses_after(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
last_response: ResponseCursor,
) -> Result<Vec<ResponseWithCursor>, DbErrorRead> {
let rows = tx
.query(
"SELECT r.id, r.created_at, r.join_set_id, \
r.delay_id, r.delay_success, \
r.child_execution_id, r.finished_version, child.json_value \
FROM t_join_set_response r LEFT OUTER JOIN t_execution_log child ON r.child_execution_id = child.execution_id \
WHERE \
r.id > $1 AND \
r.execution_id = $2 AND \
( \
r.finished_version = child.version \
OR \
r.child_execution_id IS NULL \
) \
ORDER BY id \
",
&[
&i64::from(last_response.0),
&execution_id.to_string(),
]
)
.await?;
let mut results = Vec::with_capacity(rows.len());
for row in rows {
let resp = parse_response_with_cursor(&row)?;
results.push(resp);
}
Ok(results)
}
async fn get_pending_of_single_ffqn(
tx: &Transaction<'_>,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
ffqn: &FunctionFqn,
select_strategy: SelectStrategy,
) -> Result<Vec<(ExecutionId, Version)>, ()> {
let rows = tx
.query(
&format!(
r"
SELECT execution_id, corresponding_version FROM t_state
WHERE
state = '{STATE_PENDING_AT}' AND
pending_expires_finished <= $1 AND ffqn = $2
AND is_paused = false
ORDER BY pending_expires_finished
{}
LIMIT $3
",
if select_strategy == SelectStrategy::LockForUpdate {
"FOR UPDATE SKIP LOCKED"
} else {
""
}
),
&[
&pending_at_or_sooner,
&ffqn.to_string(),
&(i64::from(batch_size)),
],
)
.await
.map_err(|err| {
warn!("Ignoring consistency error {err:?}");
})?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let unpack = || -> Result<(ExecutionId, Version), DbErrorGeneric> {
let eid_str: String = get(&row, "execution_id")?;
let corresponding_version: i64 = get(&row, "corresponding_version")?;
let corresponding_version = Version::try_from(corresponding_version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
if let Ok(eid) = ExecutionId::from_str(&eid_str) {
return Ok((eid, corresponding_version.increment()));
}
Err(consistency_db_err("invalid execution_id"))
};
match unpack() {
Ok(val) => result.push(val),
Err(err) => warn!("Ignoring corrupted row in pending check: {err:?}"),
}
}
Ok(result)
}
async fn get_pending_by_ffqns(
tx: &Transaction<'_>,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
ffqns: &[FunctionFqn],
select_strategy: SelectStrategy,
) -> Result<Vec<(ExecutionId, Version)>, DbErrorGeneric> {
let batch_size = usize::try_from(batch_size).expect("16 bit systems are unsupported");
let mut execution_ids_versions = Vec::with_capacity(batch_size);
for ffqn in ffqns {
let needed = batch_size - execution_ids_versions.len();
if needed == 0 {
break;
}
let needed =
u32::try_from(needed).expect("`batch_size`:u32 - usize cannot overflow an u32");
if let Ok(execs) =
get_pending_of_single_ffqn(tx, needed, pending_at_or_sooner, ffqn, select_strategy)
.await
{
execution_ids_versions.extend(execs);
}
}
Ok(execution_ids_versions)
}
async fn get_pending_by_ffqns_auto(
tx: &Transaction<'_>,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
ffqns: &[FunctionFqn],
current_digest: &ComponentDigest,
select_strategy: SelectStrategy,
) -> Result<Vec<(ExecutionId, Version)>, DbErrorGeneric> {
let batch_size = usize::try_from(batch_size).expect("16 bit systems are unsupported");
let mut execution_ids_versions = Vec::with_capacity(batch_size);
for ffqn in ffqns {
let needed = batch_size - execution_ids_versions.len();
if needed == 0 {
break;
}
let rows = tx
.query(
&format!(
r"
SELECT execution_id, corresponding_version FROM t_state
WHERE
state = '{STATE_PENDING_AT}' AND
pending_expires_finished <= $1 AND ffqn = $2
AND is_paused = false
AND (incompatible_digest IS NULL OR incompatible_digest <> $3)
ORDER BY pending_expires_finished
{}
LIMIT $4
",
if select_strategy == SelectStrategy::LockForUpdate {
"FOR UPDATE SKIP LOCKED"
} else {
""
}
),
&[
&pending_at_or_sooner, &ffqn.to_string(), ¤t_digest, &i64::try_from(needed).expect("`needed` is <= `batch_size` which is u32"), ],
)
.await
.map_err(DbErrorGeneric::from)?;
for row in rows {
let eid_str: String = get(&row, "execution_id")?;
let corresponding_version: i64 = get(&row, "corresponding_version")?;
let corresponding_version = Version::try_from(corresponding_version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
if let Ok(eid) = ExecutionId::from_str(&eid_str) {
execution_ids_versions.push((eid, corresponding_version.increment()));
} else {
warn!("Ignoring corrupted row in pending auto check: invalid execution_id");
}
}
}
Ok(execution_ids_versions)
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SelectStrategy {
Read,
LockForUpdate,
}
async fn get_pending_by_component_input_digest(
tx: &Transaction<'_>,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
input_digest: &ComponentDigest,
select_strategy: SelectStrategy,
) -> Result<Vec<(ExecutionId, Version)>, DbErrorGeneric> {
let rows = tx
.query(
&format!(
r"
SELECT execution_id, corresponding_version FROM t_state WHERE
state = '{STATE_PENDING_AT}' AND
pending_expires_finished <= $1 AND
component_id_input_digest = $2
AND is_paused = false
ORDER BY pending_expires_finished
{}
LIMIT $3
",
if select_strategy == SelectStrategy::LockForUpdate {
"FOR UPDATE SKIP LOCKED"
} else {
""
}
),
&[&pending_at_or_sooner, &input_digest, &i64::from(batch_size)],
)
.await?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let unpack = || -> Result<(ExecutionId, Version), DbErrorGeneric> {
let eid_str: String = get(&row, "execution_id")?;
let corresponding_version: i64 = get(&row, "corresponding_version")?;
let corresponding_version = Version::try_from(corresponding_version)
.map_err(|_| consistency_db_err("version must be non-negative"))?;
let eid = ExecutionId::from_str(&eid_str)
.map_err(|err| consistency_db_err(err.to_string()))?;
Ok((eid, corresponding_version.increment()))
};
match unpack() {
Ok(val) => result.push(val),
Err(err) => {
warn!("Skipping corrupted row in get_pending_by_component_input_digest: {err:?}");
}
}
}
Ok(result)
}
fn notify_pending_locked(
notifier: &NotifierPendingAt,
current_time: DateTime<Utc>,
ffqn_to_pending_subscription: &std::sync::MutexGuard<PendingFfqnSubscribersHolder>,
) {
if notifier.scheduled_at <= current_time {
ffqn_to_pending_subscription.notify(notifier);
}
}
async fn upgrade_execution_component(
tx: &Transaction<'_>,
execution_id: &ExecutionId,
old: &ComponentDigest,
new: &ComponentDigest,
reason: ComponentUpgradeReason,
) -> Result<(), DbErrorWrite> {
let combined_state = get_combined_state(tx, execution_id).await?;
if combined_state.execution_with_state.component_digest != *old {
return Err(DbErrorWrite::NotFound);
}
let appending_version = combined_state.get_next_version_fail_if_finished()?;
append(
tx,
execution_id,
AppendRequest {
created_at: Utc::now(),
event: ExecutionRequest::ComponentUpgradeFinished {
component_digest: new.clone(),
deployment_id: combined_state.execution_with_state.deployment_id,
outcome: ComponentUpgradeOutcome::Success { reason },
},
},
appending_version,
)
.await?;
Ok(())
}
impl PostgresConnection {
#[instrument(level = Level::TRACE, skip_all)]
fn notify_all(&self, notifiers: Vec<AppendNotifier>, current_time: DateTime<Utc>) {
let (pending_ats, finished_execs, responses) = {
let (mut pending_ats, mut finished_execs, mut responses) =
(Vec::new(), Vec::new(), Vec::new());
for notifier in notifiers {
if let Some(pending_at) = notifier.pending_at {
pending_ats.push(pending_at);
}
if let Some(finished) = notifier.execution_finished {
finished_execs.push(finished);
}
if let Some(response) = notifier.response {
responses.push(response);
}
}
(pending_ats, finished_execs, responses)
};
if !pending_ats.is_empty() {
let guard = self.pending_subscribers.lock().unwrap();
for pending_at in pending_ats {
notify_pending_locked(&pending_at, current_time, &guard);
}
}
if !finished_execs.is_empty() {
let mut guard = self.execution_finished_subscribers.lock().unwrap();
for finished in finished_execs {
if let Some(listeners_of_exe_id) = guard.remove(&finished.execution_id) {
for (_tag, sender) in listeners_of_exe_id {
let _ = sender.send(finished.retval.clone());
}
}
}
}
if !responses.is_empty() {
let mut guard = self.response_subscribers.lock().unwrap();
for (execution_id, response) in responses {
if let Some((sender, _)) = guard.remove(&execution_id) {
let _ = sender.send(response);
}
}
}
}
}
#[async_trait]
impl DbExecutor for PostgresConnection {
#[instrument(level = Level::TRACE, skip(self))]
async fn lock_pending_by_ffqns(
&self,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
ffqns: Arc<[FunctionFqn]>,
created_at: DateTime<Utc>,
component_id: ComponentId,
deployment_id: DeploymentId,
executor_id: ExecutorId,
lock_expires_at: DateTime<Utc>,
run_id: RunId,
retry_config: ComponentRetryConfig,
) -> Result<LockPendingResponse, DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let execution_ids_versions = get_pending_by_ffqns(
&tx,
batch_size,
pending_at_or_sooner,
&ffqns,
SelectStrategy::LockForUpdate,
)
.await?;
if execution_ids_versions.is_empty() {
tx.commit().await?;
return Ok(vec![]);
}
debug!("Locking {execution_ids_versions:?}");
let mut locked_execs = Vec::with_capacity(execution_ids_versions.len());
for (execution_id, version) in execution_ids_versions {
match lock_single_execution(
&tx,
created_at,
&component_id,
true, deployment_id,
&execution_id,
run_id,
&version,
executor_id,
lock_expires_at,
retry_config,
)
.await
{
Ok(locked) => locked_execs.push(locked),
Err(err) => {
tx.rollback().await?; debug!("Locking row {execution_id} failed - {err:?}");
return Err(err);
}
}
}
tx.commit().await?;
Ok(locked_execs)
}
#[instrument(level = Level::TRACE, skip(self))]
async fn lock_pending_by_ffqns_auto(
&self,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
ffqns: Arc<[FunctionFqn]>,
created_at: DateTime<Utc>,
component_id: ComponentId,
deployment_id: DeploymentId,
executor_id: ExecutorId,
lock_expires_at: DateTime<Utc>,
run_id: RunId,
retry_config: ComponentRetryConfig,
) -> Result<LockPendingResponse, DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let current_digest = component_id.component_digest.clone();
let execution_ids_versions = get_pending_by_ffqns_auto(
&tx,
batch_size,
pending_at_or_sooner,
&ffqns,
¤t_digest,
SelectStrategy::LockForUpdate,
)
.await?;
if execution_ids_versions.is_empty() {
tx.commit().await?;
return Ok(vec![]);
}
debug!("Auto-locking {execution_ids_versions:?}");
let mut locked_execs = Vec::with_capacity(execution_ids_versions.len());
for (execution_id, version) in execution_ids_versions {
match lock_single_execution(
&tx,
created_at,
&component_id,
false, deployment_id,
&execution_id,
run_id,
&version,
executor_id,
lock_expires_at,
retry_config,
)
.await
{
Ok(locked) => locked_execs.push(locked),
Err(err) => {
tx.rollback().await?;
debug!("Auto-locking row {execution_id} failed - {err:?}");
return Err(err);
}
}
}
tx.commit().await?;
Ok(locked_execs)
}
#[instrument(level = Level::TRACE, skip(self))]
async fn lock_pending_by_component_digest(
&self,
batch_size: u32,
pending_at_or_sooner: DateTime<Utc>,
component_id: &ComponentId,
deployment_id: DeploymentId,
created_at: DateTime<Utc>,
executor_id: ExecutorId,
lock_expires_at: DateTime<Utc>,
run_id: RunId,
retry_config: ComponentRetryConfig,
) -> Result<LockPendingResponse, DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let execution_ids_versions = get_pending_by_component_input_digest(
&tx,
batch_size,
pending_at_or_sooner,
&component_id.component_digest,
SelectStrategy::LockForUpdate,
)
.await?;
if execution_ids_versions.is_empty() {
tx.commit().await?;
return Ok(vec![]);
}
debug!("Locking {execution_ids_versions:?}");
let mut locked_execs = Vec::with_capacity(execution_ids_versions.len());
for (execution_id, version) in execution_ids_versions {
match lock_single_execution(
&tx,
created_at,
component_id,
true, deployment_id,
&execution_id,
run_id,
&version,
executor_id,
lock_expires_at,
retry_config,
)
.await
{
Ok(locked) => locked_execs.push(locked),
Err(err) => {
tx.rollback().await?; debug!("Locking row {execution_id} failed - {err:?}");
return Err(err);
}
}
}
tx.commit().await?;
Ok(locked_execs)
}
#[cfg(feature = "test")]
#[instrument(level = Level::DEBUG, skip(self))]
async fn lock_one(
&self,
created_at: DateTime<Utc>,
component_id: ComponentId,
deployment_id: DeploymentId,
execution_id: &ExecutionId,
run_id: RunId,
version: Version,
executor_id: ExecutorId,
lock_expires_at: DateTime<Utc>,
retry_config: ComponentRetryConfig,
) -> Result<LockedExecution, DbErrorWrite> {
debug!(%execution_id, "lock_one");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let res = lock_single_execution(
&tx,
created_at,
&component_id,
true, deployment_id,
execution_id,
run_id,
&version,
executor_id,
lock_expires_at,
retry_config,
)
.await?;
tx.commit().await?;
Ok(res)
}
#[instrument(level = Level::DEBUG, skip(self, req))]
async fn append(
&self,
execution_id: ExecutionId,
version: Version,
req: AppendRequest,
) -> Result<AppendResponse, DbErrorWrite> {
debug!(%req, "append");
trace!(?req, "append");
let created_at = req.created_at;
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let (new_version, notifier) = append(&tx, &execution_id, req, version).await?;
tx.commit().await?;
drop(client_guard);
self.notify_all(vec![notifier], created_at);
Ok(new_version)
}
#[instrument(level = Level::DEBUG, skip_all)]
async fn append_batch_respond_to_parent(
&self,
events: AppendEventsToExecution,
response: AppendResponseToExecution,
current_time: DateTime<Utc>,
) -> Result<AppendBatchResponse, DbErrorWrite> {
debug!("append_batch_respond_to_parent");
if events.execution_id == response.parent_execution_id {
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::ValidationFailed(
"Parameters `execution_id` and `parent_execution_id` cannot be the same".into(),
),
));
}
if events.batch.is_empty() {
return Err(DbErrorWrite::NonRetriable(
DbErrorWriteNonRetriable::ValidationFailed("batch cannot be empty".into()),
));
}
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let mut version = events.version;
let mut notifiers = Vec::new();
for append_request in events.batch {
let (v, n) = append(&tx, &events.execution_id, append_request, version).await?;
version = v;
notifiers.push(n);
}
let pending_at_parent = append_response(
&tx,
&response.parent_execution_id,
JoinSetResponseEventOuter {
created_at: response.created_at,
event: JoinSetResponseEvent {
join_set_id: response.join_set_id,
event: JoinSetResponse::ChildExecutionFinished {
child_execution_id: response.child_execution_id,
finished_version: response.finished_version,
result: response.result,
},
},
},
)
.await?;
notifiers.push(pending_at_parent);
tx.commit().await?;
drop(client_guard);
self.notify_all(notifiers, current_time);
Ok(version)
}
#[instrument(level = Level::TRACE, skip(self, timeout_fut))]
async fn wait_for_pending_by_ffqn(
&self,
pending_at_or_sooner: DateTime<Utc>,
ffqns: Arc<[FunctionFqn]>,
current_digest: Option<ComponentDigest>,
timeout_fut: Pin<Box<dyn Future<Output = ()> + Send>>,
) {
let unique_tag: u64 = rand::random();
let (sender, mut receiver) = mpsc::channel(1);
{
let mut pending_subscribers = self.pending_subscribers.lock().unwrap();
for ffqn in ffqns.as_ref() {
pending_subscribers.insert_ffqn(ffqn.clone(), (sender.clone(), unique_tag));
}
}
async {
let mut db_has_pending = false;
{
let mut client_guard = self.client.lock().await;
if let Ok(tx) = client_guard.transaction().await {
let res = if let Some(current_digest) = current_digest {
get_pending_by_ffqns_auto(
&tx,
1,
pending_at_or_sooner,
&ffqns,
¤t_digest,
SelectStrategy::Read,
)
.await
} else {
get_pending_by_ffqns(
&tx,
1,
pending_at_or_sooner,
&ffqns,
SelectStrategy::Read,
)
.await
};
if let Ok(res) = res
&& !res.is_empty()
{
db_has_pending = true;
}
let _ = tx.commit().await;
}
}
if db_has_pending {
trace!("Not waiting, database already contains new pending executions");
return;
}
tokio::select! {
_ = receiver.recv() => {
trace!("Received a notification");
}
() = timeout_fut => {
}
}
}
.await;
{
let mut pending_subscribers = self.pending_subscribers.lock().unwrap();
for ffqn in ffqns.as_ref() {
match pending_subscribers.remove_ffqn(ffqn) {
Some((_, tag)) if tag == unique_tag => {}
Some(other) => {
pending_subscribers.insert_ffqn(ffqn.clone(), other);
}
None => {}
}
}
}
}
#[instrument(level = Level::DEBUG, skip(self, timeout_fut))]
async fn wait_for_pending_by_component_digest(
&self,
pending_at_or_sooner: DateTime<Utc>,
component_digest: &ComponentDigest,
timeout_fut: Pin<Box<dyn Future<Output = ()> + Send>>,
) {
let unique_tag: u64 = rand::random();
let (sender, mut receiver) = mpsc::channel(1);
{
let mut pending_subscribers = self.pending_subscribers.lock().unwrap();
pending_subscribers
.insert_by_component(component_digest.clone(), (sender.clone(), unique_tag));
}
async {
let mut db_has_pending = false;
{
let mut client_guard = self.client.lock().await;
if let Ok(tx) = client_guard.transaction().await {
if let Ok(res) = get_pending_by_component_input_digest(
&tx,
1,
pending_at_or_sooner,
component_digest,
SelectStrategy::Read,
)
.await
&& !res.is_empty()
{
db_has_pending = true;
}
let _ = tx.commit().await;
}
}
if db_has_pending {
trace!("Not waiting, database already contains new pending executions");
return;
}
tokio::select! {
_ = receiver.recv() => {
trace!("Received a notification");
}
() = timeout_fut => {
}
}
}
.await;
{
let mut pending_subscribers = self.pending_subscribers.lock().unwrap();
match pending_subscribers.remove_by_component(component_digest) {
Some((_, tag)) if tag == unique_tag => {}
Some(other) => {
pending_subscribers.insert_by_component(component_digest.clone(), other);
}
None => {}
}
}
}
async fn get_last_execution_event(
&self,
execution_id: &ExecutionId,
) -> Result<ExecutionEvent, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let event = get_last_execution_event(&tx, execution_id).await?;
tx.commit().await?;
Ok(event)
}
}
#[async_trait]
impl DbConnection for PostgresConnection {
#[instrument(level = Level::DEBUG, skip_all, fields(execution_id = %req.execution_id))]
async fn create(&self, req: CreateRequest) -> Result<AppendResponse, DbErrorWrite> {
debug!("create");
trace!(?req, "create");
let created_at = req.created_at;
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let (version, notifier) = create_inner(&tx, req.clone()).await?;
tx.commit().await?;
drop(client_guard);
self.notify_all(vec![notifier], created_at);
Ok(version)
}
#[instrument(level = Level::DEBUG, skip(self))]
async fn get(
&self,
execution_id: &ExecutionId,
) -> Result<concepts::storage::ExecutionLog, DbErrorRead> {
trace!("get");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let res = get_execution_log(&tx, execution_id).await?;
tx.commit().await?;
Ok(res)
}
#[instrument(level = Level::DEBUG, skip(self, batch))]
async fn append_batch(
&self,
current_time: DateTime<Utc>,
batch: Vec<AppendRequest>,
execution_id: ExecutionId,
version: Version,
) -> Result<AppendBatchResponse, DbErrorWrite> {
debug!("append_batch");
trace!(?batch, "append_batch");
assert!(!batch.is_empty(), "Empty batch request");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let mut version = version;
let mut notifier = None;
for append_request in batch {
let (v, n) = append(&tx, &execution_id, append_request, version).await?;
version = v;
notifier = Some(n);
}
tx.commit().await?;
drop(client_guard);
self.notify_all(
vec![notifier.expect("checked that the batch is not empty")],
current_time,
);
Ok(version)
}
#[instrument(level = Level::DEBUG, skip_all, fields(%execution_id, %version))]
async fn append_batch_create_new_execution(
&self,
current_time: DateTime<Utc>,
batch: Vec<AppendRequest>,
execution_id: ExecutionId,
version: Version,
child_req: Vec<CreateRequest>,
backtraces: Vec<BacktraceInfo>,
) -> Result<AppendBatchResponse, DbErrorWrite> {
debug!("append_batch_create_new_execution");
trace!(?batch, ?child_req, "append_batch_create_new_execution");
assert!(!batch.is_empty(), "Empty batch request");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let mut version = version;
let mut notifier = None;
for append_request in batch {
let (v, n) = append(&tx, &execution_id, append_request, version).await?;
version = v;
notifier = Some(n);
}
let mut notifiers = Vec::new();
notifiers.push(notifier.expect("checked that the batch is not empty"));
for req in child_req {
let (_, n) = create_inner(&tx, req).await?;
notifiers.push(n);
}
for backtrace in backtraces {
append_backtrace(&tx, &backtrace).await?;
}
tx.commit().await?;
drop(client_guard);
self.notify_all(notifiers, current_time);
Ok(version)
}
#[instrument(level = Level::DEBUG, skip(self, timeout_fut))]
async fn subscribe_to_next_responses(
&self,
execution_id: &ExecutionId,
last_response: ResponseCursor,
timeout_fut: Pin<Box<dyn Future<Output = TimeoutOutcome> + Send>>,
) -> Result<Vec<ResponseWithCursor>, DbErrorReadWithTimeout> {
debug!("next_responses");
let unique_tag: u64 = rand::random();
let execution_id_clone = execution_id.clone();
let cleanup = || {
let mut guard = self.response_subscribers.lock().unwrap();
match guard.remove(&execution_id_clone) {
Some((_, tag)) if tag == unique_tag => {}
Some(other) => {
guard.insert(execution_id_clone.clone(), other);
}
None => {}
}
};
let receiver = {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let (sender, receiver) = oneshot::channel();
self.response_subscribers
.lock()
.unwrap()
.insert(execution_id.clone(), (sender, unique_tag));
let responses = get_responses_after(&tx, execution_id, last_response).await?;
if responses.is_empty() {
tx.commit().await.map_err(|err| {
cleanup(); DbErrorRead::from(err)
})?;
receiver
} else {
cleanup(); tx.commit().await?;
return Ok(responses);
}
};
let res = tokio::select! {
resp = receiver => {
match resp {
Ok(resp) => Ok(vec![resp]),
Err(_) => Err(DbErrorReadWithTimeout::from(DbErrorGeneric::Close)),
}
}
outcome = timeout_fut => Err(DbErrorReadWithTimeout::Timeout(outcome)),
};
cleanup();
res
}
#[instrument(level = Level::DEBUG, skip(self, timeout_fut))]
async fn wait_for_finished_result(
&self,
execution_id: &ExecutionId,
timeout_fut: Option<Pin<Box<dyn Future<Output = TimeoutOutcome> + Send>>>,
) -> Result<SupportedFunctionReturnValue, DbErrorReadWithTimeout> {
let unique_tag: u64 = rand::random();
let execution_id_clone = execution_id.clone();
let cleanup = || {
let mut guard = self.execution_finished_subscribers.lock().unwrap();
if let Some(subscribers) = guard.get_mut(&execution_id_clone) {
subscribers.remove(&unique_tag);
}
};
let receiver = {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let (sender, receiver) = oneshot::channel();
{
let mut guard = self.execution_finished_subscribers.lock().unwrap();
guard
.entry(execution_id.clone())
.or_default()
.insert(unique_tag, sender);
}
let pending_state = get_combined_state(&tx, execution_id)
.await?
.execution_with_state
.pending_state;
if let PendingState::Finished(finished) = pending_state {
let event = get_execution_event(&tx, execution_id, finished.version).await?;
tx.commit().await?;
cleanup();
if let ExecutionRequest::Finished { retval, .. } = event.event {
return Ok(retval);
}
error!("Mismatch, expected Finished row: {event:?} based on t_state {finished}");
return Err(DbErrorReadWithTimeout::from(consistency_db_err(
"cannot get finished event based on t_state version",
)));
}
tx.commit().await?;
receiver
};
let timeout_fut = timeout_fut.unwrap_or_else(|| Box::pin(std::future::pending()));
let res = tokio::select! {
resp = receiver => {
match resp {
Ok(retval) => Ok(retval),
Err(_recv_err) => Err(DbErrorGeneric::Close.into())
}
}
outcome = timeout_fut => Err(DbErrorReadWithTimeout::Timeout(outcome)),
};
cleanup();
res
}
#[instrument(level = Level::DEBUG, skip_all, fields(%join_set_id, %execution_id))]
async fn append_delay_response(
&self,
created_at: DateTime<Utc>,
execution_id: ExecutionId,
join_set_id: JoinSetId,
delay_id: DelayId,
result: Result<(), ()>,
) -> Result<AppendDelayResponseOutcome, DbErrorWrite> {
debug!("append_delay_response");
let event = JoinSetResponseEventOuter {
created_at,
event: JoinSetResponseEvent {
join_set_id,
event: JoinSetResponse::DelayFinished {
delay_id: delay_id.clone(),
result,
},
},
};
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let res = append_response(&tx, &execution_id, event).await;
match res {
Ok(notifier) => {
tx.commit().await?;
drop(client_guard);
self.notify_all(vec![notifier], created_at);
Ok(AppendDelayResponseOutcome::Success)
}
Err(DbErrorWrite::NonRetriable(DbErrorWriteNonRetriable::Conflict)) => {
tx.rollback().await?;
let tx = client_guard.transaction().await?;
let delay_success = delay_response(&tx, &execution_id, &delay_id).await?;
tx.commit().await?;
match delay_success {
Some(true) => Ok(AppendDelayResponseOutcome::AlreadyFinished),
Some(false) => Ok(AppendDelayResponseOutcome::AlreadyCancelled),
None => Err(DbErrorWrite::Generic(consistency_db_err(
"insert failed yet select did not find the response",
))),
}
}
Err(err) => {
let _ = tx.rollback().await; Err(err)
}
}
}
#[instrument(level = Level::DEBUG, skip_all)]
async fn append_backtrace(&self, append: BacktraceInfo) -> Result<(), DbErrorWrite> {
debug!("append_backtrace");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
append_backtrace(&tx, &append).await?;
tx.commit().await?;
Ok(())
}
#[instrument(level = Level::DEBUG, skip_all)]
async fn append_backtrace_batch(&self, batch: Vec<BacktraceInfo>) -> Result<(), DbErrorWrite> {
debug!("append_backtrace_batch");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
for append in batch {
append_backtrace(&tx, &append).await?;
}
tx.commit().await?;
Ok(())
}
#[instrument(level = Level::DEBUG, skip_all)]
async fn append_log(&self, row: LogInfoAppendRow) -> Result<(), DbErrorWrite> {
trace!("append_log");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
append_log(&tx, &row).await?;
tx.commit().await?;
Ok(())
}
#[instrument(level = Level::DEBUG, skip_all)]
async fn append_log_batch(&self, batch: &[LogInfoAppendRow]) -> Result<(), DbErrorWrite> {
trace!("append_log_batch");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
for row in batch {
append_log(&tx, row).await?;
}
tx.commit().await?;
Ok(())
}
#[instrument(level = Level::TRACE, skip(self))]
async fn get_expired_timers(
&self,
at: DateTime<Utc>,
) -> Result<Vec<ExpiredTimer>, DbErrorGeneric> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let rows = tx
.query(
"SELECT execution_id, join_set_id, delay_id FROM t_delay WHERE expires_at <= $1 AND NOT is_paused",
&[&at],
)
.await?;
let mut expired_timers = Vec::with_capacity(rows.len());
for row in rows {
let unpack = || -> Result<ExpiredTimer, DbErrorGeneric> {
let execution_id: String = get(&row, "execution_id")?;
let execution_id = ExecutionId::from_str(&execution_id)?;
let join_set_id: String = get(&row, "join_set_id")?;
let join_set_id = JoinSetId::from_str(&join_set_id)?;
let delay_id: String = get(&row, "delay_id")?;
let delay_id = DelayId::from_str(&delay_id)?;
Ok(ExpiredTimer::Delay(ExpiredDelay {
execution_id,
join_set_id,
delay_id,
}))
};
match unpack() {
Ok(timer) => expired_timers.push(timer),
Err(err) => warn!("Skipping corrupted row in get_expired_timers (delays): {err:?}"),
}
}
let rows = tx.query(
&format!(
"SELECT execution_id, last_lock_version, corresponding_version, intermittent_event_count, max_retries, retry_exp_backoff_millis, executor_id, run_id, is_paused \
FROM t_state \
WHERE pending_expires_finished <= $1 AND state = '{STATE_LOCKED}'"
),
&[&at]
).await?;
for row in rows {
let unpack = || -> Result<Option<ExpiredTimer>, DbErrorGeneric> {
let execution_id: String = get(&row, "execution_id")?;
let execution_id = ExecutionId::from_str(&execution_id)?;
let is_paused: bool = get(&row, "is_paused")?;
if is_paused {
error!(%execution_id, "encountered invalid paused locked execution while scanning expired locks");
return Ok(None);
}
let last_lock_version: i64 = get(&row, "last_lock_version")?;
let last_lock_version = Version::try_from(last_lock_version)?;
let corresponding_version: i64 = get(&row, "corresponding_version")?;
let corresponding_version = Version::try_from(corresponding_version)?;
let intermittent_event_count =
u32::try_from(get::<i64, _>(&row, "intermittent_event_count")?).map_err(
|_| consistency_db_err("`intermittent_event_count` must not be negative"),
)?;
let max_retries = get::<Option<i64>, _>(&row, "max_retries")?
.map(u32::try_from)
.transpose()
.map_err(|_| consistency_db_err("`max_retries` must not be negative"))?;
let retry_exp_backoff_millis =
u32::try_from(get::<i64, _>(&row, "retry_exp_backoff_millis")?).map_err(
|_| consistency_db_err("`retry_exp_backoff_millis` must not be negative"),
)?;
let executor_id: String = get(&row, "executor_id")?;
let executor_id = ExecutorId::from_str(&executor_id)?;
let run_id: String = get(&row, "run_id")?;
let run_id = RunId::from_str(&run_id)?;
Ok(Some(ExpiredTimer::Lock(ExpiredLock {
execution_id,
locked_at_version: last_lock_version,
next_version: corresponding_version.increment(),
intermittent_event_count,
max_retries,
retry_exp_backoff: Duration::from_millis(u64::from(retry_exp_backoff_millis)),
locked_by: LockedBy {
executor_id,
run_id,
},
})))
};
match unpack() {
Ok(Some(timer)) => expired_timers.push(timer),
Ok(None) => {}
Err(err) => warn!("Skipping corrupted row in get_expired_timers (locks): {err:?}"),
}
}
tx.commit().await?;
if !expired_timers.is_empty() {
debug!("get_expired_timers found {expired_timers:?}");
}
Ok(expired_timers)
}
async fn get_execution_event(
&self,
execution_id: &ExecutionId,
version: &Version,
) -> Result<ExecutionEvent, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let event = get_execution_event(&tx, execution_id, version.0).await?;
tx.commit().await?;
Ok(event)
}
#[instrument(level = Level::DEBUG, skip_all)]
async fn upsert_stub_response(
&self,
execution_id: ExecutionIdDerived,
version: Version,
req: AppendRequest,
response: AppendResponseToExecution,
current_time: DateTime<Utc>,
) -> Result<(), DbErrorStubResponse> {
debug!("upsert_stub_response");
#[cfg(debug_assertions)]
{
let (expected_parent, expected_join_set) = execution_id.split_to_parts();
debug_assert_eq!(expected_parent, response.parent_execution_id);
debug_assert_eq!(expected_join_set, response.join_set_id);
debug_assert_eq!(execution_id, response.child_execution_id);
}
let execution_id = ExecutionId::Derived(execution_id);
let expected_retval = response.result.clone();
let version_for_read = version.0;
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let notifiers = match append(&tx, &execution_id, req, version).await {
Ok((_next_version, notifier_of_child)) => {
let pending_at_parent = append_response(
&tx,
&response.parent_execution_id,
JoinSetResponseEventOuter {
created_at: response.created_at,
event: JoinSetResponseEvent {
join_set_id: response.join_set_id,
event: JoinSetResponse::ChildExecutionFinished {
child_execution_id: response.child_execution_id,
finished_version: response.finished_version,
result: response.result,
},
},
},
)
.await
.map_err(DbErrorStubResponse::Write)?;
Some(vec![notifier_of_child, pending_at_parent])
}
Err(DbErrorWrite::NonRetriable(DbErrorWriteNonRetriable::AlreadyFinished)) => {
let found = get_execution_event(&tx, &execution_id, version_for_read)
.await
.map_err(|_| DbErrorStubResponse::StubConflict)?;
match found.event {
ExecutionRequest::Finished { retval, .. } if retval == expected_retval => None,
_ => return Err(DbErrorStubResponse::StubConflict),
}
}
Err(other) => return Err(DbErrorStubResponse::Write(other)),
};
tx.commit().await?;
drop(client_guard);
if let Some(notifiers) = notifiers {
self.notify_all(notifiers, current_time);
}
Ok(())
}
async fn get_pending_state(
&self,
execution_id: &ExecutionId,
) -> Result<ExecutionWithState, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let combined_state = get_combined_state(&tx, execution_id).await?;
tx.commit().await?;
Ok(combined_state.execution_with_state)
}
}
#[async_trait]
impl DbExternalApi for PostgresConnection {
#[instrument(skip(self))]
async fn get_backtrace(
&self,
execution_id: &ExecutionId,
filter: BacktraceFilter,
) -> Result<BacktraceInfo, DbErrorRead> {
debug!("get_backtrace");
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
params.push(Box::new(execution_id.to_string())); let p_execution_id_idx = format!("${}", params.len());
let mut sql = String::new();
write!(
&mut sql,
"SELECT component_id, version_min_including, version_max_excluding, wasm_backtrace \
FROM t_execution_backtrace e INNER JOIN t_wasm_backtrace w ON e.backtrace_hash = w.backtrace_hash \
WHERE execution_id = {p_execution_id_idx}"
)
.unwrap();
match &filter {
BacktraceFilter::Specific(version) => {
params.push(Box::new(i64::from(version.0))); let p_ver_idx = format!("${}", params.len()); write!(
&mut sql,
" AND version_min_including <= {p_ver_idx} AND version_max_excluding > {p_ver_idx}"
)
.unwrap();
}
BacktraceFilter::First => {
sql.push_str(" ORDER BY version_min_including LIMIT 1");
}
BacktraceFilter::Last => {
sql.push_str(" ORDER BY version_min_including DESC LIMIT 1");
}
}
let params_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
params.iter().map(|p| p.as_ref() as _).collect();
let row = tx.query_one(&sql, ¶ms_refs).await?;
let component_id: Json<ComponentId> = get(&row, "component_id")?;
let component_id = component_id.0;
let version_min_including =
Version::try_from(get::<i64, _>(&row, "version_min_including")?)?;
let version_max_excluding =
Version::try_from(get::<i64, _>(&row, "version_max_excluding")?)?;
let wasm_backtrace: Json<WasmBacktrace> = get(&row, "wasm_backtrace")?;
let wasm_backtrace = wasm_backtrace.0;
tx.commit().await?;
Ok(BacktraceInfo {
execution_id: execution_id.clone(),
component_id,
version_min_including,
version_max_excluding,
wasm_backtrace,
})
}
#[instrument(skip_all)]
async fn upsert_source_file(
&self,
component_digest: &ComponentDigest,
frame_key: &str,
is_suffix: bool,
content: &str,
) -> Result<(), DbErrorWrite> {
let content_hash: [u8; 32] = Sha256::digest(content.as_bytes()).into();
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
tx.execute(
"INSERT INTO t_source_file (content_hash, content) \
VALUES ($1, $2) \
ON CONFLICT (content_hash) DO NOTHING",
&[&content_hash.as_slice(), &content],
)
.await?;
tx.execute(
"INSERT INTO t_component_source \
(component_digest, frame_key, is_suffix, content_hash) \
VALUES ($1, $2, $3, $4) \
ON CONFLICT (component_digest, frame_key, is_suffix) DO NOTHING",
&[
&component_digest.as_slice(),
&frame_key,
&is_suffix,
&content_hash.as_slice(),
],
)
.await?;
tx.commit().await?;
Ok(())
}
#[instrument(skip_all)]
async fn get_source_file(
&self,
component_digest: &ComponentDigest,
file: &str,
) -> Result<Option<String>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let rows = tx
.query(
"SELECT s.content \
FROM t_component_source cs \
JOIN t_source_file s ON cs.content_hash = s.content_hash \
WHERE cs.component_digest = $1 \
AND ( \
(NOT cs.is_suffix AND cs.frame_key = $2) \
OR (cs.is_suffix AND right($2, length(cs.frame_key)) = cs.frame_key) \
)",
&[&component_digest.as_slice(), &file],
)
.await?;
tx.commit().await?;
match rows.len() {
0 => Ok(None),
1 => Ok(Some(get::<String, _>(&rows[0], "content")?)),
_ => {
warn!("Multiple suffix matches for '{file}', returning None");
Ok(None)
}
}
}
#[instrument(skip_all)]
async fn upsert_component_metadata(
&self,
records: Vec<ComponentMetadataRecord>,
) -> Result<(), DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
for record in records {
tx.execute(
"INSERT INTO t_component_metadata \
(component_digest, imports_json, exports_json, wit, wit_origin) \
VALUES ($1, $2, $3, $4, $5) \
ON CONFLICT (component_digest) DO NOTHING",
&[
&record.component_digest.as_slice(),
&Json(&record.imports),
&Json(&record.exports),
&record.wit,
&record.wit_origin,
],
)
.await?;
}
tx.commit().await?;
Ok(())
}
#[instrument(skip_all)]
async fn insert_deployment_components(
&self,
deployment_id: DeploymentId,
records: Vec<DeploymentComponentRecord>,
) -> Result<(), DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
for record in records {
debug_assert_eq!(record.deployment_id, deployment_id);
tx.execute(
"INSERT INTO t_deployment_component \
(deployment_id, component_name, component_type, component_digest) \
VALUES ($1, $2, $3, $4) \
ON CONFLICT (deployment_id, component_name) DO NOTHING",
&[
&deployment_id.to_string(),
&record.component_name.to_string(),
&record.component_type.to_string(),
&record.component_digest.as_slice(),
],
)
.await?;
}
tx.commit().await?;
Ok(())
}
#[instrument(skip_all)]
async fn list_deployment_components(
&self,
deployment_id: DeploymentId,
) -> Result<Vec<DeploymentComponentDetail>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let rows = tx
.query(
"SELECT dc.component_name, dc.component_type, dc.component_digest, \
cm.imports_json, cm.exports_json, cm.wit \
FROM t_deployment_component dc \
JOIN t_component_metadata cm ON dc.component_digest = cm.component_digest \
WHERE dc.deployment_id = $1 \
ORDER BY dc.component_type, dc.component_name",
&[&deployment_id.to_string()],
)
.await?;
tx.commit().await?;
rows.iter()
.map(deployment_component_detail_from_pg_row)
.collect()
}
#[instrument(skip_all)]
async fn get_deployment_component_wit(
&self,
deployment_id: DeploymentId,
component_digest: &ComponentDigest,
) -> Result<Option<String>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let row = tx
.query_opt(
"SELECT cm.wit \
FROM t_deployment_component dc \
JOIN t_component_metadata cm ON dc.component_digest = cm.component_digest \
WHERE dc.deployment_id = $1 AND dc.component_digest = $2 \
LIMIT 1",
&[&deployment_id.to_string(), &component_digest.as_slice()],
)
.await?;
tx.commit().await?;
row.map(|row| get::<String, _>(&row, "wit").map_err(DbErrorRead::Generic))
.transpose()
}
#[instrument(skip(self))]
async fn list_executions(
&self,
filter: ListExecutionsFilter,
pagination: ExecutionListPagination,
) -> Result<Vec<ExecutionWithState>, DbErrorGeneric> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let result = list_executions(&tx, filter, &pagination).await?;
tx.commit().await?;
Ok(result)
}
#[instrument(skip(self))]
async fn list_execution_events(
&self,
execution_id: &ExecutionId,
pagination: Pagination<VersionType>,
include_backtrace_id: bool,
) -> Result<ListExecutionEventsResponse, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let events =
list_execution_events(&tx, execution_id, pagination, include_backtrace_id).await?;
let max_version = get_max_version(&tx, execution_id).await?;
tx.commit().await?;
Ok(ListExecutionEventsResponse {
events,
max_version,
})
}
#[instrument(skip(self))]
async fn list_responses(
&self,
execution_id: &ExecutionId,
pagination: Pagination<u32>,
) -> Result<ListResponsesResponse, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let responses = list_responses(&tx, execution_id, Some(pagination)).await?;
let max_cursor = get_max_response_cursor(&tx, execution_id).await?;
tx.commit().await?;
Ok(ListResponsesResponse {
responses,
max_cursor,
})
}
#[instrument(skip(self))]
async fn list_execution_events_responses(
&self,
execution_id: &ExecutionId,
req_since: &Version,
req_max_length: VersionType,
req_include_backtrace_id: bool,
resp_pagination: Pagination<u32>,
) -> Result<ExecutionWithStateRequestsResponses, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let combined_state = get_combined_state(&tx, execution_id).await?;
let events = list_execution_events(
&tx,
execution_id,
Pagination::NewerThan {
length: req_max_length
.try_into()
.expect("req_max_length fits in u16"),
cursor: req_since.0,
including_cursor: true,
},
req_include_backtrace_id,
)
.await?;
let responses = list_responses(&tx, execution_id, Some(resp_pagination)).await?;
let max_version = get_max_version(&tx, execution_id).await?;
let max_cursor = get_max_response_cursor(&tx, execution_id).await?;
tx.commit().await?;
Ok(ExecutionWithStateRequestsResponses {
execution_with_state: combined_state.execution_with_state,
events,
responses,
max_version,
max_cursor,
})
}
#[instrument(skip(self))]
async fn upgrade_execution_component(
&self,
execution_id: &ExecutionId,
old: &ComponentDigest,
new: &ComponentDigest,
reason: ComponentUpgradeReason,
) -> Result<(), DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
upgrade_execution_component(&tx, execution_id, old, new, reason).await?;
tx.commit().await?;
Ok(())
}
#[instrument(skip(self))]
async fn list_logs(
&self,
execution_id: &ExecutionId,
show_derived: bool,
filter: LogFilter,
pagination: Pagination<DateTime<Utc>>,
) -> Result<ListLogsResponse, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let responses = list_logs_tx(&tx, execution_id, show_derived, &filter, &pagination).await?;
tx.commit().await?;
Ok(responses)
}
#[instrument(skip(self))]
async fn list_deployment_states(
&self,
current_time: DateTime<Utc>,
pagination: Pagination<Option<DeploymentId>>,
include_config_json: bool,
) -> Result<Vec<DeploymentState>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let deployments =
list_deployment_states(&tx, current_time, pagination, include_config_json).await?;
tx.commit().await?;
Ok(deployments)
}
#[instrument(skip(self))]
async fn insert_deployment(&self, record: DeploymentRecord) -> Result<(), DbErrorWrite> {
assert_eq!(
record.status,
DeploymentStatus::Inactive,
"insert_deployment requires Inactive status"
);
assert!(
record.last_active_at.is_none(),
"insert_deployment requires last_active_at == None"
);
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
tx.execute(
"INSERT INTO t_deployment \
(deployment_id, created_at, status, config_json, obelisk_version, created_by) \
VALUES ($1, $2, $3, $4, $5, $6)",
&[
&record.deployment_id.to_string(), &record.created_at, &record.status.as_str(), &record.config_json, &record.obelisk_version, &record.created_by, ],
)
.await?;
tx.commit().await?;
Ok(())
}
#[instrument(skip(self))]
async fn activate_deployment(
&self,
deployment_id: DeploymentId,
now: DateTime<Utc>,
) -> Result<(), DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
tx.execute(
"UPDATE t_deployment SET status = 'inactive' WHERE status IN ('active', 'enqueued')",
&[],
)
.await?;
let rows = tx
.execute(
"UPDATE t_deployment SET status = 'active', last_active_at = $1 WHERE deployment_id = $2",
&[&now, &deployment_id.to_string()],
)
.await?;
tx.commit().await?;
if rows == 0 {
return Err(DbErrorWrite::NotFound);
}
Ok(())
}
async fn enqueue_deployment(&self, deployment_id: DeploymentId) -> Result<(), DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let status_opt = tx
.query_opt(
"SELECT status FROM t_deployment WHERE deployment_id = $1",
&[&deployment_id.to_string()],
)
.await?;
match status_opt.as_ref().map(|r| r.get::<_, &str>("status")) {
None => return Err(DbErrorWrite::NotFound),
Some("active") => return Err(DbErrorWriteNonRetriable::Conflict.into()),
_ => {}
}
tx.execute(
"UPDATE t_deployment SET status = 'inactive' WHERE status = 'enqueued'",
&[],
)
.await?;
let rows = tx
.execute(
"UPDATE t_deployment SET status = 'enqueued' WHERE deployment_id = $1",
&[&deployment_id.to_string()],
)
.await?;
tx.commit().await?;
if rows == 0 {
return Err(DbErrorWrite::NotFound);
}
Ok(())
}
#[instrument(skip(self))]
async fn get_deployment(
&self,
deployment_id: DeploymentId,
) -> Result<Option<DeploymentRecord>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let row = tx
.query_opt(
"SELECT deployment_id, created_at, last_active_at, status, config_json, obelisk_version, created_by \
FROM t_deployment WHERE deployment_id = $1",
&[&deployment_id.to_string()],
)
.await?;
tx.commit().await?;
match row {
None => Ok(None),
Some(r) => Ok(Some(deployment_record_from_pg_row(&r)?)),
}
}
#[instrument(skip(self))]
#[cfg(feature = "test")]
async fn get_active_deployment(&self) -> Result<Option<DeploymentRecord>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let row = tx
.query_opt(
"SELECT deployment_id, created_at, last_active_at, status, config_json, obelisk_version, created_by \
FROM t_deployment WHERE status = 'active' LIMIT 1",
&[],
)
.await?;
tx.commit().await?;
match row {
None => Ok(None),
Some(r) => Ok(Some(deployment_record_from_pg_row(&r)?)),
}
}
async fn get_current_deployment(&self) -> Result<Option<DeploymentRecord>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let row = tx
.query_opt(
"SELECT deployment_id, created_at, last_active_at, status, config_json, obelisk_version, created_by \
FROM t_deployment WHERE status IN ('enqueued', 'active') \
ORDER BY CASE status WHEN 'enqueued' THEN 0 ELSE 1 END LIMIT 1",
&[],
)
.await?;
tx.commit().await?;
match row {
None => Ok(None),
Some(r) => Ok(Some(deployment_record_from_pg_row(&r)?)),
}
}
#[instrument(skip(self))]
async fn list_deployments(
&self,
pagination: Pagination<Option<DeploymentId>>,
) -> Result<Vec<DeploymentRecord>, DbErrorRead> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
let mut add_param = |p: Box<dyn tokio_postgres::types::ToSql + Sync + Send>| {
params.push(p);
format!("${}", params.len())
};
let mut sql = String::from(
"SELECT deployment_id, created_at, last_active_at, status, config_json, obelisk_version, created_by \
FROM t_deployment",
);
if let Some(cursor) = pagination.cursor() {
let p_cursor = add_param(Box::new(cursor.to_string()));
write!(
sql,
" WHERE deployment_id {rel} {p_cursor}",
rel = pagination.rel()
)
.expect("writing to string");
}
let (inner_order, outer_order) = if pagination.is_desc() {
("DESC", "")
} else {
("ASC", "DESC")
};
write!(
sql,
" ORDER BY deployment_id {inner_order} LIMIT {limit}",
limit = pagination.length()
)
.expect("writing to string");
let final_sql = if outer_order.is_empty() {
sql
} else {
format!("SELECT * FROM ({sql}) AS sub ORDER BY deployment_id {outer_order}")
};
let params_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
params.iter().map(|p| p.as_ref() as _).collect();
let rows = tx.query(&final_sql, ¶ms_refs).await?;
tx.commit().await?;
rows.iter()
.map(deployment_record_from_pg_row)
.collect::<Result<Vec<_>, _>>()
}
#[instrument(skip(self))]
async fn pause_execution(
&self,
execution_id: &ExecutionId,
paused_at: DateTime<Utc>,
) -> Result<AppendResponse, DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let combined_state = get_combined_state(&tx, execution_id).await?;
let appending_version = combined_state.get_next_version_fail_if_finished()?;
debug!("Pausing with {appending_version}");
let next_version = if matches!(
combined_state.execution_with_state.pending_state,
PendingState::Locked(_)
) {
let (next_version, _notifier) = append(
&tx,
execution_id,
AppendRequest {
created_at: paused_at,
event: ExecutionRequest::Unlocked(Unlocked {
backoff_expires_at: paused_at,
reason: "paused".into(),
}),
},
appending_version,
)
.await?;
next_version
} else {
appending_version
};
let (next_version, _notifier) = append(
&tx,
execution_id,
AppendRequest {
created_at: paused_at,
event: ExecutionRequest::Paused,
},
next_version,
)
.await?;
tx.commit().await?;
Ok(next_version)
}
#[instrument(skip(self))]
async fn unpause_execution(
&self,
execution_id: &ExecutionId,
unpaused_at: DateTime<Utc>,
) -> Result<AppendResponse, DbErrorWrite> {
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let combined_state = get_combined_state(&tx, execution_id).await?;
let appending_version = combined_state.get_next_version_fail_if_finished()?;
debug!("Unpausing with {appending_version}");
let (next_version, _) = append(
&tx,
execution_id,
AppendRequest {
created_at: unpaused_at,
event: ExecutionRequest::Unpaused,
},
appending_version,
)
.await?;
tx.commit().await?;
Ok(next_version)
}
#[instrument(skip(self))]
async fn pause_delay(&self, delay_id: &DelayId) -> Result<(), DbErrorWrite> {
let (execution_id, join_set_id) = delay_id.split_to_parts();
let client_guard = self.client.lock().await;
let rows_modified = client_guard
.execute(
"UPDATE t_delay SET is_paused = TRUE \
WHERE execution_id = $1 AND join_set_id = $2 AND delay_id = $3",
&[
&execution_id.to_string(),
&join_set_id.to_string(),
&delay_id.to_string(),
],
)
.await?;
if rows_modified == 0 {
return Err(DbErrorWrite::NotFound);
}
Ok(())
}
#[instrument(skip(self))]
async fn unpause_delay(&self, delay_id: &DelayId) -> Result<(), DbErrorWrite> {
let (execution_id, join_set_id) = delay_id.split_to_parts();
let client_guard = self.client.lock().await;
let rows_modified = client_guard
.execute(
"UPDATE t_delay SET is_paused = FALSE \
WHERE execution_id = $1 AND join_set_id = $2 AND delay_id = $3",
&[
&execution_id.to_string(),
&join_set_id.to_string(),
&delay_id.to_string(),
],
)
.await?;
if rows_modified == 0 {
return Err(DbErrorWrite::NotFound);
}
Ok(())
}
}
#[async_trait]
impl DbPoolCloseable for PostgresPool {
async fn close(&self) {
self.pool.close();
}
}
#[cfg(feature = "test")]
#[async_trait]
impl concepts::storage::DbConnectionTest for PostgresConnection {
#[instrument(level = Level::DEBUG, skip(self, response_event), fields(join_set_id = %response_event.join_set_id))]
async fn append_response(
&self,
created_at: DateTime<Utc>,
execution_id: ExecutionId,
response_event: JoinSetResponseEvent,
) -> Result<(), DbErrorWrite> {
debug!("append_response");
let event = JoinSetResponseEventOuter {
created_at,
event: response_event,
};
let mut client_guard = self.client.lock().await;
let tx = client_guard.transaction().await?;
let notifier = append_response(&tx, &execution_id, event).await?;
tx.commit().await?;
drop(client_guard);
self.notify_all(vec![notifier], created_at);
Ok(())
}
}
#[cfg(feature = "test")]
impl PostgresPool {
pub async fn drop_database(&self) {
let mut cfg = deadpool_postgres::Config::new();
cfg.host = Some(self.config.host.clone());
cfg.user = Some(self.config.user.clone());
cfg.password = Some(self.config.password.expose_secret().to_string());
cfg.dbname = Some(ADMIN_DB_NAME.into());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
let pool = cfg
.create_pool(None, NoTls)
.map_err(|err| {
error!("Cannot create the default pool - {err:?}");
InitializationError
})
.unwrap();
let client = pool
.get()
.await
.map_err(|err| {
error!("Cannot get a connection from the default pool - {err:?}");
InitializationError
})
.unwrap();
for _ in 0..3 {
let res = client
.execute(&format!("DROP DATABASE {}", self.config.db_name), &[])
.await; if res.is_ok() {
debug!("Database '{}' dropped.", self.config.db_name);
return;
}
debug!("Dropping db failed - {res:?}",);
}
warn!("Did not drop database {}", self.config.db_name);
}
}