async_dashscope/
client.rs1use std::{fmt::Debug, pin::Pin};
2
3use bytes::Bytes;
4use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tokio_stream::{Stream, StreamExt as _};
7
8use crate::{
9 config::Config,
10 error::{map_deserialization_error, ApiError, DashScopeError},
11};
12
13#[derive(Debug, Default)]
14pub struct Client {
15 http_client: reqwest::Client,
16 config: Config,
17 backoff: backoff::ExponentialBackoff,
18}
19
20impl Client {
21 pub fn new() -> Self {
22 Self::default()
23 }
24 pub fn build(
25 http_client: reqwest::Client,
26 config: Config,
27 backoff: backoff::ExponentialBackoff,
28 ) -> Self {
29 Self {
30 http_client,
31 config,
32 backoff,
33 }
34 }
35
36 pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
45 crate::operation::generation::Generation::new(self)
46 }
47
48 pub fn multi_modal_conversation(
57 &self,
58 ) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
59 crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
60 }
61
62 pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
72 crate::operation::embeddings::Embeddings::new(self)
73 }
74
75 pub(crate) async fn post_stream<I, O>(
76 &self,
77 path: &str,
78 request: I,
79 ) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
80 where
81 I: Serialize,
82 O: DeserializeOwned + std::marker::Send + 'static,
83 {
84 let event_source = self
85 .http_client
86 .post(self.config.url(path))
87 .headers(self.config.headers())
88 .json(&request)
89 .eventsource()
90 .unwrap();
91
92 stream(event_source).await
93 }
94
95 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
96 where
97 I: Serialize + Debug,
98 O: DeserializeOwned,
99 {
100 dbg!(&request);
101 let request_maker = || async {
102 Ok(self
103 .http_client
104 .post(self.config.url(path))
105 .headers(self.config.headers())
106 .json(&request)
107 .build()?)
108 };
109
110 self.execute(request_maker).await
111 }
112
113 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
114 where
115 O: DeserializeOwned,
116 M: Fn() -> Fut,
117 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
118 {
119 let bytes = self.execute_raw(request_maker).await?;
120
121 let response: O = serde_json::from_slice(bytes.as_ref())
122 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
123
124 Ok(response)
125 }
126
127 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
128 where
129 M: Fn() -> Fut,
130 Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
131 {
132 let client = self.http_client.clone();
133
134 backoff::future::retry(self.backoff.clone(), || async {
135 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
136 let response = client
137 .execute(request)
138 .await
139 .map_err(DashScopeError::Reqwest)
140 .map_err(backoff::Error::Permanent)?;
141
142 let status = response.status();
143 let bytes = response
144 .bytes()
145 .await
146 .map_err(DashScopeError::Reqwest)
147 .map_err(backoff::Error::Permanent)?;
148
149 if !status.is_success() {
151 let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
154 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
155 .map_err(backoff::Error::Permanent)?;
156
157 if status.as_u16() == 429 {
158 tracing::warn!("Rate limited: {}", api_error.message);
160 return Err(backoff::Error::Transient {
161 err: DashScopeError::ApiError(api_error),
162 retry_after: None,
163 });
164 } else {
165 return Err(backoff::Error::Permanent(DashScopeError::ApiError(
166 api_error,
167 )));
168 }
169 }
170
171 Ok(bytes)
172 })
173 .await
174 }
175}
176
177pub(crate) async fn stream<O>(
178 mut event_source: EventSource,
179) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
180where
181 O: DeserializeOwned + std::marker::Send + 'static,
182{
183 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
184
185 tokio::spawn(async move {
186 while let Some(ev) = event_source.next().await {
187 match ev {
188 Err(e) => {
189 if let Err(_e) = tx.send(Err(DashScopeError::StreamError(e.to_string()))) {
190 break;
192 }
193 }
194 Ok(event) => match event {
195 Event::Message(message) => {
196 #[derive(Deserialize, Debug)]
197 struct Result {
198 output: Output,
199 }
200 #[derive(Deserialize, Debug)]
201 struct Output {
202 choices: Vec<Choices>,
203 }
204 #[derive(Deserialize, Debug)]
205 struct Choices {
206 finish_reason: Option<String>,
207 }
208
209 let r = match serde_json::from_str::<Result>(&message.data) {
210 Ok(r) => r,
211 Err(e) => {
212 if let Err(_e) = tx.send(Err(map_deserialization_error(
213 e,
214 message.data.as_bytes(),
215 ))) {
216 break;
217 }
218 continue;
219 }
220 };
221 if let Some(finish_reason) = r.output.choices[0].finish_reason.clone() {
222 if finish_reason == "stop" {
223 break;
224 }
225 }
226
227 let response = match serde_json::from_str::<O>(&message.data) {
228 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
229 Ok(output) => Ok(output),
230 };
231
232 if let Err(_e) = tx.send(response) {
233 break;
235 }
236 }
237 Event::Open => continue,
238 },
239 }
240 }
241
242 event_source.close();
243 });
244
245 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
246}