use std::collections::BTreeSet;
use std::time::{SystemTime, UNIX_EPOCH};
use ff_core::caps::{matches as caps_matches, CapabilityRequirement};
use ff_core::contracts::ClaimGrant;
use ff_core::engine_error::{EngineError, ValidationKind};
use ff_core::partition::{Partition, PartitionFamily, PartitionKey};
use ff_core::types::{ExecutionId, LaneId, WorkerId, WorkerInstanceId};
use serde_json::Value as JsonValue;
use sqlx::{PgPool, Row};
use uuid::Uuid;
use crate::error::map_sqlx_error;
use crate::signal::{current_active_kid, hmac_sign};
const ELIGIBLE_OVERFETCH: i64 = 10;
pub struct PostgresScheduler {
pool: PgPool,
}
impl PostgresScheduler {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub async fn claim_for_worker(
&self,
lane: &LaneId,
worker_id: &WorkerId,
worker_instance_id: &WorkerInstanceId,
worker_capabilities: &BTreeSet<String>,
grant_ttl_ms: u64,
) -> Result<Option<ClaimGrant>, EngineError> {
let (kid, secret) = match current_active_kid(&self.pool).await? {
Some(v) => v,
None => {
return Err(EngineError::Unavailable {
op: "claim_for_worker: ff_waitpoint_hmac keystore empty",
});
}
};
const TOTAL_PARTITIONS: i16 = 256;
for part in 0..TOTAL_PARTITIONS {
if let Some(grant) = self
.try_claim_in_partition(
part,
lane,
worker_id,
worker_instance_id,
worker_capabilities,
grant_ttl_ms,
&kid,
&secret,
)
.await?
{
return Ok(Some(grant));
}
}
Ok(None)
}
#[allow(clippy::too_many_arguments)]
async fn try_claim_in_partition(
&self,
part: i16,
lane: &LaneId,
worker_id: &WorkerId,
worker_instance_id: &WorkerInstanceId,
worker_capabilities: &BTreeSet<String>,
grant_ttl_ms: u64,
kid: &str,
secret: &[u8],
) -> Result<Option<ClaimGrant>, EngineError> {
let mut tx = self.pool.begin().await.map_err(map_sqlx_error)?;
let rows = sqlx::query(
r#"
SELECT execution_id, required_capabilities, raw_fields
FROM ff_exec_core
WHERE partition_key = $1
AND lane_id = $2
AND lifecycle_phase = 'runnable'
AND eligibility_state = 'eligible_now'
ORDER BY priority DESC, created_at_ms ASC
FOR UPDATE SKIP LOCKED
LIMIT $3
"#,
)
.bind(part)
.bind(lane.as_str())
.bind(ELIGIBLE_OVERFETCH)
.fetch_all(&mut *tx)
.await
.map_err(map_sqlx_error)?;
if rows.is_empty() {
tx.rollback().await.map_err(map_sqlx_error)?;
return Ok(None);
}
let mut picked: Option<(Uuid, JsonValue)> = None;
for row in &rows {
let required: Vec<String> = row
.try_get::<Vec<String>, _>("required_capabilities")
.map_err(map_sqlx_error)?;
let req = CapabilityRequirement::new(required);
let worker_set = ff_core::backend::CapabilitySet::new(worker_capabilities.iter().cloned());
if !caps_matches(&req, &worker_set) {
continue;
}
let eid: Uuid = row.try_get("execution_id").map_err(map_sqlx_error)?;
let raw: JsonValue = row.try_get("raw_fields").map_err(map_sqlx_error)?;
picked = Some((eid, raw));
break;
}
let Some((exec_uuid, raw_fields)) = picked else {
tx.rollback().await.map_err(map_sqlx_error)?;
return Ok(None);
};
let budget_ids: Vec<String> = raw_fields
.get("budget_ids")
.and_then(JsonValue::as_str)
.map(|s| {
s.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(str::to_owned)
.collect()
})
.unwrap_or_default();
for bid in &budget_ids {
if !admit_budget(&mut tx, bid).await? {
tx.rollback().await.map_err(map_sqlx_error)?;
return Ok(None);
}
}
let _quota_skipped_no_schema = ();
let now = now_ms();
let expires_at_ms = now.saturating_add_unsigned(grant_ttl_ms.min(i64::MAX as u64));
let partition = Partition {
family: PartitionFamily::Execution,
index: part as u16,
};
let hash_tag = partition.hash_tag();
let message = format!(
"{hash_tag}|{exec_uuid}|{wid}|{wiid}|{exp}",
wid = worker_id.as_str(),
wiid = worker_instance_id.as_str(),
exp = expires_at_ms,
);
let sig = hmac_sign(secret, kid, message.as_bytes());
let grant_key = format!("pg:{hash_tag}:{exec_uuid}:{expires_at_ms}:{sig}");
let grant_patch = serde_json::json!({
"claim_grant": {
"grant_key": grant_key,
"worker_id": worker_id.as_str(),
"worker_instance_id": worker_instance_id.as_str(),
"expires_at_ms": expires_at_ms,
"issued_at_ms": now,
"kid": kid,
}
});
sqlx::query(
r#"
UPDATE ff_exec_core
SET raw_fields = raw_fields || $1::jsonb,
eligibility_state = 'pending_claim'
WHERE partition_key = $2 AND execution_id = $3
"#,
)
.bind(grant_patch)
.bind(part)
.bind(exec_uuid)
.execute(&mut *tx)
.await
.map_err(map_sqlx_error)?;
tx.commit().await.map_err(map_sqlx_error)?;
let eid = ExecutionId::parse(&format!("{{fp:{part}}}:{exec_uuid}")).map_err(|e| {
EngineError::Validation {
kind: ValidationKind::Corruption,
detail: format!("scheduler: reassembling exec id: {e}"),
}
})?;
Ok(Some(ClaimGrant {
execution_id: eid,
partition_key: PartitionKey::from(&partition),
grant_key,
expires_at_ms: expires_at_ms as u64,
}))
}
}
pub async fn verify_grant(pool: &PgPool, grant: &ClaimGrant) -> Result<(), GrantVerifyError> {
let s = grant.grant_key.as_str();
let rest = s.strip_prefix("pg:").ok_or(GrantVerifyError::Malformed)?;
let mut parts: Vec<&str> = rest.rsplitn(4, ':').collect(); if parts.len() != 4 {
return Err(GrantVerifyError::Malformed);
}
let hex_part = parts.remove(0);
let kid = parts.remove(0);
let expires_str = parts.remove(0);
let left = parts.remove(0); let expires_at_ms: i64 = expires_str.parse().map_err(|_| GrantVerifyError::Malformed)?;
if expires_at_ms <= now_ms() {
return Err(GrantVerifyError::Expired);
}
let close = left.find("}:").ok_or(GrantVerifyError::Malformed)?;
let hash_tag = &left[..=close]; let uuid_str = &left[close + 2..];
let secret = crate::signal::fetch_kid(pool, kid)
.await
.map_err(|_| GrantVerifyError::Transport)?
.ok_or(GrantVerifyError::UnknownKid)?;
let wid_wiid = read_grant_identity(pool, grant).await?;
let message = format!(
"{hash_tag}|{uuid_str}|{wid}|{wiid}|{expires_at_ms}",
wid = wid_wiid.0,
wiid = wid_wiid.1,
);
let token = format!("{kid}:{hex_part}");
crate::signal::hmac_verify(&secret, kid, message.as_bytes(), &token)
.map_err(|_| GrantVerifyError::SignatureMismatch)?;
Ok(())
}
async fn read_grant_identity(
pool: &PgPool,
grant: &ClaimGrant,
) -> Result<(String, String), GrantVerifyError> {
let partition = grant.partition().map_err(|_| GrantVerifyError::Malformed)?;
let part = partition.index as i16;
let uuid_str = grant
.execution_id
.as_str()
.split_once("}:")
.map(|(_, u)| u)
.ok_or(GrantVerifyError::Malformed)?;
let exec_uuid = Uuid::parse_str(uuid_str).map_err(|_| GrantVerifyError::Malformed)?;
let row = sqlx::query(
"SELECT raw_fields FROM ff_exec_core WHERE partition_key = $1 AND execution_id = $2",
)
.bind(part)
.bind(exec_uuid)
.fetch_optional(pool)
.await
.map_err(|_| GrantVerifyError::Transport)?
.ok_or(GrantVerifyError::UnknownGrant)?;
let raw: JsonValue = row.try_get("raw_fields").map_err(|_| GrantVerifyError::Transport)?;
let cg = raw.get("claim_grant").ok_or(GrantVerifyError::UnknownGrant)?;
let wid = cg
.get("worker_id")
.and_then(JsonValue::as_str)
.ok_or(GrantVerifyError::Malformed)?
.to_owned();
let wiid = cg
.get("worker_instance_id")
.and_then(JsonValue::as_str)
.ok_or(GrantVerifyError::Malformed)?
.to_owned();
Ok((wid, wiid))
}
#[derive(Debug, thiserror::Error)]
pub enum GrantVerifyError {
#[error("grant_key malformed")]
Malformed,
#[error("grant expired")]
Expired,
#[error("unknown kid in grant")]
UnknownKid,
#[error("unknown grant — no row with matching claim_grant in exec_core")]
UnknownGrant,
#[error("signature verification failed")]
SignatureMismatch,
#[error("transport error while verifying grant")]
Transport,
}
async fn admit_budget(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
budget_id: &str,
) -> Result<bool, EngineError> {
let partition_key: i16 = ff_core::types::BudgetId::parse(budget_id)
.map(|bid| {
ff_core::partition::budget_partition(&bid, &ff_core::partition::PartitionConfig::default())
.index as i16
})
.unwrap_or(0);
let policy: Option<JsonValue> = sqlx::query_scalar(
r#"
SELECT policy_json FROM ff_budget_policy
WHERE partition_key = $1 AND budget_id = $2
FOR SHARE
"#,
)
.bind(partition_key)
.bind(budget_id)
.fetch_optional(&mut **tx)
.await
.map_err(map_sqlx_error)?;
let Some(policy) = policy else {
return Ok(true);
};
let hard_limit = policy
.get("hard_limit")
.and_then(JsonValue::as_u64)
.or_else(|| {
policy
.get("hard")
.and_then(JsonValue::as_object)
.and_then(|o| o.values().next())
.and_then(JsonValue::as_u64)
});
let dimension = policy
.get("dimension")
.and_then(JsonValue::as_str)
.map(str::to_owned)
.unwrap_or_else(|| "default".to_owned());
let Some(hard_limit) = hard_limit else {
return Ok(true);
};
let current: Option<i64> = sqlx::query_scalar(
r#"
SELECT current_value FROM ff_budget_usage
WHERE partition_key = $1 AND budget_id = $2 AND dimensions_key = $3
FOR SHARE
"#,
)
.bind(partition_key)
.bind(budget_id)
.bind(&dimension)
.fetch_optional(&mut **tx)
.await
.map_err(map_sqlx_error)?;
let current = current.unwrap_or(0).max(0) as u64;
Ok(current < hard_limit)
}
fn now_ms() -> i64 {
i64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0),
)
.unwrap_or(i64::MAX)
}