1use rustls::internal::msgs::codec::Reader;
2use rustls::internal::msgs::handshake::{
3 ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload,
4 ServerExtension, ServerHelloPayload,
5};
6use rustls::internal::msgs::message::{Message, MessagePayload, OpaqueMessage};
7use rustls::{Error as RustlsError, ProtocolVersion};
8
9use std::collections::hash_map::DefaultHasher;
10use std::fmt;
11use std::hash::Hasher;
12use std::str::FromStr;
13use std::time::{SystemTime, UNIX_EPOCH};
14
15#[allow(dead_code)]
16pub fn get_server_tls_version(shp: &ServerHelloPayload) -> Option<ProtocolVersion> {
17 shp.extensions
18 .iter()
19 .filter_map(|ext| {
20 if let ServerExtension::SupportedVersions(vers) = ext {
21 Some(vers)
22 } else {
23 None
24 }
25 })
26 .next()
27 .cloned()
28}
29
30pub fn get_client_tls_versions(shp: &ClientHelloPayload) -> Option<&Vec<ProtocolVersion>> {
31 shp.extensions
32 .iter()
33 .filter_map(|ext| {
34 if let ClientExtension::SupportedVersions(vers) = ext {
35 Some(vers)
36 } else {
37 None
38 }
39 })
40 .next()
41}
42
43pub trait TlsMessageExt {
44 fn into_client_hello_payload(self) -> Option<ClientHelloPayload>;
45 fn into_server_hello_payload(self) -> Option<ServerHelloPayload>;
46}
47
48impl TlsMessageExt for Message {
49 fn into_client_hello_payload(self) -> Option<ClientHelloPayload> {
50 if let MessagePayload::Handshake {
51 parsed:
52 HandshakeMessagePayload {
53 payload: HandshakePayload::ClientHello(chp),
54 ..
55 },
56 ..
57 } = self.payload
58 {
59 Some(chp)
60 } else {
61 None
62 }
63 }
64
65 fn into_server_hello_payload(self) -> Option<ServerHelloPayload> {
66 if let MessagePayload::Handshake {
67 parsed:
68 HandshakeMessagePayload {
69 payload: HandshakePayload::ServerHello(shp),
70 ..
71 },
72 ..
73 } = self.payload
74 {
75 Some(shp)
76 } else {
77 None
78 }
79 }
80}
81
82pub fn parse_tls_plain_message(buf: &[u8]) -> Result<Message, RustlsError> {
83 OpaqueMessage::read(&mut Reader::init(buf))
84 .map(|om| om.into_plain_message())
85 .map_err(|_e| RustlsError::CorruptMessage) .and_then(Message::try_from)
87}
88
89pub struct ConcatenatingFormatter<'a, T: fmt::Display, const CHAR: char>(&'a [T]);
90
91impl<'s, T: fmt::Display, const CHAR: char> fmt::Display for ConcatenatingFormatter<'s, T, CHAR> {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 let mut slice = self.0.iter().peekable();
94 while let Some(part) = slice.next() {
95 write!(f, "{}", part)?;
96 if slice.peek().is_some() {
97 write!(f, "{}", CHAR)?;
98 }
99 }
100 Ok(())
101 }
102}
103
104pub fn fmtconcat<T: fmt::Display, const CHAR: char>(
105 slice: &'_ [T],
106) -> ConcatenatingFormatter<'_, T, CHAR> {
107 ConcatenatingFormatter(slice)
108}
109
110pub struct ConcatenatedParser<T: FromStr, const CHAR: char>(pub Vec<T>);
111
112impl<T: FromStr, const CHAR: char> FromStr for ConcatenatedParser<T, CHAR> {
113 type Err = &'static str;
114
115 fn from_str(s: &str) -> Result<Self, Self::Err> {
116 let mut v = vec![];
117 for part in s.split(CHAR) {
118 v.push(part.parse::<T>().map_err(|_e| "Not valid value")?);
119 }
120 Ok(Self(v))
121 }
122}
123
124impl<T: FromStr, const CHAR: char> ConcatenatedParser<T, CHAR> {
125 pub fn into_inner(self) -> Vec<T> {
126 self.0
127 }
128}
129
130pub fn rand_in<const S: usize, const E: usize>() -> usize {
136 #[cfg(not(feature = "rand"))]
137 {
138 let nanos = SystemTime::now()
141 .duration_since(UNIX_EPOCH)
142 .unwrap()
143 .subsec_nanos();
144 let mut h = DefaultHasher::new();
145 h.write_u32(nanos);
146 (h.finish() as usize) % (E - S) + S
147 }
148 #[cfg(feature = "rand")]
149 {
150 use rand::{thread_rng, Rng};
151 thread_rng().gen_range(S..E)
152 }
153}