1use libc::c_char;
2
3use crate::{backends, key::Key, Backend, KeyPair, PeerConfigBuilder};
4
5use std::{
6 borrow::Cow,
7 ffi::CStr,
8 fmt, io,
9 net::{IpAddr, SocketAddr},
10 str::FromStr,
11 time::SystemTime,
12};
13
14#[derive(PartialEq, Eq, Clone)]
16pub struct AllowedIp {
17 pub address: IpAddr,
19 pub cidr: u8,
21}
22
23impl fmt::Debug for AllowedIp {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(f, "{}/{}", self.address, self.cidr)
26 }
27}
28
29impl std::str::FromStr for AllowedIp {
30 type Err = ();
31
32 fn from_str(s: &str) -> Result<Self, Self::Err> {
33 let parts: Vec<_> = s.split('/').collect();
34 if parts.len() != 2 {
35 return Err(());
36 }
37
38 Ok(AllowedIp {
39 address: parts[0].parse().map_err(|_| ())?,
40 cidr: parts[1].parse().map_err(|_| ())?,
41 })
42 }
43}
44
45#[derive(Debug, PartialEq, Eq, Clone)]
49#[non_exhaustive]
50pub struct PeerConfig {
51 pub public_key: Key,
53 pub preshared_key: Option<Key>,
55 pub endpoint: Option<SocketAddr>,
57 pub persistent_keepalive_interval: Option<u16>,
59 pub allowed_ips: Vec<AllowedIp>,
61}
62
63#[derive(Debug, PartialEq, Eq, Clone, Default)]
68pub struct PeerStats {
69 pub last_handshake_time: Option<SystemTime>,
71 pub rx_bytes: u64,
73 pub tx_bytes: u64,
75}
76
77#[derive(Debug, PartialEq, Eq, Clone)]
82pub struct PeerInfo {
83 pub config: PeerConfig,
84 pub stats: PeerStats,
85}
86
87impl PeerInfo {
88 pub fn from_public_key(public_key: Key) -> PeerInfo {
89 PeerInfo {
90 config: PeerConfig {
91 public_key,
92 preshared_key: None,
93 endpoint: None,
94 persistent_keepalive_interval: None,
95 allowed_ips: vec![],
96 },
97 stats: PeerStats {
98 last_handshake_time: None,
99 rx_bytes: 0,
100 tx_bytes: 0,
101 },
102 }
103 }
104}
105
106#[derive(Debug, PartialEq, Eq, Clone)]
113#[non_exhaustive]
114pub struct Device {
115 pub name: InterfaceName,
117 pub public_key: Option<Key>,
119 pub private_key: Option<Key>,
121 pub fwmark: Option<u32>,
123 pub listen_port: Option<u16>,
125 pub peers: Vec<PeerInfo>,
127 pub linked_name: Option<String>,
129 pub backend: Backend,
131}
132
133type RawInterfaceName = [c_char; libc::IFNAMSIZ];
134
135#[derive(PartialEq, Eq, Clone, Copy)]
137pub struct InterfaceName(RawInterfaceName);
138
139impl FromStr for InterfaceName {
140 type Err = InvalidInterfaceName;
141
142 fn from_str(name: &str) -> Result<Self, InvalidInterfaceName> {
146 let len = name.len();
147 if len == 0 {
148 return Err(InvalidInterfaceName::Empty);
149 }
150
151 if len > (libc::IFNAMSIZ - 1) {
153 return Err(InvalidInterfaceName::TooLong);
154 }
155
156 let mut buf = [c_char::default(); libc::IFNAMSIZ];
157 for (out, b) in buf.iter_mut().zip(name.as_bytes().iter()) {
159 if *b == 0 || *b == b'/' || b.is_ascii_whitespace() {
160 return Err(InvalidInterfaceName::InvalidChars);
161 }
162
163 *out = *b as c_char;
164 }
165
166 Ok(Self(buf))
167 }
168}
169
170impl InterfaceName {
171 pub fn as_str_lossy(&self) -> Cow<'_, str> {
175 unsafe { CStr::from_ptr(self.0.as_ptr()) }.to_string_lossy()
177 }
178
179 #[cfg(target_os = "linux")]
180 pub fn as_ptr(&self) -> *const c_char {
182 self.0.as_ptr()
183 }
184}
185
186impl fmt::Debug for InterfaceName {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 f.write_str(&self.as_str_lossy())
189 }
190}
191
192impl fmt::Display for InterfaceName {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 f.write_str(&self.as_str_lossy())
195 }
196}
197
198#[derive(Debug, PartialEq, Eq)]
200pub enum InvalidInterfaceName {
201 TooLong,
204
205 Empty,
209 InvalidChars,
211}
212
213impl fmt::Display for InvalidInterfaceName {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 match self {
216 Self::TooLong => write!(
217 f,
218 "interface name longer than system max of {} chars",
219 libc::IFNAMSIZ
220 ),
221 Self::Empty => f.write_str("an empty interface name was provided"),
222 Self::InvalidChars => f.write_str("interface name contained slash or space characters"),
223 }
224 }
225}
226
227impl From<InvalidInterfaceName> for std::io::Error {
228 fn from(e: InvalidInterfaceName) -> Self {
229 std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
230 }
231}
232
233impl std::error::Error for InvalidInterfaceName {}
234
235impl Device {
236 pub fn list(backend: Backend) -> Result<Vec<InterfaceName>, std::io::Error> {
242 match backend {
243 #[cfg(target_os = "linux")]
244 Backend::Kernel => backends::kernel::enumerate(),
245 #[cfg(target_os = "openbsd")]
246 Backend::OpenBSD => backends::openbsd::enumerate(),
247 Backend::Userspace => backends::userspace::enumerate(),
248 }
249 }
250
251 pub fn get(name: &InterfaceName, backend: Backend) -> Result<Self, std::io::Error> {
252 match backend {
253 #[cfg(target_os = "linux")]
254 Backend::Kernel => backends::kernel::get_by_name(name),
255 #[cfg(target_os = "openbsd")]
256 Backend::OpenBSD => backends::openbsd::get_by_name(name),
257 Backend::Userspace => backends::userspace::get_by_name(name),
258 }
259 }
260
261 pub fn delete(self) -> Result<(), std::io::Error> {
262 match self.backend {
263 #[cfg(target_os = "linux")]
264 Backend::Kernel => backends::kernel::delete_interface(&self.name),
265 #[cfg(target_os = "openbsd")]
266 Backend::OpenBSD => backends::openbsd::delete_interface(&self.name),
267 Backend::Userspace => backends::userspace::delete_interface(&self.name),
268 }
269 }
270}
271
272#[derive(Debug, PartialEq, Eq, Clone)]
307pub struct DeviceUpdate {
308 pub(crate) public_key: Option<Key>,
309 pub(crate) private_key: Option<Key>,
310 pub(crate) fwmark: Option<u32>,
311 pub(crate) listen_port: Option<u16>,
312 pub(crate) peers: Vec<PeerConfigBuilder>,
313 pub(crate) replace_peers: bool,
314}
315
316impl DeviceUpdate {
317 #[must_use]
319 pub fn new() -> Self {
320 DeviceUpdate {
321 public_key: None,
322 private_key: None,
323 fwmark: None,
324 listen_port: None,
325 peers: vec![],
326 replace_peers: false,
327 }
328 }
329
330 #[must_use]
336 pub fn set_keypair(self, keypair: KeyPair) -> Self {
337 self.set_public_key(keypair.public)
338 .set_private_key(keypair.private)
339 }
340
341 #[must_use]
343 pub fn set_public_key(mut self, key: Key) -> Self {
344 self.public_key = Some(key);
345 self
346 }
347
348 #[must_use]
350 pub fn unset_public_key(self) -> Self {
351 self.set_public_key(Key::zero())
352 }
353
354 #[must_use]
356 pub fn set_private_key(mut self, key: Key) -> Self {
357 self.private_key = Some(key);
358 self
359 }
360
361 #[must_use]
363 pub fn unset_private_key(self) -> Self {
364 self.set_private_key(Key::zero())
365 }
366
367 #[must_use]
369 pub fn set_fwmark(mut self, fwmark: u32) -> Self {
370 self.fwmark = Some(fwmark);
371 self
372 }
373
374 #[must_use]
376 pub fn unset_fwmark(self) -> Self {
377 self.set_fwmark(0)
378 }
379
380 #[must_use]
384 pub fn set_listen_port(mut self, port: u16) -> Self {
385 self.listen_port = Some(port);
386 self
387 }
388
389 #[must_use]
393 pub fn randomize_listen_port(self) -> Self {
394 self.set_listen_port(0)
395 }
396
397 #[must_use]
403 pub fn add_peer(mut self, peer: PeerConfigBuilder) -> Self {
404 self.peers.push(peer);
405 self
406 }
407
408 #[must_use]
414 pub fn add_peer_with(
415 self,
416 pubkey: &Key,
417 builder: impl Fn(PeerConfigBuilder) -> PeerConfigBuilder,
418 ) -> Self {
419 self.add_peer(builder(PeerConfigBuilder::new(pubkey)))
420 }
421
422 #[must_use]
424 pub fn add_peers(mut self, peers: &[PeerConfigBuilder]) -> Self {
425 self.peers.extend_from_slice(peers);
426 self
427 }
428
429 #[must_use]
432 pub fn replace_peers(mut self) -> Self {
433 self.replace_peers = true;
434 self
435 }
436
437 #[must_use]
439 pub fn remove_peer_by_key(self, public_key: &Key) -> Self {
440 let mut peer = PeerConfigBuilder::new(public_key);
441 peer.remove_me = true;
442 self.add_peer(peer)
443 }
444
445 pub fn apply(self, iface: &InterfaceName, backend: Backend) -> io::Result<()> {
449 match backend {
450 #[cfg(target_os = "linux")]
451 Backend::Kernel => backends::kernel::apply(&self, iface),
452 #[cfg(target_os = "openbsd")]
453 Backend::OpenBSD => backends::openbsd::apply(&self, iface),
454 Backend::Userspace => backends::userspace::apply(&self, iface),
455 }
456 }
457}
458
459impl Default for DeviceUpdate {
460 fn default() -> Self {
461 Self::new()
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use crate::{DeviceUpdate, InterfaceName, InvalidInterfaceName, KeyPair, PeerConfigBuilder};
468
469 const TEST_INTERFACE: &str = "wgctrl-test";
470 use super::*;
471
472 #[test]
473 fn test_add_peers() {
474 if unsafe { libc::getuid() } != 0 {
475 return;
476 }
477
478 let keypairs: Vec<_> = (0..10).map(|_| KeyPair::generate()).collect();
479 let mut builder = DeviceUpdate::new();
480 for keypair in &keypairs {
481 builder = builder.add_peer(PeerConfigBuilder::new(&keypair.public))
482 }
483 let interface = TEST_INTERFACE.parse().unwrap();
484 builder.apply(&interface, Backend::Userspace).unwrap();
485
486 let device = Device::get(&interface, Backend::Userspace).unwrap();
487
488 for keypair in &keypairs {
489 assert!(device
490 .peers
491 .iter()
492 .any(|p| p.config.public_key == keypair.public));
493 }
494
495 device.delete().unwrap();
496 }
497
498 #[test]
499 fn test_interface_names() {
500 assert_eq!(
501 "wg-01".parse::<InterfaceName>().unwrap().as_str_lossy(),
502 "wg-01"
503 );
504 assert!("longer-nul\0".parse::<InterfaceName>().is_err());
505
506 let invalid_names = &[
507 ("", InvalidInterfaceName::Empty), ("\0", InvalidInterfaceName::InvalidChars), ("ifname\0nul", InvalidInterfaceName::InvalidChars), ("if name", InvalidInterfaceName::InvalidChars), ("ifna/me", InvalidInterfaceName::InvalidChars), ("if na/me", InvalidInterfaceName::InvalidChars), ("interfacelongname", InvalidInterfaceName::TooLong), ];
515
516 for (name, expected) in invalid_names {
517 assert!(name.parse::<InterfaceName>().as_ref() == Err(expected))
518 }
519 }
520}