1pub mod error;
2pub mod types;
3
4use error::ClientError;
5use hessra_token_core::PublicKey;
6use tokio::sync::OnceCell;
7use types::*;
8
9fn parse_server_address(address: &str) -> (String, Option<u16>) {
14 let address = address.trim();
15
16 let without_protocol = address
17 .strip_prefix("https://")
18 .or_else(|| address.strip_prefix("http://"))
19 .unwrap_or(address);
20
21 let host_port = without_protocol
22 .split('/')
23 .next()
24 .unwrap_or(without_protocol);
25
26 if host_port.starts_with('[') {
27 if let Some(bracket_end) = host_port.find(']') {
28 let host = &host_port[1..bracket_end];
29 let after_bracket = &host_port[bracket_end + 1..];
30
31 if let Some(port_str) = after_bracket.strip_prefix(':') {
32 if let Ok(port) = port_str.parse::<u16>() {
33 return (host.to_string(), Some(port));
34 }
35 }
36 return (host.to_string(), None);
37 }
38 return (host_port.trim_start_matches('[').to_string(), None);
39 }
40
41 let colon_count = host_port.chars().filter(|c| *c == ':').count();
42
43 if colon_count == 1 {
44 let parts: Vec<&str> = host_port.splitn(2, ':').collect();
45 if parts.len() == 2 {
46 if let Ok(port) = parts[1].parse::<u16>() {
47 return (parts[0].to_string(), Some(port));
48 }
49 }
50 }
51
52 (host_port.to_string(), None)
53}
54
55fn format_base_url(base_url: &str, port: Option<u16>) -> String {
57 let (host, embedded_port) = parse_server_address(base_url);
58 let resolved_port = port.or(embedded_port);
59 match resolved_port {
60 Some(p) => format!("https://{host}:{p}"),
61 None => format!("https://{host}"),
62 }
63}
64
65pub struct HessraClient {
67 client: reqwest::Client,
68 base_url: String,
69 public_key: OnceCell<PublicKey>,
70}
71
72impl HessraClient {
73 pub fn builder() -> HessraClientBuilder {
75 HessraClientBuilder::default()
76 }
77
78 pub async fn fetch_public_key(&self) -> Result<PublicKey, ClientError> {
83 self.public_key
84 .get_or_try_init(|| async {
85 let url = format!("{}/public_key", self.base_url);
86 let response = self
87 .client
88 .get(&url)
89 .send()
90 .await
91 .map_err(ClientError::Http)?;
92
93 if !response.status().is_success() {
94 let status = response.status();
95 let text = response.text().await.unwrap_or_default();
96 return Err(ClientError::InvalidResponse(format!(
97 "HTTP {status}: {text}"
98 )));
99 }
100
101 let body: PublicKeyResponse = response.json().await.map_err(ClientError::Http)?;
102
103 PublicKey::from_pem(&body.public_key).map_err(|e| {
104 ClientError::InvalidResponse(format!("Failed to parse public key PEM: {e}"))
105 })
106 })
107 .await
108 .copied()
109 }
110
111 pub async fn request_token(&self, request: &TokenRequest) -> Result<TokenResponse, ClientError> {
113 self.post("request_token", request).await
114 }
115
116 pub async fn request_token_with_identity(
118 &self,
119 request: &TokenRequest,
120 identity_token: &str,
121 ) -> Result<TokenResponse, ClientError> {
122 self.post_with_auth("request_token", request, identity_token)
123 .await
124 }
125
126 pub async fn verify_token(
128 &self,
129 request: &VerifyTokenRequest,
130 ) -> Result<VerifyTokenResponse, ClientError> {
131 self.post("verify_token", request).await
132 }
133
134 pub async fn mint_identity_token(
136 &self,
137 request: &MintIdentityTokenRequest,
138 ) -> Result<MintIdentityTokenResponse, ClientError> {
139 self.post("mint_identity_token", request).await
140 }
141
142 pub async fn request_identity_token(
144 &self,
145 request: &IdentityTokenRequest,
146 ) -> Result<IdentityTokenResponse, ClientError> {
147 self.post("request_identity_token", request).await
148 }
149
150 pub async fn refresh_identity_token(
152 &self,
153 request: &RefreshIdentityTokenRequest,
154 ) -> Result<IdentityTokenResponse, ClientError> {
155 self.post("refresh_identity_token", request).await
156 }
157
158 pub async fn health(&self) -> Result<HealthResponse, ClientError> {
160 let url = format!("{}/health", self.base_url);
161 let response = self
162 .client
163 .get(&url)
164 .send()
165 .await
166 .map_err(ClientError::Http)?;
167
168 if !response.status().is_success() {
169 let status = response.status();
170 let text = response.text().await.unwrap_or_default();
171 return Err(ClientError::InvalidResponse(format!(
172 "HTTP {status}: {text}"
173 )));
174 }
175
176 response.json().await.map_err(ClientError::Http)
177 }
178
179 async fn post<T: serde::Serialize, R: serde::de::DeserializeOwned>(
181 &self,
182 endpoint: &str,
183 body: &T,
184 ) -> Result<R, ClientError> {
185 let url = format!("{}/{endpoint}", self.base_url);
186 let response = self
187 .client
188 .post(&url)
189 .json(body)
190 .send()
191 .await
192 .map_err(ClientError::Http)?;
193
194 if !response.status().is_success() {
195 let status = response.status();
196 let text = response.text().await.unwrap_or_default();
197 return Err(ClientError::InvalidResponse(format!(
198 "HTTP {status}: {text}"
199 )));
200 }
201
202 response.json().await.map_err(ClientError::Http)
203 }
204
205 async fn post_with_auth<T: serde::Serialize, R: serde::de::DeserializeOwned>(
207 &self,
208 endpoint: &str,
209 body: &T,
210 bearer_token: &str,
211 ) -> Result<R, ClientError> {
212 let url = format!("{}/{endpoint}", self.base_url);
213 let response = self
214 .client
215 .post(&url)
216 .header("Authorization", format!("Bearer {bearer_token}"))
217 .json(body)
218 .send()
219 .await
220 .map_err(ClientError::Http)?;
221
222 if !response.status().is_success() {
223 let status = response.status();
224 let text = response.text().await.unwrap_or_default();
225 return Err(ClientError::InvalidResponse(format!(
226 "HTTP {status}: {text}"
227 )));
228 }
229
230 response.json().await.map_err(ClientError::Http)
231 }
232}
233
234#[derive(Default)]
236pub struct HessraClientBuilder {
237 base_url: String,
238 port: Option<u16>,
239 mtls_cert: Option<String>,
240 mtls_key: Option<String>,
241 server_ca: Option<String>,
242}
243
244impl HessraClientBuilder {
245 pub fn base_url(mut self, url: impl Into<String>) -> Self {
247 self.base_url = url.into();
248 self
249 }
250
251 pub fn port(mut self, port: u16) -> Self {
253 self.port = Some(port);
254 self
255 }
256
257 pub fn mtls_cert(mut self, cert: impl Into<String>) -> Self {
259 self.mtls_cert = Some(cert.into());
260 self
261 }
262
263 pub fn mtls_key(mut self, key: impl Into<String>) -> Self {
265 self.mtls_key = Some(key.into());
266 self
267 }
268
269 pub fn server_ca(mut self, ca: impl Into<String>) -> Self {
271 self.server_ca = Some(ca.into());
272 self
273 }
274
275 pub fn build(self) -> Result<HessraClient, ClientError> {
277 let server_ca = self
278 .server_ca
279 .ok_or_else(|| ClientError::Config("server_ca is required".into()))?;
280
281 let certs = reqwest::Certificate::from_pem_bundle(server_ca.as_bytes()).map_err(|e| {
282 ClientError::TlsConfig(format!("Failed to parse CA certificate chain: {e}"))
283 })?;
284
285 let mut builder = reqwest::ClientBuilder::new();
286
287 for cert in certs {
288 builder = builder.add_root_certificate(cert);
289 }
290
291 if let (Some(cert), Some(key)) = (&self.mtls_cert, &self.mtls_key) {
292 let identity_pem = format!("{cert}{key}");
293 let identity =
294 reqwest::Identity::from_pem(identity_pem.as_bytes()).map_err(|e| {
295 ClientError::TlsConfig(format!(
296 "Failed to create identity from cert and key: {e}"
297 ))
298 })?;
299 builder = builder.identity(identity);
300 }
301
302 let client = builder
303 .build()
304 .map_err(|e| ClientError::TlsConfig(e.to_string()))?;
305
306 let base_url = format_base_url(&self.base_url, self.port);
307
308 Ok(HessraClient {
309 client,
310 base_url,
311 public_key: OnceCell::new(),
312 })
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_parse_server_address_ip_with_port() {
322 let (host, port) = parse_server_address("127.0.0.1:4433");
323 assert_eq!(host, "127.0.0.1");
324 assert_eq!(port, Some(4433));
325 }
326
327 #[test]
328 fn test_parse_server_address_hostname_only() {
329 let (host, port) = parse_server_address("test.hessra.net");
330 assert_eq!(host, "test.hessra.net");
331 assert_eq!(port, None);
332 }
333
334 #[test]
335 fn test_parse_server_address_with_protocol() {
336 let (host, port) = parse_server_address("https://example.com:8443/path");
337 assert_eq!(host, "example.com");
338 assert_eq!(port, Some(8443));
339 }
340
341 #[test]
342 fn test_parse_server_address_ipv6() {
343 let (host, port) = parse_server_address("[::1]:8443");
344 assert_eq!(host, "::1");
345 assert_eq!(port, Some(8443));
346 }
347
348 #[test]
349 fn test_format_base_url() {
350 assert_eq!(
351 format_base_url("infra.hessra.net", None),
352 "https://infra.hessra.net"
353 );
354 assert_eq!(
355 format_base_url("infra.hessra.net", Some(443)),
356 "https://infra.hessra.net:443"
357 );
358 assert_eq!(
359 format_base_url("127.0.0.1:4433", Some(8080)),
360 "https://127.0.0.1:8080"
361 );
362 assert_eq!(
363 format_base_url("127.0.0.1:4433", None),
364 "https://127.0.0.1:4433"
365 );
366 }
367}