async_dashscope/operation/
task.rs

1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3use tokio::time::sleep;
4
5use crate::error::{DashScopeError, Result};
6use crate::{Client, operation::common::TaskStatus};
7const TASK_PATH: &str = "/tasks";
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct TaskOutput {
11    pub task_id: String,
12    pub task_status: TaskStatus,
13    pub submit_time: String,
14    pub scheduled_time: Option<String>,
15    pub end_time: Option<String>,
16    pub image_url: Option<String>,
17    pub code: Option<String>,
18    pub message: Option<String>,
19}
20
21#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TaskResult {
23    pub request_id: String,
24    pub output: TaskOutput,
25    pub usage: Option<ImageUsage>,
26}
27
28#[derive(Serialize, Deserialize, Debug, Clone)]
29pub struct ImageUsage {
30    pub image_count: u32,
31}
32
33pub struct Task<'a> {
34    client: &'a Client,
35}
36
37impl<'a> Task<'a> {
38    pub fn new(client: &'a Client) -> Self {
39        Self { client }
40    }
41
42    pub(crate) async fn query(&self, task_id: &str) -> Result<TaskResult> {
43        let http_client = self.client.http_client();
44        let headers = self.client.config().headers();
45        let req = http_client
46            .get(
47                self.client
48                    .config()
49                    .url(format!("{}/{}", TASK_PATH, task_id).as_str()),
50            )
51            .headers(headers)
52            .build()?;
53
54        let resp = http_client.execute(req).await?.bytes().await?;
55
56        // 检查响应是否为空
57        if resp.is_empty() {
58            return Err(DashScopeError::ApiError(crate::error::ApiError {
59                message: "API returned empty response".to_string(),
60                request_id: None,
61                code: Some("EmptyResponse".to_string()),
62            }));
63        }
64
65        let raw_response_str = String::from_utf8_lossy(resp.as_ref());
66        println!("Raw API response: {}", raw_response_str);
67
68        let resp_json = serde_json::from_slice::<TaskResult>(resp.as_ref()).map_err(|e| {
69            crate::error::DashScopeError::JSONDeserialize {
70                source: e,
71                raw_response: resp.to_vec(),
72            }
73        })?;
74
75        Ok(resp_json)
76    }
77
78    /// 轮询任务状态
79    ///
80    /// 该方法会定期查询任务状态,直到任务完成、失败或达到最大轮询次数。
81    ///
82    /// # Arguments
83    /// * `task_id` - 要轮询的任务ID
84    /// * `interval` - 每次轮询之间的间隔时间(秒)
85    /// * `max_attempts` - 最大轮询尝试次数
86    ///
87    /// # Returns
88    /// 返回 `Result<TaskResult>`,包含最终任务结果或错误
89    ///
90    /// # Errors
91    /// - 当任务在最大轮询次数内未完成时返回 `TimeoutError`
92    /// - 当遇到不可重试的错误(如配置错误)时返回相应错误
93    /// - 当API返回空响应或格式错误时会继续重试
94    ///
95    /// # Notes
96    /// - 对于可恢复的错误(如网络问题、临时API错误)会自动重试
97    /// - 每次轮询会打印当前状态信息到标准输出
98    pub async fn poll_task_status(
99        &self,
100        task_id: &str,
101        interval: u64,
102        max_attempts: u32,
103    ) -> Result<TaskResult> {
104        for attempt in 1..=max_attempts {
105            println!("第 {} 次轮询...", attempt);
106
107            match self.query(task_id).await {
108                Ok(result) => {
109                    let task_status = &result.output.task_status;
110
111                    // 如果任务完成或失败,返回结果
112                    match task_status {
113                        TaskStatus::Succeeded => {
114                            return Ok(result);
115                        }
116                        TaskStatus::Failed => {
117                            return Ok(result);
118                        }
119                        TaskStatus::Pending | TaskStatus::Running => {
120                            // 继续轮询
121                            sleep(Duration::from_secs(interval)).await;
122                        }
123                        TaskStatus::Canceled | TaskStatus::Unknown => {
124                            sleep(Duration::from_secs(interval)).await;
125                        }
126                    }
127                }
128                Err(e) => {
129                    // 区分不同类型的错误
130                    match &e {
131                        DashScopeError::JSONDeserialize {
132                            source: _,
133                            raw_response: _,
134                        } => {
135                            // JSON 反序列化错误,可能是 API 响应格式问题
136                            // 继续重试,可能是临时问题
137                            sleep(Duration::from_secs(interval)).await;
138                        }
139                        DashScopeError::Reqwest(_) => {
140                            // 网络错误,继续重试
141                            sleep(Duration::from_secs(interval)).await;
142                        }
143                        DashScopeError::ApiError(api_error) => {
144                            // API 错误,检查是否是空响应错误
145                            if api_error.code.as_deref() == Some("EmptyResponse") {
146                                sleep(Duration::from_secs(interval)).await;
147                            } else {
148                                // 其他 API 错误,可能是配置问题,直接返回错误
149                                return Err(e);
150                            }
151                        }
152                        _ => {
153                            // 其他错误,可能是配置问题,直接返回错误
154                            return Err(e);
155                        }
156                    }
157                }
158            }
159        }
160
161        // 超过最大轮询次数
162        Err(DashScopeError::TimeoutError(
163            "轮询超时,任务未在预期时间内完成".to_string(),
164        ))
165    }
166}