use crate::error::{CoreError, Result};
use redis_enterprise::EnterpriseClient;
use redis_enterprise::actions::Action;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub enum EnterpriseProgressEvent {
Started { action_uid: String },
Polling {
action_uid: String,
status: String,
progress: Option<String>,
elapsed: Duration,
},
Completed { action_uid: String },
Failed { action_uid: String, error: String },
}
pub type EnterpriseProgressCallback = Box<dyn Fn(EnterpriseProgressEvent) + Send + Sync>;
pub async fn poll_action(
client: &EnterpriseClient,
action_uid: &str,
timeout: Duration,
interval: Duration,
on_progress: Option<EnterpriseProgressCallback>,
) -> Result<Action> {
let start = Instant::now();
let handler = client.actions();
emit(
&on_progress,
EnterpriseProgressEvent::Started {
action_uid: action_uid.to_string(),
},
);
loop {
let elapsed = start.elapsed();
if elapsed > timeout {
return Err(CoreError::TaskTimeout(timeout));
}
let action = handler.get(action_uid).await?;
let status = action.status.clone();
emit(
&on_progress,
EnterpriseProgressEvent::Polling {
action_uid: action_uid.to_string(),
status: status.clone(),
progress: action.progress.clone(),
elapsed,
},
);
match status.as_str() {
"completed" => {
emit(
&on_progress,
EnterpriseProgressEvent::Completed {
action_uid: action_uid.to_string(),
},
);
return Ok(action);
}
"failed" | "cancelled" => {
let error = action
.error
.clone()
.unwrap_or_else(|| format!("Action {}", status));
emit(
&on_progress,
EnterpriseProgressEvent::Failed {
action_uid: action_uid.to_string(),
error: error.clone(),
},
);
return Err(CoreError::TaskFailed(error));
}
_ => {
tokio::time::sleep(interval).await;
}
}
}
}
fn emit(callback: &Option<EnterpriseProgressCallback>, event: EnterpriseProgressEvent) {
if let Some(cb) = callback {
cb(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::{Arc, Mutex};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn test_client(uri: String) -> EnterpriseClient {
EnterpriseClient::builder()
.base_url(uri)
.username("test-user".to_string())
.password("test-pass".to_string())
.insecure(true)
.build()
.unwrap()
}
#[tokio::test]
async fn poll_action_immediate_success() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "flush",
"status": "completed",
"progress": "100"
})))
.mount(&mock_server)
.await;
let client = test_client(mock_server.uri());
let result = poll_action(
&client,
"action-1",
Duration::from_secs(5),
Duration::from_millis(10),
None,
)
.await;
match result {
Ok(action) => assert_eq!(action.status, "completed"),
other => panic!("expected Ok(completed action), got {other:?}"),
}
}
#[tokio::test]
async fn poll_action_polls_then_succeeds() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "flush",
"status": "running",
"progress": "50"
})))
.up_to_n_times(2)
.with_priority(1)
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "flush",
"status": "completed",
"progress": "100"
})))
.mount(&mock_server)
.await;
let client = test_client(mock_server.uri());
let result = poll_action(
&client,
"action-1",
Duration::from_secs(5),
Duration::from_millis(10),
None,
)
.await;
match result {
Ok(action) => assert_eq!(action.status, "completed"),
other => panic!("expected Ok(completed action), got {other:?}"),
}
}
#[tokio::test]
async fn poll_action_failure_surfaces_error() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "upgrade",
"status": "failed",
"error": "upgrade failed: version conflict"
})))
.mount(&mock_server)
.await;
let client = test_client(mock_server.uri());
let result = poll_action(
&client,
"action-1",
Duration::from_secs(5),
Duration::from_millis(10),
None,
)
.await;
match result {
Err(CoreError::TaskFailed(msg)) => {
assert_eq!(msg, "upgrade failed: version conflict");
}
other => panic!("expected TaskFailed, got {other:?}"),
}
}
#[tokio::test]
async fn poll_action_cancelled_surfaces_as_failed() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "flush",
"status": "cancelled"
})))
.mount(&mock_server)
.await;
let client = test_client(mock_server.uri());
let result = poll_action(
&client,
"action-1",
Duration::from_secs(5),
Duration::from_millis(10),
None,
)
.await;
match result {
Err(CoreError::TaskFailed(msg)) => {
assert!(msg.contains("cancelled"), "unexpected message: {msg}");
}
other => panic!("expected TaskFailed, got {other:?}"),
}
}
#[tokio::test]
async fn poll_action_times_out() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "flush",
"status": "running",
"progress": "10"
})))
.mount(&mock_server)
.await;
let client = test_client(mock_server.uri());
let result = poll_action(
&client,
"action-1",
Duration::from_millis(1),
Duration::from_millis(5),
None,
)
.await;
match result {
Err(CoreError::TaskTimeout(_)) => {}
other => panic!("expected TaskTimeout, got {other:?}"),
}
}
#[tokio::test]
async fn poll_action_emits_progress_events() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/actions/action-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"action_uid": "action-1",
"name": "flush",
"status": "completed",
"progress": "100"
})))
.mount(&mock_server)
.await;
let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let sink = Arc::clone(&events);
let callback: EnterpriseProgressCallback = Box::new(move |event| {
let label = match event {
EnterpriseProgressEvent::Started { .. } => "started",
EnterpriseProgressEvent::Polling { .. } => "polling",
EnterpriseProgressEvent::Completed { .. } => "completed",
EnterpriseProgressEvent::Failed { .. } => "failed",
};
sink.lock().unwrap().push(label.to_string());
});
let client = test_client(mock_server.uri());
let result = poll_action(
&client,
"action-1",
Duration::from_secs(5),
Duration::from_millis(10),
Some(callback),
)
.await;
assert!(result.is_ok(), "expected Ok, got {result:?}");
let observed = events.lock().unwrap();
assert!(observed.contains(&"started".to_string()), "{observed:?}");
assert!(observed.contains(&"polling".to_string()), "{observed:?}");
assert!(observed.contains(&"completed".to_string()), "{observed:?}");
}
}