rama_net/tls/client/
config.rs

1use std::sync::Arc;
2
3use rama_core::{Context, combinators::Either3};
4
5use super::{ClientHelloExtension, merge_client_hello_lists};
6use crate::tls::{CipherSuite, CompressionAlgorithm, DataEncoding, KeyLogIntent, ProtocolVersion};
7
8#[derive(Debug, Clone, Default)]
9pub struct ClientConfigChain {
10    configs: Vec<Arc<ClientConfig>>,
11}
12
13#[derive(Debug)]
14pub struct ClientConfigChainRef<'a> {
15    data: ClientConfigChainRefData<'a>,
16}
17
18impl ClientConfigChainRef<'_> {
19    pub fn append(&mut self, cfg: impl Into<Arc<ClientConfig>>) {
20        let mut data = ClientConfigChainRefData::Dummy;
21        std::mem::swap(&mut self.data, &mut data);
22        self.data = data.append(cfg);
23    }
24
25    pub fn prepend(&mut self, cfg: impl Into<Arc<ClientConfig>>) {
26        let mut data = ClientConfigChainRefData::Dummy;
27        std::mem::swap(&mut self.data, &mut data);
28        self.data = data.prepend(cfg);
29    }
30
31    pub fn into_owned(self) -> ClientConfigChain {
32        match self.data {
33            ClientConfigChainRefData::Chain(client_config_chain) => ClientConfigChain {
34                configs: client_config_chain.configs.clone(),
35            },
36            ClientConfigChainRefData::Single(client_config) => ClientConfigChain {
37                configs: vec![client_config.clone()],
38            },
39            ClientConfigChainRefData::Owned(configs) => ClientConfigChain { configs },
40            ClientConfigChainRefData::Dummy => unreachable!(),
41        }
42    }
43    pub fn iter(&self) -> impl Iterator<Item = &ClientConfig> {
44        match &self.data {
45            ClientConfigChainRefData::Chain(client_config_chain) => {
46                Either3::A(client_config_chain.configs.iter().map(|a| a.as_ref()))
47            }
48            ClientConfigChainRefData::Single(client_config) => {
49                Either3::B(std::iter::once(client_config.as_ref()))
50            }
51            ClientConfigChainRefData::Owned(configs) => {
52                Either3::C(configs.iter().map(|a| a.as_ref()))
53            }
54            ClientConfigChainRefData::Dummy => unreachable!(),
55        }
56    }
57}
58
59#[derive(Debug)]
60enum ClientConfigChainRefData<'a> {
61    Chain(&'a ClientConfigChain),
62    Single(&'a Arc<ClientConfig>),
63    Owned(Vec<Arc<ClientConfig>>),
64    Dummy,
65}
66
67impl ClientConfigChainRefData<'_> {
68    fn append(self, cfg: impl Into<Arc<ClientConfig>>) -> Self {
69        let mut configs = match self {
70            ClientConfigChainRefData::Chain(client_config_chain) => {
71                client_config_chain.configs.clone()
72            }
73            ClientConfigChainRefData::Single(client_config) => vec![client_config.clone()],
74            ClientConfigChainRefData::Owned(client_configs) => client_configs,
75            ClientConfigChainRefData::Dummy => unreachable!(),
76        };
77        configs.push(cfg.into());
78        ClientConfigChainRefData::Owned(configs)
79    }
80
81    fn prepend(self, cfg: impl Into<Arc<ClientConfig>>) -> Self {
82        ClientConfigChainRefData::Owned(match self {
83            ClientConfigChainRefData::Chain(client_config_chain) => {
84                let mut v = Vec::with_capacity(client_config_chain.configs.len() + 1);
85                v.push(cfg.into());
86                v.extend(client_config_chain.configs.iter().cloned());
87                v
88            }
89            ClientConfigChainRefData::Single(client_config) => {
90                vec![cfg.into(), client_config.clone()]
91            }
92            ClientConfigChainRefData::Owned(client_configs) => {
93                let mut v = Vec::with_capacity(client_configs.len() + 1);
94                v.push(cfg.into());
95                v.extend(client_configs);
96                v
97            }
98            ClientConfigChainRefData::Dummy => unreachable!(),
99        })
100    }
101}
102
103pub fn extract_client_config_from_ctx<State>(
104    ctx: &Context<State>,
105) -> Option<ClientConfigChainRef<'_>> {
106    match ctx.get::<ClientConfigChain>() {
107        Some(chain) => Some(ClientConfigChainRef {
108            data: ClientConfigChainRefData::Chain(chain),
109        }),
110        None => ctx
111            .get::<Arc<ClientConfig>>()
112            .map(|cfg| ClientConfigChainRef {
113                data: ClientConfigChainRefData::Single(cfg),
114            }),
115    }
116}
117
118pub fn append_client_config_to_ctx<State>(
119    ctx: &mut Context<State>,
120    cfg: impl Into<Arc<ClientConfig>>,
121) {
122    match ctx.get_mut::<ClientConfigChain>() {
123        Some(chain) => {
124            chain.configs.push(cfg.into());
125        }
126        None => match ctx.remove::<Arc<ClientConfig>>() {
127            Some(old_cfg) => {
128                ctx.insert(ClientConfigChain {
129                    configs: vec![old_cfg, cfg.into()],
130                });
131            }
132            None => {
133                ctx.insert(ClientConfigChain::from(cfg.into()));
134            }
135        },
136    }
137}
138
139pub fn append_all_client_configs_to_ctx<State>(
140    ctx: &mut Context<State>,
141    cfg_it: impl IntoIterator<Item: Into<Arc<ClientConfig>>>,
142) {
143    let cfg_it = cfg_it.into_iter();
144    match ctx.get_mut::<ClientConfigChain>() {
145        Some(chain) => {
146            chain.configs.extend(cfg_it.map(Into::into));
147        }
148        None => match ctx.remove::<Arc<ClientConfig>>() {
149            Some(old_cfg) => {
150                let (lb, _) = cfg_it.size_hint();
151                assert!(lb < usize::MAX);
152
153                let mut configs = Vec::with_capacity(lb + 1);
154                configs.push(old_cfg);
155                configs.extend(cfg_it.map(Into::into));
156
157                ctx.insert(ClientConfigChain { configs });
158            }
159            None => {
160                let chain: ClientConfigChain = cfg_it.collect();
161                ctx.insert(chain);
162            }
163        },
164    }
165}
166
167impl From<ClientConfig> for ClientConfigChain {
168    fn from(value: ClientConfig) -> Self {
169        ClientConfigChain {
170            configs: vec![Arc::new(value)],
171        }
172    }
173}
174
175impl From<Arc<ClientConfig>> for ClientConfigChain {
176    fn from(value: Arc<ClientConfig>) -> Self {
177        ClientConfigChain {
178            configs: vec![value],
179        }
180    }
181}
182
183impl<Item> FromIterator<Item> for ClientConfigChain
184where
185    Item: Into<Arc<ClientConfig>>,
186{
187    fn from_iter<T: IntoIterator<Item = Item>>(iter: T) -> Self {
188        ClientConfigChain {
189            configs: iter.into_iter().map(Into::into).collect(),
190        }
191    }
192}
193
194#[derive(Debug, Clone, Default)]
195/// Common API to configure a Proxy TLS Client
196///
197/// See [`ClientConfig`] for more information,
198/// this is only a new-type wrapper to be able to differentiate
199/// the info found in context for a dynamic https client.
200pub struct ProxyClientConfig(pub Arc<ClientConfig>);
201
202#[derive(Debug, Clone, Default)]
203/// Common API to configure a TLS Client
204pub struct ClientConfig {
205    /// optional intent for cipher suites to be used by client
206    pub cipher_suites: Option<Vec<CipherSuite>>,
207    /// optional intent for compression algorithms to be used by client
208    pub compression_algorithms: Option<Vec<CompressionAlgorithm>>,
209    /// optional intent for extensions to be used by client
210    ///
211    /// Commpon examples are:
212    ///
213    /// - [`super::ClientHelloExtension::ApplicationLayerProtocolNegotiation`]
214    /// - [`super::ClientHelloExtension::SupportedVersions`]
215    pub extensions: Option<Vec<ClientHelloExtension>>,
216    /// optionally define how server should be verified by client
217    pub server_verify_mode: Option<ServerVerifyMode>,
218    /// optionally define raw (PEM-encoded) client auth certs
219    pub client_auth: Option<ClientAuth>,
220    /// key log intent
221    pub key_logger: Option<KeyLogIntent>,
222    /// if enabled server certificates will be stored in [`NegotiatedTlsParameters`]
223    pub store_server_certificate_chain: bool,
224}
225
226impl ClientConfig {
227    /// Merge this [`ClientConfig`] with aother one.
228    pub fn merge(&mut self, other: ClientConfig) {
229        if let Some(cipher_suites) = other.cipher_suites {
230            self.cipher_suites = Some(cipher_suites);
231        }
232
233        if let Some(compression_algorithms) = other.compression_algorithms {
234            self.compression_algorithms = Some(compression_algorithms);
235        }
236
237        self.extensions = match (self.extensions.take(), other.extensions) {
238            (Some(our_ext), Some(other_ext)) => Some(merge_client_hello_lists(our_ext, other_ext)),
239            (None, Some(other_ext)) => Some(other_ext),
240            (maybe_our_ext, None) => maybe_our_ext,
241        };
242
243        if let Some(server_verify_mode) = other.server_verify_mode {
244            self.server_verify_mode = Some(server_verify_mode);
245        }
246
247        if let Some(client_auth) = other.client_auth {
248            self.client_auth = Some(client_auth);
249        }
250
251        if let Some(key_logger) = other.key_logger {
252            self.key_logger = Some(key_logger);
253        }
254    }
255}
256
257#[derive(Debug, Clone)]
258/// The kind of client auth to be used.
259pub enum ClientAuth {
260    /// Request the tls implementation to generate self-signed single data
261    SelfSigned,
262    /// Single data provided by the configurator
263    Single(ClientAuthData),
264}
265
266#[derive(Debug, Clone)]
267/// Raw private key and certificate data to facilitate client authentication.
268pub struct ClientAuthData {
269    /// private key used by client
270    pub private_key: DataEncoding,
271    /// certificate chain as a companion to the private key
272    pub cert_chain: DataEncoding,
273}
274
275#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
276/// Mode of server verification by a (tls) client
277pub enum ServerVerifyMode {
278    #[default]
279    /// Use the default verification approach as defined
280    /// by the implementation of the used (tls) client
281    Auto,
282    /// Explicitly disable server verification (if possible)
283    Disable,
284}
285
286impl From<super::ClientHello> for ClientConfig {
287    fn from(value: super::ClientHello) -> Self {
288        Self {
289            cipher_suites: (!value.cipher_suites.is_empty()).then_some(value.cipher_suites),
290            compression_algorithms: (!value.compression_algorithms.is_empty())
291                .then_some(value.compression_algorithms),
292            extensions: (!value.extensions.is_empty()).then_some(value.extensions),
293            ..Default::default()
294        }
295    }
296}
297
298impl From<ClientConfig> for super::ClientHello {
299    fn from(value: ClientConfig) -> Self {
300        super::ClientHello {
301            protocol_version: ProtocolVersion::TLSv1_2,
302            cipher_suites: value.cipher_suites.unwrap_or_default(),
303            compression_algorithms: value.compression_algorithms.unwrap_or_default(),
304            extensions: value.extensions.unwrap_or_default(),
305        }
306    }
307}