async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10    config::Config,
11    error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16    http_client: reqwest::Client,
17    config: Config,
18    backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn with_config(config: Config) -> Self {
27        Self {
28            http_client: reqwest::Client::new(),
29            config,
30            backoff: backoff::ExponentialBackoff::default(),
31        }
32    }
33    pub fn with_api_key(mut self, api_key: String) -> Self {
34        self.config.set_api_key(api_key.into());
35        self
36    }
37
38    pub fn build(
39        http_client: reqwest::Client,
40        config: Config,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// 获取当前实例的生成(Generation)信息
51    ///
52    /// 此方法属于操作级别,用于创建一个`Generation`对象,
53    /// 该对象表示当前实例的某一特定生成(代)信息
54    ///
55    /// # Returns
56    ///
57    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
58    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59        crate::operation::generation::Generation::new(self)
60    }
61
62    /// 启发多模态对话的功能
63    ///
64    /// 该函数提供了与多模态对话相关的操作入口
65    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
66    ///
67    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
68    pub fn multi_modal_conversation(
69        &self,
70    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72    }
73
74    /// 获取音频处理功能
75    pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76        crate::operation::audio::Audio::new(self)
77    }
78
79    /// 创建一个新的 Image2Image 操作实例
80    ///
81    /// 返回一个可用于执行图像到图像转换操作的 Image2Image 构建器
82    ///
83    /// # Examples
84    /// ```
85    /// let client = Client::new();
86    /// let image2image = client.image2image();
87    /// ```
88    ///
89    /// # Errors
90    /// 此方法本身不会返回错误,但后续操作可能会返回错误
91    pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
92        crate::operation::image2image::Image2Image::new(self)
93    }
94
95    pub fn http_client(&self) -> reqwest::Client {
96        self.http_client.clone()
97    }
98
99    /// 创建一个新的任务操作实例
100    ///
101    /// # 返回
102    /// 返回一个绑定到当前客户端的 `Task` 实例,用于执行任务相关操作
103    pub fn task(&self) -> crate::operation::task::Task<'_> {
104        crate::operation::task::Task::new(self)
105    }
106
107    /// 获取文本嵌入表示
108    ///
109    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
110    /// 它利用当前实例的上下文来生成文本的嵌入表示
111    ///
112    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
113    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
114    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
115        crate::operation::embeddings::Embeddings::new(self)
116    }
117
118    pub(crate) async fn post_stream<I, O>(
119        &self,
120        path: &str,
121        request: I,
122    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
123    where
124        I: Serialize,
125        O: DeserializeOwned + std::marker::Send + 'static,
126    {
127        let event_source = self
128            .http_client
129            .post(self.config.url(path))
130            .headers(self.config.headers())
131            .json(&request)
132            .eventsource()?;
133
134        Ok(stream(event_source).await)
135    }
136
137    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
138    where
139        I: Serialize + Debug,
140        O: DeserializeOwned,
141    {
142        self.post_with_headers(path, request, self.config().headers())
143            .await
144    }
145
146    /// 发送带有自定义请求头的 POST 请求
147    ///
148    /// # 参数
149    /// * `path` - API 路径
150    /// * `request` - 要发送的请求体,需要实现 Serialize 和 Debug trait
151    /// * `headers` - 自定义请求头
152    ///
153    /// # 返回值
154    /// 返回解析后的响应数据,类型由调用方指定
155    ///
156    /// # Errors
157    /// 返回 DashScopeError 如果请求失败或响应解析失败
158    ///
159    /// # 注意事项
160    /// 此函数是 crate 内部使用的工具函数,不对外公开
161    pub(crate) async fn post_with_headers<I, O>(
162        &self,
163        path: &str,
164        request: I,
165        headers: reqwest::header::HeaderMap,
166    ) -> Result<O, DashScopeError>
167    where
168        I: Serialize + Debug,
169        O: DeserializeOwned,
170    {
171        let request_maker = || async {
172            Ok(self
173                .http_client
174                .post(self.config.url(path))
175                .headers(headers.clone())
176                .json(&request)
177                .build()?)
178        };
179
180        self.execute(request_maker).await
181    }
182
183    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
184    where
185        O: DeserializeOwned,
186        M: Fn() -> Fut,
187        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
188    {
189        let bytes = self.execute_raw(request_maker).await?;
190
191        let response: O = serde_json::from_slice(bytes.as_ref())
192            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
193
194        Ok(response)
195    }
196
197    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
198    where
199        M: Fn() -> Fut,
200        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
201    {
202        let client = self.http_client.clone();
203
204        backoff::future::retry(self.backoff.clone(), || async {
205            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
206            let response = client
207                .execute(request)
208                .await
209                .map_err(DashScopeError::Reqwest)
210                .map_err(backoff::Error::Permanent)?;
211
212            let status = response.status();
213            let bytes = response
214                .bytes()
215                .await
216                .map_err(DashScopeError::Reqwest)
217                .map_err(backoff::Error::Permanent)?;
218
219            // Deserialize response body from either error object or actual response object
220            if !status.is_success() {
221                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
222                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
223                    .map_err(backoff::Error::Permanent)?;
224
225                if status.as_u16() == 429 {
226                    // Rate limited retry...
227                    tracing::warn!("Rate limited: {}", api_error.message);
228                    return Err(backoff::Error::Transient {
229                        err: DashScopeError::ApiError(api_error),
230                        retry_after: None,
231                    });
232                } else {
233                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
234                        api_error,
235                    )));
236                }
237            }
238
239            Ok(bytes)
240        })
241        .await
242    }
243
244    pub fn config(&self) -> &Config {
245        &self.config
246    }
247}
248
249pub(crate) async fn stream<O>(
250    mut event_source: EventSource,
251) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
252where
253    O: DeserializeOwned + std::marker::Send + 'static,
254{
255    let stream = try_stream! {
256        while let Some(ev) = event_source.next().await {
257            match ev {
258                Err(e) => {
259                    Err(DashScopeError::StreamError(e.to_string()))?;
260                }
261                Ok(Event::Open) => continue,
262                Ok(Event::Message(message)) => {
263                    // First, deserialize to a generic JSON Value to inspect it without failing.
264                    let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
265                        Ok(val) => val,
266                        Err(e) => {
267                            Err(map_deserialization_error(e, message.data.as_bytes()))?;
268                            continue;
269                        }
270                    };
271
272                    // Now, deserialize from the `Value` to the target type `O`.
273                    let response = serde_json::from_value::<O>(json_value.clone())
274                        .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
275
276                    // Yield the successful message
277                    yield response;
278
279                    // Check for finish reason after sending the message.
280                    // This ensures the final message with "stop" is delivered.
281                    let finish_reason = json_value
282                        .pointer("/output/choices/0/finish_reason")
283                        .and_then(|v| v.as_str());
284
285                    if let Some("stop") = finish_reason {
286                        break;
287                    }
288                }
289            }
290        }
291        event_source.close();
292    };
293
294    Box::pin(stream)
295}
296
297#[cfg(test)]
298mod tests {
299    use crate::config::ConfigBuilder;
300
301    use super::*;
302
303    #[test]
304    pub fn test_config() {
305        let config = ConfigBuilder::default()
306            .api_key("test key")
307            .build()
308            .unwrap();
309        let client = Client::with_config(config);
310
311        for header in client.config.headers().iter() {
312            if header.0 == "authorization" {
313                assert_eq!(header.1, "Bearer test key");
314            }
315        }
316    }
317}