Skip to main content

redisctl_core/enterprise/
progress.rs

1//! Progress tracking and action polling for async Enterprise operations
2//!
3//! Enterprise API operations that are asynchronous return an `Action` which must be polled
4//! until completion. This module provides utilities for that polling
5//! with optional progress callbacks for UI updates.
6
7use crate::error::{CoreError, Result};
8use redis_enterprise::EnterpriseClient;
9use redis_enterprise::actions::Action;
10use std::time::{Duration, Instant};
11
12/// Progress events emitted during async Enterprise operations
13#[derive(Debug, Clone)]
14pub enum EnterpriseProgressEvent {
15    /// Action has been created/started
16    Started { action_uid: String },
17    /// Polling iteration with current status
18    Polling {
19        /// Action UID being polled.
20        action_uid: String,
21        /// Current status string (e.g. `"running"`, `"completed"`).
22        status: String,
23        /// Percent complete as reported by the API.
24        ///
25        /// Typed as `Option<String>` because the Redis Enterprise REST
26        /// API emits this as a string (e.g. `"100"`), not a float.
27        /// Callers that need a numeric value can `.parse::<f32>().ok()`.
28        progress: Option<String>,
29        /// Time since polling started.
30        elapsed: Duration,
31    },
32    /// Action completed successfully
33    Completed { action_uid: String },
34    /// Action failed
35    Failed { action_uid: String, error: String },
36}
37
38/// Callback type for Enterprise progress updates
39///
40/// CLI can use this to update spinners/progress bars.
41/// MCP typically doesn't need this.
42pub type EnterpriseProgressCallback = Box<dyn Fn(EnterpriseProgressEvent) + Send + Sync>;
43
44/// Poll an Enterprise action until completion
45///
46/// # Arguments
47///
48/// * `client` - The Enterprise API client
49/// * `action_uid` - The action UID to poll
50/// * `timeout` - Maximum time to wait for completion
51/// * `interval` - Time between polling attempts
52/// * `on_progress` - Optional callback for progress updates
53///
54/// # Returns
55///
56/// The completed action, or an error if the action failed or timed out.
57///
58/// # Example
59///
60/// ```rust,ignore
61/// use redisctl_core::enterprise::{poll_action, EnterpriseProgressEvent};
62/// use std::time::Duration;
63///
64/// // Start an async operation (returns an action_uid)
65/// let action_uid = "some-action-uid";
66///
67/// // Poll with progress callback
68/// let completed = poll_action(
69///     &client,
70///     action_uid,
71///     Duration::from_secs(600),
72///     Duration::from_secs(5),
73///     Some(Box::new(|event| {
74///         match event {
75///             EnterpriseProgressEvent::Polling { status, progress, elapsed, .. } => {
76///                 println!("Status: {} ({:?}%) ({:.0}s)", status, progress, elapsed.as_secs());
77///             }
78///             EnterpriseProgressEvent::Completed { .. } => {
79///                 println!("Done!");
80///             }
81///             _ => {}
82///         }
83///     })),
84/// ).await?;
85/// ```
86pub async fn poll_action(
87    client: &EnterpriseClient,
88    action_uid: &str,
89    timeout: Duration,
90    interval: Duration,
91    on_progress: Option<EnterpriseProgressCallback>,
92) -> Result<Action> {
93    let start = Instant::now();
94    let handler = client.actions();
95
96    emit(
97        &on_progress,
98        EnterpriseProgressEvent::Started {
99            action_uid: action_uid.to_string(),
100        },
101    );
102
103    loop {
104        let elapsed = start.elapsed();
105        if elapsed > timeout {
106            return Err(CoreError::TaskTimeout(timeout));
107        }
108
109        let action = handler.get(action_uid).await?;
110        let status = action.status.clone();
111
112        emit(
113            &on_progress,
114            EnterpriseProgressEvent::Polling {
115                action_uid: action_uid.to_string(),
116                status: status.clone(),
117                progress: action.progress.clone(),
118                elapsed,
119            },
120        );
121
122        match status.as_str() {
123            "completed" => {
124                emit(
125                    &on_progress,
126                    EnterpriseProgressEvent::Completed {
127                        action_uid: action_uid.to_string(),
128                    },
129                );
130                return Ok(action);
131            }
132            "failed" | "cancelled" => {
133                let error = action
134                    .error
135                    .clone()
136                    .unwrap_or_else(|| format!("Action {}", status));
137
138                emit(
139                    &on_progress,
140                    EnterpriseProgressEvent::Failed {
141                        action_uid: action_uid.to_string(),
142                        error: error.clone(),
143                    },
144                );
145                return Err(CoreError::TaskFailed(error));
146            }
147            // 'queued', 'starting', 'running', 'cancelling' - still in progress
148            _ => {
149                tokio::time::sleep(interval).await;
150            }
151        }
152    }
153}
154
155/// Helper to emit progress events
156fn emit(callback: &Option<EnterpriseProgressCallback>, event: EnterpriseProgressEvent) {
157    if let Some(cb) = callback {
158        cb(event);
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use serde_json::json;
166    use std::sync::{Arc, Mutex};
167    use wiremock::matchers::{method, path};
168    use wiremock::{Mock, MockServer, ResponseTemplate};
169
170    fn test_client(uri: String) -> EnterpriseClient {
171        EnterpriseClient::builder()
172            .base_url(uri)
173            .username("test-user".to_string())
174            .password("test-pass".to_string())
175            .insecure(true)
176            .build()
177            .unwrap()
178    }
179
180    // An action that is already in a terminal "completed" state on the first
181    // poll should return Ok immediately.
182    #[tokio::test]
183    async fn poll_action_immediate_success() {
184        let mock_server = MockServer::start().await;
185        Mock::given(method("GET"))
186            .and(path("/v1/actions/action-1"))
187            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
188                "action_uid": "action-1",
189                "name": "flush",
190                "status": "completed",
191                "progress": "100"
192            })))
193            .mount(&mock_server)
194            .await;
195
196        let client = test_client(mock_server.uri());
197        let result = poll_action(
198            &client,
199            "action-1",
200            Duration::from_secs(5),
201            Duration::from_millis(10),
202            None,
203        )
204        .await;
205
206        match result {
207            Ok(action) => assert_eq!(action.status, "completed"),
208            other => panic!("expected Ok(completed action), got {other:?}"),
209        }
210    }
211
212    // An action that reports "running" twice before completing should keep
213    // polling and ultimately return Ok.
214    #[tokio::test]
215    async fn poll_action_polls_then_succeeds() {
216        let mock_server = MockServer::start().await;
217
218        // Higher priority (lower number) + a call limit means the "running"
219        // response is served for the first two polls, then exhausted.
220        Mock::given(method("GET"))
221            .and(path("/v1/actions/action-1"))
222            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
223                "action_uid": "action-1",
224                "name": "flush",
225                "status": "running",
226                "progress": "50"
227            })))
228            .up_to_n_times(2)
229            .with_priority(1)
230            .mount(&mock_server)
231            .await;
232
233        // Default-priority fallback that takes over once the "running" mock is
234        // exhausted.
235        Mock::given(method("GET"))
236            .and(path("/v1/actions/action-1"))
237            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
238                "action_uid": "action-1",
239                "name": "flush",
240                "status": "completed",
241                "progress": "100"
242            })))
243            .mount(&mock_server)
244            .await;
245
246        let client = test_client(mock_server.uri());
247        let result = poll_action(
248            &client,
249            "action-1",
250            Duration::from_secs(5),
251            Duration::from_millis(10),
252            None,
253        )
254        .await;
255
256        match result {
257            Ok(action) => assert_eq!(action.status, "completed"),
258            other => panic!("expected Ok(completed action), got {other:?}"),
259        }
260    }
261
262    // A "failed" status surfaces the action's error message as TaskFailed.
263    #[tokio::test]
264    async fn poll_action_failure_surfaces_error() {
265        let mock_server = MockServer::start().await;
266        Mock::given(method("GET"))
267            .and(path("/v1/actions/action-1"))
268            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
269                "action_uid": "action-1",
270                "name": "upgrade",
271                "status": "failed",
272                "error": "upgrade failed: version conflict"
273            })))
274            .mount(&mock_server)
275            .await;
276
277        let client = test_client(mock_server.uri());
278        let result = poll_action(
279            &client,
280            "action-1",
281            Duration::from_secs(5),
282            Duration::from_millis(10),
283            None,
284        )
285        .await;
286
287        match result {
288            Err(CoreError::TaskFailed(msg)) => {
289                assert_eq!(msg, "upgrade failed: version conflict");
290            }
291            other => panic!("expected TaskFailed, got {other:?}"),
292        }
293    }
294
295    // A "cancelled" status is also terminal and surfaces as TaskFailed. When no
296    // error is provided, the message falls back to the status.
297    #[tokio::test]
298    async fn poll_action_cancelled_surfaces_as_failed() {
299        let mock_server = MockServer::start().await;
300        Mock::given(method("GET"))
301            .and(path("/v1/actions/action-1"))
302            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
303                "action_uid": "action-1",
304                "name": "flush",
305                "status": "cancelled"
306            })))
307            .mount(&mock_server)
308            .await;
309
310        let client = test_client(mock_server.uri());
311        let result = poll_action(
312            &client,
313            "action-1",
314            Duration::from_secs(5),
315            Duration::from_millis(10),
316            None,
317        )
318        .await;
319
320        match result {
321            Err(CoreError::TaskFailed(msg)) => {
322                assert!(msg.contains("cancelled"), "unexpected message: {msg}");
323            }
324            other => panic!("expected TaskFailed, got {other:?}"),
325        }
326    }
327
328    // With a 1ms timeout and a never-completing action, the first poll runs,
329    // the function sleeps, and the next loop iteration trips the timeout.
330    #[tokio::test]
331    async fn poll_action_times_out() {
332        let mock_server = MockServer::start().await;
333        Mock::given(method("GET"))
334            .and(path("/v1/actions/action-1"))
335            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
336                "action_uid": "action-1",
337                "name": "flush",
338                "status": "running",
339                "progress": "10"
340            })))
341            .mount(&mock_server)
342            .await;
343
344        let client = test_client(mock_server.uri());
345        let result = poll_action(
346            &client,
347            "action-1",
348            Duration::from_millis(1),
349            Duration::from_millis(5),
350            None,
351        )
352        .await;
353
354        match result {
355            Err(CoreError::TaskTimeout(_)) => {}
356            other => panic!("expected TaskTimeout, got {other:?}"),
357        }
358    }
359
360    // The progress callback must observe the lifecycle: Started, at least one
361    // Polling event, and Completed.
362    #[tokio::test]
363    async fn poll_action_emits_progress_events() {
364        let mock_server = MockServer::start().await;
365        Mock::given(method("GET"))
366            .and(path("/v1/actions/action-1"))
367            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
368                "action_uid": "action-1",
369                "name": "flush",
370                "status": "completed",
371                "progress": "100"
372            })))
373            .mount(&mock_server)
374            .await;
375
376        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
377        let sink = Arc::clone(&events);
378        let callback: EnterpriseProgressCallback = Box::new(move |event| {
379            let label = match event {
380                EnterpriseProgressEvent::Started { .. } => "started",
381                EnterpriseProgressEvent::Polling { .. } => "polling",
382                EnterpriseProgressEvent::Completed { .. } => "completed",
383                EnterpriseProgressEvent::Failed { .. } => "failed",
384            };
385            sink.lock().unwrap().push(label.to_string());
386        });
387
388        let client = test_client(mock_server.uri());
389        let result = poll_action(
390            &client,
391            "action-1",
392            Duration::from_secs(5),
393            Duration::from_millis(10),
394            Some(callback),
395        )
396        .await;
397
398        assert!(result.is_ok(), "expected Ok, got {result:?}");
399
400        let observed = events.lock().unwrap();
401        assert!(observed.contains(&"started".to_string()), "{observed:?}");
402        assert!(observed.contains(&"polling".to_string()), "{observed:?}");
403        assert!(observed.contains(&"completed".to_string()), "{observed:?}");
404    }
405}