use ff_core::engine_error::EngineError;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use sqlx::PgPool;
use crate::error::map_sqlx_error;
pub const SERIALIZABLE_RETRY_BUDGET: usize = 3;
pub fn is_retryable_serialization(err: &sqlx::Error) -> bool {
if let Some(db) = err.as_database_error()
&& let Some(code) = db.code()
{
matches!(code.as_ref(), "40001" | "40P01")
} else {
false
}
}
pub fn hmac_sign(secret: &[u8], kid: &str, message: &[u8]) -> String {
let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(secret)
.expect("HMAC-SHA256 accepts any key length");
mac.update(kid.as_bytes());
mac.update(b":");
mac.update(message);
let out = mac.finalize().into_bytes();
format!("{kid}:{}", hex::encode(out))
}
pub fn hmac_verify(
secret: &[u8],
kid: &str,
message: &[u8],
token: &str,
) -> Result<(), HmacVerifyError> {
let (tok_kid, tok_hex) =
token.split_once(':').ok_or(HmacVerifyError::Malformed)?;
if tok_kid != kid {
return Err(HmacVerifyError::WrongKid {
expected: kid.to_owned(),
actual: tok_kid.to_owned(),
});
}
let expected = hex::decode(tok_hex).map_err(|_| HmacVerifyError::Malformed)?;
let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(secret)
.map_err(|_| HmacVerifyError::Malformed)?;
mac.update(kid.as_bytes());
mac.update(b":");
mac.update(message);
mac.verify_slice(&expected)
.map_err(|_| HmacVerifyError::SignatureMismatch)
}
#[derive(Debug, thiserror::Error)]
pub enum HmacVerifyError {
#[error("token malformed; expected kid:hex shape")]
Malformed,
#[error("token kid mismatch; expected {expected}, got {actual}")]
WrongKid { expected: String, actual: String },
#[error("HMAC signature mismatch")]
SignatureMismatch,
}
pub async fn current_active_kid(
pool: &PgPool,
) -> Result<Option<(String, Vec<u8>)>, EngineError> {
let row: Option<(String, Vec<u8>)> = sqlx::query_as(
"SELECT kid, secret FROM ff_waitpoint_hmac \
WHERE active = TRUE \
ORDER BY rotated_at_ms DESC LIMIT 1",
)
.fetch_optional(pool)
.await
.map_err(map_sqlx_error)?;
Ok(row)
}
pub async fn fetch_kid(pool: &PgPool, kid: &str) -> Result<Option<Vec<u8>>, EngineError> {
let row: Option<(Vec<u8>,)> = sqlx::query_as(
"SELECT secret FROM ff_waitpoint_hmac WHERE kid = $1",
)
.bind(kid)
.fetch_optional(pool)
.await
.map_err(map_sqlx_error)?;
Ok(row.map(|(s,)| s))
}
pub async fn rotate_waitpoint_hmac_secret_all_impl(
pool: &PgPool,
args: ff_core::contracts::RotateWaitpointHmacSecretAllArgs,
now_ms: i64,
) -> Result<ff_core::contracts::RotateWaitpointHmacSecretAllResult, EngineError> {
use ff_core::contracts::{
RotateWaitpointHmacSecretAllEntry, RotateWaitpointHmacSecretAllResult,
RotateWaitpointHmacSecretOutcome,
};
let secret_bytes = hex::decode(&args.new_secret_hex).map_err(|_| {
EngineError::Validation {
kind: ff_core::engine_error::ValidationKind::InvalidInput,
detail: "new_secret_hex is not valid hex".into(),
}
})?;
let outcome_res: Result<RotateWaitpointHmacSecretOutcome, EngineError> = async {
let mut tx = pool.begin().await.map_err(map_sqlx_error)?;
let existing: Option<(Vec<u8>,)> = sqlx::query_as(
"SELECT secret FROM ff_waitpoint_hmac WHERE kid = $1",
)
.bind(&args.new_kid)
.fetch_optional(&mut *tx)
.await
.map_err(map_sqlx_error)?;
if let Some((prior,)) = existing {
if prior == secret_bytes {
tx.commit().await.map_err(map_sqlx_error)?;
return Ok(RotateWaitpointHmacSecretOutcome::Noop {
kid: args.new_kid.clone(),
});
}
tx.rollback().await.ok();
return Err(EngineError::Conflict(
ff_core::engine_error::ConflictKind::RotationConflict(format!(
"kid {} already installed with a different secret",
args.new_kid
)),
));
}
let prior_active: Option<(String,)> = sqlx::query_as(
"SELECT kid FROM ff_waitpoint_hmac \
WHERE active = TRUE \
ORDER BY rotated_at_ms DESC LIMIT 1",
)
.fetch_optional(&mut *tx)
.await
.map_err(map_sqlx_error)?;
let _ = args.grace_ms;
sqlx::query("UPDATE ff_waitpoint_hmac SET active = FALSE WHERE active = TRUE")
.execute(&mut *tx)
.await
.map_err(map_sqlx_error)?;
sqlx::query(
"INSERT INTO ff_waitpoint_hmac (kid, secret, rotated_at_ms, active) \
VALUES ($1, $2, $3, TRUE)",
)
.bind(&args.new_kid)
.bind(&secret_bytes)
.bind(now_ms)
.execute(&mut *tx)
.await
.map_err(map_sqlx_error)?;
tx.commit().await.map_err(map_sqlx_error)?;
Ok(RotateWaitpointHmacSecretOutcome::Rotated {
previous_kid: prior_active.map(|(k,)| k),
new_kid: args.new_kid.clone(),
gc_count: 0,
})
}
.await;
Ok(RotateWaitpointHmacSecretAllResult::new(vec![
RotateWaitpointHmacSecretAllEntry::new(0, outcome_res),
]))
}
pub async fn seed_waitpoint_hmac_secret_impl(
pool: &PgPool,
args: ff_core::contracts::SeedWaitpointHmacSecretArgs,
now_ms: i64,
) -> Result<ff_core::contracts::SeedOutcome, EngineError> {
use ff_core::contracts::SeedOutcome;
if args.secret_hex.len() != 64 || !args.secret_hex.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(EngineError::Validation {
kind: ff_core::engine_error::ValidationKind::InvalidInput,
detail: "secret_hex must be 64 hex characters (256-bit secret)".into(),
});
}
if args.kid.is_empty() {
return Err(EngineError::Validation {
kind: ff_core::engine_error::ValidationKind::InvalidInput,
detail: "kid must be non-empty".into(),
});
}
let secret_bytes = hex::decode(&args.secret_hex).map_err(|_| EngineError::Validation {
kind: ff_core::engine_error::ValidationKind::InvalidInput,
detail: "secret_hex is not valid hex".into(),
})?;
let mut tx = pool.begin().await.map_err(map_sqlx_error)?;
let existing: Option<(Vec<u8>,)> =
sqlx::query_as("SELECT secret FROM ff_waitpoint_hmac WHERE kid = $1")
.bind(&args.kid)
.fetch_optional(&mut *tx)
.await
.map_err(map_sqlx_error)?;
if let Some((prior,)) = existing {
tx.commit().await.map_err(map_sqlx_error)?;
return Ok(SeedOutcome::AlreadySeeded {
kid: args.kid,
same_secret: prior == secret_bytes,
});
}
let active: Option<(String,)> = sqlx::query_as(
"SELECT kid FROM ff_waitpoint_hmac WHERE active = TRUE \
ORDER BY rotated_at_ms DESC LIMIT 1",
)
.fetch_optional(&mut *tx)
.await
.map_err(map_sqlx_error)?;
if let Some((active_kid,)) = active {
tx.rollback().await.ok();
return Err(EngineError::Validation {
kind: ff_core::engine_error::ValidationKind::InvalidInput,
detail: format!(
"seed_waitpoint_hmac_secret: a different kid {active_kid:?} is already active; \
use rotate_waitpoint_hmac_secret_all to change kid"
),
});
}
sqlx::query(
"INSERT INTO ff_waitpoint_hmac (kid, secret, rotated_at_ms, active) \
VALUES ($1, $2, $3, TRUE)",
)
.bind(&args.kid)
.bind(&secret_bytes)
.bind(now_ms)
.execute(&mut *tx)
.await
.map_err(map_sqlx_error)?;
tx.commit().await.map_err(map_sqlx_error)?;
Ok(SeedOutcome::Seeded { kid: args.kid })
}
#[cfg(test)]
mod hmac_tests {
use super::*;
#[test]
fn sign_then_verify_round_trip() {
let secret = b"super-secret-key";
let tok = hmac_sign(secret, "kid1", b"exec-id:wp-id");
assert!(tok.starts_with("kid1:"));
hmac_verify(secret, "kid1", b"exec-id:wp-id", &tok).expect("verify ok");
}
#[test]
fn verify_rejects_tampered_message() {
let secret = b"s";
let tok = hmac_sign(secret, "k", b"msg");
let err = hmac_verify(secret, "k", b"tampered", &tok).unwrap_err();
assert!(matches!(err, HmacVerifyError::SignatureMismatch));
}
#[test]
fn verify_rejects_wrong_kid() {
let secret = b"s";
let tok = hmac_sign(secret, "k1", b"msg");
let err = hmac_verify(secret, "k2", b"msg", &tok).unwrap_err();
assert!(matches!(err, HmacVerifyError::WrongKid { .. }));
}
#[test]
fn verify_rejects_malformed() {
assert!(matches!(
hmac_verify(b"s", "k", b"msg", "no-colon-token"),
Err(HmacVerifyError::Malformed)
));
assert!(matches!(
hmac_verify(b"s", "k", b"msg", "k:not-hex-zzzz"),
Err(HmacVerifyError::Malformed)
));
}
}