async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{de::DeserializeOwned, Serialize};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{map_deserialization_error, ApiError, DashScopeError},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 http_client: reqwest::Client,
17 config: Config,
18 backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59 crate::operation::generation::Generation::new(self)
60 }
61
62 pub fn multi_modal_conversation(
69 &self,
70 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72 }
73
74 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76 crate::operation::audio::Audio::new(self)
77 }
78
79 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
87 crate::operation::embeddings::Embeddings::new(self)
88 }
89
90 pub(crate) async fn post_stream<I, O>(
91 &self,
92 path: &str,
93 request: I,
94 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
95 where
96 I: Serialize,
97 O: DeserializeOwned + std::marker::Send + 'static,
98 {
99 let event_source = self
100 .http_client
101 .post(self.config.url(path))
102 .headers(self.config.headers())
103 .json(&request)
104 .eventsource()?;
105
106 Ok(stream(event_source).await)
107 }
108
109 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
110 where
111 I: Serialize + Debug,
112 O: DeserializeOwned,
113 {
114 let request_maker = || async {
115 Ok(self
116 .http_client
117 .post(self.config.url(path))
118 .headers(self.config.headers())
119 .json(&request)
120 .build()?)
121 };
122
123 self.execute(request_maker).await
124 }
125
126 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
127 where
128 O: DeserializeOwned,
129 M: Fn() -> Fut,
130 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
131 {
132 let bytes = self.execute_raw(request_maker).await?;
133
134 let response: O = serde_json::from_slice(bytes.as_ref())
135 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
136
137 Ok(response)
138 }
139
140 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
141 where
142 M: Fn() -> Fut,
143 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
144 {
145 let client = self.http_client.clone();
146
147 backoff::future::retry(self.backoff.clone(), || async {
148 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
149 let response = client
150 .execute(request)
151 .await
152 .map_err(DashScopeError::Reqwest)
153 .map_err(backoff::Error::Permanent)?;
154
155 let status = response.status();
156 let bytes = response
157 .bytes()
158 .await
159 .map_err(DashScopeError::Reqwest)
160 .map_err(backoff::Error::Permanent)?;
161
162 if !status.is_success() {
164 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
165 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
166 .map_err(backoff::Error::Permanent)?;
167
168 if status.as_u16() == 429 {
169 tracing::warn!("Rate limited: {}", api_error.message);
171 return Err(backoff::Error::Transient {
172 err: DashScopeError::ApiError(api_error),
173 retry_after: None,
174 });
175 } else {
176 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
177 api_error,
178 )));
179 }
180 }
181
182 Ok(bytes)
183 })
184 .await
185 }
186}
187
188pub(crate) async fn stream<O>(
189 mut event_source: EventSource,
190) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
191where
192 O: DeserializeOwned + std::marker::Send + 'static,
193{
194 let stream = try_stream! {
195 while let Some(ev) = event_source.next().await {
196 match ev {
197 Err(e) => {
198 Err(DashScopeError::StreamError(e.to_string()))?;
199 }
200 Ok(Event::Open) => continue,
201 Ok(Event::Message(message)) => {
202 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
204 Ok(val) => val,
205 Err(e) => {
206 Err(map_deserialization_error(e, message.data.as_bytes()))?;
207 continue;
208 }
209 };
210
211 let response = serde_json::from_value::<O>(json_value.clone())
213 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
214
215 yield response;
217
218 let finish_reason = json_value
221 .pointer("/output/choices/0/finish_reason")
222 .and_then(|v| v.as_str());
223
224 if let Some("stop") = finish_reason {
225 break;
226 }
227 }
228 }
229 }
230 event_source.close();
231 };
232
233 Box::pin(stream)
234}
235
236#[cfg(test)]
237mod tests {
238 use crate::config::ConfigBuilder;
239
240 use super::*;
241
242 #[test]
243 pub fn test_config() {
244 let config = ConfigBuilder::default()
245 .api_key("test key")
246 .build()
247 .unwrap();
248 let client = Client::with_config(config);
249
250 for header in client.config.headers().iter() {
251 if header.0 == "authorization" {
252 assert_eq!(header.1, "Bearer test key");
253 }
254 }
255 }
256}