1use std::collections::BTreeSet;
60use std::time::{SystemTime, UNIX_EPOCH};
61
62use ff_core::caps::{matches as caps_matches, CapabilityRequirement};
63use ff_core::contracts::ClaimGrant;
64use ff_core::engine_error::{EngineError, ValidationKind};
65use ff_core::partition::{Partition, PartitionFamily, PartitionKey};
66use ff_core::types::{ExecutionId, LaneId, WorkerId, WorkerInstanceId};
67use serde_json::Value as JsonValue;
68use sqlx::{PgPool, Row};
69use uuid::Uuid;
70
71use crate::error::map_sqlx_error;
72use crate::signal::{current_active_kid, hmac_sign};
73
74const ELIGIBLE_OVERFETCH: i64 = 10;
78
79pub struct PostgresScheduler {
82 pool: PgPool,
83}
84
85impl PostgresScheduler {
86 pub fn new(pool: PgPool) -> Self {
87 Self { pool }
88 }
89
90 pub async fn claim_for_worker(
94 &self,
95 lane: &LaneId,
96 worker_id: &WorkerId,
97 worker_instance_id: &WorkerInstanceId,
98 worker_capabilities: &BTreeSet<String>,
99 grant_ttl_ms: u64,
100 ) -> Result<Option<ClaimGrant>, EngineError> {
101 let (kid, secret) = match current_active_kid(&self.pool).await? {
105 Some(v) => v,
106 None => {
107 return Err(EngineError::Unavailable {
108 op: "claim_for_worker: ff_waitpoint_hmac keystore empty",
109 });
110 }
111 };
112
113 const TOTAL_PARTITIONS: i16 = 256;
118 for part in 0..TOTAL_PARTITIONS {
119 if let Some(grant) = self
120 .try_claim_in_partition(
121 part,
122 lane,
123 worker_id,
124 worker_instance_id,
125 worker_capabilities,
126 grant_ttl_ms,
127 &kid,
128 &secret,
129 )
130 .await?
131 {
132 return Ok(Some(grant));
133 }
134 }
135 Ok(None)
136 }
137
138 #[allow(clippy::too_many_arguments)]
139 async fn try_claim_in_partition(
140 &self,
141 part: i16,
142 lane: &LaneId,
143 worker_id: &WorkerId,
144 worker_instance_id: &WorkerInstanceId,
145 worker_capabilities: &BTreeSet<String>,
146 grant_ttl_ms: u64,
147 kid: &str,
148 secret: &[u8],
149 ) -> Result<Option<ClaimGrant>, EngineError> {
150 let mut tx = self.pool.begin().await.map_err(map_sqlx_error)?;
151
152 let rows = sqlx::query(
154 r#"
155 SELECT execution_id, required_capabilities, raw_fields
156 FROM ff_exec_core
157 WHERE partition_key = $1
158 AND lane_id = $2
159 AND lifecycle_phase = 'runnable'
160 AND eligibility_state = 'eligible_now'
161 ORDER BY priority DESC, created_at_ms ASC
162 FOR UPDATE SKIP LOCKED
163 LIMIT $3
164 "#,
165 )
166 .bind(part)
167 .bind(lane.as_str())
168 .bind(ELIGIBLE_OVERFETCH)
169 .fetch_all(&mut *tx)
170 .await
171 .map_err(map_sqlx_error)?;
172
173 if rows.is_empty() {
174 tx.rollback().await.map_err(map_sqlx_error)?;
175 return Ok(None);
176 }
177
178 let mut picked: Option<(Uuid, JsonValue)> = None;
180 for row in &rows {
181 let required: Vec<String> = row
182 .try_get::<Vec<String>, _>("required_capabilities")
183 .map_err(map_sqlx_error)?;
184 let req = CapabilityRequirement::new(required);
185 let worker_set = ff_core::backend::CapabilitySet::new(worker_capabilities.iter().cloned());
186 if !caps_matches(&req, &worker_set) {
187 continue;
188 }
189 let eid: Uuid = row.try_get("execution_id").map_err(map_sqlx_error)?;
190 let raw: JsonValue = row.try_get("raw_fields").map_err(map_sqlx_error)?;
191 picked = Some((eid, raw));
192 break;
193 }
194 let Some((exec_uuid, raw_fields)) = picked else {
195 tx.rollback().await.map_err(map_sqlx_error)?;
196 return Ok(None);
197 };
198
199 let budget_ids: Vec<String> = raw_fields
204 .get("budget_ids")
205 .and_then(JsonValue::as_str)
206 .map(|s| {
207 s.split(',')
208 .map(str::trim)
209 .filter(|s| !s.is_empty())
210 .map(str::to_owned)
211 .collect()
212 })
213 .unwrap_or_default();
214
215 for bid in &budget_ids {
216 if !admit_budget(&mut tx, bid).await? {
217 tx.rollback().await.map_err(map_sqlx_error)?;
220 return Ok(None);
221 }
222 }
223
224 let _quota_skipped_no_schema = ();
231
232 let now = now_ms();
234 let expires_at_ms = now.saturating_add_unsigned(grant_ttl_ms.min(i64::MAX as u64));
235
236 let partition = Partition {
239 family: PartitionFamily::Execution,
240 index: part as u16,
241 };
242 let hash_tag = partition.hash_tag();
243 let message = format!(
244 "{hash_tag}|{exec_uuid}|{wid}|{wiid}|{exp}",
245 wid = worker_id.as_str(),
246 wiid = worker_instance_id.as_str(),
247 exp = expires_at_ms,
248 );
249 let sig = hmac_sign(secret, kid, message.as_bytes());
250 let grant_key = format!("pg:{hash_tag}:{exec_uuid}:{expires_at_ms}:{sig}");
251
252 let grant_patch = serde_json::json!({
256 "claim_grant": {
257 "grant_key": grant_key,
258 "worker_id": worker_id.as_str(),
259 "worker_instance_id": worker_instance_id.as_str(),
260 "expires_at_ms": expires_at_ms,
261 "issued_at_ms": now,
262 "kid": kid,
263 }
264 });
265 sqlx::query(
266 r#"
267 UPDATE ff_exec_core
268 SET raw_fields = raw_fields || $1::jsonb,
269 eligibility_state = 'pending_claim'
270 WHERE partition_key = $2 AND execution_id = $3
271 "#,
272 )
273 .bind(grant_patch)
274 .bind(part)
275 .bind(exec_uuid)
276 .execute(&mut *tx)
277 .await
278 .map_err(map_sqlx_error)?;
279
280 tx.commit().await.map_err(map_sqlx_error)?;
281
282 let eid = ExecutionId::parse(&format!("{{fp:{part}}}:{exec_uuid}")).map_err(|e| {
284 EngineError::Validation {
285 kind: ValidationKind::Corruption,
286 detail: format!("scheduler: reassembling exec id: {e}"),
287 }
288 })?;
289 Ok(Some(ClaimGrant {
290 execution_id: eid,
291 partition_key: PartitionKey::from(&partition),
292 grant_key,
293 expires_at_ms: expires_at_ms as u64,
294 }))
295 }
296}
297
298pub async fn verify_grant(pool: &PgPool, grant: &ClaimGrant) -> Result<(), GrantVerifyError> {
306 let s = grant.grant_key.as_str();
308 let rest = s.strip_prefix("pg:").ok_or(GrantVerifyError::Malformed)?;
309 let mut parts: Vec<&str> = rest.rsplitn(4, ':').collect(); if parts.len() != 4 {
317 return Err(GrantVerifyError::Malformed);
318 }
319 let hex_part = parts.remove(0);
320 let kid = parts.remove(0);
321 let expires_str = parts.remove(0);
322 let left = parts.remove(0); let expires_at_ms: i64 = expires_str.parse().map_err(|_| GrantVerifyError::Malformed)?;
324 if expires_at_ms <= now_ms() {
325 return Err(GrantVerifyError::Expired);
326 }
327 let close = left.find("}:").ok_or(GrantVerifyError::Malformed)?;
329 let hash_tag = &left[..=close]; let uuid_str = &left[close + 2..];
331
332 let secret = crate::signal::fetch_kid(pool, kid)
334 .await
335 .map_err(|_| GrantVerifyError::Transport)?
336 .ok_or(GrantVerifyError::UnknownKid)?;
337
338 let wid_wiid = read_grant_identity(pool, grant).await?;
340 let message = format!(
341 "{hash_tag}|{uuid_str}|{wid}|{wiid}|{expires_at_ms}",
342 wid = wid_wiid.0,
343 wiid = wid_wiid.1,
344 );
345 let token = format!("{kid}:{hex_part}");
346 crate::signal::hmac_verify(&secret, kid, message.as_bytes(), &token)
347 .map_err(|_| GrantVerifyError::SignatureMismatch)?;
348 Ok(())
349}
350
351async fn read_grant_identity(
355 pool: &PgPool,
356 grant: &ClaimGrant,
357) -> Result<(String, String), GrantVerifyError> {
358 let partition = grant.partition().map_err(|_| GrantVerifyError::Malformed)?;
359 let part = partition.index as i16;
360 let uuid_str = grant
361 .execution_id
362 .as_str()
363 .split_once("}:")
364 .map(|(_, u)| u)
365 .ok_or(GrantVerifyError::Malformed)?;
366 let exec_uuid = Uuid::parse_str(uuid_str).map_err(|_| GrantVerifyError::Malformed)?;
367 let row = sqlx::query(
368 "SELECT raw_fields FROM ff_exec_core WHERE partition_key = $1 AND execution_id = $2",
369 )
370 .bind(part)
371 .bind(exec_uuid)
372 .fetch_optional(pool)
373 .await
374 .map_err(|_| GrantVerifyError::Transport)?
375 .ok_or(GrantVerifyError::UnknownGrant)?;
376 let raw: JsonValue = row.try_get("raw_fields").map_err(|_| GrantVerifyError::Transport)?;
377 let cg = raw.get("claim_grant").ok_or(GrantVerifyError::UnknownGrant)?;
378 let wid = cg
379 .get("worker_id")
380 .and_then(JsonValue::as_str)
381 .ok_or(GrantVerifyError::Malformed)?
382 .to_owned();
383 let wiid = cg
384 .get("worker_instance_id")
385 .and_then(JsonValue::as_str)
386 .ok_or(GrantVerifyError::Malformed)?
387 .to_owned();
388 Ok((wid, wiid))
389}
390
391#[derive(Debug, thiserror::Error)]
393pub enum GrantVerifyError {
394 #[error("grant_key malformed")]
395 Malformed,
396 #[error("grant expired")]
397 Expired,
398 #[error("unknown kid in grant")]
399 UnknownKid,
400 #[error("unknown grant — no row with matching claim_grant in exec_core")]
401 UnknownGrant,
402 #[error("signature verification failed")]
403 SignatureMismatch,
404 #[error("transport error while verifying grant")]
405 Transport,
406}
407
408async fn admit_budget(
413 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
414 budget_id: &str,
415) -> Result<bool, EngineError> {
416 let partition_key: i16 = ff_core::types::BudgetId::parse(budget_id)
420 .map(|bid| {
421 ff_core::partition::budget_partition(&bid, &ff_core::partition::PartitionConfig::default())
422 .index as i16
423 })
424 .unwrap_or(0);
425
426 let policy: Option<JsonValue> = sqlx::query_scalar(
429 r#"
430 SELECT policy_json FROM ff_budget_policy
431 WHERE partition_key = $1 AND budget_id = $2
432 FOR SHARE
433 "#,
434 )
435 .bind(partition_key)
436 .bind(budget_id)
437 .fetch_optional(&mut **tx)
438 .await
439 .map_err(map_sqlx_error)?;
440 let Some(policy) = policy else {
441 return Ok(true);
443 };
444
445 let hard_limit = policy
449 .get("hard_limit")
450 .and_then(JsonValue::as_u64)
451 .or_else(|| {
452 policy
453 .get("hard")
454 .and_then(JsonValue::as_object)
455 .and_then(|o| o.values().next())
456 .and_then(JsonValue::as_u64)
457 });
458 let dimension = policy
459 .get("dimension")
460 .and_then(JsonValue::as_str)
461 .map(str::to_owned)
462 .unwrap_or_else(|| "default".to_owned());
463 let Some(hard_limit) = hard_limit else {
464 return Ok(true);
465 };
466
467 let current: Option<i64> = sqlx::query_scalar(
470 r#"
471 SELECT current_value FROM ff_budget_usage
472 WHERE partition_key = $1 AND budget_id = $2 AND dimensions_key = $3
473 FOR SHARE
474 "#,
475 )
476 .bind(partition_key)
477 .bind(budget_id)
478 .bind(&dimension)
479 .fetch_optional(&mut **tx)
480 .await
481 .map_err(map_sqlx_error)?;
482 let current = current.unwrap_or(0).max(0) as u64;
483
484 Ok(current < hard_limit)
487}
488
489fn now_ms() -> i64 {
490 i64::try_from(
491 SystemTime::now()
492 .duration_since(UNIX_EPOCH)
493 .map(|d| d.as_millis())
494 .unwrap_or(0),
495 )
496 .unwrap_or(i64::MAX)
497}