async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10    config::Config,
11    error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16    http_client: reqwest::Client,
17    config: Config,
18    backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn with_config(config: Config) -> Self {
27        Self {
28            http_client: reqwest::Client::new(),
29            config,
30            backoff: backoff::ExponentialBackoff::default(),
31        }
32    }
33    pub fn with_api_key(mut self, api_key: String) -> Self {
34        self.config.set_api_key(api_key.into());
35        self
36    }
37
38    pub fn build(
39        http_client: reqwest::Client,
40        config: Config,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// 获取当前实例的生成(Generation)信息
51    ///
52    /// 此方法属于操作级别,用于创建一个`Generation`对象,
53    /// 该对象表示当前实例的某一特定生成(代)信息
54    ///
55    /// # Returns
56    ///
57    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
58    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59        crate::operation::generation::Generation::new(self)
60    }
61
62    /// 启发多模态对话的功能
63    ///
64    /// 该函数提供了与多模态对话相关的操作入口
65    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
66    ///
67    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
68    pub fn multi_modal_conversation(
69        &self,
70    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72    }
73
74    /// 获取音频处理功能
75    pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76        crate::operation::audio::Audio::new(self)
77    }
78
79    /// 创建一个新的 Image2Image 操作实例
80    ///
81    /// 返回一个可用于执行图像到图像转换操作的 Image2Image 构建器
82    ///
83    /// # Examples
84    /// ```
85    /// let client = Client::new();
86    /// let image2image = client.image2image();
87    /// ```
88    ///
89    /// # Errors
90    /// 此方法本身不会返回错误,但后续操作可能会返回错误
91    pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
92        crate::operation::image2image::Image2Image::new(self)
93    }
94
95    /// 创建一个新的文本转图像操作实例
96    ///
97    /// # 返回
98    /// 返回一个 `Text2Image` 操作构建器,用于配置和执行文本转图像请求
99    ///
100    /// # 示例
101    /// ```
102    /// let client = Client::new();
103    /// let text2image = client.text2image();
104    /// ```
105    pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
106        crate::operation::text2image::Text2Image::new(self)
107    }
108
109    pub fn http_client(&self) -> reqwest::Client {
110        self.http_client.clone()
111    }
112
113    /// 创建一个新的任务操作实例
114    ///
115    /// # 返回
116    /// 返回一个绑定到当前客户端的 `Task` 实例,用于执行任务相关操作
117    pub fn task(&self) -> crate::operation::task::Task<'_> {
118        crate::operation::task::Task::new(self)
119    }
120
121    /// 获取文本嵌入表示
122    ///
123    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
124    /// 它利用当前实例的上下文来生成文本的嵌入表示
125    ///
126    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
127    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
128    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
129        crate::operation::embeddings::Embeddings::new(self)
130    }
131
132    pub(crate) async fn post_stream<I, O>(
133        &self,
134        path: &str,
135        request: I,
136    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
137    where
138        I: Serialize + Debug,
139        O: DeserializeOwned + std::marker::Send + 'static,
140    {
141        self.post_stream_with_headers(path, request, self.config.headers())
142            .await
143    }
144
145    pub(crate) async fn post_stream_with_headers<I, O>(
146        &self,
147        path: &str,
148        request: I,
149        headers: reqwest::header::HeaderMap,
150    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
151    where
152        I: Serialize + Debug,
153        O: DeserializeOwned + std::marker::Send + 'static,
154    {
155        let event_source = self
156            .http_client
157            .post(self.config.url(path))
158            .headers(headers)
159            .json(&request)
160            .eventsource()?;
161
162        Ok(stream(event_source).await)
163    }
164
165    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
166    where
167        I: Serialize + Debug,
168        O: DeserializeOwned,
169    {
170        self.post_with_headers(path, request, self.config().headers())
171            .await
172    }
173
174    /// 发送带有自定义请求头的 POST 请求
175    ///
176    /// # 参数
177    /// * `path` - API 路径
178    /// * `request` - 要发送的请求体,需要实现 Serialize 和 Debug trait
179    /// * `headers` - 自定义请求头
180    ///
181    /// # 返回值
182    /// 返回解析后的响应数据,类型由调用方指定
183    ///
184    /// # Errors
185    /// 返回 DashScopeError 如果请求失败或响应解析失败
186    ///
187    /// # 注意事项
188    /// 此函数是 crate 内部使用的工具函数,不对外公开
189    pub(crate) async fn post_with_headers<I, O>(
190        &self,
191        path: &str,
192        request: I,
193        headers: reqwest::header::HeaderMap,
194    ) -> Result<O, DashScopeError>
195    where
196        I: Serialize + Debug,
197        O: DeserializeOwned,
198    {
199        let request_maker = || async {
200            Ok(self
201                .http_client
202                .post(self.config.url(path))
203                .headers(headers.clone())
204                .json(&request)
205                .build()?)
206        };
207
208        self.execute(request_maker).await
209    }
210
211    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
212    where
213        O: DeserializeOwned,
214        M: Fn() -> Fut,
215        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
216    {
217        let bytes = self.execute_raw(request_maker).await?;
218
219        let response: O = serde_json::from_slice(bytes.as_ref())
220            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
221
222        Ok(response)
223    }
224
225    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
226    where
227        M: Fn() -> Fut,
228        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
229    {
230        let client = self.http_client.clone();
231
232        backoff::future::retry(self.backoff.clone(), || async {
233            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
234            let response = client
235                .execute(request)
236                .await
237                .map_err(DashScopeError::Reqwest)
238                .map_err(backoff::Error::Permanent)?;
239
240            let status = response.status();
241            let bytes = response
242                .bytes()
243                .await
244                .map_err(DashScopeError::Reqwest)
245                .map_err(backoff::Error::Permanent)?;
246
247            // Deserialize response body from either error object or actual response object
248            if !status.is_success() {
249                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
250                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
251                    .map_err(backoff::Error::Permanent)?;
252
253                if status.as_u16() == 429 {
254                    // Rate limited retry...
255                    tracing::warn!("Rate limited: {}", api_error.message);
256                    return Err(backoff::Error::Transient {
257                        err: DashScopeError::ApiError(api_error),
258                        retry_after: None,
259                    });
260                } else {
261                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
262                        api_error,
263                    )));
264                }
265            }
266
267            Ok(bytes)
268        })
269        .await
270    }
271
272    pub fn config(&self) -> &Config {
273        &self.config
274    }
275}
276
277pub(crate) async fn stream<O>(
278    mut event_source: EventSource,
279) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
280where
281    O: DeserializeOwned + std::marker::Send + 'static,
282{
283    let stream = try_stream! {
284        while let Some(ev) = event_source.next().await {
285            match ev {
286                Err(e) => {
287                    Err(DashScopeError::StreamError(e.to_string()))?;
288                }
289                Ok(Event::Open) => continue,
290                Ok(Event::Message(message)) => {
291                    // First, deserialize to a generic JSON Value to inspect it without failing.
292                    let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
293                        Ok(val) => val,
294                        Err(e) => {
295                            Err(map_deserialization_error(e, message.data.as_bytes()))?;
296                            continue;
297                        }
298                    };
299
300                    // Now, deserialize from the `Value` to the target type `O`.
301                    let response = serde_json::from_value::<O>(json_value.clone())
302                        .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
303
304                    // Yield the successful message
305                    yield response;
306
307                    // Check for finish reason after sending the message.
308                    // This ensures the final message with "stop" is delivered.
309                    let finish_reason = json_value
310                        .pointer("/output/choices/0/finish_reason")
311                        .and_then(|v| v.as_str());
312
313                    if let Some("stop") = finish_reason {
314                        break;
315                    }
316                }
317            }
318        }
319        event_source.close();
320    };
321
322    Box::pin(stream)
323}
324
325#[cfg(test)]
326mod tests {
327    use crate::config::ConfigBuilder;
328
329    use super::*;
330
331    #[test]
332    pub fn test_config() {
333        let config = ConfigBuilder::default()
334            .api_key("test key")
335            .build()
336            .unwrap();
337        let client = Client::with_config(config);
338
339        for header in client.config.headers().iter() {
340            if header.0 == "authorization" {
341                assert_eq!(header.1, "Bearer test key");
342            }
343        }
344    }
345}