stakpak_shared/
cert_utils.rs

1use anyhow::Result;
2use rcgen::{
3    BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose, SanType,
4};
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6use rustls::{ClientConfig, RootCertStore, ServerConfig};
7use std::sync::Arc;
8use time::OffsetDateTime;
9
10pub struct CertificateChain {
11    pub ca_cert: rcgen::Certificate,
12    pub server_cert: rcgen::Certificate,
13    pub client_cert: rcgen::Certificate,
14}
15
16impl CertificateChain {
17    pub fn generate() -> Result<Self> {
18        // Generate CA certificate
19        let mut ca_params = CertificateParams::default();
20        ca_params.distinguished_name = DistinguishedName::new();
21        ca_params
22            .distinguished_name
23            .push(DnType::CommonName, "Stakpak MCP CA");
24        ca_params
25            .distinguished_name
26            .push(DnType::OrganizationName, "Stakpak");
27        ca_params.distinguished_name.push(DnType::CountryName, "US");
28
29        ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
30        ca_params.key_usages = vec![
31            KeyUsagePurpose::KeyCertSign,
32            KeyUsagePurpose::CrlSign,
33            KeyUsagePurpose::DigitalSignature,
34        ];
35
36        ca_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
37        ca_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
38
39        let ca_cert = rcgen::Certificate::from_params(ca_params)?;
40
41        // Generate server certificate
42        let mut server_params = CertificateParams::default();
43        server_params.distinguished_name = DistinguishedName::new();
44        server_params
45            .distinguished_name
46            .push(DnType::CommonName, "Stakpak MCP Server");
47        server_params
48            .distinguished_name
49            .push(DnType::OrganizationName, "Stakpak");
50        server_params
51            .distinguished_name
52            .push(DnType::CountryName, "US");
53
54        server_params.subject_alt_names = vec![
55            SanType::DnsName("localhost".to_string()),
56            SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0))),
57            SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
58        ];
59
60        server_params.key_usages = vec![
61            KeyUsagePurpose::DigitalSignature,
62            KeyUsagePurpose::KeyEncipherment,
63        ];
64
65        server_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
66        server_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
67
68        let server_cert = rcgen::Certificate::from_params(server_params)?;
69
70        // Generate client certificate
71        let mut client_params = CertificateParams::default();
72        client_params.distinguished_name = DistinguishedName::new();
73        client_params
74            .distinguished_name
75            .push(DnType::CommonName, "Stakpak MCP Client");
76        client_params
77            .distinguished_name
78            .push(DnType::OrganizationName, "Stakpak");
79        client_params
80            .distinguished_name
81            .push(DnType::CountryName, "US");
82
83        client_params.key_usages = vec![
84            KeyUsagePurpose::DigitalSignature,
85            KeyUsagePurpose::KeyEncipherment,
86        ];
87
88        client_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
89        client_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
90
91        let client_cert = rcgen::Certificate::from_params(client_params)?;
92
93        Ok(CertificateChain {
94            ca_cert,
95            server_cert,
96            client_cert,
97        })
98    }
99
100    pub fn create_server_config(&self) -> Result<ServerConfig> {
101        // Sign server certificate with CA
102        let server_cert_der = self.server_cert.serialize_der_with_signer(&self.ca_cert)?;
103        let server_key_der = self.server_cert.serialize_private_key_der();
104
105        let server_cert_chain = vec![CertificateDer::from(server_cert_der)];
106        let server_private_key = PrivateKeyDer::try_from(server_key_der)
107            .map_err(|e| anyhow::anyhow!("Failed to convert server private key: {:?}", e))?;
108
109        // Set up root certificate store to trust our CA (for client cert validation)
110        let mut root_cert_store = RootCertStore::empty();
111        let ca_cert_der = self.ca_cert.serialize_der()?;
112        root_cert_store.add(CertificateDer::from(ca_cert_der))?;
113
114        // Create client certificate verifier that requires client certificates
115        let client_cert_verifier =
116            rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
117                .build()
118                .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
119
120        let config = ServerConfig::builder()
121            .with_client_cert_verifier(client_cert_verifier)
122            .with_single_cert(server_cert_chain, server_private_key)?;
123
124        Ok(config)
125    }
126
127    pub fn create_client_config(&self) -> Result<ClientConfig> {
128        // Sign client certificate with CA
129        let client_cert_der = self.client_cert.serialize_der_with_signer(&self.ca_cert)?;
130        let client_key_der = self.client_cert.serialize_private_key_der();
131
132        let client_cert_chain = vec![CertificateDer::from(client_cert_der)];
133        let client_private_key = PrivateKeyDer::try_from(client_key_der)
134            .map_err(|e| anyhow::anyhow!("Failed to convert client private key: {:?}", e))?;
135
136        // Set up root certificate store to trust our CA (for server cert validation)
137        let mut root_cert_store = RootCertStore::empty();
138        let ca_cert_der = self.ca_cert.serialize_der()?;
139        root_cert_store.add(CertificateDer::from(ca_cert_der))?;
140
141        let config = ClientConfig::builder()
142            .with_root_certificates(root_cert_store)
143            .with_client_auth_cert(client_cert_chain, client_private_key)?;
144
145        Ok(config)
146    }
147
148    pub fn get_ca_cert_pem(&self) -> Result<String> {
149        Ok(self.ca_cert.serialize_pem()?)
150    }
151
152    pub fn get_server_cert_pem(&self) -> Result<String> {
153        Ok(self.server_cert.serialize_pem_with_signer(&self.ca_cert)?)
154    }
155
156    pub fn get_client_cert_pem(&self) -> Result<String> {
157        Ok(self.client_cert.serialize_pem_with_signer(&self.ca_cert)?)
158    }
159
160    pub fn get_server_key_pem(&self) -> Result<String> {
161        Ok(self.server_cert.serialize_private_key_pem())
162    }
163
164    pub fn get_client_key_pem(&self) -> Result<String> {
165        Ok(self.client_cert.serialize_private_key_pem())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use axum::{Router, response::Json, routing::get};
173    use axum_server::tls_rustls::RustlsConfig;
174    use reqwest::Client;
175    use serde_json::json;
176    use std::sync::Arc;
177    use tokio::net::TcpListener;
178    use tokio::time::{Duration, timeout};
179
180    fn init_crypto_provider() {
181        use std::sync::Once;
182        static INIT: Once = Once::new();
183        INIT.call_once(|| {
184            rustls::crypto::aws_lc_rs::default_provider()
185                .install_default()
186                .expect("Failed to install crypto provider");
187        });
188    }
189
190    #[tokio::test]
191    async fn test_mtls_handshake_success() {
192        init_crypto_provider();
193        // Generate certificate chain
194        let cert_chain =
195            CertificateChain::generate().expect("Failed to generate certificate chain");
196
197        // Create server config
198        let server_config = cert_chain
199            .create_server_config()
200            .expect("Failed to create server config");
201
202        // Create client config
203        let client_config = cert_chain
204            .create_client_config()
205            .expect("Failed to create client config");
206
207        // Create a simple axum app
208        let app = Router::new().route(
209            "/test",
210            get(|| async { Json(json!({"status": "success"})) }),
211        );
212
213        // Start server with mTLS
214        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
215
216        // Use a fixed port for testing
217        let test_port = 8443;
218        let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
219
220        let server_handle = tokio::spawn(async move {
221            axum_server::bind_rustls(server_addr, rustls_config)
222                .serve(app.into_make_service())
223                .await
224        });
225
226        // Give server time to start
227        tokio::time::sleep(Duration::from_millis(500)).await;
228
229        // Create reqwest client with mTLS config
230        let client = Client::builder()
231            .use_preconfigured_tls(client_config)
232            .build()
233            .expect("Failed to build client");
234
235        // Test successful mTLS connection
236        let url = format!("https://127.0.0.1:{}/test", test_port);
237        println!("Testing mTLS connection to: {}", url);
238
239        let response = timeout(Duration::from_secs(10), client.get(&url).send())
240            .await
241            .expect("Request timed out")
242            .expect("Failed to send request");
243
244        assert!(
245            response.status().is_success(),
246            "Request should succeed with valid mTLS"
247        );
248
249        let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
250        assert_eq!(body["status"], "success");
251
252        // Shutdown server
253        server_handle.abort();
254    }
255
256    #[tokio::test]
257    async fn test_mtls_handshake_failure_no_client_cert() {
258        init_crypto_provider();
259        // Generate certificate chain
260        let cert_chain =
261            CertificateChain::generate().expect("Failed to generate certificate chain");
262
263        // Create server config (requires client certs)
264        let server_config = cert_chain
265            .create_server_config()
266            .expect("Failed to create server config");
267
268        // Create a simple axum app
269        let app = Router::new().route(
270            "/test",
271            get(|| async { Json(json!({"status": "success"})) }),
272        );
273
274        // Start server with mTLS
275        let listener = TcpListener::bind("127.0.0.1:0")
276            .await
277            .expect("Failed to bind listener");
278        let server_addr = listener.local_addr().expect("Failed to get local address");
279        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
280
281        let server_handle = tokio::spawn(async move {
282            axum_server::bind_rustls(server_addr, rustls_config)
283                .serve(app.into_make_service())
284                .await
285        });
286
287        // Give server time to start
288        tokio::time::sleep(Duration::from_millis(100)).await;
289
290        // Create reqwest client without client certificates (should fail)
291        let client = Client::builder()
292            .danger_accept_invalid_certs(true) // Accept self-signed certs but still no client cert
293            .build()
294            .expect("Failed to build client");
295
296        // Test that connection fails without client certificate
297        let result = timeout(
298            Duration::from_secs(5),
299            client
300                .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
301                .send(),
302        )
303        .await;
304
305        // Should fail because no client certificate is provided
306        assert!(
307            result.is_err() || result.unwrap().is_err(),
308            "Request should fail without client certificate"
309        );
310
311        // Shutdown server
312        server_handle.abort();
313    }
314
315    #[tokio::test]
316    async fn test_mtls_handshake_failure_wrong_ca() {
317        init_crypto_provider();
318        // Generate two separate certificate chains
319        let cert_chain1 =
320            CertificateChain::generate().expect("Failed to generate certificate chain 1");
321        let cert_chain2 =
322            CertificateChain::generate().expect("Failed to generate certificate chain 2");
323
324        // Create server config with first cert chain
325        let server_config = cert_chain1
326            .create_server_config()
327            .expect("Failed to create server config");
328
329        // Create client config with second cert chain (different CA)
330        let client_config = cert_chain2
331            .create_client_config()
332            .expect("Failed to create client config");
333
334        // Create a simple axum app
335        let app = Router::new().route(
336            "/test",
337            get(|| async { Json(json!({"status": "success"})) }),
338        );
339
340        // Start server with mTLS
341        let listener = TcpListener::bind("127.0.0.1:0")
342            .await
343            .expect("Failed to bind listener");
344        let server_addr = listener.local_addr().expect("Failed to get local address");
345        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
346
347        let server_handle = tokio::spawn(async move {
348            axum_server::bind_rustls(server_addr, rustls_config)
349                .serve(app.into_make_service())
350                .await
351        });
352
353        // Give server time to start
354        tokio::time::sleep(Duration::from_millis(100)).await;
355
356        // Create reqwest client with wrong CA certificates
357        let client = Client::builder()
358            .use_preconfigured_tls(client_config)
359            .build()
360            .expect("Failed to build client");
361
362        // Test that connection fails with wrong CA
363        let result = timeout(
364            Duration::from_secs(5),
365            client
366                .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
367                .send(),
368        )
369        .await;
370
371        // Should fail because client and server have different CAs
372        assert!(
373            result.is_err() || result.unwrap().is_err(),
374            "Request should fail with wrong CA certificates"
375        );
376
377        // Shutdown server
378        server_handle.abort();
379    }
380
381    #[tokio::test]
382    async fn test_certificate_chain_generation() {
383        init_crypto_provider();
384        let cert_chain =
385            CertificateChain::generate().expect("Failed to generate certificate chain");
386
387        // Test that we can get PEM representations
388        let ca_pem = cert_chain.get_ca_cert_pem().expect("Failed to get CA PEM");
389        let server_pem = cert_chain
390            .get_server_cert_pem()
391            .expect("Failed to get server PEM");
392        let client_pem = cert_chain
393            .get_client_cert_pem()
394            .expect("Failed to get client PEM");
395        let server_key_pem = cert_chain
396            .get_server_key_pem()
397            .expect("Failed to get server key PEM");
398        let client_key_pem = cert_chain
399            .get_client_key_pem()
400            .expect("Failed to get client key PEM");
401
402        // Verify PEM format
403        assert!(ca_pem.contains("-----BEGIN CERTIFICATE-----"));
404        assert!(ca_pem.contains("-----END CERTIFICATE-----"));
405        assert!(server_pem.contains("-----BEGIN CERTIFICATE-----"));
406        assert!(server_pem.contains("-----END CERTIFICATE-----"));
407        assert!(client_pem.contains("-----BEGIN CERTIFICATE-----"));
408        assert!(client_pem.contains("-----END CERTIFICATE-----"));
409        assert!(server_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
410        assert!(server_key_pem.contains("-----END PRIVATE KEY-----"));
411        assert!(client_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
412        assert!(client_key_pem.contains("-----END PRIVATE KEY-----"));
413    }
414
415    #[tokio::test]
416    async fn test_server_config_creation() {
417        init_crypto_provider();
418        let cert_chain =
419            CertificateChain::generate().expect("Failed to generate certificate chain");
420        let _server_config = cert_chain
421            .create_server_config()
422            .expect("Failed to create server config");
423
424        // Verify server config is created successfully
425        // The fact that it doesn't panic/error is the main test
426        assert!(true, "Server config created successfully");
427    }
428
429    #[tokio::test]
430    async fn test_client_config_creation() {
431        init_crypto_provider();
432        let cert_chain =
433            CertificateChain::generate().expect("Failed to generate certificate chain");
434        let _client_config = cert_chain
435            .create_client_config()
436            .expect("Failed to create client config");
437
438        // Verify client config is created successfully
439        // The fact that it doesn't panic/error is the main test
440        assert!(true, "Client config created successfully");
441    }
442
443    #[tokio::test]
444    async fn test_mtls_multiple_requests() {
445        init_crypto_provider();
446        // Generate certificate chain
447        let cert_chain =
448            CertificateChain::generate().expect("Failed to generate certificate chain");
449
450        // Create server and client configs
451        let server_config = cert_chain
452            .create_server_config()
453            .expect("Failed to create server config");
454        let client_config = cert_chain
455            .create_client_config()
456            .expect("Failed to create client config");
457
458        // Create a simple axum app with multiple routes
459        let app = Router::new()
460            .route(
461                "/test1",
462                get(|| async { Json(json!({"endpoint": "test1"})) }),
463            )
464            .route(
465                "/test2",
466                get(|| async { Json(json!({"endpoint": "test2"})) }),
467            )
468            .route(
469                "/test3",
470                get(|| async { Json(json!({"endpoint": "test3"})) }),
471            );
472
473        // Start server with mTLS
474        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
475
476        // Use a fixed port for testing
477        let test_port = 8444; // Different port from the first test
478        let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
479
480        let server_handle = tokio::spawn(async move {
481            axum_server::bind_rustls(server_addr, rustls_config)
482                .serve(app.into_make_service())
483                .await
484        });
485
486        // Give server time to start
487        tokio::time::sleep(Duration::from_millis(500)).await;
488
489        // Create reqwest client with mTLS config
490        let client = Client::builder()
491            .use_preconfigured_tls(client_config)
492            .build()
493            .expect("Failed to build client");
494
495        // Test multiple requests to different endpoints
496        for endpoint in ["test1", "test2", "test3"] {
497            let response = timeout(
498                Duration::from_secs(10),
499                client
500                    .get(format!("https://127.0.0.1:{}/{}", test_port, endpoint))
501                    .send(),
502            )
503            .await
504            .expect("Request timed out")
505            .expect("Failed to send request");
506
507            assert!(
508                response.status().is_success(),
509                "Request to {} should succeed",
510                endpoint
511            );
512
513            let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
514            assert_eq!(body["endpoint"], endpoint);
515        }
516
517        // Shutdown server
518        server_handle.abort();
519    }
520
521    #[tokio::test]
522    async fn test_mtls_configuration_compatibility() {
523        init_crypto_provider();
524
525        // Generate certificate chain
526        let cert_chain =
527            CertificateChain::generate().expect("Failed to generate certificate chain");
528
529        // Create server config - should work without errors
530        let server_config = cert_chain
531            .create_server_config()
532            .expect("Failed to create server config");
533
534        // Create client config - should work without errors
535        let client_config = cert_chain
536            .create_client_config()
537            .expect("Failed to create client config");
538
539        // Verify we can create a reqwest client with the client config
540        let _client = Client::builder()
541            .use_preconfigured_tls(client_config)
542            .build()
543            .expect("Failed to build reqwest client with mTLS config");
544
545        // Verify we can create an axum-server RustlsConfig with the server config
546        let _rustls_config = RustlsConfig::from_config(Arc::new(server_config));
547
548        // Verify certificate chain properties
549        assert!(cert_chain.get_ca_cert_pem().is_ok());
550        assert!(cert_chain.get_server_cert_pem().is_ok());
551        assert!(cert_chain.get_client_cert_pem().is_ok());
552        assert!(cert_chain.get_server_key_pem().is_ok());
553        assert!(cert_chain.get_client_key_pem().is_ok());
554
555        // If we get here, the mTLS configuration is properly set up
556        println!("✅ mTLS configuration successfully created");
557        println!("✅ Reqwest client can be configured with client certificates");
558        println!("✅ Axum server can be configured with server certificates");
559        println!("✅ Certificate chain includes CA, server, and client certificates");
560    }
561
562    #[tokio::test]
563    async fn test_mtls_certificate_validation() {
564        init_crypto_provider();
565
566        // Test that different certificate chains are incompatible
567        let cert_chain1 =
568            CertificateChain::generate().expect("Failed to generate certificate chain 1");
569        let cert_chain2 =
570            CertificateChain::generate().expect("Failed to generate certificate chain 2");
571
572        // Create configs from different chains
573        let server_config1 = cert_chain1
574            .create_server_config()
575            .expect("Failed to create server config 1");
576        let client_config2 = cert_chain2
577            .create_client_config()
578            .expect("Failed to create client config 2");
579
580        // These should be created successfully but would fail in actual connection
581        let _client = Client::builder()
582            .use_preconfigured_tls(client_config2)
583            .build()
584            .expect("Failed to build client with different CA");
585
586        let _rustls_config = RustlsConfig::from_config(Arc::new(server_config1));
587
588        // The configurations are created successfully, but they would fail during handshake
589        // because they use different CAs
590        println!("✅ Different certificate chains create valid but incompatible configurations");
591    }
592}