rig_volcengine/
client.rs

1use rig::client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use rig::http_client::{self, HttpClientExt};
3
4use super::VOLCENGINE_API_BASE_URL;
5use super::completion::CompletionModel;
6use super::embedding::EmbeddingModel;
7
8/// Provider client: Client<T>
9#[derive(Clone)]
10pub struct Client<T = reqwest::Client> {
11    pub(crate) base_url: String,
12    pub(crate) api_key: String,
13    pub(crate) http_client: T,
14}
15
16impl<T> std::fmt::Debug for Client<T>
17where
18    T: std::fmt::Debug,
19{
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("Client")
22            .field("base_url", &self.base_url)
23            .field("http_client", &self.http_client)
24            .field("api_key", &"<REDACTED>")
25            .finish()
26    }
27}
28
29/// Client builder: ClientBuilder<'a, T>
30#[derive(Clone)]
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32    api_key: &'a str,
33    base_url: &'a str,
34    http_client: T,
35}
36
37impl<'a, T> ClientBuilder<'a, T>
38where
39    T: Default,
40{
41    pub fn new(api_key: &'a str) -> Self {
42        Self {
43            api_key,
44            base_url: VOLCENGINE_API_BASE_URL,
45            http_client: Default::default(),
46        }
47    }
48}
49
50impl<'a, T> ClientBuilder<'a, T> {
51    pub fn base_url(mut self, base_url: &'a str) -> Self {
52        self.base_url = base_url;
53        self
54    }
55
56    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
57        ClientBuilder {
58            api_key: self.api_key,
59            base_url: self.base_url,
60            http_client,
61        }
62    }
63
64    pub fn build(self) -> Client<T> {
65        Client {
66            base_url: self.base_url.to_string(),
67            api_key: self.api_key.to_string(),
68            http_client: self.http_client,
69        }
70    }
71}
72
73impl<T> Client<T>
74where
75    T: Default,
76{
77    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
78        ClientBuilder::new(api_key)
79    }
80
81    pub fn new(api_key: &str) -> Self {
82        Self::builder(api_key).build()
83    }
84}
85
86impl<T> Client<T>
87where
88    T: HttpClientExt,
89{
90    pub(crate) fn url(&self, path: &str) -> String {
91        format!("{}/{}", self.base_url, path.trim_start_matches('/'))
92    }
93
94    fn req(
95        &self,
96        method: http_client::Method,
97        path: &str,
98    ) -> http_client::Result<http_client::Builder> {
99        let url = self.url(path);
100        http_client::with_bearer_auth(
101            http_client::Builder::new().method(method).uri(url),
102            &self.api_key,
103        )
104    }
105
106    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
107        self.req(http_client::Method::GET, path)
108    }
109
110    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
111        self.req(http_client::Method::POST, path)
112    }
113}
114
115impl ProviderClient for Client<reqwest::Client> {
116    type Input = String;
117
118    fn from_env() -> Self {
119        let api_key = std::env::var("VOLCENGINE_API_KEY").expect("VOLCENGINE_API_KEY not set");
120        let base_url = std::env::var("VOLCENGINE_BASE_URL")
121            .ok()
122            .unwrap_or_else(|| VOLCENGINE_API_BASE_URL.to_string());
123        Self::builder(&api_key).base_url(&base_url).build()
124    }
125
126    fn from_val(input: String) -> Self {
127        Self::new(&input)
128    }
129}
130
131impl CompletionClient for Client<reqwest::Client> {
132    type CompletionModel = CompletionModel<reqwest::Client>;
133
134    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
135        CompletionModel::new(self.clone(), &model.into())
136    }
137}
138
139impl EmbeddingsClient for Client<reqwest::Client> {
140    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
141
142    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
143        EmbeddingModel::new(self.clone(), &model.into(), 0)
144    }
145
146    fn embedding_model_with_ndims(
147        &self,
148        model: impl Into<String>,
149        ndims: usize,
150    ) -> Self::EmbeddingModel {
151        EmbeddingModel::new(self.clone(), &model.into(), ndims)
152    }
153}
154
155impl VerifyClient for Client<reqwest::Client> {
156    async fn verify(&self) -> Result<(), VerifyError> {
157        let req = self
158            .get("/models")?
159            .body(rig::http_client::NoBody)
160            .map_err(rig::http_client::Error::from)?;
161
162        let response = HttpClientExt::send(&self.http_client, req).await?;
163
164        match response.status() {
165            reqwest::StatusCode::OK => Ok(()),
166            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
167            reqwest::StatusCode::INTERNAL_SERVER_ERROR
168            | reqwest::StatusCode::SERVICE_UNAVAILABLE
169            | reqwest::StatusCode::BAD_GATEWAY => {
170                let text = rig::http_client::text(response).await?;
171                Err(VerifyError::ProviderError(text))
172            }
173            _ => Ok(()),
174        }
175    }
176}