Skip to main content

hessra_client/
lib.rs

1pub mod error;
2pub mod types;
3
4use error::ClientError;
5use hessra_token_core::PublicKey;
6use tokio::sync::OnceCell;
7use types::*;
8
9/// Parse a server address string into (host, port) components.
10///
11/// Handles various formats: IP:Port, hostname:port, IPv6 with brackets,
12/// URLs with protocol prefix and path.
13fn parse_server_address(address: &str) -> (String, Option<u16>) {
14    let address = address.trim();
15
16    let without_protocol = address
17        .strip_prefix("https://")
18        .or_else(|| address.strip_prefix("http://"))
19        .unwrap_or(address);
20
21    let host_port = without_protocol
22        .split('/')
23        .next()
24        .unwrap_or(without_protocol);
25
26    if host_port.starts_with('[') {
27        if let Some(bracket_end) = host_port.find(']') {
28            let host = &host_port[1..bracket_end];
29            let after_bracket = &host_port[bracket_end + 1..];
30
31            if let Some(port_str) = after_bracket.strip_prefix(':') {
32                if let Ok(port) = port_str.parse::<u16>() {
33                    return (host.to_string(), Some(port));
34                }
35            }
36            return (host.to_string(), None);
37        }
38        return (host_port.trim_start_matches('[').to_string(), None);
39    }
40
41    let colon_count = host_port.chars().filter(|c| *c == ':').count();
42
43    if colon_count == 1 {
44        let parts: Vec<&str> = host_port.splitn(2, ':').collect();
45        if parts.len() == 2 {
46            if let Ok(port) = parts[1].parse::<u16>() {
47                return (parts[0].to_string(), Some(port));
48            }
49        }
50    }
51
52    (host_port.to_string(), None)
53}
54
55/// Format a base URL with optional port for HTTPS requests.
56fn format_base_url(base_url: &str, port: Option<u16>) -> String {
57    let (host, embedded_port) = parse_server_address(base_url);
58    let resolved_port = port.or(embedded_port);
59    match resolved_port {
60        Some(p) => format!("https://{host}:{p}"),
61        None => format!("https://{host}"),
62    }
63}
64
65/// HTTP client for communicating with a Hessra authorization node.
66pub struct HessraClient {
67    client: reqwest::Client,
68    base_url: String,
69    public_key: OnceCell<PublicKey>,
70}
71
72impl HessraClient {
73    /// Create a new builder for constructing a client.
74    pub fn builder() -> HessraClientBuilder {
75        HessraClientBuilder::default()
76    }
77
78    /// Fetch and cache the server's public key (PEM format).
79    ///
80    /// The key is fetched once and cached for the lifetime of the client.
81    /// Returns the parsed `PublicKey` suitable for local token verification.
82    pub async fn fetch_public_key(&self) -> Result<PublicKey, ClientError> {
83        self.public_key
84            .get_or_try_init(|| async {
85                let url = format!("{}/public_key", self.base_url);
86                let response = self
87                    .client
88                    .get(&url)
89                    .send()
90                    .await
91                    .map_err(ClientError::Http)?;
92
93                if !response.status().is_success() {
94                    let status = response.status();
95                    let text = response.text().await.unwrap_or_default();
96                    return Err(ClientError::InvalidResponse(format!(
97                        "HTTP {status}: {text}"
98                    )));
99                }
100
101                let body: PublicKeyResponse = response.json().await.map_err(ClientError::Http)?;
102
103                PublicKey::from_pem(&body.public_key).map_err(|e| {
104                    ClientError::InvalidResponse(format!("Failed to parse public key PEM: {e}"))
105                })
106            })
107            .await
108            .copied()
109    }
110
111    /// Request a capability token (mTLS-authenticated).
112    pub async fn request_token(
113        &self,
114        request: &TokenRequest,
115    ) -> Result<TokenResponse, ClientError> {
116        self.post("request_token", request).await
117    }
118
119    /// Request a capability token using an identity token for authentication.
120    pub async fn request_token_with_identity(
121        &self,
122        request: &TokenRequest,
123        identity_token: &str,
124    ) -> Result<TokenResponse, ClientError> {
125        self.post_with_auth("request_token", request, identity_token)
126            .await
127    }
128
129    /// Verify a token remotely via the authorization service.
130    pub async fn verify_token(
131        &self,
132        request: &VerifyTokenRequest,
133    ) -> Result<VerifyTokenResponse, ClientError> {
134        self.post("verify_token", request).await
135    }
136
137    /// Request an identity token (mTLS-authenticated).
138    pub async fn request_identity_token(
139        &self,
140        request: &IdentityTokenRequest,
141    ) -> Result<IdentityTokenResponse, ClientError> {
142        self.post("request_identity_token", request).await
143    }
144
145    /// Refresh an existing identity token.
146    pub async fn refresh_identity_token(
147        &self,
148        request: &RefreshIdentityTokenRequest,
149    ) -> Result<IdentityTokenResponse, ClientError> {
150        self.post("refresh_identity_token", request).await
151    }
152
153    /// Health check.
154    pub async fn health(&self) -> Result<HealthResponse, ClientError> {
155        let url = format!("{}/health", self.base_url);
156        let response = self
157            .client
158            .get(&url)
159            .send()
160            .await
161            .map_err(ClientError::Http)?;
162
163        if !response.status().is_success() {
164            let status = response.status();
165            let text = response.text().await.unwrap_or_default();
166            return Err(ClientError::InvalidResponse(format!(
167                "HTTP {status}: {text}"
168            )));
169        }
170
171        response.json().await.map_err(ClientError::Http)
172    }
173
174    /// POST a JSON request body to an endpoint and deserialize the response.
175    async fn post<T: serde::Serialize, R: serde::de::DeserializeOwned>(
176        &self,
177        endpoint: &str,
178        body: &T,
179    ) -> Result<R, ClientError> {
180        let url = format!("{}/{endpoint}", self.base_url);
181        let response = self
182            .client
183            .post(&url)
184            .json(body)
185            .send()
186            .await
187            .map_err(ClientError::Http)?;
188
189        if !response.status().is_success() {
190            let status = response.status();
191            let text = response.text().await.unwrap_or_default();
192            return Err(ClientError::InvalidResponse(format!(
193                "HTTP {status}: {text}"
194            )));
195        }
196
197        response.json().await.map_err(ClientError::Http)
198    }
199
200    /// POST with a Bearer token in the Authorization header.
201    async fn post_with_auth<T: serde::Serialize, R: serde::de::DeserializeOwned>(
202        &self,
203        endpoint: &str,
204        body: &T,
205        bearer_token: &str,
206    ) -> Result<R, ClientError> {
207        let url = format!("{}/{endpoint}", self.base_url);
208        let response = self
209            .client
210            .post(&url)
211            .header("Authorization", format!("Bearer {bearer_token}"))
212            .json(body)
213            .send()
214            .await
215            .map_err(ClientError::Http)?;
216
217        if !response.status().is_success() {
218            let status = response.status();
219            let text = response.text().await.unwrap_or_default();
220            return Err(ClientError::InvalidResponse(format!(
221                "HTTP {status}: {text}"
222            )));
223        }
224
225        response.json().await.map_err(ClientError::Http)
226    }
227}
228
229/// Builder for constructing an `HessraClient`.
230#[derive(Default)]
231pub struct HessraClientBuilder {
232    base_url: String,
233    port: Option<u16>,
234    mtls_cert: Option<String>,
235    mtls_key: Option<String>,
236    server_ca: Option<String>,
237}
238
239impl HessraClientBuilder {
240    /// Set the base URL (e.g., "infra.hessra.net").
241    pub fn base_url(mut self, url: impl Into<String>) -> Self {
242        self.base_url = url.into();
243        self
244    }
245
246    /// Set the port (overrides any port embedded in the URL).
247    pub fn port(mut self, port: u16) -> Self {
248        self.port = Some(port);
249        self
250    }
251
252    /// Set the mTLS client certificate (PEM).
253    pub fn mtls_cert(mut self, cert: impl Into<String>) -> Self {
254        self.mtls_cert = Some(cert.into());
255        self
256    }
257
258    /// Set the mTLS client private key (PEM).
259    pub fn mtls_key(mut self, key: impl Into<String>) -> Self {
260        self.mtls_key = Some(key.into());
261        self
262    }
263
264    /// Set the server CA certificate (PEM).
265    pub fn server_ca(mut self, ca: impl Into<String>) -> Self {
266        self.server_ca = Some(ca.into());
267        self
268    }
269
270    /// Build the client.
271    pub fn build(self) -> Result<HessraClient, ClientError> {
272        let server_ca = self
273            .server_ca
274            .ok_or_else(|| ClientError::Config("server_ca is required".into()))?;
275
276        let certs = reqwest::Certificate::from_pem_bundle(server_ca.as_bytes()).map_err(|e| {
277            ClientError::TlsConfig(format!("Failed to parse CA certificate chain: {e}"))
278        })?;
279
280        let mut builder = reqwest::ClientBuilder::new();
281
282        for cert in certs {
283            builder = builder.add_root_certificate(cert);
284        }
285
286        if let (Some(cert), Some(key)) = (&self.mtls_cert, &self.mtls_key) {
287            let identity_pem = format!("{cert}{key}");
288            let identity = reqwest::Identity::from_pem(identity_pem.as_bytes()).map_err(|e| {
289                ClientError::TlsConfig(format!("Failed to create identity from cert and key: {e}"))
290            })?;
291            builder = builder.identity(identity);
292        }
293
294        let client = builder
295            .build()
296            .map_err(|e| ClientError::TlsConfig(e.to_string()))?;
297
298        let base_url = format_base_url(&self.base_url, self.port);
299
300        Ok(HessraClient {
301            client,
302            base_url,
303            public_key: OnceCell::new(),
304        })
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_parse_server_address_ip_with_port() {
314        let (host, port) = parse_server_address("127.0.0.1:4433");
315        assert_eq!(host, "127.0.0.1");
316        assert_eq!(port, Some(4433));
317    }
318
319    #[test]
320    fn test_parse_server_address_hostname_only() {
321        let (host, port) = parse_server_address("test.hessra.net");
322        assert_eq!(host, "test.hessra.net");
323        assert_eq!(port, None);
324    }
325
326    #[test]
327    fn test_parse_server_address_with_protocol() {
328        let (host, port) = parse_server_address("https://example.com:8443/path");
329        assert_eq!(host, "example.com");
330        assert_eq!(port, Some(8443));
331    }
332
333    #[test]
334    fn test_parse_server_address_ipv6() {
335        let (host, port) = parse_server_address("[::1]:8443");
336        assert_eq!(host, "::1");
337        assert_eq!(port, Some(8443));
338    }
339
340    #[test]
341    fn test_format_base_url() {
342        assert_eq!(
343            format_base_url("infra.hessra.net", None),
344            "https://infra.hessra.net"
345        );
346        assert_eq!(
347            format_base_url("infra.hessra.net", Some(443)),
348            "https://infra.hessra.net:443"
349        );
350        assert_eq!(
351            format_base_url("127.0.0.1:4433", Some(8080)),
352            "https://127.0.0.1:8080"
353        );
354        assert_eq!(
355            format_base_url("127.0.0.1:4433", None),
356            "https://127.0.0.1:4433"
357        );
358    }
359}