async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{de::DeserializeOwned, Serialize};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10    config::Config,
11    error::{map_deserialization_error, ApiError, DashScopeError},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16    http_client: reqwest::Client,
17    config: Config,
18    backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn with_config(config: Config) -> Self {
27        Self {
28            http_client: reqwest::Client::new(),
29            config,
30            backoff: backoff::ExponentialBackoff::default(),
31        }
32    }
33    pub fn with_api_key(mut self, api_key: String) -> Self {
34        self.config.set_api_key(api_key.into());
35        self
36    }
37
38    pub fn build(
39        http_client: reqwest::Client,
40        config: Config,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// 获取当前实例的生成(Generation)信息
51    ///
52    /// 此方法属于操作级别,用于创建一个`Generation`对象,
53    /// 该对象表示当前实例的某一特定生成(代)信息
54    ///
55    /// # Returns
56    ///
57    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
58    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59        crate::operation::generation::Generation::new(self)
60    }
61
62    /// 启发多模态对话的功能
63    ///
64    /// 该函数提供了与多模态对话相关的操作入口
65    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
66    ///
67    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
68    pub fn multi_modal_conversation(
69        &self,
70    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72    }
73
74    /// 获取音频处理功能
75    pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76        crate::operation::audio::Audio::new(self)
77    }
78
79    /// 获取文本嵌入表示
80    ///
81    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
82    /// 它利用当前实例的上下文来生成文本的嵌入表示
83    ///
84    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
85    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
86    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
87        crate::operation::embeddings::Embeddings::new(self)
88    }
89
90    pub(crate) async fn post_stream<I, O>(
91        &self,
92        path: &str,
93        request: I,
94    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
95    where
96        I: Serialize,
97        O: DeserializeOwned + std::marker::Send + 'static,
98    {
99        let event_source = self
100            .http_client
101            .post(self.config.url(path))
102            .headers(self.config.headers())
103            .json(&request)
104            .eventsource()?;
105
106        Ok(stream(event_source).await)
107    }
108
109    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
110    where
111        I: Serialize + Debug,
112        O: DeserializeOwned,
113    {
114        let request_maker = || async {
115            Ok(self
116                .http_client
117                .post(self.config.url(path))
118                .headers(self.config.headers())
119                .json(&request)
120                .build()?)
121        };
122
123        self.execute(request_maker).await
124    }
125
126    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
127    where
128        O: DeserializeOwned,
129        M: Fn() -> Fut,
130        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
131    {
132        let bytes = self.execute_raw(request_maker).await?;
133
134        let response: O = serde_json::from_slice(bytes.as_ref())
135            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
136
137        Ok(response)
138    }
139
140    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
141    where
142        M: Fn() -> Fut,
143        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
144    {
145        let client = self.http_client.clone();
146
147        backoff::future::retry(self.backoff.clone(), || async {
148            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
149            let response = client
150                .execute(request)
151                .await
152                .map_err(DashScopeError::Reqwest)
153                .map_err(backoff::Error::Permanent)?;
154
155            let status = response.status();
156            let bytes = response
157                .bytes()
158                .await
159                .map_err(DashScopeError::Reqwest)
160                .map_err(backoff::Error::Permanent)?;
161
162            // Deserialize response body from either error object or actual response object
163            if !status.is_success() {
164                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
165                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
166                    .map_err(backoff::Error::Permanent)?;
167
168                if status.as_u16() == 429 {
169                    // Rate limited retry...
170                    tracing::warn!("Rate limited: {}", api_error.message);
171                    return Err(backoff::Error::Transient {
172                        err: DashScopeError::ApiError(api_error),
173                        retry_after: None,
174                    });
175                } else {
176                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
177                        api_error,
178                    )));
179                }
180            }
181
182            Ok(bytes)
183        })
184        .await
185    }
186    
187    pub fn config(&self) -> &Config {
188        &self.config
189    }
190}
191
192pub(crate) async fn stream<O>(
193    mut event_source: EventSource,
194) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
195where
196    O: DeserializeOwned + std::marker::Send + 'static,
197{
198    let stream = try_stream! {
199        while let Some(ev) = event_source.next().await {
200            match ev {
201                Err(e) => {
202                    Err(DashScopeError::StreamError(e.to_string()))?;
203                }
204                Ok(Event::Open) => continue,
205                Ok(Event::Message(message)) => {
206                    // First, deserialize to a generic JSON Value to inspect it without failing.
207                    let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
208                        Ok(val) => val,
209                        Err(e) => {
210                            Err(map_deserialization_error(e, message.data.as_bytes()))?;
211                            continue;
212                        }
213                    };
214
215                    // Now, deserialize from the `Value` to the target type `O`.
216                    let response = serde_json::from_value::<O>(json_value.clone())
217                        .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
218
219                    // Yield the successful message
220                    yield response;
221
222                    // Check for finish reason after sending the message.
223                    // This ensures the final message with "stop" is delivered.
224                    let finish_reason = json_value
225                        .pointer("/output/choices/0/finish_reason")
226                        .and_then(|v| v.as_str());
227
228                    if let Some("stop") = finish_reason {
229                        break;
230                    }
231                }
232            }
233        }
234        event_source.close();
235    };
236
237    Box::pin(stream)
238}
239
240#[cfg(test)]
241mod tests {
242    use crate::config::ConfigBuilder;
243
244    use super::*;
245
246    #[test]
247    pub fn test_config() {
248        let config = ConfigBuilder::default()
249            .api_key("test key")
250            .build()
251            .unwrap();
252        let client = Client::with_config(config);
253
254        for header in client.config.headers().iter() {
255            if header.0 == "authorization" {
256                assert_eq!(header.1, "Bearer test key");
257            }
258        }
259    }
260}