async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10    config::Config,
11    error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16    http_client: reqwest::Client,
17    config: Config,
18    backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn with_config(config: Config) -> Self {
27        Self {
28            http_client: reqwest::Client::new(),
29            config,
30            backoff: backoff::ExponentialBackoff::default(),
31        }
32    }
33    pub fn with_api_key(mut self, api_key: String) -> Self {
34        self.config.set_api_key(api_key.into());
35        self
36    }
37
38    pub fn build(
39        http_client: reqwest::Client,
40        config: Config,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    pub fn files(&self) -> crate::operation::file::File<'_> {
51        crate::operation::file::File::new(self)
52    }
53
54    /// 获取当前实例的生成(Generation)信息
55    ///
56    /// 此方法属于操作级别,用于创建一个`Generation`对象,
57    /// 该对象表示当前实例的某一特定生成(代)信息
58    ///
59    /// # Returns
60    ///
61    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
62    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
63        crate::operation::generation::Generation::new(self)
64    }
65
66    /// 启发多模态对话的功能
67    ///
68    /// 该函数提供了与多模态对话相关的操作入口
69    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
70    ///
71    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
72    pub fn multi_modal_conversation(
73        &self,
74    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
75        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
76    }
77
78    /// 获取音频处理功能
79    pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
80        crate::operation::audio::Audio::new(self)
81    }
82
83    /// 创建一个新的 Image2Image 操作实例
84    ///
85    /// 返回一个可用于执行图像到图像转换操作的 Image2Image 构建器
86    ///
87    /// # Examples
88    /// ```
89    /// let client = Client::new();
90    /// let image2image = client.image2image();
91    /// ```
92    ///
93    /// # Errors
94    /// 此方法本身不会返回错误,但后续操作可能会返回错误
95    pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
96        crate::operation::image2image::Image2Image::new(self)
97    }
98
99    /// 创建一个新的文本转图像操作实例
100    ///
101    /// # 返回
102    /// 返回一个 `Text2Image` 操作构建器,用于配置和执行文本转图像请求
103    ///
104    /// # 示例
105    /// ```
106    /// let client = Client::new();
107    /// let text2image = client.text2image();
108    /// ```
109    pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
110        crate::operation::text2image::Text2Image::new(self)
111    }
112
113    /// 创建一个新的文件操作实例
114    ///
115    /// # 返回
116    /// 返回一个可用于执行文件操作的 File 实例
117    ///
118    /// # 示例
119    /// ```
120    /// let client = Client::new();
121    /// let file = client.file();
122    /// ```
123    pub fn file(&self) -> crate::operation::file::File<'_> {
124        crate::operation::file::File::new(self)
125    }
126
127    pub fn http_client(&self) -> reqwest::Client {
128        self.http_client.clone()
129    }
130
131    /// 创建一个新的任务操作实例
132    ///
133    /// # 返回
134    /// 返回一个绑定到当前客户端的 `Task` 实例,用于执行任务相关操作
135    pub fn task(&self) -> crate::operation::task::Task<'_> {
136        crate::operation::task::Task::new(self)
137    }
138
139    /// 获取文本嵌入表示
140    ///
141    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
142    /// 它利用当前实例的上下文来生成文本的嵌入表示
143    ///
144    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
145    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
146    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
147        crate::operation::embeddings::Embeddings::new(self)
148    }
149
150    pub(crate) async fn post_stream<I, O>(
151        &self,
152        path: &str,
153        request: I,
154    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
155    where
156        I: Serialize + Debug,
157        O: DeserializeOwned + std::marker::Send + 'static,
158    {
159        self.post_stream_with_headers(path, request, self.config.headers())
160            .await
161    }
162
163    pub(crate) async fn post_stream_with_headers<I, O>(
164        &self,
165        path: &str,
166        request: I,
167        headers: reqwest::header::HeaderMap,
168    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
169    where
170        I: Serialize + Debug,
171        O: DeserializeOwned + std::marker::Send + 'static,
172    {
173        let event_source = self
174            .http_client
175            .post(self.config.url(path))
176            .headers(headers)
177            .json(&request)
178            .eventsource()?;
179
180        Ok(stream(event_source).await)
181    }
182
183    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
184    where
185        I: Serialize + Debug,
186        O: DeserializeOwned,
187    {
188        self.post_with_headers(path, request, self.config().headers())
189            .await
190    }
191
192    /// 发送带有自定义请求头的 POST 请求
193    ///
194    /// # 参数
195    /// * `path` - API 路径
196    /// * `request` - 要发送的请求体,需要实现 Serialize 和 Debug trait
197    /// * `headers` - 自定义请求头
198    ///
199    /// # 返回值
200    /// 返回解析后的响应数据,类型由调用方指定
201    ///
202    /// # Errors
203    /// 返回 DashScopeError 如果请求失败或响应解析失败
204    ///
205    /// # 注意事项
206    /// 此函数是 crate 内部使用的工具函数,不对外公开
207    pub(crate) async fn post_with_headers<I, O>(
208        &self,
209        path: &str,
210        request: I,
211        headers: reqwest::header::HeaderMap,
212    ) -> Result<O, DashScopeError>
213    where
214        I: Serialize + Debug,
215        O: DeserializeOwned,
216    {
217        let request_maker = || async {
218            Ok(self
219                .http_client
220                .post(self.config.url(path))
221                .headers(headers.clone())
222                .json(&request)
223                .build()?)
224        };
225
226        self.execute(request_maker).await
227    }
228
229    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
230    where
231        O: DeserializeOwned,
232        M: Fn() -> Fut,
233        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
234    {
235        let bytes = self.execute_raw(request_maker).await?;
236
237        let response: O = serde_json::from_slice(bytes.as_ref())
238            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
239
240        Ok(response)
241    }
242
243    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
244    where
245        M: Fn() -> Fut,
246        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
247    {
248        let client = self.http_client.clone();
249
250        backoff::future::retry(self.backoff.clone(), || async {
251            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
252            let response = client
253                .execute(request)
254                .await
255                .map_err(DashScopeError::Reqwest)
256                .map_err(backoff::Error::Permanent)?;
257
258            let status = response.status();
259            let bytes = response
260                .bytes()
261                .await
262                .map_err(DashScopeError::Reqwest)
263                .map_err(backoff::Error::Permanent)?;
264
265            // Deserialize response body from either error object or actual response object
266            if !status.is_success() {
267                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
268                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
269                    .map_err(backoff::Error::Permanent)?;
270
271                if status.as_u16() == 429 {
272                    // Rate limited retry...
273                    tracing::warn!("Rate limited: {}", api_error.message);
274                    return Err(backoff::Error::Transient {
275                        err: DashScopeError::ApiError(api_error),
276                        retry_after: None,
277                    });
278                } else {
279                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
280                        api_error,
281                    )));
282                }
283            }
284
285            Ok(bytes)
286        })
287        .await
288    }
289
290    pub fn config(&self) -> &Config {
291        &self.config
292    }
293    
294    /// 发送 multipart 表单 POST 请求
295    ///
296    /// # 参数
297    /// * `path` - API 路径
298    /// * `form_fn` - 返回 multipart 表单数据的函数
299    ///
300    /// # 返回值
301    /// 返回响应结果,类型由调用方指定
302    ///
303    /// # 错误
304    /// 返回 DashScopeError 如果请求失败或响应解析失败
305    ///
306    /// # 注意事项
307    /// 此函数是 crate 内部使用的工具函数,不对外公开
308    pub(crate) async fn post_multipart<O, F>(
309        &self,
310        path: &str,
311        form_fn: F,
312    ) -> Result<O, DashScopeError>
313    where
314        O: DeserializeOwned,
315        F: Fn() -> reqwest::multipart::Form,
316    {
317        let request_maker = || async {
318            let mut headers = self.config.headers();
319            headers.remove("Content-Type");
320            headers.remove("X-DashScope-OssResourceResolve");
321            Ok(self
322                .http_client
323                .post(self.config.url(path))
324                .headers(headers)
325                .multipart(form_fn())
326                .build()?)
327        };
328
329        self.execute(request_maker).await
330    }
331
332    /// 发送带查询参数的 GET 请求
333    ///
334    /// # 参数
335    /// * `path` - API 路径
336    /// * `params` - 查询参数
337    ///
338    /// # 返回值
339    /// 返回响应结果,类型由调用方指定
340    ///
341    /// # 错误
342    /// 返回 DashScopeError 如果请求失败或响应解析失败
343    ///
344    /// # 注意事项
345    /// 此函数是 crate 内部使用的工具函数,不对外公开
346    pub(crate) async fn get_with_params<O, P>(&self, path: &str, params: &P) -> Result<O, DashScopeError>
347    where
348        O: DeserializeOwned,
349        P: serde::Serialize + ?Sized,
350    {
351        let request_maker = || async {
352            Ok(self
353                .http_client
354                .get(self.config.url(path))
355                .headers(self.config.headers())
356                .query(params)
357                .build()?)
358        };
359
360        self.execute(request_maker).await
361    }
362
363    /// 发送 DELETE 请求
364    ///
365    /// # 参数
366    /// * `path` - API 路径
367    ///
368    /// # 返回值
369    /// 返回响应结果,类型由调用方指定
370    ///
371    /// # 错误
372    /// 返回 DashScopeError 如果请求失败或响应解析失败
373    ///
374    /// # 注意事项
375    /// 此函数是 crate 内部使用的工具函数,不对外公开
376    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, DashScopeError>
377    where
378        O: DeserializeOwned,
379    {
380        let request_maker = || async {
381            Ok(self
382                .http_client
383                .delete(self.config.url(path))
384                .headers(self.config.headers())
385                .build()?)
386        };
387
388        self.execute(request_maker).await
389    }
390}
391
392pub(crate) async fn stream<O>(
393    mut event_source: EventSource,
394) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
395where
396    O: DeserializeOwned + std::marker::Send + 'static,
397{
398    let stream = try_stream! {
399        while let Some(ev) = event_source.next().await {
400            match ev {
401                Err(e) => {
402                    Err(DashScopeError::StreamError(e.to_string()))?;
403                }
404                Ok(Event::Open) => continue,
405                Ok(Event::Message(message)) => {
406                    // First, deserialize to a generic JSON Value to inspect it without failing.
407                    let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
408                        Ok(val) => val,
409                        Err(e) => {
410                            Err(map_deserialization_error(e, message.data.as_bytes()))?;
411                            continue;
412                        }
413                    };
414
415                    // Now, deserialize from the `Value` to the target type `O`.
416                    let response = serde_json::from_value::<O>(json_value.clone())
417                        .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
418
419                    // Yield the successful message
420                    yield response;
421
422                    // Check for finish reason after sending the message.
423                    // This ensures the final message with "stop" is delivered.
424                    let finish_reason = json_value
425                        .pointer("/output/choices/0/finish_reason")
426                        .and_then(|v| v.as_str());
427
428                    if let Some("stop") = finish_reason {
429                        break;
430                    }
431                }
432            }
433        }
434        event_source.close();
435    };
436
437    Box::pin(stream)
438}
439
440#[cfg(test)]
441mod tests {
442    use crate::config::ConfigBuilder;
443
444    use super::*;
445
446    #[test]
447    pub fn test_config() {
448        let config = ConfigBuilder::default()
449            .api_key("test key")
450            .build()
451            .unwrap();
452        let client = Client::with_config(config);
453
454        for header in client.config.headers().iter() {
455            if header.0 == "authorization" {
456                assert_eq!(header.1, "Bearer test key");
457            }
458        }
459    }
460}