ios_core/tunnel/
handshake.rs1use std::time::Duration;
2
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use crate::tunnel::TunnelError;
6
7const MAGIC: &[u8] = b"CDTunnel";
8const HEADER_LEN: usize = 10;
9
10#[derive(Debug, Clone)]
12pub struct TunnelInfo {
13 pub server_address: String,
14 pub server_rsd_port: u16,
15 pub client_address: String,
16 pub client_mtu: u32,
17}
18
19fn parse_nonzero_u16(raw: &serde_json::Value, field: &str) -> Result<u16, TunnelError> {
20 let value = raw
21 .as_u64()
22 .ok_or_else(|| TunnelError::Protocol(format!("missing {field}")))?;
23 u16::try_from(value)
24 .ok()
25 .filter(|value| *value != 0)
26 .ok_or_else(|| {
27 TunnelError::Protocol(format!(
28 "invalid {field}: expected integer in 1..={}",
29 u16::MAX
30 ))
31 })
32}
33
34fn parse_nonzero_u32(raw: &serde_json::Value, field: &str) -> Result<u32, TunnelError> {
35 let value = raw
36 .as_u64()
37 .ok_or_else(|| TunnelError::Protocol(format!("missing {field}")))?;
38 u32::try_from(value)
39 .ok()
40 .filter(|value| *value != 0)
41 .ok_or_else(|| {
42 TunnelError::Protocol(format!(
43 "invalid {field}: expected integer in 1..={}",
44 u32::MAX
45 ))
46 })
47}
48
49pub fn encode_handshake_request(mtu: u32) -> Result<Vec<u8>, TunnelError> {
50 let json = serde_json::json!({
51 "type": "clientHandshakeRequest",
52 "mtu": mtu,
53 });
54 let json_bytes = serde_json::to_vec(&json)
55 .map_err(|e| TunnelError::Protocol(format!("failed to serialize handshake: {e}")))?;
56 if json_bytes.len() > u16::MAX as usize {
57 return Err(TunnelError::Protocol(
58 "handshake JSON exceeds 65535 bytes".into(),
59 ));
60 }
61 let mut buf = Vec::new();
62 buf.extend_from_slice(MAGIC);
63 buf.extend_from_slice(&(json_bytes.len() as u16).to_be_bytes());
64 buf.extend_from_slice(&json_bytes);
65 Ok(buf)
66}
67
68pub async fn exchange_tunnel_parameters_with_timeout<S>(
69 stream: &mut S,
70 timeout: Duration,
71) -> Result<TunnelInfo, TunnelError>
72where
73 S: AsyncRead + AsyncWrite + Unpin,
74{
75 tokio::time::timeout(timeout, exchange_tunnel_parameters_inner(stream))
76 .await
77 .map_err(|_| {
78 TunnelError::Protocol(format!(
79 "CDTunnel handshake timed out after {} ms",
80 timeout.as_millis()
81 ))
82 })?
83}
84
85async fn exchange_tunnel_parameters_inner<S>(stream: &mut S) -> Result<TunnelInfo, TunnelError>
86where
87 S: AsyncRead + AsyncWrite + Unpin,
88{
89 let req = encode_handshake_request(1280)?;
90 stream.write_all(&req).await?;
91 stream.flush().await?;
92
93 let mut header = [0u8; HEADER_LEN];
95 stream.read_exact(&mut header).await?;
96
97 if &header[..MAGIC.len()] != MAGIC {
98 return Err(TunnelError::Protocol(format!(
99 "invalid CDTunnel magic: {:?}",
100 &header[..MAGIC.len()]
101 )));
102 }
103
104 let body_len = u16::from_be_bytes([header[8], header[9]]) as usize;
105 let mut body = vec![0u8; body_len];
106 stream.read_exact(&mut body).await?;
107
108 let raw: serde_json::Value = serde_json::from_slice(&body)
109 .map_err(|e| TunnelError::Protocol(format!("invalid CDTunnel JSON: {e}")))?;
110
111 Ok(TunnelInfo {
112 server_address: raw["serverAddress"]
113 .as_str()
114 .ok_or_else(|| TunnelError::Protocol("missing serverAddress".into()))?
115 .to_string(),
116 server_rsd_port: parse_nonzero_u16(&raw["serverRSDPort"], "serverRSDPort")?,
117 client_address: raw["clientParameters"]["address"]
118 .as_str()
119 .ok_or_else(|| TunnelError::Protocol("missing clientParameters.address".into()))?
120 .to_string(),
121 client_mtu: parse_nonzero_u32(&raw["clientParameters"]["mtu"], "clientParameters.mtu")?,
122 })
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 async fn exchange_with_response_json(
130 response_json: serde_json::Value,
131 ) -> Result<TunnelInfo, TunnelError> {
132 let response_bytes = serde_json::to_vec(&response_json).unwrap();
133 let mut response = Vec::new();
134 response.extend_from_slice(b"CDTunnel");
135 response.extend_from_slice(&(response_bytes.len() as u16).to_be_bytes());
136 response.extend_from_slice(&response_bytes);
137
138 let (mut client, mut server) = tokio::io::duplex(4096);
139 tokio::spawn(async move {
140 let mut buf = vec![0u8; 256];
141 let _ = server.read(&mut buf).await;
142 server.write_all(&response).await.unwrap();
143 });
144
145 exchange_tunnel_parameters_with_timeout(&mut client, Duration::from_secs(5)).await
146 }
147
148 #[test]
149 fn test_encode_cdtunnel_request() {
150 let bytes = encode_handshake_request(1280).unwrap();
151 assert_eq!(&bytes[..8], b"CDTunnel");
152 let json_len = u16::from_be_bytes([bytes[8], bytes[9]]) as usize;
153 let json: serde_json::Value = serde_json::from_slice(&bytes[10..10 + json_len]).unwrap();
154 assert_eq!(json["type"], "clientHandshakeRequest");
155 assert_eq!(json["mtu"], 1280);
156 }
157
158 #[tokio::test]
159 async fn test_exchange_tunnel_parameters_with_16_bit_length() {
160 let response_json = serde_json::json!({
161 "serverAddress": "fd59:2381:6956::1",
162 "serverRSDPort": 58783u16,
163 "clientParameters": {
164 "address": "fd59:2381:6956::2",
165 "mtu": 1280u32,
166 "padding": "x".repeat(300),
167 }
168 });
169 let response_bytes = serde_json::to_vec(&response_json).unwrap();
170 assert!(response_bytes.len() > 255);
171
172 let params = exchange_with_response_json(response_json).await.unwrap();
173 assert_eq!(params.server_address, "fd59:2381:6956::1");
174 assert_eq!(params.server_rsd_port, 58783);
175 assert_eq!(params.client_address, "fd59:2381:6956::2");
176 assert_eq!(params.client_mtu, 1280);
177 }
178
179 #[tokio::test]
180 async fn test_exchange_tunnel_parameters_rejects_zero_server_rsd_port() {
181 let err = exchange_with_response_json(serde_json::json!({
182 "serverAddress": "fd59:2381:6956::1",
183 "serverRSDPort": 0,
184 "clientParameters": {
185 "address": "fd59:2381:6956::2",
186 "mtu": 1280,
187 }
188 }))
189 .await
190 .unwrap_err();
191
192 match err {
193 TunnelError::Protocol(message) => {
194 assert!(
195 message.contains("invalid serverRSDPort"),
196 "unexpected error: {message}"
197 );
198 }
199 other => panic!("unexpected error variant: {other:?}"),
200 }
201 }
202
203 #[tokio::test]
204 async fn test_exchange_tunnel_parameters_rejects_out_of_range_server_rsd_port() {
205 let err = exchange_with_response_json(serde_json::json!({
206 "serverAddress": "fd59:2381:6956::1",
207 "serverRSDPort": 65536u64,
208 "clientParameters": {
209 "address": "fd59:2381:6956::2",
210 "mtu": 1280,
211 }
212 }))
213 .await
214 .unwrap_err();
215
216 match err {
217 TunnelError::Protocol(message) => {
218 assert!(
219 message.contains("invalid serverRSDPort"),
220 "unexpected error: {message}"
221 );
222 }
223 other => panic!("unexpected error variant: {other:?}"),
224 }
225 }
226
227 #[tokio::test]
228 async fn test_exchange_tunnel_parameters_rejects_zero_client_mtu() {
229 let err = exchange_with_response_json(serde_json::json!({
230 "serverAddress": "fd59:2381:6956::1",
231 "serverRSDPort": 58783,
232 "clientParameters": {
233 "address": "fd59:2381:6956::2",
234 "mtu": 0,
235 }
236 }))
237 .await
238 .unwrap_err();
239
240 match err {
241 TunnelError::Protocol(message) => {
242 assert!(
243 message.contains("invalid clientParameters.mtu"),
244 "unexpected error: {message}"
245 );
246 }
247 other => panic!("unexpected error variant: {other:?}"),
248 }
249 }
250
251 #[tokio::test]
252 async fn test_exchange_tunnel_parameters_rejects_out_of_range_client_mtu() {
253 let err = exchange_with_response_json(serde_json::json!({
254 "serverAddress": "fd59:2381:6956::1",
255 "serverRSDPort": 58783,
256 "clientParameters": {
257 "address": "fd59:2381:6956::2",
258 "mtu": 4294967296u64,
259 }
260 }))
261 .await
262 .unwrap_err();
263
264 match err {
265 TunnelError::Protocol(message) => {
266 assert!(
267 message.contains("invalid clientParameters.mtu"),
268 "unexpected error: {message}"
269 );
270 }
271 other => panic!("unexpected error variant: {other:?}"),
272 }
273 }
274
275 #[tokio::test]
276 async fn test_exchange_tunnel_parameters_timeout() {
277 let (mut client, _server) = tokio::io::duplex(4096);
278 let err = exchange_tunnel_parameters_with_timeout(&mut client, Duration::from_millis(20))
279 .await
280 .unwrap_err();
281 match err {
282 TunnelError::Protocol(message) => {
283 assert!(message.contains("timed out"), "unexpected error: {message}");
284 }
285 other => panic!("unexpected error variant: {other:?}"),
286 }
287 }
288}