redisctl_core/enterprise/
progress.rs1use crate::error::{CoreError, Result};
8use redis_enterprise::EnterpriseClient;
9use redis_enterprise::actions::Action;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
14pub enum EnterpriseProgressEvent {
15 Started { action_uid: String },
17 Polling {
19 action_uid: String,
21 status: String,
23 progress: Option<String>,
29 elapsed: Duration,
31 },
32 Completed { action_uid: String },
34 Failed { action_uid: String, error: String },
36}
37
38pub type EnterpriseProgressCallback = Box<dyn Fn(EnterpriseProgressEvent) + Send + Sync>;
43
44pub async fn poll_action(
87 client: &EnterpriseClient,
88 action_uid: &str,
89 timeout: Duration,
90 interval: Duration,
91 on_progress: Option<EnterpriseProgressCallback>,
92) -> Result<Action> {
93 let start = Instant::now();
94 let handler = client.actions();
95
96 emit(
97 &on_progress,
98 EnterpriseProgressEvent::Started {
99 action_uid: action_uid.to_string(),
100 },
101 );
102
103 loop {
104 let elapsed = start.elapsed();
105 if elapsed > timeout {
106 return Err(CoreError::TaskTimeout(timeout));
107 }
108
109 let action = handler.get(action_uid).await?;
110 let status = action.status.clone();
111
112 emit(
113 &on_progress,
114 EnterpriseProgressEvent::Polling {
115 action_uid: action_uid.to_string(),
116 status: status.clone(),
117 progress: action.progress.clone(),
118 elapsed,
119 },
120 );
121
122 match status.as_str() {
123 "completed" => {
124 emit(
125 &on_progress,
126 EnterpriseProgressEvent::Completed {
127 action_uid: action_uid.to_string(),
128 },
129 );
130 return Ok(action);
131 }
132 "failed" | "cancelled" => {
133 let error = action
134 .error
135 .clone()
136 .unwrap_or_else(|| format!("Action {}", status));
137
138 emit(
139 &on_progress,
140 EnterpriseProgressEvent::Failed {
141 action_uid: action_uid.to_string(),
142 error: error.clone(),
143 },
144 );
145 return Err(CoreError::TaskFailed(error));
146 }
147 _ => {
149 tokio::time::sleep(interval).await;
150 }
151 }
152 }
153}
154
155fn emit(callback: &Option<EnterpriseProgressCallback>, event: EnterpriseProgressEvent) {
157 if let Some(cb) = callback {
158 cb(event);
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use serde_json::json;
166 use std::sync::{Arc, Mutex};
167 use wiremock::matchers::{method, path};
168 use wiremock::{Mock, MockServer, ResponseTemplate};
169
170 fn test_client(uri: String) -> EnterpriseClient {
171 EnterpriseClient::builder()
172 .base_url(uri)
173 .username("test-user".to_string())
174 .password("test-pass".to_string())
175 .insecure(true)
176 .build()
177 .unwrap()
178 }
179
180 #[tokio::test]
183 async fn poll_action_immediate_success() {
184 let mock_server = MockServer::start().await;
185 Mock::given(method("GET"))
186 .and(path("/v1/actions/action-1"))
187 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
188 "action_uid": "action-1",
189 "name": "flush",
190 "status": "completed",
191 "progress": "100"
192 })))
193 .mount(&mock_server)
194 .await;
195
196 let client = test_client(mock_server.uri());
197 let result = poll_action(
198 &client,
199 "action-1",
200 Duration::from_secs(5),
201 Duration::from_millis(10),
202 None,
203 )
204 .await;
205
206 match result {
207 Ok(action) => assert_eq!(action.status, "completed"),
208 other => panic!("expected Ok(completed action), got {other:?}"),
209 }
210 }
211
212 #[tokio::test]
215 async fn poll_action_polls_then_succeeds() {
216 let mock_server = MockServer::start().await;
217
218 Mock::given(method("GET"))
221 .and(path("/v1/actions/action-1"))
222 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
223 "action_uid": "action-1",
224 "name": "flush",
225 "status": "running",
226 "progress": "50"
227 })))
228 .up_to_n_times(2)
229 .with_priority(1)
230 .mount(&mock_server)
231 .await;
232
233 Mock::given(method("GET"))
236 .and(path("/v1/actions/action-1"))
237 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
238 "action_uid": "action-1",
239 "name": "flush",
240 "status": "completed",
241 "progress": "100"
242 })))
243 .mount(&mock_server)
244 .await;
245
246 let client = test_client(mock_server.uri());
247 let result = poll_action(
248 &client,
249 "action-1",
250 Duration::from_secs(5),
251 Duration::from_millis(10),
252 None,
253 )
254 .await;
255
256 match result {
257 Ok(action) => assert_eq!(action.status, "completed"),
258 other => panic!("expected Ok(completed action), got {other:?}"),
259 }
260 }
261
262 #[tokio::test]
264 async fn poll_action_failure_surfaces_error() {
265 let mock_server = MockServer::start().await;
266 Mock::given(method("GET"))
267 .and(path("/v1/actions/action-1"))
268 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
269 "action_uid": "action-1",
270 "name": "upgrade",
271 "status": "failed",
272 "error": "upgrade failed: version conflict"
273 })))
274 .mount(&mock_server)
275 .await;
276
277 let client = test_client(mock_server.uri());
278 let result = poll_action(
279 &client,
280 "action-1",
281 Duration::from_secs(5),
282 Duration::from_millis(10),
283 None,
284 )
285 .await;
286
287 match result {
288 Err(CoreError::TaskFailed(msg)) => {
289 assert_eq!(msg, "upgrade failed: version conflict");
290 }
291 other => panic!("expected TaskFailed, got {other:?}"),
292 }
293 }
294
295 #[tokio::test]
298 async fn poll_action_cancelled_surfaces_as_failed() {
299 let mock_server = MockServer::start().await;
300 Mock::given(method("GET"))
301 .and(path("/v1/actions/action-1"))
302 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
303 "action_uid": "action-1",
304 "name": "flush",
305 "status": "cancelled"
306 })))
307 .mount(&mock_server)
308 .await;
309
310 let client = test_client(mock_server.uri());
311 let result = poll_action(
312 &client,
313 "action-1",
314 Duration::from_secs(5),
315 Duration::from_millis(10),
316 None,
317 )
318 .await;
319
320 match result {
321 Err(CoreError::TaskFailed(msg)) => {
322 assert!(msg.contains("cancelled"), "unexpected message: {msg}");
323 }
324 other => panic!("expected TaskFailed, got {other:?}"),
325 }
326 }
327
328 #[tokio::test]
331 async fn poll_action_times_out() {
332 let mock_server = MockServer::start().await;
333 Mock::given(method("GET"))
334 .and(path("/v1/actions/action-1"))
335 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
336 "action_uid": "action-1",
337 "name": "flush",
338 "status": "running",
339 "progress": "10"
340 })))
341 .mount(&mock_server)
342 .await;
343
344 let client = test_client(mock_server.uri());
345 let result = poll_action(
346 &client,
347 "action-1",
348 Duration::from_millis(1),
349 Duration::from_millis(5),
350 None,
351 )
352 .await;
353
354 match result {
355 Err(CoreError::TaskTimeout(_)) => {}
356 other => panic!("expected TaskTimeout, got {other:?}"),
357 }
358 }
359
360 #[tokio::test]
363 async fn poll_action_emits_progress_events() {
364 let mock_server = MockServer::start().await;
365 Mock::given(method("GET"))
366 .and(path("/v1/actions/action-1"))
367 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
368 "action_uid": "action-1",
369 "name": "flush",
370 "status": "completed",
371 "progress": "100"
372 })))
373 .mount(&mock_server)
374 .await;
375
376 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
377 let sink = Arc::clone(&events);
378 let callback: EnterpriseProgressCallback = Box::new(move |event| {
379 let label = match event {
380 EnterpriseProgressEvent::Started { .. } => "started",
381 EnterpriseProgressEvent::Polling { .. } => "polling",
382 EnterpriseProgressEvent::Completed { .. } => "completed",
383 EnterpriseProgressEvent::Failed { .. } => "failed",
384 };
385 sink.lock().unwrap().push(label.to_string());
386 });
387
388 let client = test_client(mock_server.uri());
389 let result = poll_action(
390 &client,
391 "action-1",
392 Duration::from_secs(5),
393 Duration::from_millis(10),
394 Some(callback),
395 )
396 .await;
397
398 assert!(result.is_ok(), "expected Ok, got {result:?}");
399
400 let observed = events.lock().unwrap();
401 assert!(observed.contains(&"started".to_string()), "{observed:?}");
402 assert!(observed.contains(&"polling".to_string()), "{observed:?}");
403 assert!(observed.contains(&"completed".to_string()), "{observed:?}");
404 }
405}