async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use bytes::Bytes;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tokio_stream::{Stream, StreamExt as _};
7
8use crate::{
9    config::Config,
10    error::{map_deserialization_error, ApiError, DashScopeError},
11};
12
13#[derive(Debug, Default, Clone)]
14pub struct Client {
15    http_client: reqwest::Client,
16    config: Config,
17    backoff: backoff::ExponentialBackoff,
18}
19
20impl Client {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn with_config(config: Config) -> Self {
26        Self {
27            http_client: reqwest::Client::new(),
28            config,
29            backoff: backoff::ExponentialBackoff::default(),
30        }
31    }
32    pub fn with_api_key(mut self, api_key: String) -> Self {
33        self.config.set_api_key(api_key.into());
34        self
35    }
36
37    pub fn build(
38        http_client: reqwest::Client,
39        config: Config,
40        backoff: backoff::ExponentialBackoff,
41    ) -> Self {
42        Self {
43            http_client,
44            config,
45            backoff,
46        }
47    }
48
49    /// 获取当前实例的生成(Generation)信息
50    ///
51    /// 此方法属于操作级别,用于创建一个`Generation`对象,
52    /// 该对象表示当前实例的某一特定生成(代)信息
53    ///
54    /// # Returns
55    ///
56    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
57    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
58        crate::operation::generation::Generation::new(self)
59    }
60
61    /// 启发多模态对话的功能
62    ///
63    /// 该函数提供了与多模态对话相关的操作入口
64    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
65    ///
66    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
67    pub fn multi_modal_conversation(
68        &self,
69    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
70        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
71    }
72
73    /// 获取文本嵌入表示
74    ///
75    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
76    /// 它利用当前实例的上下文来生成文本的嵌入表示
77    ///
78    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
79    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
80    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
81        crate::operation::embeddings::Embeddings::new(self)
82    }
83
84    pub(crate) async fn post_stream<I, O>(
85        &self,
86        path: &str,
87        request: I,
88    ) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
89    where
90        I: Serialize,
91        O: DeserializeOwned + std::marker::Send + 'static,
92    {
93        let event_source = self
94            .http_client
95            .post(self.config.url(path))
96            .headers(self.config.headers())
97            .json(&request)
98            .eventsource()
99            .unwrap();
100
101        stream(event_source).await
102    }
103
104    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
105    where
106        I: Serialize + Debug,
107        O: DeserializeOwned,
108    {
109        // dbg!(&request);
110        let request_maker = || async {
111            Ok(self
112                .http_client
113                .post(self.config.url(path))
114                .headers(self.config.headers())
115                .json(&request)
116                .build()?)
117        };
118
119        self.execute(request_maker).await
120    }
121
122    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
123    where
124        O: DeserializeOwned,
125        M: Fn() -> Fut,
126        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
127    {
128        let bytes = self.execute_raw(request_maker).await?;
129
130        // bytes to string
131        // let s = String::from_utf8(bytes.to_vec()).unwrap();
132
133        let response: O = serde_json::from_slice(bytes.as_ref())
134            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
135
136        Ok(response)
137    }
138
139    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
140    where
141        M: Fn() -> Fut,
142        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
143    {
144        let client = self.http_client.clone();
145
146        backoff::future::retry(self.backoff.clone(), || async {
147            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
148            let response = client
149                .execute(request)
150                .await
151                .map_err(DashScopeError::Reqwest)
152                .map_err(backoff::Error::Permanent)?;
153
154            let status = response.status();
155            let bytes = response
156                .bytes()
157                .await
158                .map_err(DashScopeError::Reqwest)
159                .map_err(backoff::Error::Permanent)?;
160
161            // Deserialize response body from either error object or actual response object
162            if !status.is_success() {
163                // bytes to string
164
165                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
166                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
167                    .map_err(backoff::Error::Permanent)?;
168
169                if status.as_u16() == 429 {
170                    // Rate limited retry...
171                    tracing::warn!("Rate limited: {}", api_error.message);
172                    return Err(backoff::Error::Transient {
173                        err: DashScopeError::ApiError(api_error),
174                        retry_after: None,
175                    });
176                } else {
177                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
178                        api_error,
179                    )));
180                }
181            }
182
183            Ok(bytes)
184        })
185        .await
186    }
187}
188
189pub(crate) async fn stream<O>(
190    mut event_source: EventSource,
191) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
192where
193    O: DeserializeOwned + std::marker::Send + 'static,
194{
195    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
196
197    tokio::spawn(async move {
198        while let Some(ev) = event_source.next().await {
199            match ev {
200                Err(e) => {
201                    if let Err(_e) = tx.send(Err(DashScopeError::StreamError(e.to_string()))) {
202                        // rx dropped
203                        break;
204                    }
205                }
206                Ok(event) => match event {
207                    Event::Message(message) => {
208                        #[derive(Deserialize, Debug)]
209                        struct Result {
210                            output: Output,
211                        }
212                        #[derive(Deserialize, Debug)]
213                        struct Output {
214                            choices: Vec<Choices>,
215                        }
216                        #[derive(Deserialize, Debug)]
217                        struct Choices {
218                            finish_reason: Option<String>,
219                        }
220
221                        let r = match serde_json::from_str::<Result>(&message.data) {
222                            Ok(r) => r,
223                            Err(e) => {
224                                if let Err(_e) = tx.send(Err(map_deserialization_error(
225                                    e,
226                                    message.data.as_bytes(),
227                                ))) {
228                                    break;
229                                }
230                                continue;
231                            }
232                        };
233                        if let Some(finish_reason) = r.output.choices[0].finish_reason.clone() {
234                            if finish_reason == "stop" {
235                                break;
236                            }
237                        }
238
239                        let response = match serde_json::from_str::<O>(&message.data) {
240                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
241                            Ok(output) => Ok(output),
242                        };
243
244                        if let Err(_e) = tx.send(response) {
245                            // rx dropped
246                            break;
247                        }
248                    }
249                    Event::Open => continue,
250                },
251            }
252        }
253
254        event_source.close();
255    });
256
257    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
258}
259
260#[cfg(test)]
261mod tests {
262    use crate::config::ConfigBuilder;
263
264    use super::*;
265
266    #[test]
267    pub fn test_config() {
268        let config = ConfigBuilder::default()
269            .api_key("test key")
270            .build()
271            .unwrap();
272        let client = Client::with_config(config);
273
274        for header in client.config.headers().iter() {
275            if header.0 == "authorization" {
276                assert_eq!(header.1, "Bearer test key");
277            }
278        }
279    }
280}