async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use bytes::Bytes;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tokio_stream::{Stream, StreamExt as _};
7
8use crate::{
9 config::Config,
10 error::{map_deserialization_error, ApiError, DashScopeError},
11};
12
13#[derive(Debug, Default, Clone)]
14pub struct Client {
15 http_client: reqwest::Client,
16 config: Config,
17 backoff: backoff::ExponentialBackoff,
18}
19
20impl Client {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn with_config(config: Config) -> Self {
26 Self {
27 http_client: reqwest::Client::new(),
28 config,
29 backoff: backoff::ExponentialBackoff::default(),
30 }
31 }
32 pub fn with_api_key(mut self, api_key: String) -> Self {
33 self.config.set_api_key(api_key.into());
34 self
35 }
36
37 pub fn build(
38 http_client: reqwest::Client,
39 config: Config,
40 backoff: backoff::ExponentialBackoff,
41 ) -> Self {
42 Self {
43 http_client,
44 config,
45 backoff,
46 }
47 }
48
49 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
58 crate::operation::generation::Generation::new(self)
59 }
60
61 pub fn multi_modal_conversation(
68 &self,
69 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
70 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
71 }
72
73 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
81 crate::operation::embeddings::Embeddings::new(self)
82 }
83
84 pub(crate) async fn post_stream<I, O>(
85 &self,
86 path: &str,
87 request: I,
88 ) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
89 where
90 I: Serialize,
91 O: DeserializeOwned + std::marker::Send + 'static,
92 {
93 let event_source = self
94 .http_client
95 .post(self.config.url(path))
96 .headers(self.config.headers())
97 .json(&request)
98 .eventsource()
99 .unwrap();
100
101 stream(event_source).await
102 }
103
104 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
105 where
106 I: Serialize + Debug,
107 O: DeserializeOwned,
108 {
109 let request_maker = || async {
111 Ok(self
112 .http_client
113 .post(self.config.url(path))
114 .headers(self.config.headers())
115 .json(&request)
116 .build()?)
117 };
118
119 self.execute(request_maker).await
120 }
121
122 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
123 where
124 O: DeserializeOwned,
125 M: Fn() -> Fut,
126 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
127 {
128 let bytes = self.execute_raw(request_maker).await?;
129
130 let response: O = serde_json::from_slice(bytes.as_ref())
134 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
135
136 Ok(response)
137 }
138
139 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
140 where
141 M: Fn() -> Fut,
142 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
143 {
144 let client = self.http_client.clone();
145
146 backoff::future::retry(self.backoff.clone(), || async {
147 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
148 let response = client
149 .execute(request)
150 .await
151 .map_err(DashScopeError::Reqwest)
152 .map_err(backoff::Error::Permanent)?;
153
154 let status = response.status();
155 let bytes = response
156 .bytes()
157 .await
158 .map_err(DashScopeError::Reqwest)
159 .map_err(backoff::Error::Permanent)?;
160
161 if !status.is_success() {
163 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
166 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
167 .map_err(backoff::Error::Permanent)?;
168
169 if status.as_u16() == 429 {
170 tracing::warn!("Rate limited: {}", api_error.message);
172 return Err(backoff::Error::Transient {
173 err: DashScopeError::ApiError(api_error),
174 retry_after: None,
175 });
176 } else {
177 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
178 api_error,
179 )));
180 }
181 }
182
183 Ok(bytes)
184 })
185 .await
186 }
187}
188
189pub(crate) async fn stream<O>(
190 mut event_source: EventSource,
191) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
192where
193 O: DeserializeOwned + std::marker::Send + 'static,
194{
195 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
196
197 tokio::spawn(async move {
198 while let Some(ev) = event_source.next().await {
199 match ev {
200 Err(e) => {
201 if let Err(_e) = tx.send(Err(DashScopeError::StreamError(e.to_string()))) {
202 break;
204 }
205 }
206 Ok(event) => match event {
207 Event::Message(message) => {
208 #[derive(Deserialize, Debug)]
209 struct Result {
210 output: Output,
211 }
212 #[derive(Deserialize, Debug)]
213 struct Output {
214 choices: Vec<Choices>,
215 }
216 #[derive(Deserialize, Debug)]
217 struct Choices {
218 finish_reason: Option<String>,
219 }
220
221 let r = match serde_json::from_str::<Result>(&message.data) {
222 Ok(r) => r,
223 Err(e) => {
224 if let Err(_e) = tx.send(Err(map_deserialization_error(
225 e,
226 message.data.as_bytes(),
227 ))) {
228 break;
229 }
230 continue;
231 }
232 };
233 if let Some(finish_reason) = r.output.choices[0].finish_reason.clone() {
234 if finish_reason == "stop" {
235 break;
236 }
237 }
238
239 let response = match serde_json::from_str::<O>(&message.data) {
240 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
241 Ok(output) => Ok(output),
242 };
243
244 if let Err(_e) = tx.send(response) {
245 break;
247 }
248 }
249 Event::Open => continue,
250 },
251 }
252 }
253
254 event_source.close();
255 });
256
257 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
258}
259
260#[cfg(test)]
261mod tests {
262 use crate::config::ConfigBuilder;
263
264 use super::*;
265
266 #[test]
267 pub fn test_config() {
268 let config = ConfigBuilder::default()
269 .api_key("test key")
270 .build()
271 .unwrap();
272 let client = Client::with_config(config);
273
274 for header in client.config.headers().iter() {
275 if header.0 == "authorization" {
276 assert_eq!(header.1, "Bearer test key");
277 }
278 }
279 }
280}