1use rig::client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use rig::http_client::{self, HttpClientExt};
3
4use super::VOLCENGINE_API_BASE_URL;
5use super::completion::CompletionModel;
6use super::embedding::EmbeddingModel;
7
8#[derive(Clone)]
10pub struct Client<T = reqwest::Client> {
11 pub(crate) base_url: String,
12 pub(crate) api_key: String,
13 pub(crate) http_client: T,
14}
15
16impl<T> std::fmt::Debug for Client<T>
17where
18 T: std::fmt::Debug,
19{
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("Client")
22 .field("base_url", &self.base_url)
23 .field("http_client", &self.http_client)
24 .field("api_key", &"<REDACTED>")
25 .finish()
26 }
27}
28
29#[derive(Clone)]
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32 api_key: &'a str,
33 base_url: &'a str,
34 http_client: T,
35}
36
37impl<'a, T> ClientBuilder<'a, T>
38where
39 T: Default,
40{
41 pub fn new(api_key: &'a str) -> Self {
42 Self {
43 api_key,
44 base_url: VOLCENGINE_API_BASE_URL,
45 http_client: Default::default(),
46 }
47 }
48}
49
50impl<'a, T> ClientBuilder<'a, T> {
51 pub fn base_url(mut self, base_url: &'a str) -> Self {
52 self.base_url = base_url;
53 self
54 }
55
56 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
57 ClientBuilder {
58 api_key: self.api_key,
59 base_url: self.base_url,
60 http_client,
61 }
62 }
63
64 pub fn build(self) -> Client<T> {
65 Client {
66 base_url: self.base_url.to_string(),
67 api_key: self.api_key.to_string(),
68 http_client: self.http_client,
69 }
70 }
71}
72
73impl<T> Client<T>
74where
75 T: Default,
76{
77 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
78 ClientBuilder::new(api_key)
79 }
80
81 pub fn new(api_key: &str) -> Self {
82 Self::builder(api_key).build()
83 }
84}
85
86impl<T> Client<T>
87where
88 T: HttpClientExt,
89{
90 pub(crate) fn url(&self, path: &str) -> String {
91 format!("{}/{}", self.base_url, path.trim_start_matches('/'))
92 }
93
94 fn req(
95 &self,
96 method: http_client::Method,
97 path: &str,
98 ) -> http_client::Result<http_client::Builder> {
99 let url = self.url(path);
100 http_client::with_bearer_auth(
101 http_client::Builder::new().method(method).uri(url),
102 &self.api_key,
103 )
104 }
105
106 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
107 self.req(http_client::Method::GET, path)
108 }
109
110 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
111 self.req(http_client::Method::POST, path)
112 }
113}
114
115impl Client<reqwest::Client> {
116 pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
117 self.http_client
118 .post(self.url(path))
119 .bearer_auth(&self.api_key)
120 }
121}
122
123impl ProviderClient for Client<reqwest::Client> {
124 fn from_env() -> Self
125 where
126 Self: Sized,
127 {
128 let api_key = std::env::var("VOLCENGINE_API_KEY").expect("VOLCENGINE_API_KEY not set");
129 let base_url = std::env::var("VOLCENGINE_BASE_URL")
130 .ok()
131 .unwrap_or_else(|| VOLCENGINE_API_BASE_URL.to_string());
132 Self::builder(&api_key).base_url(&base_url).build()
133 }
134
135 fn from_val(input: rig::client::ProviderValue) -> Self
136 where
137 Self: Sized,
138 {
139 let rig::client::ProviderValue::Simple(api_key) = input else {
140 panic!("Incorrect provider value type")
141 };
142 Self::new(&api_key)
143 }
144}
145
146impl CompletionClient for Client<reqwest::Client> {
147 type CompletionModel = CompletionModel<reqwest::Client>;
148
149 fn completion_model(&self, model: &str) -> Self::CompletionModel {
150 CompletionModel::new(self.clone(), model)
151 }
152}
153
154impl EmbeddingsClient for Client<reqwest::Client> {
155 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
156
157 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
158 EmbeddingModel::new(self.clone(), model, 0)
159 }
160
161 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
162 EmbeddingModel::new(self.clone(), model, ndims)
163 }
164}
165
166impl VerifyClient for Client<reqwest::Client> {
167 async fn verify(&self) -> Result<(), VerifyError> {
168 let req = self
169 .get("/models")?
170 .body(rig::http_client::NoBody)
171 .map_err(rig::http_client::Error::from)?;
172
173 let response = HttpClientExt::send(&self.http_client, req).await?;
174
175 match response.status() {
176 reqwest::StatusCode::OK => Ok(()),
177 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
178 reqwest::StatusCode::INTERNAL_SERVER_ERROR
179 | reqwest::StatusCode::SERVICE_UNAVAILABLE
180 | reqwest::StatusCode::BAD_GATEWAY => {
181 let text = rig::http_client::text(response).await?;
182 Err(VerifyError::ProviderError(text))
183 }
184 _ => Ok(()),
185 }
186 }
187}