1use std::sync::Arc;
2
3use rama_core::{Context, combinators::Either3};
4
5use super::{ClientHelloExtension, merge_client_hello_lists};
6use crate::tls::{CipherSuite, CompressionAlgorithm, DataEncoding, KeyLogIntent, ProtocolVersion};
7
8#[derive(Debug, Clone, Default)]
9pub struct ClientConfigChain {
10 configs: Vec<Arc<ClientConfig>>,
11}
12
13#[derive(Debug)]
14pub struct ClientConfigChainRef<'a> {
15 data: ClientConfigChainRefData<'a>,
16}
17
18impl ClientConfigChainRef<'_> {
19 pub fn append(&mut self, cfg: impl Into<Arc<ClientConfig>>) {
20 let mut data = ClientConfigChainRefData::Dummy;
21 std::mem::swap(&mut self.data, &mut data);
22 self.data = data.append(cfg);
23 }
24
25 pub fn prepend(&mut self, cfg: impl Into<Arc<ClientConfig>>) {
26 let mut data = ClientConfigChainRefData::Dummy;
27 std::mem::swap(&mut self.data, &mut data);
28 self.data = data.prepend(cfg);
29 }
30
31 pub fn into_owned(self) -> ClientConfigChain {
32 match self.data {
33 ClientConfigChainRefData::Chain(client_config_chain) => ClientConfigChain {
34 configs: client_config_chain.configs.clone(),
35 },
36 ClientConfigChainRefData::Single(client_config) => ClientConfigChain {
37 configs: vec![client_config.clone()],
38 },
39 ClientConfigChainRefData::Owned(configs) => ClientConfigChain { configs },
40 ClientConfigChainRefData::Dummy => unreachable!(),
41 }
42 }
43 pub fn iter(&self) -> impl Iterator<Item = &ClientConfig> {
44 match &self.data {
45 ClientConfigChainRefData::Chain(client_config_chain) => {
46 Either3::A(client_config_chain.configs.iter().map(|a| a.as_ref()))
47 }
48 ClientConfigChainRefData::Single(client_config) => {
49 Either3::B(std::iter::once(client_config.as_ref()))
50 }
51 ClientConfigChainRefData::Owned(configs) => {
52 Either3::C(configs.iter().map(|a| a.as_ref()))
53 }
54 ClientConfigChainRefData::Dummy => unreachable!(),
55 }
56 }
57}
58
59#[derive(Debug)]
60enum ClientConfigChainRefData<'a> {
61 Chain(&'a ClientConfigChain),
62 Single(&'a Arc<ClientConfig>),
63 Owned(Vec<Arc<ClientConfig>>),
64 Dummy,
65}
66
67impl ClientConfigChainRefData<'_> {
68 fn append(self, cfg: impl Into<Arc<ClientConfig>>) -> Self {
69 let mut configs = match self {
70 ClientConfigChainRefData::Chain(client_config_chain) => {
71 client_config_chain.configs.clone()
72 }
73 ClientConfigChainRefData::Single(client_config) => vec![client_config.clone()],
74 ClientConfigChainRefData::Owned(client_configs) => client_configs,
75 ClientConfigChainRefData::Dummy => unreachable!(),
76 };
77 configs.push(cfg.into());
78 ClientConfigChainRefData::Owned(configs)
79 }
80
81 fn prepend(self, cfg: impl Into<Arc<ClientConfig>>) -> Self {
82 ClientConfigChainRefData::Owned(match self {
83 ClientConfigChainRefData::Chain(client_config_chain) => {
84 let mut v = Vec::with_capacity(client_config_chain.configs.len() + 1);
85 v.push(cfg.into());
86 v.extend(client_config_chain.configs.iter().cloned());
87 v
88 }
89 ClientConfigChainRefData::Single(client_config) => {
90 vec![cfg.into(), client_config.clone()]
91 }
92 ClientConfigChainRefData::Owned(client_configs) => {
93 let mut v = Vec::with_capacity(client_configs.len() + 1);
94 v.push(cfg.into());
95 v.extend(client_configs);
96 v
97 }
98 ClientConfigChainRefData::Dummy => unreachable!(),
99 })
100 }
101}
102
103pub fn extract_client_config_from_ctx<State>(
104 ctx: &Context<State>,
105) -> Option<ClientConfigChainRef<'_>> {
106 match ctx.get::<ClientConfigChain>() {
107 Some(chain) => Some(ClientConfigChainRef {
108 data: ClientConfigChainRefData::Chain(chain),
109 }),
110 None => ctx
111 .get::<Arc<ClientConfig>>()
112 .map(|cfg| ClientConfigChainRef {
113 data: ClientConfigChainRefData::Single(cfg),
114 }),
115 }
116}
117
118pub fn append_client_config_to_ctx<State>(
119 ctx: &mut Context<State>,
120 cfg: impl Into<Arc<ClientConfig>>,
121) {
122 match ctx.get_mut::<ClientConfigChain>() {
123 Some(chain) => {
124 chain.configs.push(cfg.into());
125 }
126 None => match ctx.remove::<Arc<ClientConfig>>() {
127 Some(old_cfg) => {
128 ctx.insert(ClientConfigChain {
129 configs: vec![old_cfg, cfg.into()],
130 });
131 }
132 None => {
133 ctx.insert(ClientConfigChain::from(cfg.into()));
134 }
135 },
136 }
137}
138
139pub fn append_all_client_configs_to_ctx<State>(
140 ctx: &mut Context<State>,
141 cfg_it: impl IntoIterator<Item: Into<Arc<ClientConfig>>>,
142) {
143 let cfg_it = cfg_it.into_iter();
144 match ctx.get_mut::<ClientConfigChain>() {
145 Some(chain) => {
146 chain.configs.extend(cfg_it.map(Into::into));
147 }
148 None => match ctx.remove::<Arc<ClientConfig>>() {
149 Some(old_cfg) => {
150 let (lb, _) = cfg_it.size_hint();
151 assert!(lb < usize::MAX);
152
153 let mut configs = Vec::with_capacity(lb + 1);
154 configs.push(old_cfg);
155 configs.extend(cfg_it.map(Into::into));
156
157 ctx.insert(ClientConfigChain { configs });
158 }
159 None => {
160 let chain: ClientConfigChain = cfg_it.collect();
161 ctx.insert(chain);
162 }
163 },
164 }
165}
166
167impl From<ClientConfig> for ClientConfigChain {
168 fn from(value: ClientConfig) -> Self {
169 ClientConfigChain {
170 configs: vec![Arc::new(value)],
171 }
172 }
173}
174
175impl From<Arc<ClientConfig>> for ClientConfigChain {
176 fn from(value: Arc<ClientConfig>) -> Self {
177 ClientConfigChain {
178 configs: vec![value],
179 }
180 }
181}
182
183impl<Item> FromIterator<Item> for ClientConfigChain
184where
185 Item: Into<Arc<ClientConfig>>,
186{
187 fn from_iter<T: IntoIterator<Item = Item>>(iter: T) -> Self {
188 ClientConfigChain {
189 configs: iter.into_iter().map(Into::into).collect(),
190 }
191 }
192}
193
194#[derive(Debug, Clone, Default)]
195pub struct ProxyClientConfig(pub Arc<ClientConfig>);
201
202#[derive(Debug, Clone, Default)]
203pub struct ClientConfig {
205 pub cipher_suites: Option<Vec<CipherSuite>>,
207 pub compression_algorithms: Option<Vec<CompressionAlgorithm>>,
209 pub extensions: Option<Vec<ClientHelloExtension>>,
216 pub server_verify_mode: Option<ServerVerifyMode>,
218 pub client_auth: Option<ClientAuth>,
220 pub key_logger: Option<KeyLogIntent>,
222 pub store_server_certificate_chain: bool,
224}
225
226impl ClientConfig {
227 pub fn merge(&mut self, other: ClientConfig) {
229 if let Some(cipher_suites) = other.cipher_suites {
230 self.cipher_suites = Some(cipher_suites);
231 }
232
233 if let Some(compression_algorithms) = other.compression_algorithms {
234 self.compression_algorithms = Some(compression_algorithms);
235 }
236
237 self.extensions = match (self.extensions.take(), other.extensions) {
238 (Some(our_ext), Some(other_ext)) => Some(merge_client_hello_lists(our_ext, other_ext)),
239 (None, Some(other_ext)) => Some(other_ext),
240 (maybe_our_ext, None) => maybe_our_ext,
241 };
242
243 if let Some(server_verify_mode) = other.server_verify_mode {
244 self.server_verify_mode = Some(server_verify_mode);
245 }
246
247 if let Some(client_auth) = other.client_auth {
248 self.client_auth = Some(client_auth);
249 }
250
251 if let Some(key_logger) = other.key_logger {
252 self.key_logger = Some(key_logger);
253 }
254 }
255}
256
257#[derive(Debug, Clone)]
258pub enum ClientAuth {
260 SelfSigned,
262 Single(ClientAuthData),
264}
265
266#[derive(Debug, Clone)]
267pub struct ClientAuthData {
269 pub private_key: DataEncoding,
271 pub cert_chain: DataEncoding,
273}
274
275#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
276pub enum ServerVerifyMode {
278 #[default]
279 Auto,
282 Disable,
284}
285
286impl From<super::ClientHello> for ClientConfig {
287 fn from(value: super::ClientHello) -> Self {
288 Self {
289 cipher_suites: (!value.cipher_suites.is_empty()).then_some(value.cipher_suites),
290 compression_algorithms: (!value.compression_algorithms.is_empty())
291 .then_some(value.compression_algorithms),
292 extensions: (!value.extensions.is_empty()).then_some(value.extensions),
293 ..Default::default()
294 }
295 }
296}
297
298impl From<ClientConfig> for super::ClientHello {
299 fn from(value: ClientConfig) -> Self {
300 super::ClientHello {
301 protocol_version: ProtocolVersion::TLSv1_2,
302 cipher_suites: value.cipher_suites.unwrap_or_default(),
303 compression_algorithms: value.compression_algorithms.unwrap_or_default(),
304 extensions: value.extensions.unwrap_or_default(),
305 }
306 }
307}