rig_tei/
client.rs

1use rig::client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use rig::http_client::{self};
3
4use super::TEI_DEFAULT_BASE_URL;
5use super::embedding::EmbeddingModel;
6
7/// Provider client: Client<T>
8/// Note: base_url is resolved into concrete endpoints during build, so we don't store base_url.
9#[derive(Clone, Debug)]
10pub struct Client<T = reqwest::Client> {
11    pub(crate) http_client: T,
12    pub(crate) endpoints: Endpoints,
13}
14
15/// Resolved endpoints for TEI features.
16#[derive(Clone, Debug)]
17pub struct Endpoints {
18    pub embed: String,
19    pub rerank: String,
20    pub predict: String,
21}
22
23impl Endpoints {
24    pub fn with_base(base_url: &str) -> Self {
25        let base = base_url.trim_end_matches('/');
26        Self {
27            embed: format!("{}/embed", base),
28            rerank: format!("{}/rerank", base),
29            predict: format!("{}/predict", base),
30        }
31    }
32}
33
34/// Client builder: ClientBuilder<'a, T>
35pub struct ClientBuilder<'a, T = reqwest::Client> {
36    base_url: &'a str,
37    http_client: T,
38    // Optional endpoint overrides
39    embed_endpoint: Option<&'a str>,
40    rerank_endpoint: Option<&'a str>,
41    predict_endpoint: Option<&'a str>,
42}
43
44impl<'a, T> ClientBuilder<'a, T>
45where
46    T: Default,
47{
48    pub fn new() -> Self {
49        Self {
50            base_url: TEI_DEFAULT_BASE_URL,
51            http_client: Default::default(),
52            embed_endpoint: None,
53            rerank_endpoint: None,
54            predict_endpoint: None,
55        }
56    }
57}
58
59impl<'a, T> ClientBuilder<'a, T> {
60    pub fn base_url(mut self, base_url: &'a str) -> Self {
61        self.base_url = base_url;
62        self
63    }
64
65    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
66        ClientBuilder {
67            base_url: self.base_url,
68            http_client,
69            embed_endpoint: self.embed_endpoint,
70            rerank_endpoint: self.rerank_endpoint,
71            predict_endpoint: self.predict_endpoint,
72        }
73    }
74
75    // Custom endpoint overrides
76    pub fn embed_endpoint(mut self, url: &'a str) -> Self {
77        self.embed_endpoint = Some(url);
78        self
79    }
80
81    pub fn rerank_endpoint(mut self, url: &'a str) -> Self {
82        self.rerank_endpoint = Some(url);
83        self
84    }
85
86    pub fn predict_endpoint(mut self, url: &'a str) -> Self {
87        self.predict_endpoint = Some(url);
88        self
89    }
90
91    pub fn build(self) -> Client<T> {
92        let mut endpoints = Endpoints::with_base(self.base_url);
93        if let Some(url) = self.embed_endpoint {
94            endpoints.embed = url.to_string();
95        }
96        if let Some(url) = self.rerank_endpoint {
97            endpoints.rerank = url.to_string();
98        }
99        if let Some(url) = self.predict_endpoint {
100            endpoints.predict = url.to_string();
101        }
102
103        Client {
104            http_client: self.http_client,
105            endpoints,
106        }
107    }
108}
109
110impl<T> Default for Client<T>
111where
112    T: Default,
113{
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl<T> Client<T>
120where
121    T: Default,
122{
123    pub fn builder<'a>() -> ClientBuilder<'a, T> {
124        ClientBuilder::new()
125    }
126
127    pub fn new() -> Self {
128        Self::builder().build()
129    }
130}
131
132// Build a POST request using a full URL (used when endpoints are overridden).
133impl<T> Client<T> {
134    pub(crate) fn post_full(&self, url: &str) -> http_client::Builder {
135        http_client::Builder::new()
136            .method(http_client::Method::POST)
137            .uri(url.to_string())
138    }
139}
140
141impl ProviderClient for Client<reqwest::Client> {
142    type Input = String;
143
144    fn from_env() -> Self {
145        let base_url =
146            std::env::var("TEI_BASE_URL").unwrap_or_else(|_| TEI_DEFAULT_BASE_URL.to_string());
147        Self::builder().base_url(&base_url).build()
148    }
149
150    fn from_val(input: String) -> Self {
151        ClientBuilder::new().base_url(&input).build()
152    }
153}
154
155impl VerifyClient for Client<reqwest::Client> {
156    async fn verify(&self) -> Result<(), VerifyError> {
157        // TEI local router often has no auth and no health endpoint needed.
158        Ok(())
159    }
160}
161
162impl EmbeddingsClient for Client<reqwest::Client> {
163    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
164
165    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
166        EmbeddingModel::new(self.clone(), &model.into(), 0)
167    }
168
169    fn embedding_model_with_ndims(
170        &self,
171        model: impl Into<String>,
172        ndims: usize,
173    ) -> Self::EmbeddingModel {
174        EmbeddingModel::new(self.clone(), &model.into(), ndims)
175    }
176}