async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10    config::Config,
11    error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16    http_client: reqwest::Client,
17    config: Config,
18    backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn with_config(config: Config) -> Self {
27        Self {
28            http_client: reqwest::Client::new(),
29            config,
30            backoff: backoff::ExponentialBackoff::default(),
31        }
32    }
33    pub fn with_api_key(mut self, api_key: String) -> Self {
34        self.config.set_api_key(api_key.into());
35        self
36    }
37
38    pub fn build(
39        http_client: reqwest::Client,
40        config: Config,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// 获取当前实例的生成(Generation)信息
51    ///
52    /// 此方法属于操作级别,用于创建一个`Generation`对象,
53    /// 该对象表示当前实例的某一特定生成(代)信息
54    ///
55    /// # Returns
56    ///
57    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
58    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59        crate::operation::generation::Generation::new(self)
60    }
61
62    /// 启发多模态对话的功能
63    ///
64    /// 该函数提供了与多模态对话相关的操作入口
65    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
66    ///
67    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
68    pub fn multi_modal_conversation(
69        &self,
70    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72    }
73
74    /// 获取音频处理功能
75    pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76        crate::operation::audio::Audio::new(self)
77    }
78
79    /// 创建一个新的 Image2Image 操作实例
80    ///
81    /// 返回一个可用于执行图像到图像转换操作的 Image2Image 构建器
82    ///
83    /// # Examples
84    /// ```
85    /// let client = Client::new();
86    /// let image2image = client.image2image();
87    /// ```
88    ///
89    /// # Errors
90    /// 此方法本身不会返回错误,但后续操作可能会返回错误
91    pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
92        crate::operation::image2image::Image2Image::new(self)
93    }
94
95    /// 创建一个新的文本转图像操作实例
96    ///
97    /// # 返回
98    /// 返回一个 `Text2Image` 操作构建器,用于配置和执行文本转图像请求
99    ///
100    /// # 示例
101    /// ```
102    /// let client = Client::new();
103    /// let text2image = client.text2image();
104    /// ```
105    pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
106        crate::operation::text2image::Text2Image::new(self)
107    }
108
109    pub fn http_client(&self) -> reqwest::Client {
110        self.http_client.clone()
111    }
112
113    /// 创建一个新的任务操作实例
114    ///
115    /// # 返回
116    /// 返回一个绑定到当前客户端的 `Task` 实例,用于执行任务相关操作
117    pub fn task(&self) -> crate::operation::task::Task<'_> {
118        crate::operation::task::Task::new(self)
119    }
120
121    /// 获取文本嵌入表示
122    ///
123    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
124    /// 它利用当前实例的上下文来生成文本的嵌入表示
125    ///
126    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
127    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
128    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
129        crate::operation::embeddings::Embeddings::new(self)
130    }
131
132    pub(crate) async fn post_stream<I, O>(
133        &self,
134        path: &str,
135        request: I,
136    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
137    where
138        I: Serialize,
139        O: DeserializeOwned + std::marker::Send + 'static,
140    {
141        let event_source = self
142            .http_client
143            .post(self.config.url(path))
144            .headers(self.config.headers())
145            .json(&request)
146            .eventsource()?;
147
148        Ok(stream(event_source).await)
149    }
150
151    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
152    where
153        I: Serialize + Debug,
154        O: DeserializeOwned,
155    {
156        self.post_with_headers(path, request, self.config().headers())
157            .await
158    }
159
160    /// 发送带有自定义请求头的 POST 请求
161    ///
162    /// # 参数
163    /// * `path` - API 路径
164    /// * `request` - 要发送的请求体,需要实现 Serialize 和 Debug trait
165    /// * `headers` - 自定义请求头
166    ///
167    /// # 返回值
168    /// 返回解析后的响应数据,类型由调用方指定
169    ///
170    /// # Errors
171    /// 返回 DashScopeError 如果请求失败或响应解析失败
172    ///
173    /// # 注意事项
174    /// 此函数是 crate 内部使用的工具函数,不对外公开
175    pub(crate) async fn post_with_headers<I, O>(
176        &self,
177        path: &str,
178        request: I,
179        headers: reqwest::header::HeaderMap,
180    ) -> Result<O, DashScopeError>
181    where
182        I: Serialize + Debug,
183        O: DeserializeOwned,
184    {
185        let request_maker = || async {
186            Ok(self
187                .http_client
188                .post(self.config.url(path))
189                .headers(headers.clone())
190                .json(&request)
191                .build()?)
192        };
193
194        self.execute(request_maker).await
195    }
196
197    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
198    where
199        O: DeserializeOwned,
200        M: Fn() -> Fut,
201        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
202    {
203        let bytes = self.execute_raw(request_maker).await?;
204
205        let response: O = serde_json::from_slice(bytes.as_ref())
206            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
207
208        Ok(response)
209    }
210
211    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
212    where
213        M: Fn() -> Fut,
214        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
215    {
216        let client = self.http_client.clone();
217
218        backoff::future::retry(self.backoff.clone(), || async {
219            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
220            let response = client
221                .execute(request)
222                .await
223                .map_err(DashScopeError::Reqwest)
224                .map_err(backoff::Error::Permanent)?;
225
226            let status = response.status();
227            let bytes = response
228                .bytes()
229                .await
230                .map_err(DashScopeError::Reqwest)
231                .map_err(backoff::Error::Permanent)?;
232
233            // Deserialize response body from either error object or actual response object
234            if !status.is_success() {
235                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
236                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
237                    .map_err(backoff::Error::Permanent)?;
238
239                if status.as_u16() == 429 {
240                    // Rate limited retry...
241                    tracing::warn!("Rate limited: {}", api_error.message);
242                    return Err(backoff::Error::Transient {
243                        err: DashScopeError::ApiError(api_error),
244                        retry_after: None,
245                    });
246                } else {
247                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
248                        api_error,
249                    )));
250                }
251            }
252
253            Ok(bytes)
254        })
255        .await
256    }
257
258    pub fn config(&self) -> &Config {
259        &self.config
260    }
261}
262
263pub(crate) async fn stream<O>(
264    mut event_source: EventSource,
265) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
266where
267    O: DeserializeOwned + std::marker::Send + 'static,
268{
269    let stream = try_stream! {
270        while let Some(ev) = event_source.next().await {
271            match ev {
272                Err(e) => {
273                    Err(DashScopeError::StreamError(e.to_string()))?;
274                }
275                Ok(Event::Open) => continue,
276                Ok(Event::Message(message)) => {
277                    // First, deserialize to a generic JSON Value to inspect it without failing.
278                    let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
279                        Ok(val) => val,
280                        Err(e) => {
281                            Err(map_deserialization_error(e, message.data.as_bytes()))?;
282                            continue;
283                        }
284                    };
285
286                    // Now, deserialize from the `Value` to the target type `O`.
287                    let response = serde_json::from_value::<O>(json_value.clone())
288                        .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
289
290                    // Yield the successful message
291                    yield response;
292
293                    // Check for finish reason after sending the message.
294                    // This ensures the final message with "stop" is delivered.
295                    let finish_reason = json_value
296                        .pointer("/output/choices/0/finish_reason")
297                        .and_then(|v| v.as_str());
298
299                    if let Some("stop") = finish_reason {
300                        break;
301                    }
302                }
303            }
304        }
305        event_source.close();
306    };
307
308    Box::pin(stream)
309}
310
311#[cfg(test)]
312mod tests {
313    use crate::config::ConfigBuilder;
314
315    use super::*;
316
317    #[test]
318    pub fn test_config() {
319        let config = ConfigBuilder::default()
320            .api_key("test key")
321            .build()
322            .unwrap();
323        let client = Client::with_config(config);
324
325        for header in client.config.headers().iter() {
326            if header.0 == "authorization" {
327                assert_eq!(header.1, "Bearer test key");
328            }
329        }
330    }
331}