use crate::error::{CoreError, Result};
use redis_cloud::tasks::TaskStateUpdate;
use redis_cloud::types::TaskStatus;
use redis_cloud::{CloudClient, TaskHandler};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub enum ProgressEvent {
Started { task_id: String },
Polling {
task_id: String,
status: String,
elapsed: Duration,
},
Completed {
task_id: String,
resource_id: Option<i32>,
},
Failed { task_id: String, error: String },
}
pub type ProgressCallback = Box<dyn Fn(ProgressEvent) + Send + Sync>;
pub async fn poll_task(
client: &CloudClient,
task_id: &str,
timeout: Duration,
interval: Duration,
on_progress: Option<ProgressCallback>,
) -> Result<TaskStateUpdate> {
let start = Instant::now();
let handler = TaskHandler::new(client.clone());
emit(
&on_progress,
ProgressEvent::Started {
task_id: task_id.to_string(),
},
);
loop {
let elapsed = start.elapsed();
if elapsed > timeout {
return Err(CoreError::TaskTimeout(timeout));
}
let task = handler.get_task_by_id(task_id.to_string()).await?;
let status = task.status.clone();
let status_label = task_status_label(status.as_ref());
emit(
&on_progress,
ProgressEvent::Polling {
task_id: task_id.to_string(),
status: status_label.clone(),
elapsed,
},
);
match status {
Some(TaskStatus::ProcessingCompleted) => {
let resource_id = task.response.as_ref().and_then(|r| r.resource_id);
emit(
&on_progress,
ProgressEvent::Completed {
task_id: task_id.to_string(),
resource_id,
},
);
return Ok(task);
}
Some(TaskStatus::ProcessingError) => {
let error = task
.response
.as_ref()
.and_then(|r| r.error_message())
.unwrap_or_else(|| format!("Task failed with status: {}", status_label));
emit(
&on_progress,
ProgressEvent::Failed {
task_id: task_id.to_string(),
error: error.clone(),
},
);
return Err(CoreError::TaskFailed(error));
}
_ => {
tokio::time::sleep(interval).await;
}
}
}
}
fn emit(callback: &Option<ProgressCallback>, event: ProgressEvent) {
if let Some(cb) = callback {
cb(event);
}
}
fn task_status_label(status: Option<&TaskStatus>) -> String {
match status {
Some(TaskStatus::Initialized) => "initialized",
Some(TaskStatus::Received) => "received",
Some(TaskStatus::ProcessingInProgress) => "processing-in-progress",
Some(TaskStatus::ProcessingCompleted) => "processing-completed",
Some(TaskStatus::ProcessingError) => "processing-error",
Some(TaskStatus::Unknown) | None => "unknown",
}
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn test_client(uri: String) -> CloudClient {
CloudClient::builder()
.api_key("test-key".to_string())
.api_secret("test-secret".to_string())
.base_url(uri)
.build()
.unwrap()
}
#[tokio::test]
async fn poll_task_surfaces_object_error_description() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/tasks/task-backup"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"taskId": "task-backup",
"commandType": "DATABASE_BACKUP",
"status": "processing-error",
"response": {
"error": {
"type": "BACKUP_FAILED",
"status": "400 BAD_REQUEST",
"description": "Remote backup location is not configured"
}
}
})))
.mount(&mock_server)
.await;
let client = test_client(mock_server.uri());
let result = poll_task(
&client,
"task-backup",
Duration::from_secs(5),
Duration::from_millis(10),
None,
)
.await;
match result {
Err(CoreError::TaskFailed(msg)) => {
assert_eq!(msg, "Remote backup location is not configured");
}
other => panic!("expected TaskFailed with description, got {other:?}"),
}
}
}