1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 pub(crate) http_client: reqwest::Client,
17 pub(crate) config: Config,
18 pub(crate) backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn files(&self) -> crate::operation::file::File<'_> {
51 crate::operation::file::File::new(self)
52 }
53
54 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
63 crate::operation::generation::Generation::new(self)
64 }
65
66 pub fn multi_modal_conversation(
73 &self,
74 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
75 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
76 }
77
78 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
80 crate::operation::audio::Audio::new(self)
81 }
82
83 pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
96 crate::operation::image2image::Image2Image::new(self)
97 }
98
99 pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
110 crate::operation::text2image::Text2Image::new(self)
111 }
112
113 pub fn file(&self) -> crate::operation::file::File<'_> {
124 crate::operation::file::File::new(self)
125 }
126
127 pub fn http_client(&self) -> reqwest::Client {
128 self.http_client.clone()
129 }
130
131 pub fn task(&self) -> crate::operation::task::Task<'_> {
136 crate::operation::task::Task::new(self)
137 }
138
139 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
147 crate::operation::embeddings::Embeddings::new(self)
148 }
149
150 pub(crate) async fn post_stream<I, O>(
151 &self,
152 path: &str,
153 request: I,
154 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
155 where
156 I: Serialize + Debug,
157 O: DeserializeOwned + std::marker::Send + 'static,
158 {
159 self.post_stream_with_headers(path, request, self.config.headers())
160 .await
161 }
162
163 pub(crate) async fn post_stream_with_headers<I, O>(
164 &self,
165 path: &str,
166 request: I,
167 headers: reqwest::header::HeaderMap,
168 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
169 where
170 I: Serialize + Debug,
171 O: DeserializeOwned + std::marker::Send + 'static,
172 {
173 let event_source = self
174 .http_client
175 .post(self.config.url(path))
176 .headers(headers)
177 .json(&request)
178 .eventsource()?;
179
180 Ok(stream(event_source).await)
181 }
182
183 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
184 where
185 I: Serialize + Debug,
186 O: DeserializeOwned,
187 {
188 self.post_with_headers(path, request, self.config().headers())
189 .await
190 }
191
192 pub(crate) async fn post_with_headers<I, O>(
208 &self,
209 path: &str,
210 request: I,
211 headers: reqwest::header::HeaderMap,
212 ) -> Result<O, DashScopeError>
213 where
214 I: Serialize + Debug,
215 O: DeserializeOwned,
216 {
217 let request_maker = || async {
218 Ok(self
219 .http_client
220 .post(self.config.url(path))
221 .headers(headers.clone())
222 .json(&request)
223 .build()?)
224 };
225
226 self.execute(request_maker).await
227 }
228
229 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
230 where
231 O: DeserializeOwned,
232 M: Fn() -> Fut,
233 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
234 {
235 let bytes = self.execute_raw(request_maker).await?;
236
237 let response: O = serde_json::from_slice(bytes.as_ref())
238 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
239
240 Ok(response)
241 }
242
243 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
244 where
245 M: Fn() -> Fut,
246 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
247 {
248 let client = self.http_client.clone();
249
250 backoff::future::retry(self.backoff.clone(), || async {
251 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
252 let response = client
253 .execute(request)
254 .await
255 .map_err(DashScopeError::Reqwest)
256 .map_err(backoff::Error::Permanent)?;
257
258 let status = response.status();
259 let bytes = response
260 .bytes()
261 .await
262 .map_err(DashScopeError::Reqwest)
263 .map_err(backoff::Error::Permanent)?;
264
265 if !status.is_success() {
267 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
268 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
269 .map_err(backoff::Error::Permanent)?;
270
271 if status.as_u16() == 429 {
272 tracing::warn!("Rate limited: {}", api_error.message);
274 return Err(backoff::Error::Transient {
275 err: DashScopeError::ApiError(api_error),
276 retry_after: None,
277 });
278 } else {
279 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
280 api_error,
281 )));
282 }
283 }
284
285 Ok(bytes)
286 })
287 .await
288 }
289
290 pub fn config(&self) -> &Config {
291 &self.config
292 }
293
294 pub(crate) async fn post_multipart<O, F>(
309 &self,
310 path: &str,
311 form_fn: F,
312 ) -> Result<O, DashScopeError>
313 where
314 O: DeserializeOwned,
315 F: Fn() -> reqwest::multipart::Form,
316 {
317 let request_maker = || async {
318 let mut headers = self.config.headers();
319 headers.remove("Content-Type");
320 headers.remove("X-DashScope-OssResourceResolve");
321 Ok(self
322 .http_client
323 .post(self.config.url(path))
324 .headers(headers)
325 .multipart(form_fn())
326 .build()?)
327 };
328
329 self.execute(request_maker).await
330 }
331
332 pub(crate) async fn get_with_params<O, P>(&self, path: &str, params: &P) -> Result<O, DashScopeError>
347 where
348 O: DeserializeOwned,
349 P: serde::Serialize + ?Sized,
350 {
351 let request_maker = || async {
352 Ok(self
353 .http_client
354 .get(self.config.url(path))
355 .headers(self.config.headers())
356 .query(params)
357 .build()?)
358 };
359
360 self.execute(request_maker).await
361 }
362
363 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, DashScopeError>
377 where
378 O: DeserializeOwned,
379 {
380 let request_maker = || async {
381 Ok(self
382 .http_client
383 .delete(self.config.url(path))
384 .headers(self.config.headers())
385 .build()?)
386 };
387
388 self.execute(request_maker).await
389 }
390}
391
392pub(crate) async fn stream<O>(
393 mut event_source: EventSource,
394) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
395where
396 O: DeserializeOwned + std::marker::Send + 'static,
397{
398 let stream = try_stream! {
399 while let Some(ev) = event_source.next().await {
400 match ev {
401 Err(e) => {
402 Err(DashScopeError::StreamError(e.to_string()))?;
403 }
404 Ok(Event::Open) => continue,
405 Ok(Event::Message(message)) => {
406 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
408 Ok(val) => val,
409 Err(e) => {
410 Err(map_deserialization_error(e, message.data.as_bytes()))?;
411 continue;
412 }
413 };
414
415 let response = serde_json::from_value::<O>(json_value.clone())
417 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
418
419 yield response;
421
422 let finish_reason = json_value
425 .pointer("/output/choices/0/finish_reason")
426 .and_then(|v| v.as_str());
427
428 if let Some("stop") = finish_reason {
429 break;
430 }
431 }
432 }
433 }
434 event_source.close();
435 };
436
437 Box::pin(stream)
438}
439
440#[cfg(test)]
441mod tests {
442 use crate::config::ConfigBuilder;
443
444 use super::*;
445
446 #[test]
447 pub fn test_config() {
448 let config = ConfigBuilder::default()
449 .api_key("test key")
450 .build()
451 .unwrap();
452 let client = Client::with_config(config);
453
454 for header in client.config.headers().iter() {
455 if header.0 == "authorization" {
456 assert_eq!(header.1, "Bearer test key");
457 }
458 }
459 }
460}