quic_reverse/
negotiation.rs

1// Copyright 2024-2026 Farlight Networks, LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Protocol negotiation.
16//!
17//! Handles the Hello/HelloAck handshake to establish protocol version
18//! and feature set.
19
20use crate::control::ControlStream;
21use crate::error::NegotiationError;
22use crate::Config;
23use quic_reverse_control::{Features, Hello, HelloAck, ProtocolMessage, PROTOCOL_VERSION};
24use quic_reverse_transport::{RecvStream, SendStream};
25use tracing::{debug, instrument, trace, warn};
26
27/// Result of successful negotiation.
28#[derive(Debug, Clone)]
29pub struct NegotiatedParams {
30    /// The negotiated protocol version.
31    pub version: u16,
32    /// The negotiated feature set (intersection of both peers).
33    pub features: Features,
34    /// The remote peer's agent string, if provided.
35    pub remote_agent: Option<String>,
36}
37
38/// Performs the client-side negotiation.
39///
40/// The client sends `Hello` first, then waits for the server's `Hello`,
41/// then sends `HelloAck`.
42#[instrument(skip_all, name = "negotiate_client")]
43pub async fn negotiate_client<S: SendStream, R: RecvStream>(
44    control: &mut ControlStream<S, R>,
45    config: &Config,
46) -> Result<NegotiatedParams, NegotiationError> {
47    // Build our Hello
48    let mut our_hello = Hello::new(config.features);
49    if let Some(ref agent) = config.agent {
50        our_hello = our_hello.with_agent(agent.clone());
51    }
52
53    // Send our Hello
54    trace!(version = PROTOCOL_VERSION, features = ?config.features, "sending Hello");
55    control
56        .write_message(&ProtocolMessage::Hello(our_hello))
57        .await
58        .map_err(|_| NegotiationError::Timeout)?;
59    control
60        .flush()
61        .await
62        .map_err(|_| NegotiationError::Timeout)?;
63
64    // Wait for server's Hello
65    let their_hello = match control.read_message().await {
66        Ok(Some(ProtocolMessage::Hello(h))) => {
67            trace!(
68                version = h.protocol_version,
69                features = ?h.features,
70                agent = ?h.agent,
71                "received server Hello"
72            );
73            h
74        }
75        Ok(Some(_)) => {
76            warn!("received unexpected message during negotiation");
77            return Err(NegotiationError::UnexpectedMessage);
78        }
79        Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
80    };
81
82    // Validate version compatibility
83    if !config
84        .supported_versions
85        .contains(&their_hello.protocol_version)
86    {
87        warn!(
88            local = ?config.supported_versions,
89            remote = their_hello.protocol_version,
90            "version mismatch"
91        );
92        return Err(NegotiationError::VersionMismatch {
93            local: config.supported_versions.clone(),
94            remote: their_hello.protocol_version,
95        });
96    }
97
98    // Compute negotiated parameters
99    let negotiated_version = their_hello.protocol_version.min(PROTOCOL_VERSION);
100    let negotiated_features = config.features & their_hello.features;
101
102    // Send HelloAck
103    trace!(version = negotiated_version, features = ?negotiated_features, "sending HelloAck");
104    let ack = HelloAck {
105        selected_version: negotiated_version,
106        selected_features: negotiated_features,
107    };
108    control
109        .write_message(&ProtocolMessage::HelloAck(ack))
110        .await
111        .map_err(|_| NegotiationError::Timeout)?;
112    control
113        .flush()
114        .await
115        .map_err(|_| NegotiationError::Timeout)?;
116
117    // Wait for their HelloAck to confirm
118    match control.read_message().await {
119        Ok(Some(ProtocolMessage::HelloAck(their_ack))) => {
120            trace!(
121                version = their_ack.selected_version,
122                features = ?their_ack.selected_features,
123                "received server HelloAck"
124            );
125            // Verify consistency
126            if their_ack.selected_version != negotiated_version {
127                warn!(
128                    expected = negotiated_version,
129                    received = their_ack.selected_version,
130                    "server HelloAck version mismatch"
131                );
132                return Err(NegotiationError::VersionMismatch {
133                    local: vec![negotiated_version],
134                    remote: their_ack.selected_version,
135                });
136            }
137        }
138        Ok(Some(_)) => {
139            warn!("received unexpected message instead of HelloAck");
140            return Err(NegotiationError::UnexpectedMessage);
141        }
142        Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
143    }
144
145    debug!(
146        version = negotiated_version,
147        features = ?negotiated_features,
148        remote_agent = ?their_hello.agent,
149        "client negotiation complete"
150    );
151
152    Ok(NegotiatedParams {
153        version: negotiated_version,
154        features: negotiated_features,
155        remote_agent: their_hello.agent,
156    })
157}
158
159/// Performs the server-side negotiation.
160///
161/// The server waits for the client's `Hello`, sends its own `Hello`,
162/// waits for `HelloAck`, then sends its own `HelloAck`.
163#[instrument(skip_all, name = "negotiate_server")]
164pub async fn negotiate_server<S: SendStream, R: RecvStream>(
165    control: &mut ControlStream<S, R>,
166    config: &Config,
167) -> Result<NegotiatedParams, NegotiationError> {
168    // Wait for client's Hello
169    let their_hello = match control.read_message().await {
170        Ok(Some(ProtocolMessage::Hello(h))) => {
171            trace!(
172                version = h.protocol_version,
173                features = ?h.features,
174                agent = ?h.agent,
175                "received client Hello"
176            );
177            h
178        }
179        Ok(Some(_)) => {
180            warn!("received unexpected message during negotiation");
181            return Err(NegotiationError::UnexpectedMessage);
182        }
183        Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
184    };
185
186    // Validate version compatibility
187    if !config
188        .supported_versions
189        .contains(&their_hello.protocol_version)
190    {
191        warn!(
192            local = ?config.supported_versions,
193            remote = their_hello.protocol_version,
194            "version mismatch"
195        );
196        return Err(NegotiationError::VersionMismatch {
197            local: config.supported_versions.clone(),
198            remote: their_hello.protocol_version,
199        });
200    }
201
202    // Build and send our Hello
203    trace!(version = PROTOCOL_VERSION, features = ?config.features, "sending Hello");
204    let mut our_hello = Hello::new(config.features);
205    if let Some(ref agent) = config.agent {
206        our_hello = our_hello.with_agent(agent.clone());
207    }
208    control
209        .write_message(&ProtocolMessage::Hello(our_hello))
210        .await
211        .map_err(|_| NegotiationError::Timeout)?;
212    control
213        .flush()
214        .await
215        .map_err(|_| NegotiationError::Timeout)?;
216
217    // Wait for their HelloAck
218    let their_ack = match control.read_message().await {
219        Ok(Some(ProtocolMessage::HelloAck(ack))) => {
220            trace!(
221                version = ack.selected_version,
222                features = ?ack.selected_features,
223                "received client HelloAck"
224            );
225            ack
226        }
227        Ok(Some(_)) => {
228            warn!("received unexpected message instead of HelloAck");
229            return Err(NegotiationError::UnexpectedMessage);
230        }
231        Ok(None) | Err(_) => return Err(NegotiationError::Timeout),
232    };
233
234    // Verify the selected parameters make sense
235    let negotiated_version = their_hello.protocol_version.min(PROTOCOL_VERSION);
236    let negotiated_features = config.features & their_hello.features;
237
238    if their_ack.selected_version != negotiated_version {
239        warn!(
240            expected = negotiated_version,
241            received = their_ack.selected_version,
242            "client HelloAck version mismatch"
243        );
244        return Err(NegotiationError::VersionMismatch {
245            local: vec![negotiated_version],
246            remote: their_ack.selected_version,
247        });
248    }
249
250    // Send our HelloAck
251    trace!(version = negotiated_version, features = ?negotiated_features, "sending HelloAck");
252    let our_ack = HelloAck {
253        selected_version: negotiated_version,
254        selected_features: negotiated_features,
255    };
256    control
257        .write_message(&ProtocolMessage::HelloAck(our_ack))
258        .await
259        .map_err(|_| NegotiationError::Timeout)?;
260    control
261        .flush()
262        .await
263        .map_err(|_| NegotiationError::Timeout)?;
264
265    debug!(
266        version = negotiated_version,
267        features = ?negotiated_features,
268        remote_agent = ?their_hello.agent,
269        "server negotiation complete"
270    );
271
272    Ok(NegotiatedParams {
273        version: negotiated_version,
274        features: negotiated_features,
275        remote_agent: their_hello.agent,
276    })
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::control::ControlStream;
283    use quic_reverse_transport::{mock_connection_pair, Connection};
284
285    #[tokio::test]
286    async fn successful_negotiation() {
287        let (conn_client, conn_server) = mock_connection_pair();
288
289        let client_config = Config::new()
290            .with_features(Features::PING_PONG | Features::STRUCTURED_METADATA)
291            .with_agent("test-client/1.0");
292
293        let server_config = Config::new()
294            .with_features(Features::PING_PONG)
295            .with_agent("test-server/1.0");
296
297        // Spawn client negotiation
298        let client_handle = tokio::spawn(async move {
299            let (send, recv) = conn_client.open_bi().await.expect("open");
300            let mut control = ControlStream::new(send, recv);
301            negotiate_client(&mut control, &client_config).await
302        });
303
304        // Server accepts and negotiates
305        let (send, recv) = conn_server
306            .accept_bi()
307            .await
308            .expect("accept")
309            .expect("stream");
310        let mut server_control = ControlStream::new(send, recv);
311        let server_result = negotiate_server(&mut server_control, &server_config).await;
312
313        let client_result = client_handle.await.expect("client task");
314
315        // Both should succeed
316        let client_params = client_result.expect("client negotiation");
317        let server_params = server_result.expect("server negotiation");
318
319        // Should agree on version and features
320        assert_eq!(client_params.version, PROTOCOL_VERSION);
321        assert_eq!(server_params.version, PROTOCOL_VERSION);
322
323        // Features should be intersection (only PING_PONG)
324        assert_eq!(client_params.features, Features::PING_PONG);
325        assert_eq!(server_params.features, Features::PING_PONG);
326
327        // Should see each other's agent strings
328        assert_eq!(
329            client_params.remote_agent.as_deref(),
330            Some("test-server/1.0")
331        );
332        assert_eq!(
333            server_params.remote_agent.as_deref(),
334            Some("test-client/1.0")
335        );
336    }
337
338    #[tokio::test]
339    async fn version_mismatch() {
340        // Test that when client receives server's Hello with an unsupported version,
341        // the client detects the mismatch.
342        // Note: The Hello message's protocol_version field is always PROTOCOL_VERSION (1).
343        // The supported_versions config is used to validate the received version.
344
345        let (conn_client, conn_server) = mock_connection_pair();
346
347        // Client only accepts version 99 (not the actual PROTOCOL_VERSION which is 1)
348        let client_config = Config::new().with_versions(vec![99]);
349        // Server uses default config (supports version 1)
350        let server_config = Config::new();
351
352        // Client opens stream
353        let (client_send, client_recv) = conn_client.open_bi().await.expect("open");
354        let mut client_control = ControlStream::new(client_send, client_recv);
355
356        // Server accepts
357        let (server_send, server_recv) = conn_server
358            .accept_bi()
359            .await
360            .expect("accept")
361            .expect("stream");
362        let mut server_control = ControlStream::new(server_send, server_recv);
363
364        // Run both negotiations concurrently
365        let client_handle =
366            tokio::spawn(
367                async move { negotiate_client(&mut client_control, &client_config).await },
368            );
369
370        let server_handle =
371            tokio::spawn(
372                async move { negotiate_server(&mut server_control, &server_config).await },
373            );
374
375        // Client should fail because server's Hello has version 1, but client only accepts 99
376        let client_result = client_handle.await.expect("client task");
377        assert!(
378            matches!(
379                client_result,
380                Err(NegotiationError::VersionMismatch { remote: 1, .. })
381            ),
382            "expected version mismatch, got: {client_result:?}"
383        );
384
385        // Server task will be stuck waiting for HelloAck, abort it
386        server_handle.abort();
387    }
388
389    #[tokio::test]
390    async fn no_common_features_still_succeeds() {
391        let (conn_client, conn_server) = mock_connection_pair();
392
393        // Different feature sets with no overlap
394        let client_config = Config::new().with_features(Features::PING_PONG);
395        let server_config = Config::new().with_features(Features::STRUCTURED_METADATA);
396
397        let client_handle = tokio::spawn(async move {
398            let (send, recv) = conn_client.open_bi().await.expect("open");
399            let mut control = ControlStream::new(send, recv);
400            negotiate_client(&mut control, &client_config).await
401        });
402
403        let (send, recv) = conn_server
404            .accept_bi()
405            .await
406            .expect("accept")
407            .expect("stream");
408        let mut server_control = ControlStream::new(send, recv);
409        let server_result = negotiate_server(&mut server_control, &server_config).await;
410
411        let client_result = client_handle.await.expect("client task");
412
413        // Both should succeed with empty feature set
414        let client_params = client_result.expect("client negotiation");
415        let server_params = server_result.expect("server negotiation");
416
417        assert!(client_params.features.is_empty());
418        assert!(server_params.features.is_empty());
419    }
420}