async_dashscope/
client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use bytes::Bytes;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tokio_stream::{Stream, StreamExt as _};
7
8use crate::{
9    config::Config,
10    error::{map_deserialization_error, ApiError, DashScopeError},
11};
12
13#[derive(Debug, Default)]
14pub struct Client {
15    http_client: reqwest::Client,
16    config: Config,
17    backoff: backoff::ExponentialBackoff,
18}
19
20impl Client {
21    pub fn new() -> Self {
22        Self::default()
23    }
24    pub fn build(
25        http_client: reqwest::Client,
26        config: Config,
27        backoff: backoff::ExponentialBackoff,
28    ) -> Self {
29        Self {
30            http_client,
31            config,
32            backoff,
33        }
34    }
35
36    /// 获取当前实例的生成(Generation)信息
37    ///
38    /// 此方法属于操作级别,用于创建一个`Generation`对象,
39    /// 该对象表示当前实例的某一特定生成(代)信息
40    ///
41    /// # Returns
42    ///
43    /// 返回一个`Generation`对象,用于表示当前实例的生成信息
44    pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
45        crate::operation::generation::Generation::new(self)
46    }
47
48    /// 启发多模态对话的功能
49    ///
50    /// 该函数提供了与多模态对话相关的操作入口
51    /// 它创建并返回一个MultiModalConversation实例,用于执行多模态对话操作
52    ///
53    /// # Returns
54    ///
55    /// 返回一个`MultiModalConversation`实例,用于进行多模态对话操作
56    pub fn multi_modal_conversation(
57        &self,
58    ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
59        crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
60    }
61
62    /// 获取文本嵌入表示
63    ///
64    /// 此函数提供了一个接口,用于将文本转换为嵌入表示
65    /// 它利用当前实例的上下文来生成文本的嵌入表示
66    ///
67    /// # 返回值
68    ///
69    /// 返回一个`Embeddings`实例,该实例封装了文本嵌入相关的操作和数据
70    /// `Embeddings`类型提供了进一步处理文本数据的能力,如计算文本相似度或进行文本分类等
71    pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
72        crate::operation::embeddings::Embeddings::new(self)
73    }
74
75    pub(crate) async fn post_stream<I, O>(
76        &self,
77        path: &str,
78        request: I,
79    ) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
80    where
81        I: Serialize,
82        O: DeserializeOwned + std::marker::Send + 'static,
83    {
84        let event_source = self
85            .http_client
86            .post(self.config.url(path))
87            .headers(self.config.headers())
88            .json(&request)
89            .eventsource()
90            .unwrap();
91
92        stream(event_source).await
93    }
94
95    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
96    where
97        I: Serialize + Debug,
98        O: DeserializeOwned,
99    {
100        dbg!(&request);
101        let request_maker = || async {
102            Ok(self
103                .http_client
104                .post(self.config.url(path))
105                .headers(self.config.headers())
106                .json(&request)
107                .build()?)
108        };
109
110        self.execute(request_maker).await
111    }
112
113    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
114    where
115        O: DeserializeOwned,
116        M: Fn() -> Fut,
117        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
118    {
119        let bytes = self.execute_raw(request_maker).await?;
120
121        let response: O = serde_json::from_slice(bytes.as_ref())
122            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
123
124        Ok(response)
125    }
126
127    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
128    where
129        M: Fn() -> Fut,
130        Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
131    {
132        let client = self.http_client.clone();
133
134        backoff::future::retry(self.backoff.clone(), || async {
135            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
136            let response = client
137                .execute(request)
138                .await
139                .map_err(DashScopeError::Reqwest)
140                .map_err(backoff::Error::Permanent)?;
141
142            let status = response.status();
143            let bytes = response
144                .bytes()
145                .await
146                .map_err(DashScopeError::Reqwest)
147                .map_err(backoff::Error::Permanent)?;
148
149            // Deserialize response body from either error object or actual response object
150            if !status.is_success() {
151                // bytes to string
152
153                let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
154                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
155                    .map_err(backoff::Error::Permanent)?;
156
157                if status.as_u16() == 429 {
158                    // Rate limited retry...
159                    tracing::warn!("Rate limited: {}", api_error.message);
160                    return Err(backoff::Error::Transient {
161                        err: DashScopeError::ApiError(api_error),
162                        retry_after: None,
163                    });
164                } else {
165                    return Err(backoff::Error::Permanent(DashScopeError::ApiError(
166                        api_error,
167                    )));
168                }
169            }
170
171            Ok(bytes)
172        })
173        .await
174    }
175}
176
177pub(crate) async fn stream<O>(
178    mut event_source: EventSource,
179) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
180where
181    O: DeserializeOwned + std::marker::Send + 'static,
182{
183    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
184
185    tokio::spawn(async move {
186        while let Some(ev) = event_source.next().await {
187            match ev {
188                Err(e) => {
189                    if let Err(_e) = tx.send(Err(DashScopeError::StreamError(e.to_string()))) {
190                        // rx dropped
191                        break;
192                    }
193                }
194                Ok(event) => match event {
195                    Event::Message(message) => {
196                        #[derive(Deserialize, Debug)]
197                        struct Result {
198                            output: Output,
199                        }
200                        #[derive(Deserialize, Debug)]
201                        struct Output {
202                            choices: Vec<Choices>,
203                        }
204                        #[derive(Deserialize, Debug)]
205                        struct Choices {
206                            finish_reason: Option<String>,
207                        }
208
209                        let r = match serde_json::from_str::<Result>(&message.data) {
210                            Ok(r) => r,
211                            Err(e) => {
212                                if let Err(_e) = tx.send(Err(map_deserialization_error(
213                                    e,
214                                    message.data.as_bytes(),
215                                ))) {
216                                    break;
217                                }
218                                continue;
219                            }
220                        };
221                        if let Some(finish_reason) = r.output.choices[0].finish_reason.clone() {
222                            if finish_reason == "stop" {
223                                break;
224                            }
225                        }
226
227                        let response = match serde_json::from_str::<O>(&message.data) {
228                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
229                            Ok(output) => Ok(output),
230                        };
231
232                        if let Err(_e) = tx.send(response) {
233                            // rx dropped
234                            break;
235                        }
236                    }
237                    Event::Open => continue,
238                },
239            }
240        }
241
242        event_source.close();
243    });
244
245    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
246}