async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 http_client: reqwest::Client,
17 config: Config,
18 backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59 crate::operation::generation::Generation::new(self)
60 }
61
62 pub fn multi_modal_conversation(
69 &self,
70 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72 }
73
74 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76 crate::operation::audio::Audio::new(self)
77 }
78
79 pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
92 crate::operation::image2image::Image2Image::new(self)
93 }
94
95 pub fn http_client(&self) -> reqwest::Client {
96 self.http_client.clone()
97 }
98
99 pub fn task(&self) -> crate::operation::task::Task<'_> {
104 crate::operation::task::Task::new(self)
105 }
106
107 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
115 crate::operation::embeddings::Embeddings::new(self)
116 }
117
118 pub(crate) async fn post_stream<I, O>(
119 &self,
120 path: &str,
121 request: I,
122 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
123 where
124 I: Serialize,
125 O: DeserializeOwned + std::marker::Send + 'static,
126 {
127 let event_source = self
128 .http_client
129 .post(self.config.url(path))
130 .headers(self.config.headers())
131 .json(&request)
132 .eventsource()?;
133
134 Ok(stream(event_source).await)
135 }
136
137 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
138 where
139 I: Serialize + Debug,
140 O: DeserializeOwned,
141 {
142 self.post_with_headers(path, request, self.config().headers())
143 .await
144 }
145
146 pub(crate) async fn post_with_headers<I, O>(
162 &self,
163 path: &str,
164 request: I,
165 headers: reqwest::header::HeaderMap,
166 ) -> Result<O, DashScopeError>
167 where
168 I: Serialize + Debug,
169 O: DeserializeOwned,
170 {
171 let request_maker = || async {
172 Ok(self
173 .http_client
174 .post(self.config.url(path))
175 .headers(headers.clone())
176 .json(&request)
177 .build()?)
178 };
179
180 self.execute(request_maker).await
181 }
182
183 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
184 where
185 O: DeserializeOwned,
186 M: Fn() -> Fut,
187 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
188 {
189 let bytes = self.execute_raw(request_maker).await?;
190
191 let response: O = serde_json::from_slice(bytes.as_ref())
192 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
193
194 Ok(response)
195 }
196
197 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
198 where
199 M: Fn() -> Fut,
200 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
201 {
202 let client = self.http_client.clone();
203
204 backoff::future::retry(self.backoff.clone(), || async {
205 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
206 let response = client
207 .execute(request)
208 .await
209 .map_err(DashScopeError::Reqwest)
210 .map_err(backoff::Error::Permanent)?;
211
212 let status = response.status();
213 let bytes = response
214 .bytes()
215 .await
216 .map_err(DashScopeError::Reqwest)
217 .map_err(backoff::Error::Permanent)?;
218
219 if !status.is_success() {
221 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
222 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
223 .map_err(backoff::Error::Permanent)?;
224
225 if status.as_u16() == 429 {
226 tracing::warn!("Rate limited: {}", api_error.message);
228 return Err(backoff::Error::Transient {
229 err: DashScopeError::ApiError(api_error),
230 retry_after: None,
231 });
232 } else {
233 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
234 api_error,
235 )));
236 }
237 }
238
239 Ok(bytes)
240 })
241 .await
242 }
243
244 pub fn config(&self) -> &Config {
245 &self.config
246 }
247}
248
249pub(crate) async fn stream<O>(
250 mut event_source: EventSource,
251) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
252where
253 O: DeserializeOwned + std::marker::Send + 'static,
254{
255 let stream = try_stream! {
256 while let Some(ev) = event_source.next().await {
257 match ev {
258 Err(e) => {
259 Err(DashScopeError::StreamError(e.to_string()))?;
260 }
261 Ok(Event::Open) => continue,
262 Ok(Event::Message(message)) => {
263 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
265 Ok(val) => val,
266 Err(e) => {
267 Err(map_deserialization_error(e, message.data.as_bytes()))?;
268 continue;
269 }
270 };
271
272 let response = serde_json::from_value::<O>(json_value.clone())
274 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
275
276 yield response;
278
279 let finish_reason = json_value
282 .pointer("/output/choices/0/finish_reason")
283 .and_then(|v| v.as_str());
284
285 if let Some("stop") = finish_reason {
286 break;
287 }
288 }
289 }
290 }
291 event_source.close();
292 };
293
294 Box::pin(stream)
295}
296
297#[cfg(test)]
298mod tests {
299 use crate::config::ConfigBuilder;
300
301 use super::*;
302
303 #[test]
304 pub fn test_config() {
305 let config = ConfigBuilder::default()
306 .api_key("test key")
307 .build()
308 .unwrap();
309 let client = Client::with_config(config);
310
311 for header in client.config.headers().iter() {
312 if header.0 == "authorization" {
313 assert_eq!(header.1, "Bearer test key");
314 }
315 }
316 }
317}