ate_comms/
hello.rs

1use std::io;
2use tokio::io::AsyncRead;
3use tokio::io::AsyncWrite;
4#[allow(unused_imports)]
5use tracing::{debug, error, info, instrument, span, trace, warn, Level};
6use ate_crypto::KeySize;
7use ate_crypto::NodeId;
8use ate_crypto::SerializationFormat;
9use serde::{Deserialize, Serialize};
10
11use super::protocol::MessageProtocolVersion;
12use super::protocol::MessageProtocolApi;
13
14#[derive(Serialize, Deserialize, Debug, Clone)]
15pub struct HelloMetadata {
16    pub client_id: NodeId,
17    pub server_id: NodeId,
18    pub path: String,
19    pub encryption: Option<KeySize>,
20    pub wire_format: SerializationFormat,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone)]
24struct SenderHello {
25    pub id: NodeId,
26    pub path: String,
27    pub domain: String,
28    pub key_size: Option<KeySize>,
29    #[serde(default = "default_stream_protocol_version")]
30    pub version: MessageProtocolVersion,
31}
32
33fn default_stream_protocol_version() -> MessageProtocolVersion {
34    MessageProtocolVersion::V1
35}
36
37#[derive(Serialize, Deserialize, Debug, Clone)]
38struct ReceiverHello {
39    pub id: NodeId,
40    pub encryption: Option<KeySize>,
41    pub wire_format: SerializationFormat,
42    #[serde(default = "default_stream_protocol_version")]
43    pub version: MessageProtocolVersion,
44}
45
46pub async fn mesh_hello_exchange_sender(
47    stream_rx: Box<dyn AsyncRead + Send + Sync + Unpin + 'static>,
48    stream_tx: Box<dyn AsyncWrite + Send + Sync + Unpin + 'static>,
49    client_id: NodeId,
50    hello_path: String,
51    domain: String,
52    key_size: Option<KeySize>,
53) -> tokio::io::Result<(
54    Box<dyn MessageProtocolApi + Send + Sync + 'static>,
55    HelloMetadata
56)> {
57    // Send over the hello message and wait for a response
58    trace!("client sending hello");
59    let hello_client = SenderHello {
60        id: client_id,
61        path: hello_path.clone(),
62        domain,
63        key_size,
64        version: MessageProtocolVersion::default(),
65    };
66    let hello_client_bytes = serde_json::to_vec(&hello_client)?;
67    let mut proto = MessageProtocolVersion::V1.create(
68        Some(stream_rx),
69        Some(stream_tx)
70    );
71    proto
72        .write_with_fixed_16bit_header(&hello_client_bytes[..], false)
73        .await?;
74
75    // Read the hello message from the other side
76    let hello_server_bytes = proto.read_with_fixed_16bit_header().await?;
77    trace!("client received hello from server");
78    trace!("{}", String::from_utf8_lossy(&hello_server_bytes[..]));
79    let hello_server: ReceiverHello = serde_json::from_slice(&hello_server_bytes[..])?;
80
81    // Validate the encryption is strong enough
82    if let Some(needed_size) = &key_size {
83        match &hello_server.encryption {
84            None => {
85                return Err(io::Error::new(io::ErrorKind::ConnectionRefused, "the server encryption strength is too weak"));
86            }
87            Some(a) if *a < *needed_size => {
88                return Err(io::Error::new(io::ErrorKind::ConnectionRefused, "the server encryption strength is too weak"));
89            }
90            _ => {}
91        }
92    }
93
94    // Switch to the correct protocol version
95    let version = hello_server.version.min(hello_client.version);
96    proto = version.upgrade(proto);
97    
98    // Upgrade the key_size if the server is bigger
99    trace!(
100        "client encryption={}",
101        match &hello_server.encryption {
102            Some(a) => a.as_str(),
103            None => "none",
104        }
105    );
106    trace!("client wire_format={}", hello_server.wire_format);
107
108    Ok((
109        proto,
110        HelloMetadata {
111            client_id,
112            server_id: hello_server.id,
113            path: hello_path,
114            encryption: hello_server.encryption,
115            wire_format: hello_server.wire_format,
116        }
117    ))
118}
119
120pub async fn mesh_hello_exchange_receiver(
121    stream_rx: Box<dyn AsyncRead + Send + Sync + Unpin + 'static>,
122    stream_tx: Box<dyn AsyncWrite + Send + Sync + Unpin + 'static>,
123    server_id: NodeId,
124    key_size: Option<KeySize>,
125    wire_format: SerializationFormat,
126) -> tokio::io::Result<(
127    Box<dyn MessageProtocolApi + Send + Sync + 'static>,
128    HelloMetadata
129)>
130{
131    // Read the hello message from the other side
132    let mut proto = MessageProtocolVersion::V1.create(
133        Some(stream_rx),
134        Some(stream_tx)
135    );
136    let hello_client_bytes = proto
137        .read_with_fixed_16bit_header()
138        .await?;
139    trace!("server received hello from client");
140    //trace!("server received hello from client: {}", String::from_utf8_lossy(&hello_client_bytes[..]));
141    let hello_client: SenderHello = serde_json::from_slice(&hello_client_bytes[..])?;
142
143    // Upgrade the key_size if the client is bigger
144    let encryption = mesh_hello_upgrade_key(key_size, hello_client.key_size);
145
146    // Send over the hello message and wait for a response
147    trace!("server sending hello (wire_format={})", wire_format);
148    let hello_server = ReceiverHello {
149        id: server_id,
150        encryption,
151        wire_format,
152        version: MessageProtocolVersion::default(),
153    };
154    let hello_server_bytes = serde_json::to_vec(&hello_server)?;
155    proto
156        .write_with_fixed_16bit_header(&hello_server_bytes[..], false)
157        .await?;
158
159    // Switch to the correct protocol version
160    proto = hello_server.version
161        .min(hello_client.version)
162        .upgrade(proto);
163
164    Ok((
165        proto,
166        HelloMetadata {
167            client_id: hello_client.id,
168            server_id,
169            path: hello_client.path,
170            encryption,
171            wire_format,
172        }
173    ))
174}
175
176fn mesh_hello_upgrade_key(key1: Option<KeySize>, key2: Option<KeySize>) -> Option<KeySize> {
177    // If both don't want encryption then who are we to argue about that?
178    if key1.is_none() && key2.is_none() {
179        return None;
180    }
181
182    // Wanting encryption takes priority over not wanting encyption
183    let key1 = match key1 {
184        Some(a) => a,
185        None => {
186            trace!("upgrading to {}bit shared secret", key2.unwrap());
187            return key2;
188        }
189    };
190    let key2 = match key2 {
191        Some(a) => a,
192        None => {
193            trace!("upgrading to {}bit shared secret", key1);
194            return Some(key1);
195        }
196    };
197
198    // Upgrade the key_size if the client is bigger
199    if key2 > key1 {
200        trace!("upgrading to {}bit shared secret", key2);
201        return Some(key2);
202    }
203    if key1 > key2 {
204        trace!("upgrading to {}bit shared secret", key2);
205        return Some(key1);
206    }
207
208    // They are identical
209    return Some(key1);
210}