1use rig::client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
4use rig::http_client::{self, HttpClientExt};
5
6use super::BAILIAN_API_BASE_URL;
7use super::completion::CompletionModel;
8use super::embedding::EmbeddingModel;
9use super::rerank::RerankModel;
10
11#[derive(Clone)]
13pub struct Client<T = reqwest::Client> {
14 pub(crate) base_url: String,
15 pub(crate) api_key: String,
16 pub(crate) http_client: T,
17}
18
19impl<T> std::fmt::Debug for Client<T>
20where
21 T: std::fmt::Debug,
22{
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("Client")
25 .field("base_url", &self.base_url)
26 .field("http_client", &self.http_client)
27 .field("api_key", &"<REDACTED>")
28 .finish()
29 }
30}
31
32#[derive(Clone)]
34pub struct ClientBuilder<'a, T = reqwest::Client> {
35 api_key: &'a str,
36 base_url: &'a str,
37 http_client: T,
38}
39
40impl<'a, T> ClientBuilder<'a, T>
41where
42 T: Default,
43{
44 pub fn new(api_key: &'a str) -> Self {
45 Self {
46 api_key,
47 base_url: BAILIAN_API_BASE_URL,
48 http_client: Default::default(),
49 }
50 }
51}
52
53impl<'a, T> ClientBuilder<'a, T> {
54 pub fn base_url(mut self, base_url: &'a str) -> Self {
55 self.base_url = base_url;
56 self
57 }
58
59 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
60 ClientBuilder {
61 api_key: self.api_key,
62 base_url: self.base_url,
63 http_client,
64 }
65 }
66
67 pub fn build(self) -> Client<T> {
68 Client {
69 base_url: self.base_url.to_string(),
70 api_key: self.api_key.to_string(),
71 http_client: self.http_client,
72 }
73 }
74}
75
76impl<T> Client<T>
77where
78 T: Default,
79{
80 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
81 ClientBuilder::new(api_key)
82 }
83
84 pub fn new(api_key: &str) -> Self {
85 Self::builder(api_key).build()
86 }
87}
88
89impl<T> Client<T>
90where
91 T: HttpClientExt,
92{
93 pub(crate) fn url(&self, path: &str) -> String {
94 format!("{}/{}", self.base_url, path.trim_start_matches('/'))
95 }
96
97 fn req(
98 &self,
99 method: http_client::Method,
100 path: &str,
101 ) -> http_client::Result<http_client::Builder> {
102 let url = self.url(path);
103 http_client::with_bearer_auth(
104 http_client::Builder::new().method(method).uri(url),
105 &self.api_key,
106 )
107 }
108
109 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
110 self.req(http_client::Method::GET, path)
111 }
112
113 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
114 self.req(http_client::Method::POST, path)
115 }
116}
117
118impl Client<reqwest::Client> {
119 pub fn rerank_model(&self, model: &str, endpoint: Option<String>) -> RerankModel {
121 RerankModel::new(self.clone(), model, endpoint)
122 }
123}
124
125impl ProviderClient for Client<reqwest::Client> {
126 type Input = String;
127
128 fn from_env() -> Self {
129 let api_key = std::env::var("BAILIAN_API_KEY").expect("BAILIAN_API_KEY not set");
130 let base_url = std::env::var("BAILIAN_BASE_URL")
131 .ok()
132 .unwrap_or_else(|| BAILIAN_API_BASE_URL.to_string());
133 Self::builder(&api_key).base_url(&base_url).build()
134 }
135
136 fn from_val(input: String) -> Self {
137 Self::new(&input)
138 }
139}
140
141impl CompletionClient for Client<reqwest::Client> {
142 type CompletionModel = CompletionModel<reqwest::Client>;
143
144 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
145 CompletionModel::new(self.clone(), &model.into())
146 }
147}
148
149impl EmbeddingsClient for Client<reqwest::Client> {
150 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
151
152 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
153 EmbeddingModel::new(self.clone(), &model.into(), 0)
154 }
155
156 fn embedding_model_with_ndims(
157 &self,
158 model: impl Into<String>,
159 ndims: usize,
160 ) -> Self::EmbeddingModel {
161 EmbeddingModel::new(self.clone(), &model.into(), ndims)
162 }
163}
164
165impl VerifyClient for Client<reqwest::Client> {
166 async fn verify(&self) -> Result<(), VerifyError> {
167 let req = self
168 .get("/models")?
169 .body(rig::http_client::NoBody)
170 .map_err(rig::http_client::Error::from)?;
171
172 let response = HttpClientExt::send(&self.http_client, req).await?;
173
174 match response.status() {
175 reqwest::StatusCode::OK => Ok(()),
176 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
177 reqwest::StatusCode::INTERNAL_SERVER_ERROR
178 | reqwest::StatusCode::SERVICE_UNAVAILABLE
179 | reqwest::StatusCode::BAD_GATEWAY => {
180 let text = rig::http_client::text(response).await?;
181 Err(VerifyError::ProviderError(text))
182 }
183 _ => Ok(()),
184 }
185 }
186}