1use crate::control::ControlStream;
21use crate::error::NegotiationError;
22use crate::Config;
23use quic_reverse_control::{Features, Hello, HelloAck, ProtocolMessage, PROTOCOL_VERSION};
24use quic_reverse_transport::{RecvStream, SendStream};
25use tracing::{debug, instrument, trace, warn};
26
27#[derive(Debug, Clone)]
29pub struct NegotiatedParams {
30 pub version: u16,
32 pub features: Features,
34 pub remote_agent: Option<String>,
36}
37
38#[instrument(skip_all, name = "negotiate_client")]
43pub async fn negotiate_client<S: SendStream, R: RecvStream>(
44 control: &mut ControlStream<S, R>,
45 config: &Config,
46) -> Result<NegotiatedParams, NegotiationError> {
47 let mut our_hello = Hello::new(config.features);
49 if let Some(ref agent) = config.agent {
50 our_hello = our_hello.with_agent(agent.clone());
51 }
52
53 trace!(version = PROTOCOL_VERSION, features = ?config.features, "sending Hello");
55 control
56 .write_message(&ProtocolMessage::Hello(our_hello))
57 .await
58 .map_err(|_| NegotiationError::Timeout)?;
59 control
60 .flush()
61 .await
62 .map_err(|_| NegotiationError::Timeout)?;
63
64 let their_hello = match control.read_message().await {
66 Ok(Some(ProtocolMessage::Hello(h))) => {
67 trace!(
68 version = h.protocol_version,
69 features = ?h.features,
70 agent = ?h.agent,
71 "received server Hello"
72 );
73 h
74 }
75 Ok(Some(_)) => {
76 warn!("received unexpected message during negotiation");
77 return Err(NegotiationError::UnexpectedMessage);
78 }
79 Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
80 };
81
82 if !config
84 .supported_versions
85 .contains(&their_hello.protocol_version)
86 {
87 warn!(
88 local = ?config.supported_versions,
89 remote = their_hello.protocol_version,
90 "version mismatch"
91 );
92 return Err(NegotiationError::VersionMismatch {
93 local: config.supported_versions.clone(),
94 remote: their_hello.protocol_version,
95 });
96 }
97
98 let negotiated_version = their_hello.protocol_version.min(PROTOCOL_VERSION);
100 let negotiated_features = config.features & their_hello.features;
101
102 trace!(version = negotiated_version, features = ?negotiated_features, "sending HelloAck");
104 let ack = HelloAck {
105 selected_version: negotiated_version,
106 selected_features: negotiated_features,
107 };
108 control
109 .write_message(&ProtocolMessage::HelloAck(ack))
110 .await
111 .map_err(|_| NegotiationError::Timeout)?;
112 control
113 .flush()
114 .await
115 .map_err(|_| NegotiationError::Timeout)?;
116
117 match control.read_message().await {
119 Ok(Some(ProtocolMessage::HelloAck(their_ack))) => {
120 trace!(
121 version = their_ack.selected_version,
122 features = ?their_ack.selected_features,
123 "received server HelloAck"
124 );
125 if their_ack.selected_version != negotiated_version {
127 warn!(
128 expected = negotiated_version,
129 received = their_ack.selected_version,
130 "server HelloAck version mismatch"
131 );
132 return Err(NegotiationError::VersionMismatch {
133 local: vec![negotiated_version],
134 remote: their_ack.selected_version,
135 });
136 }
137 }
138 Ok(Some(_)) => {
139 warn!("received unexpected message instead of HelloAck");
140 return Err(NegotiationError::UnexpectedMessage);
141 }
142 Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
143 }
144
145 debug!(
146 version = negotiated_version,
147 features = ?negotiated_features,
148 remote_agent = ?their_hello.agent,
149 "client negotiation complete"
150 );
151
152 Ok(NegotiatedParams {
153 version: negotiated_version,
154 features: negotiated_features,
155 remote_agent: their_hello.agent,
156 })
157}
158
159#[instrument(skip_all, name = "negotiate_server")]
164pub async fn negotiate_server<S: SendStream, R: RecvStream>(
165 control: &mut ControlStream<S, R>,
166 config: &Config,
167) -> Result<NegotiatedParams, NegotiationError> {
168 let their_hello = match control.read_message().await {
170 Ok(Some(ProtocolMessage::Hello(h))) => {
171 trace!(
172 version = h.protocol_version,
173 features = ?h.features,
174 agent = ?h.agent,
175 "received client Hello"
176 );
177 h
178 }
179 Ok(Some(_)) => {
180 warn!("received unexpected message during negotiation");
181 return Err(NegotiationError::UnexpectedMessage);
182 }
183 Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
184 };
185
186 if !config
188 .supported_versions
189 .contains(&their_hello.protocol_version)
190 {
191 warn!(
192 local = ?config.supported_versions,
193 remote = their_hello.protocol_version,
194 "version mismatch"
195 );
196 return Err(NegotiationError::VersionMismatch {
197 local: config.supported_versions.clone(),
198 remote: their_hello.protocol_version,
199 });
200 }
201
202 trace!(version = PROTOCOL_VERSION, features = ?config.features, "sending Hello");
204 let mut our_hello = Hello::new(config.features);
205 if let Some(ref agent) = config.agent {
206 our_hello = our_hello.with_agent(agent.clone());
207 }
208 control
209 .write_message(&ProtocolMessage::Hello(our_hello))
210 .await
211 .map_err(|_| NegotiationError::Timeout)?;
212 control
213 .flush()
214 .await
215 .map_err(|_| NegotiationError::Timeout)?;
216
217 let their_ack = match control.read_message().await {
219 Ok(Some(ProtocolMessage::HelloAck(ack))) => {
220 trace!(
221 version = ack.selected_version,
222 features = ?ack.selected_features,
223 "received client HelloAck"
224 );
225 ack
226 }
227 Ok(Some(_)) => {
228 warn!("received unexpected message instead of HelloAck");
229 return Err(NegotiationError::UnexpectedMessage);
230 }
231 Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
232 };
233
234 let negotiated_version = their_hello.protocol_version.min(PROTOCOL_VERSION);
236 let negotiated_features = config.features & their_hello.features;
237
238 if their_ack.selected_version != negotiated_version {
239 warn!(
240 expected = negotiated_version,
241 received = their_ack.selected_version,
242 "client HelloAck version mismatch"
243 );
244 return Err(NegotiationError::VersionMismatch {
245 local: vec![negotiated_version],
246 remote: their_ack.selected_version,
247 });
248 }
249
250 trace!(version = negotiated_version, features = ?negotiated_features, "sending HelloAck");
252 let our_ack = HelloAck {
253 selected_version: negotiated_version,
254 selected_features: negotiated_features,
255 };
256 control
257 .write_message(&ProtocolMessage::HelloAck(our_ack))
258 .await
259 .map_err(|_| NegotiationError::Timeout)?;
260 control
261 .flush()
262 .await
263 .map_err(|_| NegotiationError::Timeout)?;
264
265 debug!(
266 version = negotiated_version,
267 features = ?negotiated_features,
268 remote_agent = ?their_hello.agent,
269 "server negotiation complete"
270 );
271
272 Ok(NegotiatedParams {
273 version: negotiated_version,
274 features: negotiated_features,
275 remote_agent: their_hello.agent,
276 })
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::control::ControlStream;
283 use quic_reverse_transport::{mock_connection_pair, Connection};
284
285 #[tokio::test]
286 async fn successful_negotiation() {
287 let (conn_client, conn_server) = mock_connection_pair();
288
289 let client_config = Config::new()
290 .with_features(Features::PING_PONG | Features::STRUCTURED_METADATA)
291 .with_agent("test-client/1.0");
292
293 let server_config = Config::new()
294 .with_features(Features::PING_PONG)
295 .with_agent("test-server/1.0");
296
297 let client_handle = tokio::spawn(async move {
299 let (send, recv) = conn_client.open_bi().await.expect("open");
300 let mut control = ControlStream::new(send, recv);
301 negotiate_client(&mut control, &client_config).await
302 });
303
304 let (send, recv) = conn_server
306 .accept_bi()
307 .await
308 .expect("accept")
309 .expect("stream");
310 let mut server_control = ControlStream::new(send, recv);
311 let server_result = negotiate_server(&mut server_control, &server_config).await;
312
313 let client_result = client_handle.await.expect("client task");
314
315 let client_params = client_result.expect("client negotiation");
317 let server_params = server_result.expect("server negotiation");
318
319 assert_eq!(client_params.version, PROTOCOL_VERSION);
321 assert_eq!(server_params.version, PROTOCOL_VERSION);
322
323 assert_eq!(client_params.features, Features::PING_PONG);
325 assert_eq!(server_params.features, Features::PING_PONG);
326
327 assert_eq!(
329 client_params.remote_agent.as_deref(),
330 Some("test-server/1.0")
331 );
332 assert_eq!(
333 server_params.remote_agent.as_deref(),
334 Some("test-client/1.0")
335 );
336 }
337
338 #[tokio::test]
339 async fn version_mismatch() {
340 let (conn_client, conn_server) = mock_connection_pair();
346
347 let client_config = Config::new().with_versions(vec![99]);
349 let server_config = Config::new();
351
352 let (client_send, client_recv) = conn_client.open_bi().await.expect("open");
354 let mut client_control = ControlStream::new(client_send, client_recv);
355
356 let (server_send, server_recv) = conn_server
358 .accept_bi()
359 .await
360 .expect("accept")
361 .expect("stream");
362 let mut server_control = ControlStream::new(server_send, server_recv);
363
364 let client_handle =
366 tokio::spawn(
367 async move { negotiate_client(&mut client_control, &client_config).await },
368 );
369
370 let server_handle =
371 tokio::spawn(
372 async move { negotiate_server(&mut server_control, &server_config).await },
373 );
374
375 let client_result = client_handle.await.expect("client task");
377 assert!(
378 matches!(
379 client_result,
380 Err(NegotiationError::VersionMismatch { remote: 1, .. })
381 ),
382 "expected version mismatch, got: {client_result:?}"
383 );
384
385 server_handle.abort();
387 }
388
389 #[tokio::test]
390 async fn no_common_features_still_succeeds() {
391 let (conn_client, conn_server) = mock_connection_pair();
392
393 let client_config = Config::new().with_features(Features::PING_PONG);
395 let server_config = Config::new().with_features(Features::STRUCTURED_METADATA);
396
397 let client_handle = tokio::spawn(async move {
398 let (send, recv) = conn_client.open_bi().await.expect("open");
399 let mut control = ControlStream::new(send, recv);
400 negotiate_client(&mut control, &client_config).await
401 });
402
403 let (send, recv) = conn_server
404 .accept_bi()
405 .await
406 .expect("accept")
407 .expect("stream");
408 let mut server_control = ControlStream::new(send, recv);
409 let server_result = negotiate_server(&mut server_control, &server_config).await;
410
411 let client_result = client_handle.await.expect("client task");
412
413 let client_params = client_result.expect("client negotiation");
415 let server_params = server_result.expect("server negotiation");
416
417 assert!(client_params.features.is_empty());
418 assert!(server_params.features.is_empty());
419 }
420}