use std::sync::Arc;
use ff_core::engine_backend::EngineBackend;
use ff_core::partition::{Partition, PartitionFamily, PartitionKey};
use ff_core::types::{ExecutionId, WaitpointId, WaitpointToken};
use subtle::ConstantTimeEq;
use crate::task::{Signal, SignalOutcome};
use crate::worker::FlowFabricWorker;
use crate::SdkError;
#[derive(Debug)]
pub enum SignalBridgeError {
UnknownWaitpoint,
TokenMismatch,
Backend(SdkError),
}
impl std::fmt::Display for SignalBridgeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownWaitpoint => {
f.write_str("unknown waitpoint (consumed, expired, or unknown id)")
}
Self::TokenMismatch => f.write_str("token mismatch (presented HMAC does not verify)"),
Self::Backend(e) => write!(f, "backend: {e}"),
}
}
}
impl std::error::Error for SignalBridgeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Backend(e) => Some(e),
_ => None,
}
}
}
impl From<SdkError> for SignalBridgeError {
fn from(e: SdkError) -> Self {
Self::Backend(e)
}
}
pub async fn verify_and_deliver(
backend: &dyn EngineBackend,
worker: &FlowFabricWorker,
execution_id: &ExecutionId,
waitpoint_id: &WaitpointId,
presented: &WaitpointToken,
mut signal: Signal,
) -> Result<SignalOutcome, SignalBridgeError> {
let partition = PartitionKey::from(&Partition {
family: PartitionFamily::Flow,
index: execution_id.partition(),
});
let stored = backend
.read_waitpoint_token(partition, waitpoint_id)
.await
.map_err(|e| SignalBridgeError::Backend(SdkError::from(e)))?
.ok_or(SignalBridgeError::UnknownWaitpoint)?;
let stored_bytes = stored.as_bytes();
let presented_bytes = presented.as_str().as_bytes();
if !bool::from(stored_bytes.ct_eq(presented_bytes)) {
return Err(SignalBridgeError::TokenMismatch);
}
signal.waitpoint_token = presented.clone();
worker
.deliver_signal(execution_id, waitpoint_id, signal)
.await
.map_err(SignalBridgeError::Backend)
}
pub async fn verify_and_deliver_arc(
backend: &Arc<dyn EngineBackend>,
worker: &FlowFabricWorker,
execution_id: &ExecutionId,
waitpoint_id: &WaitpointId,
presented: &WaitpointToken,
signal: Signal,
) -> Result<SignalOutcome, SignalBridgeError> {
verify_and_deliver(
backend.as_ref(),
worker,
execution_id,
waitpoint_id,
presented,
signal,
)
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_each_variant() {
assert!(SignalBridgeError::UnknownWaitpoint
.to_string()
.contains("unknown waitpoint"));
assert!(SignalBridgeError::TokenMismatch
.to_string()
.contains("token mismatch"));
let wrapped = SignalBridgeError::Backend(SdkError::Config {
context: "test".into(),
field: None,
message: "boom".into(),
});
let msg = wrapped.to_string();
assert!(msg.starts_with("backend: "), "got: {msg}");
}
#[test]
fn error_source_wraps_backend() {
let inner = SdkError::Config {
context: "test".into(),
field: None,
message: "boom".into(),
};
let e = SignalBridgeError::Backend(inner);
assert!(
std::error::Error::source(&e).is_some(),
"Backend variant must expose source"
);
assert!(
std::error::Error::source(&SignalBridgeError::UnknownWaitpoint).is_none(),
"UnknownWaitpoint has no underlying source"
);
}
#[test]
fn from_sdk_error() {
let inner = SdkError::Config {
context: "test".into(),
field: None,
message: "boom".into(),
};
match SignalBridgeError::from(inner) {
SignalBridgeError::Backend(_) => {}
other => panic!("expected Backend, got {other:?}"),
}
}
}