Skip to main content

async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10    config::Config,
11    error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16    pub(crate) http_client: reqwest::Client,
17    pub(crate) config: Config,
18    pub(crate) backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn with_config(config: Config) -> Self {
27        Self {
28            http_client: reqwest::Client::new(),
29            config,
30            backoff: backoff::ExponentialBackoff::default(),
31        }
32    }
33    pub fn with_api_key(mut self, api_key: String) -> Self {
34        self.config.set_api_key(api_key.into());
35        self
36    }
37
38    pub fn build(
39        http_client: reqwest::Client,
40        config: Config,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    pub fn files(&self) -> crate::operation::file::File<'_> {
51        crate::operation::file::File::new(self)
52    }
53
54    /// 获取当前实例的生成(Generation)信息
55    ///
56    /// 此方法属于操作级别,用于创建一个`Generation`对象,
57    /// 该对象表示当前实例的某一特定生成(代)信息
58    ///
59    /// # Returns
60    ///
61    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
62    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
63        crate::operation::generation::Generation::new(self)
64    }
65
66    /// 启发多模态对话的功能
67    ///
68    /// 该函数提供了与多模态对话相关的操作入口
69    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
70    ///
71    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
72    pub fn multi_modal_conversation(
73        &self,
74    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
75        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
76    }
77
78    /// 获取音频处理功能
79    pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
80        crate::operation::audio::Audio::new(self)
81    }
82
83    /// 创建一个新的 Image2Image 操作实例
84    ///
85    /// 返回一个可用于执行图像到图像转换操作的 Image2Image 构建器
86    ///
87    /// # Examples
88    /// ```
89    /// use async_dashscope::Client;
90    /// let client = Client::new();
91    /// let image2image = client.image2image();
92    /// ```
93    ///
94    /// # Errors
95    /// 此方法本身不会返回错误,但后续操作可能会返回错误
96    pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
97        crate::operation::image2image::Image2Image::new(self)
98    }
99
100    /// 创建一个新的文本转图像操作实例
101    ///
102    /// # 返回
103    /// 返回一个 `Text2Image` 操作构建器,用于配置和执行文本转图像请求
104    ///
105    /// # 示例
106    /// ```
107    /// use async_dashscope::Client;
108    /// let client = Client::new();
109    /// let text2image = client.text2image();
110    /// ```
111    pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
112        crate::operation::text2image::Text2Image::new(self)
113    }
114
115    /// 创建一个新的文件操作实例
116    ///
117    /// # 返回
118    /// 返回一个可用于执行文件操作的 File 实例
119    ///
120    /// # 示例
121    /// ```
122    /// use async_dashscope::Client;
123    /// let client = Client::new();
124    /// let file = client.file();
125    /// ```
126    pub fn file(&self) -> crate::operation::file::File<'_> {
127        crate::operation::file::File::new(self)
128    }
129
130    pub fn http_client(&self) -> reqwest::Client {
131        self.http_client.clone()
132    }
133
134    /// 创建一个新的任务操作实例
135    ///
136    /// # 返回
137    /// 返回一个绑定到当前客户端的 `Task` 实例,用于执行任务相关操作
138    pub fn task(&self) -> crate::operation::task::Task<'_> {
139        crate::operation::task::Task::new(self)
140    }
141
142    /// 获取文本嵌入表示
143    ///
144    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
145    /// 它利用当前实例的上下文来生成文本的嵌入表示
146    ///
147    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
148    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
149    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
150        crate::operation::embeddings::Embeddings::new(self)
151    }
152
153    pub(crate) async fn post_stream<I, O>(
154        &self,
155        path: &str,
156        request: I,
157    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
158    where
159        I: Serialize + Debug,
160        O: DeserializeOwned + std::marker::Send + 'static,
161    {
162        self.post_stream_with_headers(path, request, self.config.headers())
163            .await
164    }
165
166    pub(crate) async fn post_stream_with_headers<I, O>(
167        &self,
168        path: &str,
169        request: I,
170        headers: reqwest::header::HeaderMap,
171    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
172    where
173        I: Serialize + Debug,
174        O: DeserializeOwned + std::marker::Send + 'static,
175    {
176        let event_source = self
177            .http_client
178            .post(self.config.url(path))
179            .headers(headers)
180            .json(&request)
181            .eventsource()?;
182
183        Ok(stream(event_source).await)
184    }
185
186    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
187    where
188        I: Serialize + Debug,
189        O: DeserializeOwned,
190    {
191        self.post_with_headers(path, request, self.config().headers())
192            .await
193    }
194
195    /// 发送带有自定义请求头的 POST 请求
196    ///
197    /// # 参数
198    /// * `path` - API 路径
199    /// * `request` - 要发送的请求体,需要实现 Serialize 和 Debug trait
200    /// * `headers` - 自定义请求头
201    ///
202    /// # 返回值
203    /// 返回解析后的响应数据,类型由调用方指定
204    ///
205    /// # Errors
206    /// 返回 DashScopeError 如果请求失败或响应解析失败
207    ///
208    /// # 注意事项
209    /// 此函数是 crate 内部使用的工具函数,不对外公开
210    pub(crate) async fn post_with_headers<I, O>(
211        &self,
212        path: &str,
213        request: I,
214        headers: reqwest::header::HeaderMap,
215    ) -> Result<O, DashScopeError>
216    where
217        I: Serialize + Debug,
218        O: DeserializeOwned,
219    {
220        let request_maker = || async {
221            Ok(self
222                .http_client
223                .post(self.config.url(path))
224                .headers(headers.clone())
225                .json(&request)
226                .build()?)
227        };
228
229        self.execute(request_maker).await
230    }
231
232    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
233    where
234        O: DeserializeOwned,
235        M: Fn() -> Fut,
236        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
237    {
238        let bytes = self.execute_raw(request_maker).await?;
239
240        let response: O = serde_json::from_slice(bytes.as_ref())
241            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
242
243        Ok(response)
244    }
245
246    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
247    where
248        M: Fn() -> Fut,
249        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
250    {
251        let client = self.http_client.clone();
252
253        backoff::future::retry(self.backoff.clone(), || async {
254            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
255            let response = client
256                .execute(request)
257                .await
258                .map_err(DashScopeError::Reqwest)
259                .map_err(backoff::Error::Permanent)?;
260
261            let status = response.status();
262            let bytes = response
263                .bytes()
264                .await
265                .map_err(DashScopeError::Reqwest)
266                .map_err(backoff::Error::Permanent)?;
267
268            // Deserialize response body from either error object or actual response object
269            if !status.is_success() {
270                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
271                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
272                    .map_err(backoff::Error::Permanent)?;
273
274                if status.as_u16() == 429 {
275                    // Rate limited retry...
276                    tracing::warn!("Rate limited: {}", api_error.message);
277                    return Err(backoff::Error::Transient {
278                        err: DashScopeError::ApiError(api_error),
279                        retry_after: None,
280                    });
281                } else {
282                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
283                        api_error,
284                    )));
285                }
286            }
287
288            Ok(bytes)
289        })
290        .await
291    }
292
293    pub fn config(&self) -> &Config {
294        &self.config
295    }
296    
297    /// 发送 multipart 表单 POST 请求
298    ///
299    /// # 参数
300    /// * `path` - API 路径
301    /// * `form_fn` - 返回 multipart 表单数据的函数
302    ///
303    /// # 返回值
304    /// 返回响应结果,类型由调用方指定
305    ///
306    /// # 错误
307    /// 返回 DashScopeError 如果请求失败或响应解析失败
308    ///
309    /// # 注意事项
310    /// 此函数是 crate 内部使用的工具函数,不对外公开
311    pub(crate) async fn post_multipart<O, F>(
312        &self,
313        path: &str,
314        form_fn: F,
315    ) -> Result<O, DashScopeError>
316    where
317        O: DeserializeOwned,
318        F: Fn() -> reqwest::multipart::Form,
319    {
320        let request_maker = || async {
321            let mut headers = self.config.headers();
322            headers.remove("Content-Type");
323            headers.remove("X-DashScope-OssResourceResolve");
324            Ok(self
325                .http_client
326                .post(self.config.url(path))
327                .headers(headers)
328                .multipart(form_fn())
329                .build()?)
330        };
331
332        self.execute(request_maker).await
333    }
334
335    /// 发送带查询参数的 GET 请求
336    ///
337    /// # 参数
338    /// * `path` - API 路径
339    /// * `params` - 查询参数
340    ///
341    /// # 返回值
342    /// 返回响应结果,类型由调用方指定
343    ///
344    /// # 错误
345    /// 返回 DashScopeError 如果请求失败或响应解析失败
346    ///
347    /// # 注意事项
348    /// 此函数是 crate 内部使用的工具函数,不对外公开
349    pub(crate) async fn get_with_params<O, P>(&self, path: &str, params: &P) -> Result<O, DashScopeError>
350    where
351        O: DeserializeOwned,
352        P: serde::Serialize + ?Sized,
353    {
354        let request_maker = || async {
355            Ok(self
356                .http_client
357                .get(self.config.url(path))
358                .headers(self.config.headers())
359                .query(params)
360                .build()?)
361        };
362
363        self.execute(request_maker).await
364    }
365
366    /// 发送 DELETE 请求
367    ///
368    /// # 参数
369    /// * `path` - API 路径
370    ///
371    /// # 返回值
372    /// 返回响应结果,类型由调用方指定
373    ///
374    /// # 错误
375    /// 返回 DashScopeError 如果请求失败或响应解析失败
376    ///
377    /// # 注意事项
378    /// 此函数是 crate 内部使用的工具函数,不对外公开
379    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, DashScopeError>
380    where
381        O: DeserializeOwned,
382    {
383        let request_maker = || async {
384            Ok(self
385                .http_client
386                .delete(self.config.url(path))
387                .headers(self.config.headers())
388                .build()?)
389        };
390
391        self.execute(request_maker).await
392    }
393}
394
395pub(crate) async fn stream<O>(
396    mut event_source: EventSource,
397) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
398where
399    O: DeserializeOwned + std::marker::Send + 'static,
400{
401    let stream = try_stream! {
402        while let Some(ev) = event_source.next().await {
403            match ev {
404                Err(e) => {
405                    Err(DashScopeError::StreamError(e.to_string()))?;
406                }
407                Ok(Event::Open) => continue,
408                Ok(Event::Message(message)) => {
409                    // First, deserialize to a generic JSON Value to inspect it without failing.
410                    let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
411                        Ok(val) => val,
412                        Err(e) => {
413                            Err(map_deserialization_error(e, message.data.as_bytes()))?;
414                            continue;
415                        }
416                    };
417
418                    // Now, deserialize from the `Value` to the target type `O`.
419                    let response = serde_json::from_value::<O>(json_value.clone())
420                        .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
421
422                    // Yield the successful message
423                    yield response;
424
425                    // Check for finish reason after sending the message.
426                    // This ensures the final message with "stop" is delivered.
427                    let finish_reason = json_value
428                        .pointer("/output/choices/0/finish_reason")
429                        .and_then(|v| v.as_str());
430
431                    if let Some("stop") = finish_reason {
432                        break;
433                    }
434                }
435            }
436        }
437        event_source.close();
438    };
439
440    Box::pin(stream)
441}
442
443#[cfg(test)]
444mod tests {
445    use crate::config::ConfigBuilder;
446
447    use super::*;
448
449    #[test]
450    pub fn test_config() {
451        let config = ConfigBuilder::default()
452            .api_key("test key")
453            .build()
454            .unwrap();
455        let client = Client::with_config(config);
456
457        for header in client.config.headers().iter() {
458            if header.0 == "authorization" {
459                assert_eq!(header.1, "Bearer test key");
460            }
461        }
462    }
463}