async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use bytes::Bytes;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tokio_stream::{Stream, StreamExt as _};
7
8use crate::{
9    config::Config,
10    error::{map_deserialization_error, ApiError, DashScopeError},
11};
12
13#[derive(Debug, Default, Clone)]
14pub struct Client {
15    http_client: reqwest::Client,
16    config: Config,
17    backoff: backoff::ExponentialBackoff,
18}
19
20impl Client {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn with_config(config: Config) -> Self {
26        Self {
27            http_client: reqwest::Client::new(),
28            config,
29            backoff: backoff::ExponentialBackoff::default(),
30        }
31    }
32    pub fn with_api_key(mut self, api_key: String) -> Self {
33        self.config.set_api_key(api_key.into());
34        self
35    }
36
37    pub fn build(
38        http_client: reqwest::Client,
39        config: Config,
40        backoff: backoff::ExponentialBackoff,
41    ) -> Self {
42        Self {
43            http_client,
44            config,
45            backoff,
46        }
47    }
48
49    /// 获取当前实例的生成(Generation)信息
50    ///
51    /// 此方法属于操作级别,用于创建一个`Generation`对象,
52    /// 该对象表示当前实例的某一特定生成(代)信息
53    ///
54    /// # Returns
55    ///
56    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
57    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
58        crate::operation::generation::Generation::new(self)
59    }
60
61    /// 启发多模态对话的功能
62    ///
63    /// 该函数提供了与多模态对话相关的操作入口
64    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
65    ///
66    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
67    pub fn multi_modal_conversation(
68        &self,
69    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
70        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
71    }
72
73    /// 获取文本嵌入表示
74    ///
75    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
76    /// 它利用当前实例的上下文来生成文本的嵌入表示
77    ///
78    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
79    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
80    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
81        crate::operation::embeddings::Embeddings::new(self)
82    }
83
84    pub(crate) async fn post_stream<I, O>(
85        &self,
86        path: &str,
87        request: I,
88    ) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
89    where
90        I: Serialize,
91        O: DeserializeOwned + std::marker::Send + 'static,
92    {
93        let event_source = self
94            .http_client
95            .post(self.config.url(path))
96            .headers(self.config.headers())
97            .json(&request)
98            .eventsource()
99            .unwrap();
100
101        stream(event_source).await
102    }
103
104    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
105    where
106        I: Serialize + Debug,
107        O: DeserializeOwned,
108    {
109        // dbg!(&request);
110        let request_maker = || async {
111            Ok(self
112                .http_client
113                .post(self.config.url(path))
114                .headers(self.config.headers())
115                .json(&request)
116                .build()?)
117        };
118
119        self.execute(request_maker).await
120    }
121
122    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
123    where
124        O: DeserializeOwned,
125        M: Fn() -> Fut,
126        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
127    {
128        let bytes = self.execute_raw(request_maker).await?;
129
130        // bytes to string
131        let s = String::from_utf8(bytes.to_vec()).unwrap();
132        dbg!(s);
133
134        let response: O = serde_json::from_slice(bytes.as_ref())
135            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
136
137        Ok(response)
138    }
139
140    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
141    where
142        M: Fn() -> Fut,
143        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
144    {
145        let client = self.http_client.clone();
146
147        backoff::future::retry(self.backoff.clone(), || async {
148            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
149            let response = client
150                .execute(request)
151                .await
152                .map_err(DashScopeError::Reqwest)
153                .map_err(backoff::Error::Permanent)?;
154
155            let status = response.status();
156            let bytes = response
157                .bytes()
158                .await
159                .map_err(DashScopeError::Reqwest)
160                .map_err(backoff::Error::Permanent)?;
161
162            // Deserialize response body from either error object or actual response object
163            if !status.is_success() {
164                // bytes to string
165
166                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
167                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
168                    .map_err(backoff::Error::Permanent)?;
169
170                if status.as_u16() == 429 {
171                    // Rate limited retry...
172                    tracing::warn!("Rate limited: {}", api_error.message);
173                    return Err(backoff::Error::Transient {
174                        err: DashScopeError::ApiError(api_error),
175                        retry_after: None,
176                    });
177                } else {
178                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
179                        api_error,
180                    )));
181                }
182            }
183
184            Ok(bytes)
185        })
186        .await
187    }
188}
189
190pub(crate) async fn stream<O>(
191    mut event_source: EventSource,
192) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
193where
194    O: DeserializeOwned + std::marker::Send + 'static,
195{
196    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
197
198    tokio::spawn(async move {
199        while let Some(ev) = event_source.next().await {
200            match ev {
201                Err(e) => {
202                    if let Err(_e) = tx.send(Err(DashScopeError::StreamError(e.to_string()))) {
203                        // rx dropped
204                        break;
205                    }
206                }
207                Ok(event) => match event {
208                    Event::Message(message) => {
209                        #[derive(Deserialize, Debug)]
210                        struct Result {
211                            output: Output,
212                        }
213                        #[derive(Deserialize, Debug)]
214                        struct Output {
215                            choices: Vec<Choices>,
216                        }
217                        #[derive(Deserialize, Debug)]
218                        struct Choices {
219                            finish_reason: Option<String>,
220                        }
221
222                        let r = match serde_json::from_str::<Result>(&message.data) {
223                            Ok(r) => r,
224                            Err(e) => {
225                                if let Err(_e) = tx.send(Err(map_deserialization_error(
226                                    e,
227                                    message.data.as_bytes(),
228                                ))) {
229                                    break;
230                                }
231                                continue;
232                            }
233                        };
234                        if let Some(finish_reason) = r.output.choices[0].finish_reason.clone() {
235                            if finish_reason == "stop" {
236                                break;
237                            }
238                        }
239
240                        let response = match serde_json::from_str::<O>(&message.data) {
241                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
242                            Ok(output) => Ok(output),
243                        };
244
245                        if let Err(_e) = tx.send(response) {
246                            // rx dropped
247                            break;
248                        }
249                    }
250                    Event::Open => continue,
251                },
252            }
253        }
254
255        event_source.close();
256    });
257
258    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
259}
260
261#[cfg(test)]
262mod tests {
263    use crate::config::ConfigBuilder;
264
265    use super::*;
266
267    #[test]
268    pub fn test_config() {
269        let config = ConfigBuilder::default()
270            .api_key("test key")
271            .build()
272            .unwrap();
273        let client = Client::with_config(config);
274
275        for header in client.config.headers().iter() {
276            if header.0 == "authorization" {
277                assert_eq!(header.1, "Bearer test key");
278            }
279        }
280    }
281}