async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use bytes::Bytes;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tokio_stream::{Stream, StreamExt as _};
7
8use crate::{
9 config::Config,
10 error::{map_deserialization_error, ApiError, DashScopeError},
11};
12
13#[derive(Debug, Default, Clone)]
14pub struct Client {
15 http_client: reqwest::Client,
16 config: Config,
17 backoff: backoff::ExponentialBackoff,
18}
19
20impl Client {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn with_config(config: Config) -> Self {
26 Self {
27 http_client: reqwest::Client::new(),
28 config,
29 backoff: backoff::ExponentialBackoff::default(),
30 }
31 }
32 pub fn with_api_key(mut self, api_key: String) -> Self {
33 self.config.set_api_key(api_key.into());
34 self
35 }
36
37 pub fn build(
38 http_client: reqwest::Client,
39 config: Config,
40 backoff: backoff::ExponentialBackoff,
41 ) -> Self {
42 Self {
43 http_client,
44 config,
45 backoff,
46 }
47 }
48
49 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
58 crate::operation::generation::Generation::new(self)
59 }
60
61 pub fn multi_modal_conversation(
68 &self,
69 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
70 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
71 }
72
73 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
81 crate::operation::embeddings::Embeddings::new(self)
82 }
83
84 pub(crate) async fn post_stream<I, O>(
85 &self,
86 path: &str,
87 request: I,
88 ) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
89 where
90 I: Serialize,
91 O: DeserializeOwned + std::marker::Send + 'static,
92 {
93 let event_source = self
94 .http_client
95 .post(self.config.url(path))
96 .headers(self.config.headers())
97 .json(&request)
98 .eventsource()
99 .unwrap();
100
101 stream(event_source).await
102 }
103
104 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
105 where
106 I: Serialize + Debug,
107 O: DeserializeOwned,
108 {
109 let request_maker = || async {
111 Ok(self
112 .http_client
113 .post(self.config.url(path))
114 .headers(self.config.headers())
115 .json(&request)
116 .build()?)
117 };
118
119 self.execute(request_maker).await
120 }
121
122 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
123 where
124 O: DeserializeOwned,
125 M: Fn() -> Fut,
126 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
127 {
128 let bytes = self.execute_raw(request_maker).await?;
129
130 let s = String::from_utf8(bytes.to_vec()).unwrap();
132 dbg!(s);
133
134 let response: O = serde_json::from_slice(bytes.as_ref())
135 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
136
137 Ok(response)
138 }
139
140 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
141 where
142 M: Fn() -> Fut,
143 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
144 {
145 let client = self.http_client.clone();
146
147 backoff::future::retry(self.backoff.clone(), || async {
148 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
149 let response = client
150 .execute(request)
151 .await
152 .map_err(DashScopeError::Reqwest)
153 .map_err(backoff::Error::Permanent)?;
154
155 let status = response.status();
156 let bytes = response
157 .bytes()
158 .await
159 .map_err(DashScopeError::Reqwest)
160 .map_err(backoff::Error::Permanent)?;
161
162 if !status.is_success() {
164 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
167 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
168 .map_err(backoff::Error::Permanent)?;
169
170 if status.as_u16() == 429 {
171 tracing::warn!("Rate limited: {}", api_error.message);
173 return Err(backoff::Error::Transient {
174 err: DashScopeError::ApiError(api_error),
175 retry_after: None,
176 });
177 } else {
178 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
179 api_error,
180 )));
181 }
182 }
183
184 Ok(bytes)
185 })
186 .await
187 }
188}
189
190pub(crate) async fn stream<O>(
191 mut event_source: EventSource,
192) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
193where
194 O: DeserializeOwned + std::marker::Send + 'static,
195{
196 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
197
198 tokio::spawn(async move {
199 while let Some(ev) = event_source.next().await {
200 match ev {
201 Err(e) => {
202 if let Err(_e) = tx.send(Err(DashScopeError::StreamError(e.to_string()))) {
203 break;
205 }
206 }
207 Ok(event) => match event {
208 Event::Message(message) => {
209 #[derive(Deserialize, Debug)]
210 struct Result {
211 output: Output,
212 }
213 #[derive(Deserialize, Debug)]
214 struct Output {
215 choices: Vec<Choices>,
216 }
217 #[derive(Deserialize, Debug)]
218 struct Choices {
219 finish_reason: Option<String>,
220 }
221
222 let r = match serde_json::from_str::<Result>(&message.data) {
223 Ok(r) => r,
224 Err(e) => {
225 if let Err(_e) = tx.send(Err(map_deserialization_error(
226 e,
227 message.data.as_bytes(),
228 ))) {
229 break;
230 }
231 continue;
232 }
233 };
234 if let Some(finish_reason) = r.output.choices[0].finish_reason.clone() {
235 if finish_reason == "stop" {
236 break;
237 }
238 }
239
240 let response = match serde_json::from_str::<O>(&message.data) {
241 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
242 Ok(output) => Ok(output),
243 };
244
245 if let Err(_e) = tx.send(response) {
246 break;
248 }
249 }
250 Event::Open => continue,
251 },
252 }
253 }
254
255 event_source.close();
256 });
257
258 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
259}
260
261#[cfg(test)]
262mod tests {
263 use crate::config::ConfigBuilder;
264
265 use super::*;
266
267 #[test]
268 pub fn test_config() {
269 let config = ConfigBuilder::default()
270 .api_key("test key")
271 .build()
272 .unwrap();
273 let client = Client::with_config(config);
274
275 for header in client.config.headers().iter() {
276 if header.0 == "authorization" {
277 assert_eq!(header.1, "Bearer test key");
278 }
279 }
280 }
281}