1use std::net::Ipv4Addr;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8use crate::{ProtocolError, Result};
9
10fn validate_ipv4(s: &str) -> Result<()> {
12 s.parse::<Ipv4Addr>()
13 .map_err(|_| ProtocolError::InvalidPacket(format!("invalid IPv4 address: {}", s)))?;
14 Ok(())
15}
16
17#[derive(Debug, Clone)]
19pub enum ControlMessage {
20 TlsData(Bytes),
22 PushRequest,
24 PushReply(PushReply),
26 Auth(AuthMessage),
28 Info(String),
30 Exit,
32}
33
34#[derive(Debug, Clone)]
36pub struct ControlPacket {
37 pub packet_id: u32,
39 pub message: ControlMessage,
41}
42
43impl ControlPacket {
44 pub fn new(packet_id: u32, message: ControlMessage) -> Self {
46 Self { packet_id, message }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PushReply {
53 pub routes: Vec<PushRoute>,
55 pub ifconfig: Option<(String, String)>,
57 pub ifconfig_ipv6: Option<String>,
59 pub dns: Vec<String>,
61 pub dns_search: Vec<String>,
63 pub redirect_gateway: bool,
65 pub route_gateway: Option<String>,
67 pub topology: Topology,
69 pub ping: u32,
71 pub ping_restart: u32,
73 pub options: Vec<String>,
75}
76
77impl Default for PushReply {
78 fn default() -> Self {
79 Self {
80 routes: vec![],
81 ifconfig: None,
82 ifconfig_ipv6: None,
83 dns: vec![],
84 dns_search: vec![],
85 redirect_gateway: false,
86 route_gateway: None,
87 topology: Topology::Subnet,
88 ping: 10,
89 ping_restart: 60,
90 options: vec![],
91 }
92 }
93}
94
95impl PushReply {
96 pub fn encode(&self) -> String {
98 let mut parts = vec!["PUSH_REPLY".to_string()];
99
100 parts.push(format!("topology {}", self.topology.as_str()));
102
103 if let Some((ip, mask)) = &self.ifconfig {
105 parts.push(format!("ifconfig {} {}", ip, mask));
106 }
107
108 if let Some(ipv6) = &self.ifconfig_ipv6 {
110 parts.push(format!("ifconfig-ipv6 {}", ipv6));
111 }
112
113 for route in &self.routes {
115 parts.push(route.encode());
116 }
117
118 if let Some(gw) = &self.route_gateway {
120 parts.push(format!("route-gateway {}", gw));
121 }
122
123 if self.redirect_gateway {
125 parts.push("redirect-gateway def1".to_string());
126 }
127
128 for dns in &self.dns {
130 parts.push(format!("dhcp-option DNS {}", dns));
131 }
132
133 for domain in &self.dns_search {
135 parts.push(format!("dhcp-option DOMAIN {}", domain));
136 }
137
138 parts.push(format!("ping {}", self.ping));
140 parts.push(format!("ping-restart {}", self.ping_restart));
141
142 for opt in &self.options {
144 parts.push(opt.clone());
145 }
146
147 parts.join(",")
148 }
149
150 pub fn parse(s: &str) -> Result<Self> {
152 let mut reply = Self::default();
153
154 let s = s.strip_prefix("PUSH_REPLY,").unwrap_or(s);
156
157 for part in s.split(',') {
158 let part = part.trim();
159 if part.is_empty() {
160 continue;
161 }
162
163 let mut tokens = part.split_whitespace();
164 match tokens.next() {
165 Some("topology") => {
166 if let Some(topo) = tokens.next() {
167 reply.topology = Topology::parse(topo);
168 }
169 }
170 Some("ifconfig") => {
171 let ip = tokens.next().unwrap_or("").to_string();
172 let mask = tokens.next().unwrap_or("").to_string();
173 reply.ifconfig = Some((ip, mask));
174 }
175 Some("ifconfig-ipv6") => {
176 if let Some(ipv6) = tokens.next() {
177 reply.ifconfig_ipv6 = Some(ipv6.to_string());
178 }
179 }
180 Some("route") => {
181 if let Ok(route) = PushRoute::parse(part) {
182 reply.routes.push(route);
183 }
184 }
185 Some("route-gateway") => {
186 if let Some(gw) = tokens.next() {
187 reply.route_gateway = Some(gw.to_string());
188 }
189 }
190 Some("redirect-gateway") => {
191 reply.redirect_gateway = true;
192 }
193 Some("dhcp-option") => {
194 match tokens.next() {
195 Some("DNS") => {
196 if let Some(dns) = tokens.next() {
197 reply.dns.push(dns.to_string());
198 }
199 }
200 Some("DOMAIN") => {
201 if let Some(domain) = tokens.next() {
202 reply.dns_search.push(domain.to_string());
203 }
204 }
205 _ => {}
206 }
207 }
208 Some("ping") => {
209 if let Some(Ok(p)) = tokens.next().map(|s| s.parse()) {
210 reply.ping = p;
211 }
212 }
213 Some("ping-restart") => {
214 if let Some(Ok(p)) = tokens.next().map(|s| s.parse()) {
215 reply.ping_restart = p;
216 }
217 }
218 _ => {
219 reply.options.push(part.to_string());
220 }
221 }
222 }
223
224 Ok(reply)
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct PushRoute {
231 pub network: String,
233 pub netmask: String,
235 pub gateway: Option<String>,
237 pub metric: Option<u32>,
239}
240
241impl PushRoute {
242 pub fn new(network: &str, netmask: &str) -> Self {
244 Self {
245 network: network.to_string(),
246 netmask: netmask.to_string(),
247 gateway: None,
248 metric: None,
249 }
250 }
251
252 pub fn encode(&self) -> String {
254 let mut s = format!("route {} {}", self.network, self.netmask);
255 if let Some(gw) = &self.gateway {
256 s.push_str(&format!(" {}", gw));
257 } else {
258 s.push_str(" vpn_gateway");
259 }
260 if let Some(metric) = self.metric {
261 s.push_str(&format!(" {}", metric));
262 }
263 s
264 }
265
266 pub fn parse(s: &str) -> Result<Self> {
268 let mut tokens = s.split_whitespace();
269 tokens.next(); let network_str = tokens
272 .next()
273 .ok_or_else(|| ProtocolError::InvalidPacket("missing network in route".into()))?;
274
275 validate_ipv4(network_str)?;
277 let network = network_str.to_string();
278
279 let netmask_str = tokens
280 .next()
281 .ok_or_else(|| ProtocolError::InvalidPacket("missing netmask in route".into()))?;
282
283 validate_ipv4(netmask_str)?;
285 let netmask = netmask_str.to_string();
286
287 let gateway = tokens.next().and_then(|g| {
288 if g == "vpn_gateway" {
289 None
290 } else {
291 validate_ipv4(g).ok().map(|_| g.to_string())
293 }
294 });
295
296 let metric = tokens.next().and_then(|m| {
297 m.parse::<u32>().ok().filter(|&m| m <= 9999) });
299
300 Ok(Self {
301 network,
302 netmask,
303 gateway,
304 metric,
305 })
306 }
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
311pub enum Topology {
312 Net30,
314 P2P,
316 #[default]
318 Subnet,
319}
320
321impl Topology {
322 pub fn parse(s: &str) -> Self {
324 match s.to_lowercase().as_str() {
325 "net30" => Topology::Net30,
326 "p2p" => Topology::P2P,
327 "subnet" => Topology::Subnet,
328 _ => Topology::Subnet,
329 }
330 }
331
332 pub fn as_str(&self) -> &'static str {
334 match self {
335 Topology::Net30 => "net30",
336 Topology::P2P => "p2p",
337 Topology::Subnet => "subnet",
338 }
339 }
340}
341
342#[derive(Debug, Clone)]
344pub struct AuthMessage {
345 pub username: String,
347 pub password: String,
349}
350
351impl AuthMessage {
352 const MAX_USERNAME_LEN: usize = 256;
354 const MAX_PASSWORD_LEN: usize = 1024;
356
357 pub fn parse(data: &[u8]) -> Result<Self> {
359 if data.len() > Self::MAX_USERNAME_LEN + Self::MAX_PASSWORD_LEN + 2 {
362 return Err(ProtocolError::InvalidPacket("auth data too long".into()));
363 }
364
365 let s = std::str::from_utf8(data)
366 .map_err(|_| ProtocolError::InvalidPacket("invalid UTF-8 in auth".into()))?;
367
368 let parts: Vec<&str> = s.split('\0').collect();
369 if parts.len() < 2 {
370 return Err(ProtocolError::InvalidPacket("missing auth fields".into()));
371 }
372
373 let username = parts[0];
374 let password = parts[1];
375
376 if username.len() > Self::MAX_USERNAME_LEN {
378 return Err(ProtocolError::InvalidPacket(
379 format!("username too long (max {} bytes)", Self::MAX_USERNAME_LEN).into(),
380 ));
381 }
382 if password.len() > Self::MAX_PASSWORD_LEN {
383 return Err(ProtocolError::InvalidPacket(
384 format!("password too long (max {} bytes)", Self::MAX_PASSWORD_LEN).into(),
385 ));
386 }
387
388 Ok(Self {
389 username: username.to_string(),
390 password: password.to_string(),
391 })
392 }
393
394 pub fn encode(&self) -> Vec<u8> {
396 let mut data = Vec::new();
397 data.extend_from_slice(self.username.as_bytes());
398 data.push(0);
399 data.extend_from_slice(self.password.as_bytes());
400 data.push(0);
401 data
402 }
403}
404
405#[derive(Debug, Clone)]
407pub struct KeyMethodV2 {
408 pub pre_master: [u8; 48],
410 pub random1: [u8; 32],
412 pub random2: [u8; 32],
414 pub options: String,
416 pub username: Option<String>,
418 pub password: Option<String>,
420 pub peer_info: Option<String>,
422}
423
424impl KeyMethodV2 {
425 pub fn parse(data: &[u8]) -> Result<Self> {
438 if data.len() < 119 {
440 return Err(ProtocolError::PacketTooShort {
441 expected: 119,
442 got: data.len(),
443 });
444 }
445
446 let mut offset = 0;
447
448 offset += 4;
450
451 let key_method = data[offset];
453 offset += 1;
454 if key_method != 2 {
455 return Err(ProtocolError::InvalidPacket(
456 format!("unsupported key method: {}", key_method),
457 ));
458 }
459
460 let mut pre_master = [0u8; 48];
462 pre_master.copy_from_slice(&data[offset..offset + 48]);
463 offset += 48;
464
465 let mut random1 = [0u8; 32];
467 random1.copy_from_slice(&data[offset..offset + 32]);
468 offset += 32;
469
470 let mut random2 = [0u8; 32];
472 random2.copy_from_slice(&data[offset..offset + 32]);
473 offset += 32;
474
475 let options = Self::read_length_prefixed_string(data, &mut offset)?;
477
478 let username = if offset + 2 <= data.len() {
480 let s = Self::read_length_prefixed_string(data, &mut offset)?;
481 if s.is_empty() { None } else { Some(s) }
482 } else {
483 None
484 };
485
486 let password = if offset + 2 <= data.len() {
488 let s = Self::read_length_prefixed_string(data, &mut offset)?;
489 if s.is_empty() { None } else { Some(s) }
490 } else {
491 None
492 };
493
494 let peer_info = if offset + 2 <= data.len() {
496 let s = Self::read_length_prefixed_string(data, &mut offset)?;
497 if s.is_empty() { None } else { Some(s) }
498 } else {
499 None
500 };
501
502 Ok(Self {
503 pre_master,
504 random1,
505 random2,
506 options,
507 username,
508 password,
509 peer_info,
510 })
511 }
512
513 fn read_length_prefixed_string(data: &[u8], offset: &mut usize) -> Result<String> {
515 if *offset + 2 > data.len() {
516 return Err(ProtocolError::PacketTooShort {
517 expected: *offset + 2,
518 got: data.len(),
519 });
520 }
521 let len = u16::from_be_bytes([data[*offset], data[*offset + 1]]) as usize;
522 *offset += 2;
523 if *offset + len > data.len() {
524 return Err(ProtocolError::PacketTooShort {
525 expected: *offset + len,
526 got: data.len(),
527 });
528 }
529 let s = std::str::from_utf8(&data[*offset..*offset + len])
530 .map_err(|_| ProtocolError::InvalidPacket("invalid UTF-8 in key method v2".into()))?;
531 *offset += len;
532 Ok(s.trim_end_matches('\0').to_string())
534 }
535
536 fn write_string(buf: &mut Vec<u8>, s: &str) {
539 let len = s.len() + 1; buf.extend_from_slice(&(len as u16).to_be_bytes());
541 buf.extend_from_slice(s.as_bytes());
542 buf.push(0); }
544
545 pub fn encode(&self, is_server: bool) -> Vec<u8> {
552 let mut buf = Vec::new();
553
554 buf.extend_from_slice(&[0u8; 4]);
556
557 buf.push(2);
559
560 if !is_server {
564 buf.extend_from_slice(&self.pre_master);
565 }
566
567 buf.extend_from_slice(&self.random1);
569
570 buf.extend_from_slice(&self.random2);
572
573 Self::write_string(&mut buf, &self.options);
575
576 if let Some(username) = &self.username {
578 Self::write_string(&mut buf, username);
579 } else {
580 Self::write_string(&mut buf, "");
581 }
582
583 if let Some(password) = &self.password {
585 Self::write_string(&mut buf, password);
586 } else {
587 Self::write_string(&mut buf, "");
588 }
589
590 if let Some(peer_info) = &self.peer_info {
592 Self::write_string(&mut buf, peer_info);
593 }
594
595 buf
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
604 fn test_push_reply_roundtrip() {
605 let mut reply = PushReply::default();
606 reply.ifconfig = Some(("10.8.0.2".to_string(), "255.255.255.0".to_string()));
607 reply.dns.push("1.1.1.1".to_string());
608 reply.routes.push(PushRoute::new("192.168.1.0", "255.255.255.0"));
609 reply.redirect_gateway = true;
610
611 let encoded = reply.encode();
612 let parsed = PushReply::parse(&encoded).unwrap();
613
614 assert_eq!(parsed.ifconfig, reply.ifconfig);
615 assert_eq!(parsed.dns, reply.dns);
616 assert!(parsed.redirect_gateway);
617 }
618
619 #[test]
620 fn test_auth_message() {
621 let auth = AuthMessage {
622 username: "user".to_string(),
623 password: "pass".to_string(),
624 };
625
626 let encoded = auth.encode();
627 let parsed = AuthMessage::parse(&encoded).unwrap();
628
629 assert_eq!(parsed.username, "user");
630 assert_eq!(parsed.password, "pass");
631 }
632}