1use std::io;
2use tokio::io::AsyncRead;
3use tokio::io::AsyncWrite;
4#[allow(unused_imports)]
5use tracing::{debug, error, info, instrument, span, trace, warn, Level};
6use ate_crypto::KeySize;
7use ate_crypto::NodeId;
8use ate_crypto::SerializationFormat;
9use serde::{Deserialize, Serialize};
10
11use super::protocol::MessageProtocolVersion;
12use super::protocol::MessageProtocolApi;
13
14#[derive(Serialize, Deserialize, Debug, Clone)]
15pub struct HelloMetadata {
16 pub client_id: NodeId,
17 pub server_id: NodeId,
18 pub path: String,
19 pub encryption: Option<KeySize>,
20 pub wire_format: SerializationFormat,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone)]
24struct SenderHello {
25 pub id: NodeId,
26 pub path: String,
27 pub domain: String,
28 pub key_size: Option<KeySize>,
29 #[serde(default = "default_stream_protocol_version")]
30 pub version: MessageProtocolVersion,
31}
32
33fn default_stream_protocol_version() -> MessageProtocolVersion {
34 MessageProtocolVersion::V1
35}
36
37#[derive(Serialize, Deserialize, Debug, Clone)]
38struct ReceiverHello {
39 pub id: NodeId,
40 pub encryption: Option<KeySize>,
41 pub wire_format: SerializationFormat,
42 #[serde(default = "default_stream_protocol_version")]
43 pub version: MessageProtocolVersion,
44}
45
46pub async fn mesh_hello_exchange_sender(
47 stream_rx: Box<dyn AsyncRead + Send + Sync + Unpin + 'static>,
48 stream_tx: Box<dyn AsyncWrite + Send + Sync + Unpin + 'static>,
49 client_id: NodeId,
50 hello_path: String,
51 domain: String,
52 key_size: Option<KeySize>,
53) -> tokio::io::Result<(
54 Box<dyn MessageProtocolApi + Send + Sync + 'static>,
55 HelloMetadata
56)> {
57 trace!("client sending hello");
59 let hello_client = SenderHello {
60 id: client_id,
61 path: hello_path.clone(),
62 domain,
63 key_size,
64 version: MessageProtocolVersion::default(),
65 };
66 let hello_client_bytes = serde_json::to_vec(&hello_client)?;
67 let mut proto = MessageProtocolVersion::V1.create(
68 Some(stream_rx),
69 Some(stream_tx)
70 );
71 proto
72 .write_with_fixed_16bit_header(&hello_client_bytes[..], false)
73 .await?;
74
75 let hello_server_bytes = proto.read_with_fixed_16bit_header().await?;
77 trace!("client received hello from server");
78 trace!("{}", String::from_utf8_lossy(&hello_server_bytes[..]));
79 let hello_server: ReceiverHello = serde_json::from_slice(&hello_server_bytes[..])?;
80
81 if let Some(needed_size) = &key_size {
83 match &hello_server.encryption {
84 None => {
85 return Err(io::Error::new(io::ErrorKind::ConnectionRefused, "the server encryption strength is too weak"));
86 }
87 Some(a) if *a < *needed_size => {
88 return Err(io::Error::new(io::ErrorKind::ConnectionRefused, "the server encryption strength is too weak"));
89 }
90 _ => {}
91 }
92 }
93
94 let version = hello_server.version.min(hello_client.version);
96 proto = version.upgrade(proto);
97
98 trace!(
100 "client encryption={}",
101 match &hello_server.encryption {
102 Some(a) => a.as_str(),
103 None => "none",
104 }
105 );
106 trace!("client wire_format={}", hello_server.wire_format);
107
108 Ok((
109 proto,
110 HelloMetadata {
111 client_id,
112 server_id: hello_server.id,
113 path: hello_path,
114 encryption: hello_server.encryption,
115 wire_format: hello_server.wire_format,
116 }
117 ))
118}
119
120pub async fn mesh_hello_exchange_receiver(
121 stream_rx: Box<dyn AsyncRead + Send + Sync + Unpin + 'static>,
122 stream_tx: Box<dyn AsyncWrite + Send + Sync + Unpin + 'static>,
123 server_id: NodeId,
124 key_size: Option<KeySize>,
125 wire_format: SerializationFormat,
126) -> tokio::io::Result<(
127 Box<dyn MessageProtocolApi + Send + Sync + 'static>,
128 HelloMetadata
129)>
130{
131 let mut proto = MessageProtocolVersion::V1.create(
133 Some(stream_rx),
134 Some(stream_tx)
135 );
136 let hello_client_bytes = proto
137 .read_with_fixed_16bit_header()
138 .await?;
139 trace!("server received hello from client");
140 let hello_client: SenderHello = serde_json::from_slice(&hello_client_bytes[..])?;
142
143 let encryption = mesh_hello_upgrade_key(key_size, hello_client.key_size);
145
146 trace!("server sending hello (wire_format={})", wire_format);
148 let hello_server = ReceiverHello {
149 id: server_id,
150 encryption,
151 wire_format,
152 version: MessageProtocolVersion::default(),
153 };
154 let hello_server_bytes = serde_json::to_vec(&hello_server)?;
155 proto
156 .write_with_fixed_16bit_header(&hello_server_bytes[..], false)
157 .await?;
158
159 proto = hello_server.version
161 .min(hello_client.version)
162 .upgrade(proto);
163
164 Ok((
165 proto,
166 HelloMetadata {
167 client_id: hello_client.id,
168 server_id,
169 path: hello_client.path,
170 encryption,
171 wire_format,
172 }
173 ))
174}
175
176fn mesh_hello_upgrade_key(key1: Option<KeySize>, key2: Option<KeySize>) -> Option<KeySize> {
177 if key1.is_none() && key2.is_none() {
179 return None;
180 }
181
182 let key1 = match key1 {
184 Some(a) => a,
185 None => {
186 trace!("upgrading to {}bit shared secret", key2.unwrap());
187 return key2;
188 }
189 };
190 let key2 = match key2 {
191 Some(a) => a,
192 None => {
193 trace!("upgrading to {}bit shared secret", key1);
194 return Some(key1);
195 }
196 };
197
198 if key2 > key1 {
200 trace!("upgrading to {}bit shared secret", key2);
201 return Some(key2);
202 }
203 if key1 > key2 {
204 trace!("upgrading to {}bit shared secret", key2);
205 return Some(key1);
206 }
207
208 return Some(key1);
210}