Skip to main content

ios_core/tunnel/
handshake.rs

1use std::time::Duration;
2
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use crate::tunnel::TunnelError;
6
7const MAGIC: &[u8] = b"CDTunnel";
8const HEADER_LEN: usize = 10;
9
10/// Information returned by the CDTunnel handshake.
11#[derive(Debug, Clone)]
12pub struct TunnelInfo {
13    pub server_address: String,
14    pub server_rsd_port: u16,
15    pub client_address: String,
16    pub client_mtu: u32,
17}
18
19fn parse_nonzero_u16(raw: &serde_json::Value, field: &str) -> Result<u16, TunnelError> {
20    let value = raw
21        .as_u64()
22        .ok_or_else(|| TunnelError::Protocol(format!("missing {field}")))?;
23    u16::try_from(value)
24        .ok()
25        .filter(|value| *value != 0)
26        .ok_or_else(|| {
27            TunnelError::Protocol(format!(
28                "invalid {field}: expected integer in 1..={}",
29                u16::MAX
30            ))
31        })
32}
33
34fn parse_nonzero_u32(raw: &serde_json::Value, field: &str) -> Result<u32, TunnelError> {
35    let value = raw
36        .as_u64()
37        .ok_or_else(|| TunnelError::Protocol(format!("missing {field}")))?;
38    u32::try_from(value)
39        .ok()
40        .filter(|value| *value != 0)
41        .ok_or_else(|| {
42            TunnelError::Protocol(format!(
43                "invalid {field}: expected integer in 1..={}",
44                u32::MAX
45            ))
46        })
47}
48
49pub fn encode_handshake_request(mtu: u32) -> Result<Vec<u8>, TunnelError> {
50    let json = serde_json::json!({
51        "type": "clientHandshakeRequest",
52        "mtu": mtu,
53    });
54    let json_bytes = serde_json::to_vec(&json)
55        .map_err(|e| TunnelError::Protocol(format!("failed to serialize handshake: {e}")))?;
56    if json_bytes.len() > u16::MAX as usize {
57        return Err(TunnelError::Protocol(
58            "handshake JSON exceeds 65535 bytes".into(),
59        ));
60    }
61    let mut buf = Vec::new();
62    buf.extend_from_slice(MAGIC);
63    buf.extend_from_slice(&(json_bytes.len() as u16).to_be_bytes());
64    buf.extend_from_slice(&json_bytes);
65    Ok(buf)
66}
67
68pub async fn exchange_tunnel_parameters_with_timeout<S>(
69    stream: &mut S,
70    timeout: Duration,
71) -> Result<TunnelInfo, TunnelError>
72where
73    S: AsyncRead + AsyncWrite + Unpin,
74{
75    tokio::time::timeout(timeout, exchange_tunnel_parameters_inner(stream))
76        .await
77        .map_err(|_| {
78            TunnelError::Protocol(format!(
79                "CDTunnel handshake timed out after {} ms",
80                timeout.as_millis()
81            ))
82        })?
83}
84
85async fn exchange_tunnel_parameters_inner<S>(stream: &mut S) -> Result<TunnelInfo, TunnelError>
86where
87    S: AsyncRead + AsyncWrite + Unpin,
88{
89    let req = encode_handshake_request(1280)?;
90    stream.write_all(&req).await?;
91    stream.flush().await?;
92
93    // Response: "CDTunnel" (8 bytes) + body_len (u16, big-endian) = 10 bytes header
94    let mut header = [0u8; HEADER_LEN];
95    stream.read_exact(&mut header).await?;
96
97    if &header[..MAGIC.len()] != MAGIC {
98        return Err(TunnelError::Protocol(format!(
99            "invalid CDTunnel magic: {:?}",
100            &header[..MAGIC.len()]
101        )));
102    }
103
104    let body_len = u16::from_be_bytes([header[8], header[9]]) as usize;
105    let mut body = vec![0u8; body_len];
106    stream.read_exact(&mut body).await?;
107
108    let raw: serde_json::Value = serde_json::from_slice(&body)
109        .map_err(|e| TunnelError::Protocol(format!("invalid CDTunnel JSON: {e}")))?;
110
111    Ok(TunnelInfo {
112        server_address: raw["serverAddress"]
113            .as_str()
114            .ok_or_else(|| TunnelError::Protocol("missing serverAddress".into()))?
115            .to_string(),
116        server_rsd_port: parse_nonzero_u16(&raw["serverRSDPort"], "serverRSDPort")?,
117        client_address: raw["clientParameters"]["address"]
118            .as_str()
119            .ok_or_else(|| TunnelError::Protocol("missing clientParameters.address".into()))?
120            .to_string(),
121        client_mtu: parse_nonzero_u32(&raw["clientParameters"]["mtu"], "clientParameters.mtu")?,
122    })
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    async fn exchange_with_response_json(
130        response_json: serde_json::Value,
131    ) -> Result<TunnelInfo, TunnelError> {
132        let response_bytes = serde_json::to_vec(&response_json).unwrap();
133        let mut response = Vec::new();
134        response.extend_from_slice(b"CDTunnel");
135        response.extend_from_slice(&(response_bytes.len() as u16).to_be_bytes());
136        response.extend_from_slice(&response_bytes);
137
138        let (mut client, mut server) = tokio::io::duplex(4096);
139        tokio::spawn(async move {
140            let mut buf = vec![0u8; 256];
141            let _ = server.read(&mut buf).await;
142            server.write_all(&response).await.unwrap();
143        });
144
145        exchange_tunnel_parameters_with_timeout(&mut client, Duration::from_secs(5)).await
146    }
147
148    #[test]
149    fn test_encode_cdtunnel_request() {
150        let bytes = encode_handshake_request(1280).unwrap();
151        assert_eq!(&bytes[..8], b"CDTunnel");
152        let json_len = u16::from_be_bytes([bytes[8], bytes[9]]) as usize;
153        let json: serde_json::Value = serde_json::from_slice(&bytes[10..10 + json_len]).unwrap();
154        assert_eq!(json["type"], "clientHandshakeRequest");
155        assert_eq!(json["mtu"], 1280);
156    }
157
158    #[tokio::test]
159    async fn test_exchange_tunnel_parameters_with_16_bit_length() {
160        let response_json = serde_json::json!({
161            "serverAddress": "fd59:2381:6956::1",
162            "serverRSDPort": 58783u16,
163            "clientParameters": {
164                "address": "fd59:2381:6956::2",
165                "mtu": 1280u32,
166                "padding": "x".repeat(300),
167            }
168        });
169        let response_bytes = serde_json::to_vec(&response_json).unwrap();
170        assert!(response_bytes.len() > 255);
171
172        let params = exchange_with_response_json(response_json).await.unwrap();
173        assert_eq!(params.server_address, "fd59:2381:6956::1");
174        assert_eq!(params.server_rsd_port, 58783);
175        assert_eq!(params.client_address, "fd59:2381:6956::2");
176        assert_eq!(params.client_mtu, 1280);
177    }
178
179    #[tokio::test]
180    async fn test_exchange_tunnel_parameters_rejects_zero_server_rsd_port() {
181        let err = exchange_with_response_json(serde_json::json!({
182            "serverAddress": "fd59:2381:6956::1",
183            "serverRSDPort": 0,
184            "clientParameters": {
185                "address": "fd59:2381:6956::2",
186                "mtu": 1280,
187            }
188        }))
189        .await
190        .unwrap_err();
191
192        match err {
193            TunnelError::Protocol(message) => {
194                assert!(
195                    message.contains("invalid serverRSDPort"),
196                    "unexpected error: {message}"
197                );
198            }
199            other => panic!("unexpected error variant: {other:?}"),
200        }
201    }
202
203    #[tokio::test]
204    async fn test_exchange_tunnel_parameters_rejects_out_of_range_server_rsd_port() {
205        let err = exchange_with_response_json(serde_json::json!({
206            "serverAddress": "fd59:2381:6956::1",
207            "serverRSDPort": 65536u64,
208            "clientParameters": {
209                "address": "fd59:2381:6956::2",
210                "mtu": 1280,
211            }
212        }))
213        .await
214        .unwrap_err();
215
216        match err {
217            TunnelError::Protocol(message) => {
218                assert!(
219                    message.contains("invalid serverRSDPort"),
220                    "unexpected error: {message}"
221                );
222            }
223            other => panic!("unexpected error variant: {other:?}"),
224        }
225    }
226
227    #[tokio::test]
228    async fn test_exchange_tunnel_parameters_rejects_zero_client_mtu() {
229        let err = exchange_with_response_json(serde_json::json!({
230            "serverAddress": "fd59:2381:6956::1",
231            "serverRSDPort": 58783,
232            "clientParameters": {
233                "address": "fd59:2381:6956::2",
234                "mtu": 0,
235            }
236        }))
237        .await
238        .unwrap_err();
239
240        match err {
241            TunnelError::Protocol(message) => {
242                assert!(
243                    message.contains("invalid clientParameters.mtu"),
244                    "unexpected error: {message}"
245                );
246            }
247            other => panic!("unexpected error variant: {other:?}"),
248        }
249    }
250
251    #[tokio::test]
252    async fn test_exchange_tunnel_parameters_rejects_out_of_range_client_mtu() {
253        let err = exchange_with_response_json(serde_json::json!({
254            "serverAddress": "fd59:2381:6956::1",
255            "serverRSDPort": 58783,
256            "clientParameters": {
257                "address": "fd59:2381:6956::2",
258                "mtu": 4294967296u64,
259            }
260        }))
261        .await
262        .unwrap_err();
263
264        match err {
265            TunnelError::Protocol(message) => {
266                assert!(
267                    message.contains("invalid clientParameters.mtu"),
268                    "unexpected error: {message}"
269                );
270            }
271            other => panic!("unexpected error variant: {other:?}"),
272        }
273    }
274
275    #[tokio::test]
276    async fn test_exchange_tunnel_parameters_timeout() {
277        let (mut client, _server) = tokio::io::duplex(4096);
278        let err = exchange_tunnel_parameters_with_timeout(&mut client, Duration::from_millis(20))
279            .await
280            .unwrap_err();
281        match err {
282            TunnelError::Protocol(message) => {
283                assert!(message.contains("timed out"), "unexpected error: {message}");
284            }
285            other => panic!("unexpected error variant: {other:?}"),
286        }
287    }
288}