async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use async_stream::try_stream;
4use bytes::Bytes;
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
6use serde::{de::DeserializeOwned, Serialize};
7use tokio_stream::{Stream, StreamExt as _};
8
9use crate::{
10 config::Config,
11 error::{map_deserialization_error, ApiError, DashScopeError},
12};
13
14#[derive(Debug, Default, Clone)]
15pub struct Client {
16 http_client: reqwest::Client,
17 config: Config,
18 backoff: backoff::ExponentialBackoff,
19}
20
21impl Client {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_config(config: Config) -> Self {
27 Self {
28 http_client: reqwest::Client::new(),
29 config,
30 backoff: backoff::ExponentialBackoff::default(),
31 }
32 }
33 pub fn with_api_key(mut self, api_key: String) -> Self {
34 self.config.set_api_key(api_key.into());
35 self
36 }
37
38 pub fn build(
39 http_client: reqwest::Client,
40 config: Config,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
59 crate::operation::generation::Generation::new(self)
60 }
61
62 pub fn multi_modal_conversation(
69 &self,
70 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
71 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
72 }
73
74 pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
76 crate::operation::audio::Audio::new(self)
77 }
78
79 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
87 crate::operation::embeddings::Embeddings::new(self)
88 }
89
90 pub(crate) async fn post_stream<I, O>(
91 &self,
92 path: &str,
93 request: I,
94 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
95 where
96 I: Serialize,
97 O: DeserializeOwned + std::marker::Send + 'static,
98 {
99 let event_source = self
100 .http_client
101 .post(self.config.url(path))
102 .headers(self.config.headers())
103 .json(&request)
104 .eventsource()?;
105
106 Ok(stream(event_source).await)
107 }
108
109 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
110 where
111 I: Serialize + Debug,
112 O: DeserializeOwned,
113 {
114 let request_maker = || async {
115 Ok(self
116 .http_client
117 .post(self.config.url(path))
118 .headers(self.config.headers())
119 .json(&request)
120 .build()?)
121 };
122
123 self.execute(request_maker).await
124 }
125
126 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
127 where
128 O: DeserializeOwned,
129 M: Fn() -> Fut,
130 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
131 {
132 let bytes = self.execute_raw(request_maker).await?;
133
134 let response: O = serde_json::from_slice(bytes.as_ref())
135 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
136
137 Ok(response)
138 }
139
140 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
141 where
142 M: Fn() -> Fut,
143 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
144 {
145 let client = self.http_client.clone();
146
147 backoff::future::retry(self.backoff.clone(), || async {
148 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
149 let response = client
150 .execute(request)
151 .await
152 .map_err(DashScopeError::Reqwest)
153 .map_err(backoff::Error::Permanent)?;
154
155 let status = response.status();
156 let bytes = response
157 .bytes()
158 .await
159 .map_err(DashScopeError::Reqwest)
160 .map_err(backoff::Error::Permanent)?;
161
162 if !status.is_success() {
164 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
165 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
166 .map_err(backoff::Error::Permanent)?;
167
168 if status.as_u16() == 429 {
169 tracing::warn!("Rate limited: {}", api_error.message);
171 return Err(backoff::Error::Transient {
172 err: DashScopeError::ApiError(api_error),
173 retry_after: None,
174 });
175 } else {
176 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
177 api_error,
178 )));
179 }
180 }
181
182 Ok(bytes)
183 })
184 .await
185 }
186
187 pub fn config(&self) -> &Config {
188 &self.config
189 }
190}
191
192pub(crate) async fn stream<O>(
193 mut event_source: EventSource,
194) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
195where
196 O: DeserializeOwned + std::marker::Send + 'static,
197{
198 let stream = try_stream! {
199 while let Some(ev) = event_source.next().await {
200 match ev {
201 Err(e) => {
202 Err(DashScopeError::StreamError(e.to_string()))?;
203 }
204 Ok(Event::Open) => continue,
205 Ok(Event::Message(message)) => {
206 let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
208 Ok(val) => val,
209 Err(e) => {
210 Err(map_deserialization_error(e, message.data.as_bytes()))?;
211 continue;
212 }
213 };
214
215 let response = serde_json::from_value::<O>(json_value.clone())
217 .map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
218
219 yield response;
221
222 let finish_reason = json_value
225 .pointer("/output/choices/0/finish_reason")
226 .and_then(|v| v.as_str());
227
228 if let Some("stop") = finish_reason {
229 break;
230 }
231 }
232 }
233 }
234 event_source.close();
235 };
236
237 Box::pin(stream)
238}
239
240#[cfg(test)]
241mod tests {
242 use crate::config::ConfigBuilder;
243
244 use super::*;
245
246 #[test]
247 pub fn test_config() {
248 let config = ConfigBuilder::default()
249 .api_key("test key")
250 .build()
251 .unwrap();
252 let client = Client::with_config(config);
253
254 for header in client.config.headers().iter() {
255 if header.0 == "authorization" {
256 assert_eq!(header.1, "Bearer test key");
257 }
258 }
259 }
260}