1use rig::client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use rig::http_client::{self};
3
4use super::TEI_DEFAULT_BASE_URL;
5use super::embedding::EmbeddingModel;
6
7#[derive(Clone, Debug)]
10pub struct Client<T = reqwest::Client> {
11 pub(crate) http_client: T,
12 pub(crate) endpoints: Endpoints,
13}
14
15#[derive(Clone, Debug)]
17pub struct Endpoints {
18 pub embed: String,
19 pub rerank: String,
20 pub predict: String,
21}
22
23impl Endpoints {
24 pub fn with_base(base_url: &str) -> Self {
25 let base = base_url.trim_end_matches('/');
26 Self {
27 embed: format!("{}/embed", base),
28 rerank: format!("{}/rerank", base),
29 predict: format!("{}/predict", base),
30 }
31 }
32}
33
34pub struct ClientBuilder<'a, T = reqwest::Client> {
36 base_url: &'a str,
37 http_client: T,
38 embed_endpoint: Option<&'a str>,
40 rerank_endpoint: Option<&'a str>,
41 predict_endpoint: Option<&'a str>,
42}
43
44impl<'a, T> ClientBuilder<'a, T>
45where
46 T: Default,
47{
48 pub fn new() -> Self {
49 Self {
50 base_url: TEI_DEFAULT_BASE_URL,
51 http_client: Default::default(),
52 embed_endpoint: None,
53 rerank_endpoint: None,
54 predict_endpoint: None,
55 }
56 }
57}
58
59impl<'a, T> ClientBuilder<'a, T> {
60 pub fn base_url(mut self, base_url: &'a str) -> Self {
61 self.base_url = base_url;
62 self
63 }
64
65 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
66 ClientBuilder {
67 base_url: self.base_url,
68 http_client,
69 embed_endpoint: self.embed_endpoint,
70 rerank_endpoint: self.rerank_endpoint,
71 predict_endpoint: self.predict_endpoint,
72 }
73 }
74
75 pub fn embed_endpoint(mut self, url: &'a str) -> Self {
77 self.embed_endpoint = Some(url);
78 self
79 }
80
81 pub fn rerank_endpoint(mut self, url: &'a str) -> Self {
82 self.rerank_endpoint = Some(url);
83 self
84 }
85
86 pub fn predict_endpoint(mut self, url: &'a str) -> Self {
87 self.predict_endpoint = Some(url);
88 self
89 }
90
91 pub fn build(self) -> Client<T> {
92 let mut endpoints = Endpoints::with_base(self.base_url);
93 if let Some(url) = self.embed_endpoint {
94 endpoints.embed = url.to_string();
95 }
96 if let Some(url) = self.rerank_endpoint {
97 endpoints.rerank = url.to_string();
98 }
99 if let Some(url) = self.predict_endpoint {
100 endpoints.predict = url.to_string();
101 }
102
103 Client {
104 http_client: self.http_client,
105 endpoints,
106 }
107 }
108}
109
110impl<T> Default for Client<T>
111where
112 T: Default,
113{
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl<T> Client<T>
120where
121 T: Default,
122{
123 pub fn builder<'a>() -> ClientBuilder<'a, T> {
124 ClientBuilder::new()
125 }
126
127 pub fn new() -> Self {
128 Self::builder().build()
129 }
130}
131
132impl<T> Client<T> {
134 pub(crate) fn post_full(&self, url: &str) -> http_client::Builder {
135 http_client::Builder::new()
136 .method(http_client::Method::POST)
137 .uri(url.to_string())
138 }
139}
140
141impl ProviderClient for Client<reqwest::Client> {
142 type Input = String;
143
144 fn from_env() -> Self {
145 let base_url =
146 std::env::var("TEI_BASE_URL").unwrap_or_else(|_| TEI_DEFAULT_BASE_URL.to_string());
147 Self::builder().base_url(&base_url).build()
148 }
149
150 fn from_val(input: String) -> Self {
151 ClientBuilder::new().base_url(&input).build()
152 }
153}
154
155impl VerifyClient for Client<reqwest::Client> {
156 async fn verify(&self) -> Result<(), VerifyError> {
157 Ok(())
159 }
160}
161
162impl EmbeddingsClient for Client<reqwest::Client> {
163 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
164
165 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
166 EmbeddingModel::new(self.clone(), &model.into(), 0)
167 }
168
169 fn embedding_model_with_ndims(
170 &self,
171 model: impl Into<String>,
172 ndims: usize,
173 ) -> Self::EmbeddingModel {
174 EmbeddingModel::new(self.clone(), &model.into(), ndims)
175 }
176}