rocket_community/tls/
resolver.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::ops::Deref;
4use std::sync::Arc;
5
6pub use rustls::server::{ClientHello, ServerConfig};
7
8use crate::fairing::{self, Info, Kind};
9use crate::{Build, Ignite, Rocket};
10
11/// Proxy type to get PartialEq + Debug impls.
12#[derive(Clone)]
13pub(crate) struct DynResolver(Arc<dyn Resolver>);
14
15pub struct Fairing<T: ?Sized>(PhantomData<T>);
16
17/// A dynamic TLS configuration resolver.
18///
19/// # Example
20///
21/// This is an async trait. Implement it as follows:
22///
23/// ```rust
24/// # #[macro_use] extern crate rocket_community as rocket;
25/// use std::sync::Arc;
26/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
27/// use rocket::{Rocket, Build};
28///
29/// struct MyResolver(Arc<ServerConfig>);
30///
31/// #[rocket::async_trait]
32/// impl Resolver for MyResolver {
33///     async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
34///         // This is equivalent to what the default resolver would do.
35///         let config: TlsConfig = rocket.figment().extract_inner("tls")?;
36///         let server_config = config.server_config().await?;
37///         Ok(MyResolver(Arc::new(server_config)))
38///     }
39///
40///     async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
41///         // return a `ServerConfig` based on `hello`; here we ignore it
42///         Some(self.0.clone())
43///     }
44/// }
45///
46/// #[launch]
47/// fn rocket() -> _ {
48///     rocket::build().attach(MyResolver::fairing())
49/// }
50/// ```
51#[crate::async_trait]
52pub trait Resolver: Send + Sync + 'static {
53    async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self>
54    where
55        Self: Sized,
56    {
57        let _rocket = rocket;
58        let type_name = std::any::type_name::<Self>();
59        Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into())
60    }
61
62    async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;
63
64    fn fairing() -> Fairing<Self>
65    where
66        Self: Sized,
67    {
68        Fairing(PhantomData)
69    }
70}
71
72#[crate::async_trait]
73impl<T: Resolver> fairing::Fairing for Fairing<T> {
74    fn info(&self) -> Info {
75        Info {
76            name: "Resolver Fairing",
77            kind: Kind::Ignite | Kind::Singleton,
78        }
79    }
80
81    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
82        let result = T::init(&rocket).await;
83        match result {
84            Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc<dyn Resolver>)),
85            Err(e) => {
86                let type_name = std::any::type_name::<T>();
87                error!(type_name, reason = %e, "TLS resolver failed to initialize");
88                Err(rocket)
89            }
90        }
91    }
92}
93
94impl DynResolver {
95    pub fn extract(rocket: &Rocket<Ignite>) -> Option<Self> {
96        rocket.state::<Arc<dyn Resolver>>().map(|r| Self(r.clone()))
97    }
98}
99
100impl fmt::Debug for DynResolver {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        f.debug_tuple("Resolver").finish()
103    }
104}
105
106impl PartialEq for DynResolver {
107    fn eq(&self, _: &Self) -> bool {
108        false
109    }
110}
111
112impl Deref for DynResolver {
113    type Target = dyn Resolver;
114
115    fn deref(&self) -> &Self::Target {
116        &*self.0
117    }
118}