use std::{fmt::Debug, pin::Pin};
use async_stream::try_stream;
use bytes::Bytes;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt as _};
use serde::{Serialize, de::DeserializeOwned};
use tokio_stream::{Stream, StreamExt as _};
use crate::{
config::Config,
error::{ApiError, DashScopeError, map_deserialization_error},
};
#[derive(Debug, Default, Clone)]
pub struct Client {
pub(crate) http_client: reqwest::Client,
pub(crate) config: Config,
pub(crate) backoff: backoff::ExponentialBackoff,
}
impl Client {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: Config) -> Self {
Self {
http_client: reqwest::Client::new(),
config,
backoff: backoff::ExponentialBackoff::default(),
}
}
pub fn with_api_key(mut self, api_key: String) -> Self {
self.config.set_api_key(api_key.into());
self
}
pub fn build(
http_client: reqwest::Client,
config: Config,
backoff: backoff::ExponentialBackoff,
) -> Self {
Self {
http_client,
config,
backoff,
}
}
pub fn files(&self) -> crate::operation::file::File<'_> {
crate::operation::file::File::new(self)
}
pub fn generation(&self) -> crate::operation::generation::Generation<'_> {
crate::operation::generation::Generation::new(self)
}
pub fn multi_modal_conversation(
&self,
) -> crate::operation::multi_modal_conversation::MultiModalConversation<'_> {
crate::operation::multi_modal_conversation::MultiModalConversation::new(self)
}
pub fn audio(&self) -> crate::operation::audio::Audio<'_> {
crate::operation::audio::Audio::new(self)
}
pub fn image2image(&self) -> crate::operation::image2image::Image2Image<'_> {
crate::operation::image2image::Image2Image::new(self)
}
pub fn text2image(&self) -> crate::operation::text2image::Text2Image<'_> {
crate::operation::text2image::Text2Image::new(self)
}
pub fn file(&self) -> crate::operation::file::File<'_> {
crate::operation::file::File::new(self)
}
pub fn http_client(&self) -> reqwest::Client {
self.http_client.clone()
}
pub fn task(&self) -> crate::operation::task::Task<'_> {
crate::operation::task::Task::new(self)
}
pub fn text_embeddings(&self) -> crate::operation::embeddings::Embeddings<'_> {
crate::operation::embeddings::Embeddings::new(self)
}
pub(crate) async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
where
I: Serialize + Debug,
O: DeserializeOwned + std::marker::Send + 'static,
{
self.post_stream_with_headers(path, request, self.config.headers())
.await
}
pub(crate) async fn post_stream_with_headers<I, O>(
&self,
path: &str,
request: I,
headers: reqwest::header::HeaderMap,
) -> Result<Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>, DashScopeError>
where
I: Serialize + Debug,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.headers(headers)
.json(&request)
.eventsource()?;
Ok(stream(event_source).await)
}
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, DashScopeError>
where
I: Serialize + Debug,
O: DeserializeOwned,
{
self.post_with_headers(path, request, self.config().headers())
.await
}
pub(crate) async fn post_with_headers<I, O>(
&self,
path: &str,
request: I,
headers: reqwest::header::HeaderMap,
) -> Result<O, DashScopeError>
where
I: Serialize + Debug,
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.headers(headers.clone())
.json(&request)
.build()?)
};
self.execute(request_maker).await
}
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, DashScopeError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
{
let bytes = self.execute_raw(request_maker).await?;
let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
Ok(response)
}
async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, DashScopeError>
where
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, DashScopeError>>,
{
let client = self.http_client.clone();
backoff::future::retry(self.backoff.clone(), || async {
let request = request_maker().await.map_err(backoff::Error::Permanent)?;
let response = client
.execute(request)
.await
.map_err(DashScopeError::Reqwest)
.map_err(backoff::Error::Permanent)?;
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(DashScopeError::Reqwest)
.map_err(backoff::Error::Permanent)?;
if !status.is_success() {
let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
if status.as_u16() == 429 {
tracing::warn!("Rate limited: {}", api_error.message);
return Err(backoff::Error::Transient {
err: DashScopeError::ApiError(api_error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(DashScopeError::ApiError(
api_error,
)));
}
}
Ok(bytes)
})
.await
}
pub fn config(&self) -> &Config {
&self.config
}
pub(crate) async fn post_multipart<O, F>(
&self,
path: &str,
form_fn: F,
) -> Result<O, DashScopeError>
where
O: DeserializeOwned,
F: Fn() -> reqwest::multipart::Form,
{
let request_maker = || async {
let mut headers = self.config.headers();
headers.remove("Content-Type");
headers.remove("X-DashScope-OssResourceResolve");
Ok(self
.http_client
.post(self.config.url(path))
.headers(headers)
.multipart(form_fn())
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn get_with_params<O, P>(&self, path: &str, params: &P) -> Result<O, DashScopeError>
where
O: DeserializeOwned,
P: serde::Serialize + ?Sized,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.headers(self.config.headers())
.query(params)
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, DashScopeError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.delete(self.config.url(path))
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
}
pub(crate) async fn stream<O>(
mut event_source: EventSource,
) -> Pin<Box<dyn Stream<Item = Result<O, DashScopeError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let stream = try_stream! {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
Err(DashScopeError::StreamError(e.to_string()))?;
}
Ok(Event::Open) => continue,
Ok(Event::Message(message)) => {
let json_value: serde_json::Value = match serde_json::from_str(&message.data) {
Ok(val) => val,
Err(e) => {
Err(map_deserialization_error(e, message.data.as_bytes()))?;
continue;
}
};
let response = serde_json::from_value::<O>(json_value.clone())
.map_err(|e| map_deserialization_error(e, message.data.as_bytes()))?;
yield response;
let finish_reason = json_value
.pointer("/output/choices/0/finish_reason")
.and_then(|v| v.as_str());
if let Some("stop") = finish_reason {
break;
}
}
}
}
event_source.close();
};
Box::pin(stream)
}
#[cfg(test)]
mod tests {
use crate::config::ConfigBuilder;
use super::*;
#[test]
pub fn test_config() {
let config = ConfigBuilder::default()
.api_key("test key")
.build()
.unwrap();
let client = Client::with_config(config);
for header in client.config.headers().iter() {
if header.0 == "authorization" {
assert_eq!(header.1, "Bearer test key");
}
}
}
}