1use rig::client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use rig::http_client::{self, HttpClientExt};
3
4use super::VOLCENGINE_API_BASE_URL;
5use super::completion::CompletionModel;
6use super::embedding::EmbeddingModel;
7
8#[derive(Clone)]
10pub struct Client<T = reqwest::Client> {
11 pub(crate) base_url: String,
12 pub(crate) api_key: String,
13 pub(crate) http_client: T,
14}
15
16impl<T> std::fmt::Debug for Client<T>
17where
18 T: std::fmt::Debug,
19{
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("Client")
22 .field("base_url", &self.base_url)
23 .field("http_client", &self.http_client)
24 .field("api_key", &"<REDACTED>")
25 .finish()
26 }
27}
28
29#[derive(Clone)]
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32 api_key: &'a str,
33 base_url: &'a str,
34 http_client: T,
35}
36
37impl<'a, T> ClientBuilder<'a, T>
38where
39 T: Default,
40{
41 pub fn new(api_key: &'a str) -> Self {
42 Self {
43 api_key,
44 base_url: VOLCENGINE_API_BASE_URL,
45 http_client: Default::default(),
46 }
47 }
48}
49
50impl<'a, T> ClientBuilder<'a, T> {
51 pub fn base_url(mut self, base_url: &'a str) -> Self {
52 self.base_url = base_url;
53 self
54 }
55
56 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
57 ClientBuilder {
58 api_key: self.api_key,
59 base_url: self.base_url,
60 http_client,
61 }
62 }
63
64 pub fn build(self) -> Client<T> {
65 Client {
66 base_url: self.base_url.to_string(),
67 api_key: self.api_key.to_string(),
68 http_client: self.http_client,
69 }
70 }
71}
72
73impl<T> Client<T>
74where
75 T: Default,
76{
77 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
78 ClientBuilder::new(api_key)
79 }
80
81 pub fn new(api_key: &str) -> Self {
82 Self::builder(api_key).build()
83 }
84}
85
86impl<T> Client<T>
87where
88 T: HttpClientExt,
89{
90 pub(crate) fn url(&self, path: &str) -> String {
91 format!("{}/{}", self.base_url, path.trim_start_matches('/'))
92 }
93
94 fn req(
95 &self,
96 method: http_client::Method,
97 path: &str,
98 ) -> http_client::Result<http_client::Builder> {
99 let url = self.url(path);
100 http_client::with_bearer_auth(
101 http_client::Builder::new().method(method).uri(url),
102 &self.api_key,
103 )
104 }
105
106 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
107 self.req(http_client::Method::GET, path)
108 }
109
110 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
111 self.req(http_client::Method::POST, path)
112 }
113}
114
115impl ProviderClient for Client<reqwest::Client> {
116 type Input = String;
117
118 fn from_env() -> Self {
119 let api_key = std::env::var("VOLCENGINE_API_KEY").expect("VOLCENGINE_API_KEY not set");
120 let base_url = std::env::var("VOLCENGINE_BASE_URL")
121 .ok()
122 .unwrap_or_else(|| VOLCENGINE_API_BASE_URL.to_string());
123 Self::builder(&api_key).base_url(&base_url).build()
124 }
125
126 fn from_val(input: String) -> Self {
127 Self::new(&input)
128 }
129}
130
131impl CompletionClient for Client<reqwest::Client> {
132 type CompletionModel = CompletionModel<reqwest::Client>;
133
134 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
135 CompletionModel::new(self.clone(), &model.into())
136 }
137}
138
139impl EmbeddingsClient for Client<reqwest::Client> {
140 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
141
142 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
143 EmbeddingModel::new(self.clone(), &model.into(), 0)
144 }
145
146 fn embedding_model_with_ndims(
147 &self,
148 model: impl Into<String>,
149 ndims: usize,
150 ) -> Self::EmbeddingModel {
151 EmbeddingModel::new(self.clone(), &model.into(), ndims)
152 }
153}
154
155impl VerifyClient for Client<reqwest::Client> {
156 async fn verify(&self) -> Result<(), VerifyError> {
157 let req = self
158 .get("/models")?
159 .body(rig::http_client::NoBody)
160 .map_err(rig::http_client::Error::from)?;
161
162 let response = HttpClientExt::send(&self.http_client, req).await?;
163
164 match response.status() {
165 reqwest::StatusCode::OK => Ok(()),
166 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
167 reqwest::StatusCode::INTERNAL_SERVER_ERROR
168 | reqwest::StatusCode::SERVICE_UNAVAILABLE
169 | reqwest::StatusCode::BAD_GATEWAY => {
170 let text = rig::http_client::text(response).await?;
171 Err(VerifyError::ProviderError(text))
172 }
173 _ => Ok(()),
174 }
175 }
176}