use ferriskey::Client;
use crate::{LIBRARY_SOURCE, LIBRARY_VERSION};
#[derive(Debug, thiserror::Error)]
pub enum LoadError {
#[error("valkey: {0}")]
Valkey(#[from] ferriskey::Error),
#[error("version mismatch after load: expected {expected}, got {got}")]
VersionMismatch { expected: String, got: String },
}
impl LoadError {
pub fn valkey_kind(&self) -> Option<ferriskey::ErrorKind> {
match self {
Self::Valkey(e) => Some(e.kind()),
_ => None,
}
}
}
pub async fn ensure_library(client: &Client) -> Result<(), LoadError> {
match check_version(client).await {
Ok(true) => {
tracing::debug!("flowfabric library already loaded at version {LIBRARY_VERSION}");
return Ok(());
}
Ok(false) => {
tracing::info!("flowfabric library version mismatch, reloading");
}
Err(_) => {
tracing::info!("flowfabric library not loaded, loading");
}
}
const MAX_ATTEMPTS: u32 = 6;
let backoff_ms: [u64; 5] = [500, 1_000, 2_000, 4_000, 7_000];
let mut last_err = None;
for attempt in 1..=MAX_ATTEMPTS {
match client.function_load_replace(LIBRARY_SOURCE).await {
Ok(_name) => {
last_err = None;
break;
}
Err(e) => {
if is_permanent_load_error(&e) {
tracing::error!(attempt, error = %e, "FUNCTION LOAD failed with permanent error");
return Err(LoadError::Valkey(e));
}
if attempt < MAX_ATTEMPTS {
let backoff = backoff_ms
.get((attempt as usize).saturating_sub(1))
.copied()
.unwrap_or(4_000);
if let Err(refresh_err) = client.force_cluster_slot_refresh().await {
tracing::debug!(
attempt,
error = %refresh_err,
"force_cluster_slot_refresh failed between FUNCTION LOAD retries"
);
}
tracing::warn!(
attempt,
max_attempts = MAX_ATTEMPTS,
backoff_ms = backoff,
error = %e,
"FUNCTION LOAD failed (transient), refreshed slot map, retrying"
);
last_err = Some(e);
tokio::time::sleep(std::time::Duration::from_millis(backoff)).await;
} else {
last_err = Some(e);
}
}
}
}
if let Some(e) = last_err {
return Err(LoadError::Valkey(e));
}
match check_version(client).await {
Ok(true) => {
tracing::info!("flowfabric library loaded successfully (version {LIBRARY_VERSION})");
Ok(())
}
Ok(false) => {
let got = get_version_string(client).await.unwrap_or_default();
Err(LoadError::VersionMismatch {
expected: LIBRARY_VERSION.to_string(),
got,
})
}
Err(e) => Err(LoadError::Valkey(e)),
}
}
fn is_permanent_load_error(e: &ferriskey::Error) -> bool {
let msg = e.to_string();
msg.contains("Error compiling") || msg.contains("syntax error") || msg.contains("ERR Error")
}
async fn check_version(client: &Client) -> Result<bool, ferriskey::Error> {
let result: String = client
.fcall("ff_version", &[] as &[&str], &[] as &[&str])
.await?;
Ok(result == LIBRARY_VERSION)
}
async fn get_version_string(client: &Client) -> Result<String, ferriskey::Error> {
let result: String = client
.fcall("ff_version", &[] as &[&str], &[] as &[&str])
.await?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::is_permanent_load_error;
use ferriskey::{Error, ErrorKind};
fn server_err(kind: ErrorKind, detail: &str) -> Error {
Error::from((kind, "server error", detail.to_string()))
}
#[test]
fn readonly_is_not_permanent() {
let e = server_err(
ErrorKind::ReadOnly,
"You can't write against a read only replica.",
);
assert!(
!is_permanent_load_error(&e),
"READONLY must be transient so loader retries on replica-claim races; \
got permanent for {e}"
);
}
#[test]
fn moved_is_not_permanent() {
let e = server_err(ErrorKind::Moved, "3999 127.0.0.1:7002");
assert!(!is_permanent_load_error(&e), "MOVED must be transient; got {e}");
}
#[test]
fn tryagain_is_not_permanent() {
let e = server_err(ErrorKind::TryAgain, "resharding in progress");
assert!(!is_permanent_load_error(&e), "TRYAGAIN must be transient; got {e}");
}
#[test]
fn clusterdown_is_not_permanent() {
let e = server_err(ErrorKind::ClusterDown, "The cluster is down");
assert!(!is_permanent_load_error(&e), "CLUSTERDOWN must be transient; got {e}");
}
#[test]
fn lua_syntax_error_is_permanent() {
let e = server_err(
ErrorKind::ResponseError,
"Error compiling function: user_script:12: syntax error near 'end'",
);
assert!(
is_permanent_load_error(&e),
"Lua compilation failure must be permanent; got transient for {e}"
);
}
}