rig_bailian/
client.rs

1//! Category: client.rs (Client and Builder; implements Provider/Verify/Completion/Embedding)
2
3use rig::client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
4use rig::http_client::{self, HttpClientExt};
5
6use super::BAILIAN_API_BASE_URL;
7use super::completion::CompletionModel;
8use super::embedding::EmbeddingModel;
9use super::rerank::RerankModel;
10
11/// Provider client: Client<T>
12#[derive(Clone)]
13pub struct Client<T = reqwest::Client> {
14    pub(crate) base_url: String,
15    pub(crate) api_key: String,
16    pub(crate) http_client: T,
17}
18
19impl<T> std::fmt::Debug for Client<T>
20where
21    T: std::fmt::Debug,
22{
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("Client")
25            .field("base_url", &self.base_url)
26            .field("http_client", &self.http_client)
27            .field("api_key", &"<REDACTED>")
28            .finish()
29    }
30}
31
32/// Client builder: ClientBuilder<'a, T>
33#[derive(Clone)]
34pub struct ClientBuilder<'a, T = reqwest::Client> {
35    api_key: &'a str,
36    base_url: &'a str,
37    http_client: T,
38}
39
40impl<'a, T> ClientBuilder<'a, T>
41where
42    T: Default,
43{
44    pub fn new(api_key: &'a str) -> Self {
45        Self {
46            api_key,
47            base_url: BAILIAN_API_BASE_URL,
48            http_client: Default::default(),
49        }
50    }
51}
52
53impl<'a, T> ClientBuilder<'a, T> {
54    pub fn base_url(mut self, base_url: &'a str) -> Self {
55        self.base_url = base_url;
56        self
57    }
58
59    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
60        ClientBuilder {
61            api_key: self.api_key,
62            base_url: self.base_url,
63            http_client,
64        }
65    }
66
67    pub fn build(self) -> Client<T> {
68        Client {
69            base_url: self.base_url.to_string(),
70            api_key: self.api_key.to_string(),
71            http_client: self.http_client,
72        }
73    }
74}
75
76impl<T> Client<T>
77where
78    T: Default,
79{
80    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
81        ClientBuilder::new(api_key)
82    }
83
84    pub fn new(api_key: &str) -> Self {
85        Self::builder(api_key).build()
86    }
87}
88
89impl<T> Client<T>
90where
91    T: HttpClientExt,
92{
93    pub(crate) fn url(&self, path: &str) -> String {
94        format!("{}/{}", self.base_url, path.trim_start_matches('/'))
95    }
96
97    fn req(
98        &self,
99        method: http_client::Method,
100        path: &str,
101    ) -> http_client::Result<http_client::Builder> {
102        let url = self.url(path);
103        http_client::with_bearer_auth(
104            http_client::Builder::new().method(method).uri(url),
105            &self.api_key,
106        )
107    }
108
109    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
110        self.req(http_client::Method::GET, path)
111    }
112
113    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
114        self.req(http_client::Method::POST, path)
115    }
116}
117
118impl Client<reqwest::Client> {
119    /// Create a rerank model bound to this client (DashScope endpoint).
120    pub fn rerank_model(&self, model: &str, endpoint: Option<String>) -> RerankModel {
121        RerankModel::new(self.clone(), model, endpoint)
122    }
123}
124
125impl ProviderClient for Client<reqwest::Client> {
126    type Input = String;
127
128    fn from_env() -> Self {
129        let api_key = std::env::var("BAILIAN_API_KEY").expect("BAILIAN_API_KEY not set");
130        let base_url = std::env::var("BAILIAN_BASE_URL")
131            .ok()
132            .unwrap_or_else(|| BAILIAN_API_BASE_URL.to_string());
133        Self::builder(&api_key).base_url(&base_url).build()
134    }
135
136    fn from_val(input: String) -> Self {
137        Self::new(&input)
138    }
139}
140
141impl CompletionClient for Client<reqwest::Client> {
142    type CompletionModel = CompletionModel<reqwest::Client>;
143
144    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
145        CompletionModel::new(self.clone(), &model.into())
146    }
147}
148
149impl EmbeddingsClient for Client<reqwest::Client> {
150    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
151
152    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
153        EmbeddingModel::new(self.clone(), &model.into(), 0)
154    }
155
156    fn embedding_model_with_ndims(
157        &self,
158        model: impl Into<String>,
159        ndims: usize,
160    ) -> Self::EmbeddingModel {
161        EmbeddingModel::new(self.clone(), &model.into(), ndims)
162    }
163}
164
165impl VerifyClient for Client<reqwest::Client> {
166    async fn verify(&self) -> Result<(), VerifyError> {
167        let req = self
168            .get("/models")?
169            .body(rig::http_client::NoBody)
170            .map_err(rig::http_client::Error::from)?;
171
172        let response = HttpClientExt::send(&self.http_client, req).await?;
173
174        match response.status() {
175            reqwest::StatusCode::OK => Ok(()),
176            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
177            reqwest::StatusCode::INTERNAL_SERVER_ERROR
178            | reqwest::StatusCode::SERVICE_UNAVAILABLE
179            | reqwest::StatusCode::BAD_GATEWAY => {
180                let text = rig::http_client::text(response).await?;
181                Err(VerifyError::ProviderError(text))
182            }
183            _ => Ok(()),
184        }
185    }
186}