1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 pub(crate) http_client: reqwest::Client,
17 pub(crate) config: Config,
18 pub(crate) backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn files(&self) -> crate::operation::file::File<'_> {
51 crate::operation::file::File::new(self)
52 }
53
54 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
63 crate::operation::generation::Generation::new(self)
64 }
65
66 pub fn multi_modal_conversation(
73 &self,
74 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
75 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
76 }
77
78 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
80 crate::operation::audio::Audio::new(self)
81 }
82
83 pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
97 crate::operation::image2image::Image2Image::new(self)
98 }
99
100 pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
112 crate::operation::text2image::Text2Image::new(self)
113 }
114
115 pub fn file(&self) -> crate::operation::file::File<'_> {
127 crate::operation::file::File::new(self)
128 }
129
130 pub fn http_client(&self) -> reqwest::Client {
131 self.http_client.clone()
132 }
133
134 pub fn task(&self) -> crate::operation::task::Task<'_> {
139 crate::operation::task::Task::new(self)
140 }
141
142 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
150 crate::operation::embeddings::Embeddings::new(self)
151 }
152
153 pub(crate) async fn post_stream<I, O>(
154 &self,
155 path: &str,
156 request: I,
157 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
158 where
159 I: Serialize + Debug,
160 O: DeserializeOwned + std::marker::Send + 'static,
161 {
162 self.post_stream_with_headers(path, request, self.config.headers())
163 .await
164 }
165
166 pub(crate) async fn post_stream_with_headers<I, O>(
167 &self,
168 path: &str,
169 request: I,
170 headers: reqwest::header::HeaderMap,
171 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
172 where
173 I: Serialize + Debug,
174 O: DeserializeOwned + std::marker::Send + 'static,
175 {
176 let event_source = self
177 .http_client
178 .post(self.config.url(path))
179 .headers(headers)
180 .json(&request)
181 .eventsource()?;
182
183 Ok(stream(event_source).await)
184 }
185
186 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
187 where
188 I: Serialize + Debug,
189 O: DeserializeOwned,
190 {
191 self.post_with_headers(path, request, self.config().headers())
192 .await
193 }
194
195 pub(crate) async fn post_with_headers<I, O>(
211 &self,
212 path: &str,
213 request: I,
214 headers: reqwest::header::HeaderMap,
215 ) -> Result<O, DashScopeError>
216 where
217 I: Serialize + Debug,
218 O: DeserializeOwned,
219 {
220 let request_maker = || async {
221 Ok(self
222 .http_client
223 .post(self.config.url(path))
224 .headers(headers.clone())
225 .json(&request)
226 .build()?)
227 };
228
229 self.execute(request_maker).await
230 }
231
232 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
233 where
234 O: DeserializeOwned,
235 M: Fn() -> Fut,
236 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
237 {
238 let bytes = self.execute_raw(request_maker).await?;
239
240 let response: O = serde_json::from_slice(bytes.as_ref())
241 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
242
243 Ok(response)
244 }
245
246 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
247 where
248 M: Fn() -> Fut,
249 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
250 {
251 let client = self.http_client.clone();
252
253 backoff::future::retry(self.backoff.clone(), || async {
254 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
255 let response = client
256 .execute(request)
257 .await
258 .map_err(DashScopeError::Reqwest)
259 .map_err(backoff::Error::Permanent)?;
260
261 let status = response.status();
262 let bytes = response
263 .bytes()
264 .await
265 .map_err(DashScopeError::Reqwest)
266 .map_err(backoff::Error::Permanent)?;
267
268 if !status.is_success() {
270 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
271 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
272 .map_err(backoff::Error::Permanent)?;
273
274 if status.as_u16() == 429 {
275 tracing::warn!("Rate limited: {}", api_error.message);
277 return Err(backoff::Error::Transient {
278 err: DashScopeError::ApiError(api_error),
279 retry_after: None,
280 });
281 } else {
282 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
283 api_error,
284 )));
285 }
286 }
287
288 Ok(bytes)
289 })
290 .await
291 }
292
293 pub fn config(&self) -> &Config {
294 &self.config
295 }
296
297 pub(crate) async fn post_multipart<O, F>(
312 &self,
313 path: &str,
314 form_fn: F,
315 ) -> Result<O, DashScopeError>
316 where
317 O: DeserializeOwned,
318 F: Fn() -> reqwest::multipart::Form,
319 {
320 let request_maker = || async {
321 let mut headers = self.config.headers();
322 headers.remove("Content-Type");
323 headers.remove("X-DashScope-OssResourceResolve");
324 Ok(self
325 .http_client
326 .post(self.config.url(path))
327 .headers(headers)
328 .multipart(form_fn())
329 .build()?)
330 };
331
332 self.execute(request_maker).await
333 }
334
335 pub(crate) async fn get_with_params<O, P>(&self, path: &str, params: &P) -> Result<O, DashScopeError>
350 where
351 O: DeserializeOwned,
352 P: serde::Serialize + ?Sized,
353 {
354 let request_maker = || async {
355 Ok(self
356 .http_client
357 .get(self.config.url(path))
358 .headers(self.config.headers())
359 .query(params)
360 .build()?)
361 };
362
363 self.execute(request_maker).await
364 }
365
366 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, DashScopeError>
380 where
381 O: DeserializeOwned,
382 {
383 let request_maker = || async {
384 Ok(self
385 .http_client
386 .delete(self.config.url(path))
387 .headers(self.config.headers())
388 .build()?)
389 };
390
391 self.execute(request_maker).await
392 }
393}
394
395pub(crate) async fn stream<O>(
396 mut event_source: EventSource,
397) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
398where
399 O: DeserializeOwned + std::marker::Send + 'static,
400{
401 let stream = try_stream! {
402 while let Some(ev) = event_source.next().await {
403 match ev {
404 Err(e) => {
405 Err(DashScopeError::StreamError(e.to_string()))?;
406 }
407 Ok(Event::Open) => continue,
408 Ok(Event::Message(message)) => {
409 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
411 Ok(val) => val,
412 Err(e) => {
413 Err(map_deserialization_error(e, message.data.as_bytes()))?;
414 continue;
415 }
416 };
417
418 let response = serde_json::from_value::<O>(json_value.clone())
420 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
421
422 yield response;
424
425 let finish_reason = json_value
428 .pointer("/output/choices/0/finish_reason")
429 .and_then(|v| v.as_str());
430
431 if let Some("stop") = finish_reason {
432 break;
433 }
434 }
435 }
436 }
437 event_source.close();
438 };
439
440 Box::pin(stream)
441}
442
443#[cfg(test)]
444mod tests {
445 use crate::config::ConfigBuilder;
446
447 use super::*;
448
449 #[test]
450 pub fn test_config() {
451 let config = ConfigBuilder::default()
452 .api_key("test key")
453 .build()
454 .unwrap();
455 let client = Client::with_config(config);
456
457 for header in client.config.headers().iter() {
458 if header.0 == "authorization" {
459 assert_eq!(header.1, "Bearer test key");
460 }
461 }
462 }
463}