1pub mod error;
2pub mod types;
3
4use error::ClientError;
5use hessra_token_core::PublicKey;
6use tokio::sync::OnceCell;
7use types::*;
8
9fn parse_server_address(address: &str) -> (String, Option<u16>) {
14 let address = address.trim();
15
16 let without_protocol = address
17 .strip_prefix("https://")
18 .or_else(|| address.strip_prefix("http://"))
19 .unwrap_or(address);
20
21 let host_port = without_protocol
22 .split('/')
23 .next()
24 .unwrap_or(without_protocol);
25
26 if host_port.starts_with('[') {
27 if let Some(bracket_end) = host_port.find(']') {
28 let host = &host_port[1..bracket_end];
29 let after_bracket = &host_port[bracket_end + 1..];
30
31 if let Some(port_str) = after_bracket.strip_prefix(':') {
32 if let Ok(port) = port_str.parse::<u16>() {
33 return (host.to_string(), Some(port));
34 }
35 }
36 return (host.to_string(), None);
37 }
38 return (host_port.trim_start_matches('[').to_string(), None);
39 }
40
41 let colon_count = host_port.chars().filter(|c| *c == ':').count();
42
43 if colon_count == 1 {
44 let parts: Vec<&str> = host_port.splitn(2, ':').collect();
45 if parts.len() == 2 {
46 if let Ok(port) = parts[1].parse::<u16>() {
47 return (parts[0].to_string(), Some(port));
48 }
49 }
50 }
51
52 (host_port.to_string(), None)
53}
54
55fn format_base_url(base_url: &str, port: Option<u16>) -> String {
57 let (host, embedded_port) = parse_server_address(base_url);
58 let resolved_port = port.or(embedded_port);
59 match resolved_port {
60 Some(p) => format!("https://{host}:{p}"),
61 None => format!("https://{host}"),
62 }
63}
64
65pub struct HessraClient {
67 client: reqwest::Client,
68 base_url: String,
69 public_key: OnceCell<PublicKey>,
70}
71
72impl HessraClient {
73 pub fn builder() -> HessraClientBuilder {
75 HessraClientBuilder::default()
76 }
77
78 pub async fn fetch_public_key(&self) -> Result<PublicKey, ClientError> {
83 self.public_key
84 .get_or_try_init(|| async {
85 let url = format!("{}/public_key", self.base_url);
86 let response = self
87 .client
88 .get(&url)
89 .send()
90 .await
91 .map_err(ClientError::Http)?;
92
93 if !response.status().is_success() {
94 let status = response.status();
95 let text = response.text().await.unwrap_or_default();
96 return Err(ClientError::InvalidResponse(format!(
97 "HTTP {status}: {text}"
98 )));
99 }
100
101 let body: PublicKeyResponse = response.json().await.map_err(ClientError::Http)?;
102
103 PublicKey::from_pem(&body.public_key).map_err(|e| {
104 ClientError::InvalidResponse(format!("Failed to parse public key PEM: {e}"))
105 })
106 })
107 .await
108 .copied()
109 }
110
111 pub async fn request_token(
113 &self,
114 request: &TokenRequest,
115 ) -> Result<TokenResponse, ClientError> {
116 self.post("request_token", request).await
117 }
118
119 pub async fn request_token_with_identity(
121 &self,
122 request: &TokenRequest,
123 identity_token: &str,
124 ) -> Result<TokenResponse, ClientError> {
125 self.post_with_auth("request_token", request, identity_token)
126 .await
127 }
128
129 pub async fn verify_token(
131 &self,
132 request: &VerifyTokenRequest,
133 ) -> Result<VerifyTokenResponse, ClientError> {
134 self.post("verify_token", request).await
135 }
136
137 pub async fn request_identity_token(
139 &self,
140 request: &IdentityTokenRequest,
141 ) -> Result<IdentityTokenResponse, ClientError> {
142 self.post("request_identity_token", request).await
143 }
144
145 pub async fn refresh_identity_token(
147 &self,
148 request: &RefreshIdentityTokenRequest,
149 ) -> Result<IdentityTokenResponse, ClientError> {
150 self.post("refresh_identity_token", request).await
151 }
152
153 pub async fn health(&self) -> Result<HealthResponse, ClientError> {
155 let url = format!("{}/health", self.base_url);
156 let response = self
157 .client
158 .get(&url)
159 .send()
160 .await
161 .map_err(ClientError::Http)?;
162
163 if !response.status().is_success() {
164 let status = response.status();
165 let text = response.text().await.unwrap_or_default();
166 return Err(ClientError::InvalidResponse(format!(
167 "HTTP {status}: {text}"
168 )));
169 }
170
171 response.json().await.map_err(ClientError::Http)
172 }
173
174 async fn post<T: serde::Serialize, R: serde::de::DeserializeOwned>(
176 &self,
177 endpoint: &str,
178 body: &T,
179 ) -> Result<R, ClientError> {
180 let url = format!("{}/{endpoint}", self.base_url);
181 let response = self
182 .client
183 .post(&url)
184 .json(body)
185 .send()
186 .await
187 .map_err(ClientError::Http)?;
188
189 if !response.status().is_success() {
190 let status = response.status();
191 let text = response.text().await.unwrap_or_default();
192 return Err(ClientError::InvalidResponse(format!(
193 "HTTP {status}: {text}"
194 )));
195 }
196
197 response.json().await.map_err(ClientError::Http)
198 }
199
200 async fn post_with_auth<T: serde::Serialize, R: serde::de::DeserializeOwned>(
202 &self,
203 endpoint: &str,
204 body: &T,
205 bearer_token: &str,
206 ) -> Result<R, ClientError> {
207 let url = format!("{}/{endpoint}", self.base_url);
208 let response = self
209 .client
210 .post(&url)
211 .header("Authorization", format!("Bearer {bearer_token}"))
212 .json(body)
213 .send()
214 .await
215 .map_err(ClientError::Http)?;
216
217 if !response.status().is_success() {
218 let status = response.status();
219 let text = response.text().await.unwrap_or_default();
220 return Err(ClientError::InvalidResponse(format!(
221 "HTTP {status}: {text}"
222 )));
223 }
224
225 response.json().await.map_err(ClientError::Http)
226 }
227}
228
229#[derive(Default)]
231pub struct HessraClientBuilder {
232 base_url: String,
233 port: Option<u16>,
234 mtls_cert: Option<String>,
235 mtls_key: Option<String>,
236 server_ca: Option<String>,
237}
238
239impl HessraClientBuilder {
240 pub fn base_url(mut self, url: impl Into<String>) -> Self {
242 self.base_url = url.into();
243 self
244 }
245
246 pub fn port(mut self, port: u16) -> Self {
248 self.port = Some(port);
249 self
250 }
251
252 pub fn mtls_cert(mut self, cert: impl Into<String>) -> Self {
254 self.mtls_cert = Some(cert.into());
255 self
256 }
257
258 pub fn mtls_key(mut self, key: impl Into<String>) -> Self {
260 self.mtls_key = Some(key.into());
261 self
262 }
263
264 pub fn server_ca(mut self, ca: impl Into<String>) -> Self {
266 self.server_ca = Some(ca.into());
267 self
268 }
269
270 pub fn build(self) -> Result<HessraClient, ClientError> {
272 let server_ca = self
273 .server_ca
274 .ok_or_else(|| ClientError::Config("server_ca is required".into()))?;
275
276 let certs = reqwest::Certificate::from_pem_bundle(server_ca.as_bytes()).map_err(|e| {
277 ClientError::TlsConfig(format!("Failed to parse CA certificate chain: {e}"))
278 })?;
279
280 let mut builder = reqwest::ClientBuilder::new();
281
282 for cert in certs {
283 builder = builder.add_root_certificate(cert);
284 }
285
286 if let (Some(cert), Some(key)) = (&self.mtls_cert, &self.mtls_key) {
287 let identity_pem = format!("{cert}{key}");
288 let identity = reqwest::Identity::from_pem(identity_pem.as_bytes()).map_err(|e| {
289 ClientError::TlsConfig(format!("Failed to create identity from cert and key: {e}"))
290 })?;
291 builder = builder.identity(identity);
292 }
293
294 let client = builder
295 .build()
296 .map_err(|e| ClientError::TlsConfig(e.to_string()))?;
297
298 let base_url = format_base_url(&self.base_url, self.port);
299
300 Ok(HessraClient {
301 client,
302 base_url,
303 public_key: OnceCell::new(),
304 })
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_parse_server_address_ip_with_port() {
314 let (host, port) = parse_server_address("127.0.0.1:4433");
315 assert_eq!(host, "127.0.0.1");
316 assert_eq!(port, Some(4433));
317 }
318
319 #[test]
320 fn test_parse_server_address_hostname_only() {
321 let (host, port) = parse_server_address("test.hessra.net");
322 assert_eq!(host, "test.hessra.net");
323 assert_eq!(port, None);
324 }
325
326 #[test]
327 fn test_parse_server_address_with_protocol() {
328 let (host, port) = parse_server_address("https://example.com:8443/path");
329 assert_eq!(host, "example.com");
330 assert_eq!(port, Some(8443));
331 }
332
333 #[test]
334 fn test_parse_server_address_ipv6() {
335 let (host, port) = parse_server_address("[::1]:8443");
336 assert_eq!(host, "::1");
337 assert_eq!(port, Some(8443));
338 }
339
340 #[test]
341 fn test_format_base_url() {
342 assert_eq!(
343 format_base_url("infra.hessra.net", None),
344 "https://infra.hessra.net"
345 );
346 assert_eq!(
347 format_base_url("infra.hessra.net", Some(443)),
348 "https://infra.hessra.net:443"
349 );
350 assert_eq!(
351 format_base_url("127.0.0.1:4433", Some(8080)),
352 "https://127.0.0.1:8080"
353 );
354 assert_eq!(
355 format_base_url("127.0.0.1:4433", None),
356 "https://127.0.0.1:4433"
357 );
358 }
359}