rig_volcengine/
client.rs

1use rig::client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use rig::http_client::{self, HttpClientExt};
3
4use super::VOLCENGINE_API_BASE_URL;
5use super::completion::CompletionModel;
6use super::embedding::EmbeddingModel;
7
8/// Provider client: Client<T>
9#[derive(Clone)]
10pub struct Client<T = reqwest::Client> {
11    pub(crate) base_url: String,
12    pub(crate) api_key: String,
13    pub(crate) http_client: T,
14}
15
16impl<T> std::fmt::Debug for Client<T>
17where
18    T: std::fmt::Debug,
19{
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("Client")
22            .field("base_url", &self.base_url)
23            .field("http_client", &self.http_client)
24            .field("api_key", &"<REDACTED>")
25            .finish()
26    }
27}
28
29/// Client builder: ClientBuilder<'a, T>
30#[derive(Clone)]
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32    api_key: &'a str,
33    base_url: &'a str,
34    http_client: T,
35}
36
37impl<'a, T> ClientBuilder<'a, T>
38where
39    T: Default,
40{
41    pub fn new(api_key: &'a str) -> Self {
42        Self {
43            api_key,
44            base_url: VOLCENGINE_API_BASE_URL,
45            http_client: Default::default(),
46        }
47    }
48}
49
50impl<'a, T> ClientBuilder<'a, T> {
51    pub fn base_url(mut self, base_url: &'a str) -> Self {
52        self.base_url = base_url;
53        self
54    }
55
56    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
57        ClientBuilder {
58            api_key: self.api_key,
59            base_url: self.base_url,
60            http_client,
61        }
62    }
63
64    pub fn build(self) -> Client<T> {
65        Client {
66            base_url: self.base_url.to_string(),
67            api_key: self.api_key.to_string(),
68            http_client: self.http_client,
69        }
70    }
71}
72
73impl<T> Client<T>
74where
75    T: Default,
76{
77    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
78        ClientBuilder::new(api_key)
79    }
80
81    pub fn new(api_key: &str) -> Self {
82        Self::builder(api_key).build()
83    }
84}
85
86impl<T> Client<T>
87where
88    T: HttpClientExt,
89{
90    pub(crate) fn url(&self, path: &str) -> String {
91        format!("{}/{}", self.base_url, path.trim_start_matches('/'))
92    }
93
94    fn req(
95        &self,
96        method: http_client::Method,
97        path: &str,
98    ) -> http_client::Result<http_client::Builder> {
99        let url = self.url(path);
100        http_client::with_bearer_auth(
101            http_client::Builder::new().method(method).uri(url),
102            &self.api_key,
103        )
104    }
105
106    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
107        self.req(http_client::Method::GET, path)
108    }
109
110    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
111        self.req(http_client::Method::POST, path)
112    }
113}
114
115impl Client<reqwest::Client> {
116    pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
117        self.http_client
118            .post(self.url(path))
119            .bearer_auth(&self.api_key)
120    }
121}
122
123impl ProviderClient for Client<reqwest::Client> {
124    fn from_env() -> Self
125    where
126        Self: Sized,
127    {
128        let api_key = std::env::var("VOLCENGINE_API_KEY").expect("VOLCENGINE_API_KEY not set");
129        let base_url = std::env::var("VOLCENGINE_BASE_URL")
130            .ok()
131            .unwrap_or_else(|| VOLCENGINE_API_BASE_URL.to_string());
132        Self::builder(&api_key).base_url(&base_url).build()
133    }
134
135    fn from_val(input: rig::client::ProviderValue) -> Self
136    where
137        Self: Sized,
138    {
139        let rig::client::ProviderValue::Simple(api_key) = input else {
140            panic!("Incorrect provider value type")
141        };
142        Self::new(&api_key)
143    }
144}
145
146impl CompletionClient for Client<reqwest::Client> {
147    type CompletionModel = CompletionModel<reqwest::Client>;
148
149    fn completion_model(&self, model: &str) -> Self::CompletionModel {
150        CompletionModel::new(self.clone(), model)
151    }
152}
153
154impl EmbeddingsClient for Client<reqwest::Client> {
155    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
156
157    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
158        EmbeddingModel::new(self.clone(), model, 0)
159    }
160
161    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
162        EmbeddingModel::new(self.clone(), model, ndims)
163    }
164}
165
166impl VerifyClient for Client<reqwest::Client> {
167    async fn verify(&self) -> Result<(), VerifyError> {
168        let req = self
169            .get("/models")?
170            .body(rig::http_client::NoBody)
171            .map_err(rig::http_client::Error::from)?;
172
173        let response = HttpClientExt::send(&self.http_client, req).await?;
174
175        match response.status() {
176            reqwest::StatusCode::OK => Ok(()),
177            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
178            reqwest::StatusCode::INTERNAL_SERVER_ERROR
179            | reqwest::StatusCode::SERVICE_UNAVAILABLE
180            | reqwest::StatusCode::BAD_GATEWAY => {
181                let text = rig::http_client::text(response).await?;
182                Err(VerifyError::ProviderError(text))
183            }
184            _ => Ok(()),
185        }
186    }
187}