1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 http_client: reqwest::Client,
17 config: Config,
18 backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59 crate::operation::generation::Generation::new(self)
60 }
61
62 pub fn multi_modal_conversation(
69 &self,
70 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72 }
73
74 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76 crate::operation::audio::Audio::new(self)
77 }
78
79 pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
92 crate::operation::image2image::Image2Image::new(self)
93 }
94
95 pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
106 crate::operation::text2image::Text2Image::new(self)
107 }
108
109 pub fn http_client(&self) -> reqwest::Client {
110 self.http_client.clone()
111 }
112
113 pub fn task(&self) -> crate::operation::task::Task<'_> {
118 crate::operation::task::Task::new(self)
119 }
120
121 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
129 crate::operation::embeddings::Embeddings::new(self)
130 }
131
132 pub(crate) async fn post_stream<I, O>(
133 &self,
134 path: &str,
135 request: I,
136 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
137 where
138 I: Serialize + Debug,
139 O: DeserializeOwned + std::marker::Send + 'static,
140 {
141 self.post_stream_with_headers(path, request, self.config.headers())
142 .await
143 }
144
145 pub(crate) async fn post_stream_with_headers<I, O>(
146 &self,
147 path: &str,
148 request: I,
149 headers: reqwest::header::HeaderMap,
150 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
151 where
152 I: Serialize + Debug,
153 O: DeserializeOwned + std::marker::Send + 'static,
154 {
155 let event_source = self
156 .http_client
157 .post(self.config.url(path))
158 .headers(headers)
159 .json(&request)
160 .eventsource()?;
161
162 Ok(stream(event_source).await)
163 }
164
165 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
166 where
167 I: Serialize + Debug,
168 O: DeserializeOwned,
169 {
170 self.post_with_headers(path, request, self.config().headers())
171 .await
172 }
173
174 pub(crate) async fn post_with_headers<I, O>(
190 &self,
191 path: &str,
192 request: I,
193 headers: reqwest::header::HeaderMap,
194 ) -> Result<O, DashScopeError>
195 where
196 I: Serialize + Debug,
197 O: DeserializeOwned,
198 {
199 let request_maker = || async {
200 Ok(self
201 .http_client
202 .post(self.config.url(path))
203 .headers(headers.clone())
204 .json(&request)
205 .build()?)
206 };
207
208 self.execute(request_maker).await
209 }
210
211 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
212 where
213 O: DeserializeOwned,
214 M: Fn() -> Fut,
215 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
216 {
217 let bytes = self.execute_raw(request_maker).await?;
218
219 let response: O = serde_json::from_slice(bytes.as_ref())
220 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
221
222 Ok(response)
223 }
224
225 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
226 where
227 M: Fn() -> Fut,
228 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
229 {
230 let client = self.http_client.clone();
231
232 backoff::future::retry(self.backoff.clone(), || async {
233 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
234 let response = client
235 .execute(request)
236 .await
237 .map_err(DashScopeError::Reqwest)
238 .map_err(backoff::Error::Permanent)?;
239
240 let status = response.status();
241 let bytes = response
242 .bytes()
243 .await
244 .map_err(DashScopeError::Reqwest)
245 .map_err(backoff::Error::Permanent)?;
246
247 if !status.is_success() {
249 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
250 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
251 .map_err(backoff::Error::Permanent)?;
252
253 if status.as_u16() == 429 {
254 tracing::warn!("Rate limited: {}", api_error.message);
256 return Err(backoff::Error::Transient {
257 err: DashScopeError::ApiError(api_error),
258 retry_after: None,
259 });
260 } else {
261 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
262 api_error,
263 )));
264 }
265 }
266
267 Ok(bytes)
268 })
269 .await
270 }
271
272 pub fn config(&self) -> &Config {
273 &self.config
274 }
275}
276
277pub(crate) async fn stream<O>(
278 mut event_source: EventSource,
279) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
280where
281 O: DeserializeOwned + std::marker::Send + 'static,
282{
283 let stream = try_stream! {
284 while let Some(ev) = event_source.next().await {
285 match ev {
286 Err(e) => {
287 Err(DashScopeError::StreamError(e.to_string()))?;
288 }
289 Ok(Event::Open) => continue,
290 Ok(Event::Message(message)) => {
291 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
293 Ok(val) => val,
294 Err(e) => {
295 Err(map_deserialization_error(e, message.data.as_bytes()))?;
296 continue;
297 }
298 };
299
300 let response = serde_json::from_value::<O>(json_value.clone())
302 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
303
304 yield response;
306
307 let finish_reason = json_value
310 .pointer("/output/choices/0/finish_reason")
311 .and_then(|v| v.as_str());
312
313 if let Some("stop") = finish_reason {
314 break;
315 }
316 }
317 }
318 }
319 event_source.close();
320 };
321
322 Box::pin(stream)
323}
324
325#[cfg(test)]
326mod tests {
327 use crate::config::ConfigBuilder;
328
329 use super::*;
330
331 #[test]
332 pub fn test_config() {
333 let config = ConfigBuilder::default()
334 .api_key("test key")
335 .build()
336 .unwrap();
337 let client = Client::with_config(config);
338
339 for header in client.config.headers().iter() {
340 if header.0 == "authorization" {
341 assert_eq!(header.1, "Bearer test key");
342 }
343 }
344 }
345}