async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{Serialize, de::DeserializeOwned};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{ApiError, DashScopeError, map_deserialization_error},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 http_client: reqwest::Client,
17 config: Config,
18 backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59 crate::operation::generation::Generation::new(self)
60 }
61
62 pub fn multi_modal_conversation(
69 &self,
70 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72 }
73
74 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76 crate::operation::audio::Audio::new(self)
77 }
78
79 pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
92 crate::operation::image2image::Image2Image::new(self)
93 }
94
95 pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
106 crate::operation::text2image::Text2Image::new(self)
107 }
108
109 pub fn http_client(&self) -> reqwest::Client {
110 self.http_client.clone()
111 }
112
113 pub fn task(&self) -> crate::operation::task::Task<'_> {
118 crate::operation::task::Task::new(self)
119 }
120
121 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
129 crate::operation::embeddings::Embeddings::new(self)
130 }
131
132 pub(crate) async fn post_stream<I, O>(
133 &self,
134 path: &str,
135 request: I,
136 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
137 where
138 I: Serialize,
139 O: DeserializeOwned + std::marker::Send + 'static,
140 {
141 let event_source = self
142 .http_client
143 .post(self.config.url(path))
144 .headers(self.config.headers())
145 .json(&request)
146 .eventsource()?;
147
148 Ok(stream(event_source).await)
149 }
150
151 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
152 where
153 I: Serialize + Debug,
154 O: DeserializeOwned,
155 {
156 self.post_with_headers(path, request, self.config().headers())
157 .await
158 }
159
160 pub(crate) async fn post_with_headers<I, O>(
176 &self,
177 path: &str,
178 request: I,
179 headers: reqwest::header::HeaderMap,
180 ) -> Result<O, DashScopeError>
181 where
182 I: Serialize + Debug,
183 O: DeserializeOwned,
184 {
185 let request_maker = || async {
186 Ok(self
187 .http_client
188 .post(self.config.url(path))
189 .headers(headers.clone())
190 .json(&request)
191 .build()?)
192 };
193
194 self.execute(request_maker).await
195 }
196
197 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
198 where
199 O: DeserializeOwned,
200 M: Fn() -> Fut,
201 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
202 {
203 let bytes = self.execute_raw(request_maker).await?;
204
205 let response: O = serde_json::from_slice(bytes.as_ref())
206 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
207
208 Ok(response)
209 }
210
211 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
212 where
213 M: Fn() -> Fut,
214 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
215 {
216 let client = self.http_client.clone();
217
218 backoff::future::retry(self.backoff.clone(), || async {
219 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
220 let response = client
221 .execute(request)
222 .await
223 .map_err(DashScopeError::Reqwest)
224 .map_err(backoff::Error::Permanent)?;
225
226 let status = response.status();
227 let bytes = response
228 .bytes()
229 .await
230 .map_err(DashScopeError::Reqwest)
231 .map_err(backoff::Error::Permanent)?;
232
233 if !status.is_success() {
235 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
236 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
237 .map_err(backoff::Error::Permanent)?;
238
239 if status.as_u16() == 429 {
240 tracing::warn!("Rate limited: {}", api_error.message);
242 return Err(backoff::Error::Transient {
243 err: DashScopeError::ApiError(api_error),
244 retry_after: None,
245 });
246 } else {
247 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
248 api_error,
249 )));
250 }
251 }
252
253 Ok(bytes)
254 })
255 .await
256 }
257
258 pub fn config(&self) -> &Config {
259 &self.config
260 }
261}
262
263pub(crate) async fn stream<O>(
264 mut event_source: EventSource,
265) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
266where
267 O: DeserializeOwned + std::marker::Send + 'static,
268{
269 let stream = try_stream! {
270 while let Some(ev) = event_source.next().await {
271 match ev {
272 Err(e) => {
273 Err(DashScopeError::StreamError(e.to_string()))?;
274 }
275 Ok(Event::Open) => continue,
276 Ok(Event::Message(message)) => {
277 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
279 Ok(val) => val,
280 Err(e) => {
281 Err(map_deserialization_error(e, message.data.as_bytes()))?;
282 continue;
283 }
284 };
285
286 let response = serde_json::from_value::<O>(json_value.clone())
288 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
289
290 yield response;
292
293 let finish_reason = json_value
296 .pointer("/output/choices/0/finish_reason")
297 .and_then(|v| v.as_str());
298
299 if let Some("stop") = finish_reason {
300 break;
301 }
302 }
303 }
304 }
305 event_source.close();
306 };
307
308 Box::pin(stream)
309}
310
311#[cfg(test)]
312mod tests {
313 use crate::config::ConfigBuilder;
314
315 use super::*;
316
317 #[test]
318 pub fn test_config() {
319 let config = ConfigBuilder::default()
320 .api_key("test key")
321 .build()
322 .unwrap();
323 let client = Client::with_config(config);
324
325 for header in client.config.headers().iter() {
326 if header.0 == "authorization" {
327 assert_eq!(header.1, "Bearer test key");
328 }
329 }
330 }
331}