1use anyhow::{anyhow, Error};
2use clap::{
3 builder::{PossibleValuesParser, TypedValueParser},
4 Args,
5};
6use ipnet::IpNet;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::{
11 fmt::{self, Display, Formatter},
12 io,
13 net::{IpAddr, SocketAddr, ToSocketAddrs},
14 ops::{Deref, DerefMut},
15 path::{Path, PathBuf},
16 str::FromStr,
17 time::{Duration, SystemTime},
18 vec,
19};
20use url::Host;
21use wireguard_control::{
22 AllowedIp, Backend, InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder,
23 PeerInfo,
24};
25
26use crate::wg::PeerInfoExt;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct Interface {
30 name: InterfaceName,
31}
32
33impl FromStr for Interface {
34 type Err = InvalidInterfaceName;
35
36 fn from_str(name: &str) -> Result<Self, Self::Err> {
37 if !Hostname::is_valid(name) {
38 Err(InvalidInterfaceName::InvalidChars)
39 } else {
40 Ok(Self {
41 name: name.parse()?,
42 })
43 }
44 }
45}
46
47impl Deref for Interface {
48 type Target = InterfaceName;
49
50 fn deref(&self) -> &Self::Target {
51 &self.name
52 }
53}
54
55impl Display for Interface {
56 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
57 f.write_str(&self.name.to_string())
58 }
59}
60
61#[derive(Clone, Debug, PartialEq, Eq)]
62pub struct Endpoint {
64 host: Host,
65 port: u16,
66}
67
68impl From<SocketAddr> for Endpoint {
69 fn from(addr: SocketAddr) -> Self {
70 match addr {
71 SocketAddr::V4(v4addr) => Self {
72 host: Host::Ipv4(*v4addr.ip()),
73 port: v4addr.port(),
74 },
75 SocketAddr::V6(v6addr) => Self {
76 host: Host::Ipv6(*v6addr.ip()),
77 port: v6addr.port(),
78 },
79 }
80 }
81}
82
83impl FromStr for Endpoint {
84 type Err = &'static str;
85
86 fn from_str(s: &str) -> Result<Self, Self::Err> {
87 match s.rsplitn(2, ':').collect::<Vec<&str>>().as_slice() {
88 [port, host] => {
89 let port = port.parse().map_err(|_| "couldn't parse port")?;
90 let host = Host::parse(host).map_err(|_| "couldn't parse host")?;
91 Ok(Endpoint { host, port })
92 },
93 _ => Err("couldn't parse in form of 'host:port'"),
94 }
95 }
96}
97
98impl Serialize for Endpoint {
99 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100 where
101 S: serde::Serializer,
102 {
103 serializer.serialize_str(&self.to_string())
104 }
105}
106
107impl<'de> Deserialize<'de> for Endpoint {
108 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109 where
110 D: serde::Deserializer<'de>,
111 {
112 struct EndpointVisitor;
113 impl serde::de::Visitor<'_> for EndpointVisitor {
114 type Value = Endpoint;
115
116 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
117 formatter.write_str("a valid host:port endpoint")
118 }
119
120 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
121 where
122 E: serde::de::Error,
123 {
124 s.parse().map_err(serde::de::Error::custom)
125 }
126 }
127 deserializer.deserialize_str(EndpointVisitor)
128 }
129}
130
131impl Display for Endpoint {
132 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
133 self.host.fmt(f)?;
134 f.write_str(":")?;
135 self.port.fmt(f)
136 }
137}
138
139impl Endpoint {
140 pub fn resolve(&self) -> Result<SocketAddr, io::Error> {
141 let mut addrs = self.to_string().to_socket_addrs()?;
142 addrs.next().ok_or_else(|| {
143 io::Error::new(
144 io::ErrorKind::AddrNotAvailable,
145 "failed to resolve address".to_string(),
146 )
147 })
148 }
149
150 pub fn is_host_unspecified(&self) -> bool {
152 match self.host {
153 Host::Ipv4(ip) => ip.is_unspecified(),
154 Host::Ipv6(ip) => ip.is_unspecified(),
155 Host::Domain(_) => false,
156 }
157 }
158
159 pub fn port(&self) -> u16 {
160 self.port
161 }
162}
163
164#[derive(Deserialize, Serialize, Debug)]
165#[serde(tag = "option", content = "content")]
166pub enum EndpointContents {
167 Set(Endpoint),
168 Unset,
169}
170
171impl From<EndpointContents> for Option<Endpoint> {
172 fn from(endpoint: EndpointContents) -> Self {
173 match endpoint {
174 EndpointContents::Set(addr) => Some(addr),
175 EndpointContents::Unset => None,
176 }
177 }
178}
179
180impl From<Option<Endpoint>> for EndpointContents {
181 fn from(option: Option<Endpoint>) -> Self {
182 match option {
183 Some(addr) => Self::Set(addr),
184 None => Self::Unset,
185 }
186 }
187}
188
189#[derive(Deserialize, Serialize, Debug)]
190pub struct AssociationContents {
191 pub cidr_id_1: i64,
192 pub cidr_id_2: i64,
193}
194
195#[derive(Deserialize, Serialize, Debug)]
196pub struct Association {
197 pub id: i64,
198
199 #[serde(flatten)]
200 pub contents: AssociationContents,
201}
202
203impl Deref for Association {
204 type Target = AssociationContents;
205
206 fn deref(&self) -> &Self::Target {
207 &self.contents
208 }
209}
210
211#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd, Eq, Ord)]
212pub struct CidrContents {
213 pub name: String,
214 pub cidr: IpNet,
215 pub parent: Option<i64>,
216}
217
218impl Deref for CidrContents {
219 type Target = IpNet;
220
221 fn deref(&self) -> &Self::Target {
222 &self.cidr
223 }
224}
225
226#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd, Eq, Ord)]
227pub struct Cidr {
228 pub id: i64,
229
230 #[serde(flatten)]
231 pub contents: CidrContents,
232}
233
234impl Display for Cidr {
235 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
236 write!(f, "{} ({})", self.name, self.cidr)
237 }
238}
239
240impl Deref for Cidr {
241 type Target = CidrContents;
242
243 fn deref(&self) -> &Self::Target {
244 &self.contents
245 }
246}
247
248#[derive(Clone, PartialEq, PartialOrd, Eq, Ord)]
249pub struct CidrTree<'a> {
250 cidrs: &'a [Cidr],
251 contents: &'a Cidr,
252}
253
254impl std::ops::Deref for CidrTree<'_> {
255 type Target = Cidr;
256
257 fn deref(&self) -> &Self::Target {
258 self.contents
259 }
260}
261
262impl<'a> CidrTree<'a> {
263 pub fn new(cidrs: &'a [Cidr]) -> Self {
264 let root = cidrs
265 .iter()
266 .min_by_key(|c| c.cidr.prefix_len())
267 .expect("failed to find root CIDR");
268 Self::with_root(cidrs, root)
269 }
270
271 pub fn with_root(cidrs: &'a [Cidr], root: &'a Cidr) -> Self {
272 Self {
273 cidrs,
274 contents: root,
275 }
276 }
277
278 pub fn children(&self) -> impl Iterator<Item = CidrTree<'_>> {
279 self.cidrs
280 .iter()
281 .filter(move |c| c.parent == Some(self.contents.id))
282 .map(move |c| CidrTree {
283 cidrs: self.cidrs,
284 contents: c,
285 })
286 }
287
288 pub fn leaves(&self) -> Vec<Cidr> {
289 if !self.cidrs.iter().any(|cidr| cidr.parent == Some(self.id)) {
290 vec![self.contents.clone()]
291 } else {
292 self.children().flat_map(|child| child.leaves()).collect()
293 }
294 }
295}
296
297#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
298pub struct RedeemContents {
299 pub public_key: String,
300}
301
302#[derive(Debug, Clone, PartialEq, Eq, Args)]
303pub struct InstallOpts {
304 #[clap(long, conflicts_with = "default_name")]
306 pub name: Option<String>,
307
308 #[clap(long = "default-name")]
310 pub default_name: bool,
311
312 #[clap(short, long)]
314 pub delete_invite: bool,
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, Args)]
318pub struct AddPeerOpts {
319 #[clap(long)]
321 pub name: Option<Hostname>,
322
323 #[clap(long, conflicts_with = "auto_ip")]
325 pub ip: Option<IpAddr>,
326
327 #[clap(long = "auto-ip")]
329 pub auto_ip: bool,
330
331 #[clap(long)]
333 pub cidr: Option<String>,
334
335 #[clap(long)]
337 pub admin: Option<bool>,
338
339 #[clap(long)]
341 pub yes: bool,
342
343 #[clap(long)]
345 pub save_config: Option<String>,
346
347 #[clap(long)]
349 pub invite_expires: Option<Timestring>,
350}
351
352#[derive(Debug, Clone, PartialEq, Eq, Args)]
353pub struct RenamePeerOpts {
354 #[clap(long)]
356 pub name: Option<Hostname>,
357
358 #[clap(long)]
360 pub new_name: Option<Hostname>,
361
362 #[clap(long)]
364 pub yes: bool,
365}
366
367#[derive(Debug, Clone, PartialEq, Eq, Args)]
368pub struct EnableDisablePeerOpts {
369 #[clap(long)]
371 pub name: Option<Hostname>,
372
373 #[clap(long, requires("name"))]
375 pub yes: bool,
376}
377
378#[derive(Debug, Clone, PartialEq, Eq, Args)]
379pub struct AddCidrOpts {
380 #[clap(long)]
382 pub name: Option<Hostname>,
383
384 #[clap(long)]
386 pub cidr: Option<IpNet>,
387
388 #[clap(long)]
390 pub parent: Option<String>,
391
392 #[clap(long)]
394 pub yes: bool,
395}
396
397#[derive(Debug, Clone, PartialEq, Eq, Args)]
398pub struct RenameCidrOpts {
399 #[clap(long)]
401 pub name: Option<String>,
402
403 #[clap(long)]
405 pub new_name: Option<String>,
406
407 #[clap(long)]
409 pub yes: bool,
410}
411
412#[derive(Debug, Clone, PartialEq, Eq, Args)]
413pub struct DeleteCidrOpts {
414 #[clap(long)]
416 pub name: Option<String>,
417
418 #[clap(long)]
420 pub yes: bool,
421}
422
423#[derive(Debug, Clone, PartialEq, Eq, Args)]
424pub struct AddDeleteAssociationOpts {
425 pub cidr1: Option<String>,
427
428 pub cidr2: Option<String>,
430
431 #[clap(long)]
433 pub yes: bool,
434}
435
436#[derive(Debug, Clone, PartialEq, Eq, Args)]
437pub struct ListenPortOpts {
438 #[clap(short, long)]
440 pub listen_port: Option<u16>,
441
442 #[clap(short, long, conflicts_with = "listen_port")]
444 pub unset: bool,
445
446 #[clap(long)]
448 pub yes: bool,
449}
450
451#[derive(Debug, Clone, PartialEq, Eq, Args)]
452pub struct OverrideEndpointOpts {
453 #[clap(short, long)]
458 pub endpoint: Option<Endpoint>,
459
460 #[clap(short, long, conflicts_with = "endpoint")]
462 pub unset: bool,
463
464 #[clap(long)]
466 pub yes: bool,
467}
468
469#[derive(Debug, Clone, Args)]
470pub struct NatOpts {
471 #[clap(long)]
472 pub no_nat_traversal: bool,
475
476 #[clap(long)]
477 pub exclude_nat_candidates: Vec<IpNet>,
480
481 #[clap(long, conflicts_with = "exclude_nat_candidates")]
482 pub no_nat_candidates: bool,
485}
486
487impl NatOpts {
488 pub fn all_disabled() -> Self {
489 Self {
490 no_nat_traversal: true,
491 exclude_nat_candidates: vec![],
492 no_nat_candidates: true,
493 }
494 }
495
496 pub fn is_excluded(&self, ip: IpAddr) -> bool {
498 self.no_nat_candidates
499 || self
500 .exclude_nat_candidates
501 .iter()
502 .any(|network| network.contains(&ip))
503 }
504}
505
506#[derive(Debug, Clone, Copy, Args)]
507pub struct NetworkOpts {
508 #[clap(long)]
509 pub no_routing: bool,
512
513 #[clap(long, default_value_t, value_parser = PossibleValuesParser::new(Backend::variants()).map(|s| s.parse::<Backend>().unwrap()))]
514 pub backend: Backend,
517
518 #[clap(long)]
519 pub mtu: Option<u32>,
521}
522
523#[derive(Clone, Debug, Args)]
524pub struct HostsOpt {
525 #[clap(long = "hosts-path", default_value = "/etc/hosts")]
527 pub hosts_path: PathBuf,
528
529 #[clap(long = "no-write-hosts", conflicts_with = "hosts_path")]
531 pub no_write_hosts: bool,
532}
533
534impl From<HostsOpt> for Option<PathBuf> {
535 fn from(opt: HostsOpt) -> Self {
536 (!opt.no_write_hosts).then_some(opt.hosts_path)
537 }
538}
539
540#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
541pub struct PeerContents {
542 pub name: Hostname,
543 pub ip: IpAddr,
544 pub cidr_id: i64,
545 pub public_key: String,
546 pub endpoint: Option<Endpoint>,
547 pub persistent_keepalive_interval: Option<u16>,
548 pub is_admin: bool,
549 pub is_disabled: bool,
550 pub is_redeemed: bool,
551 pub invite_expires: Option<SystemTime>,
552 #[serde(default)]
553 pub candidates: Vec<Endpoint>,
554}
555
556#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
557pub struct Peer {
558 pub id: i64,
559
560 #[serde(flatten)]
561 pub contents: PeerContents,
562}
563
564impl AsRef<Peer> for Peer {
565 fn as_ref(&self) -> &Peer {
566 self
567 }
568}
569
570impl Deref for Peer {
571 type Target = PeerContents;
572
573 fn deref(&self) -> &Self::Target {
574 &self.contents
575 }
576}
577
578impl DerefMut for Peer {
579 fn deref_mut(&mut self) -> &mut Self::Target {
580 &mut self.contents
581 }
582}
583
584impl Display for Peer {
585 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 write!(f, "{} ({})", &self.name, &self.public_key)
587 }
588}
589
590#[derive(Debug, Clone, PartialEq, Eq)]
591pub enum PeerChange {
592 AllowedIPs {
593 old: Vec<AllowedIp>,
594 new: Vec<AllowedIp>,
595 },
596 PersistentKeepalive {
597 old: Option<u16>,
598 new: Option<u16>,
599 },
600 Endpoint {
601 old: Option<SocketAddr>,
602 new: Option<SocketAddr>,
603 },
604 NatTraverseReattempt,
605}
606
607impl Display for PeerChange {
608 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
609 match self {
610 Self::AllowedIPs { old, new } => write!(f, "Allowed IPs: {old:?} => {new:?}"),
611 Self::PersistentKeepalive { old, new } => write!(
612 f,
613 "Persistent Keepalive: {} => {}",
614 old.display_string(),
615 new.display_string()
616 ),
617 Self::Endpoint { old, new } => write!(
618 f,
619 "Endpoint: {} => {}",
620 old.display_string(),
621 new.display_string()
622 ),
623 Self::NatTraverseReattempt => write!(f, "NAT Traversal Reattempt"),
624 }
625 }
626}
627
628trait OptionExt {
629 fn display_string(&self) -> String;
630}
631
632impl<T: std::fmt::Debug> OptionExt for Option<T> {
633 fn display_string(&self) -> String {
634 match self {
635 Some(x) => {
636 format!("{x:?}")
637 },
638 None => "[none]".to_string(),
639 }
640 }
641}
642
643#[derive(Clone, Debug, PartialEq, Eq)]
646pub struct PeerDiff<'a> {
647 pub old: Option<&'a PeerConfig>,
648 pub new: Option<&'a Peer>,
649 builder: PeerConfigBuilder,
650 changes: Vec<PeerChange>,
651}
652
653impl<'a> PeerDiff<'a> {
654 pub fn new(
655 old_info: Option<&'a PeerInfo>,
656 new: Option<&'a Peer>,
657 ) -> Result<Option<Self>, Error> {
658 let old = old_info.map(|p| &p.config);
659 match (old_info, new) {
660 (Some(old), Some(new)) if old.config.public_key.to_base64() != new.public_key => Err(
661 anyhow!("old and new peer configs have different public keys"),
662 ),
663 (None, None) => Ok(None),
664 _ => Ok(
665 Self::peer_config_builder(old_info, new).map(|(builder, changes)| Self {
666 old,
667 new,
668 builder,
669 changes,
670 }),
671 ),
672 }
673 }
674
675 pub fn public_key(&self) -> &Key {
676 self.builder.public_key()
677 }
678
679 pub fn changes(&self) -> &[PeerChange] {
680 &self.changes
681 }
682
683 fn peer_config_builder(
684 old_info: Option<&PeerInfo>,
685 new: Option<&Peer>,
686 ) -> Option<(PeerConfigBuilder, Vec<PeerChange>)> {
687 let old = old_info.map(|p| &p.config);
688 let public_key = match (old, new) {
689 (Some(old), _) => old.public_key.clone(),
690 (_, Some(new)) => Key::from_base64(&new.public_key).unwrap(),
691 _ => return None,
692 };
693 let mut builder = PeerConfigBuilder::new(&public_key);
694 let mut changes = vec![];
695
696 if new.is_none() || matches!(new, Some(new) if new.is_disabled) {
698 return Some((builder.remove(), changes));
699 }
700 let new = new.unwrap();
702
703 let new_allowed_ips = &[AllowedIp {
704 address: new.ip,
705 cidr: if new.ip.is_ipv4() { 32 } else { 128 },
706 }];
707 if old.is_none() || matches!(old, Some(old) if old.allowed_ips != new_allowed_ips) {
708 builder = builder
709 .replace_allowed_ips()
710 .add_allowed_ips(new_allowed_ips);
711 changes.push(PeerChange::AllowedIPs {
712 old: old.map(|o| o.allowed_ips.clone()).unwrap_or_default(),
713 new: new_allowed_ips.to_vec(),
714 });
715 }
716
717 if old.is_none()
718 || matches!(old, Some(old) if old.persistent_keepalive_interval != new.persistent_keepalive_interval)
719 {
720 builder = match new.persistent_keepalive_interval {
721 Some(interval) => builder.set_persistent_keepalive_interval(interval),
722 None => builder.unset_persistent_keepalive(),
723 };
724 changes.push(PeerChange::PersistentKeepalive {
725 old: old.and_then(|p| p.persistent_keepalive_interval),
726 new: new.persistent_keepalive_interval,
727 });
728 }
729
730 if !old_info
732 .map(|info| info.is_recently_connected())
733 .unwrap_or_default()
734 {
735 let mut endpoint_changed = false;
736 let resolved = new.endpoint.as_ref().and_then(|e| e.resolve().ok());
737 if let Some(addr) = resolved {
738 if old.is_none() || matches!(old, Some(old) if old.endpoint != resolved) {
739 builder = builder.set_endpoint(addr);
740 changes.push(PeerChange::Endpoint {
741 old: old.and_then(|p| p.endpoint),
742 new: Some(addr),
743 });
744 endpoint_changed = true;
745 }
746 }
747 if !endpoint_changed && !new.candidates.is_empty() {
748 changes.push(PeerChange::NatTraverseReattempt)
749 }
750 }
751
752 if !changes.is_empty() {
753 Some((builder, changes))
754 } else {
755 None
756 }
757 }
758}
759
760impl From<&Peer> for PeerConfigBuilder {
761 fn from(peer: &Peer) -> Self {
762 PeerDiff::new(None, Some(peer))
763 .expect("No Err on explicitly set peer data")
764 .expect("None -> Some(peer) will always create a PeerDiff")
765 .into()
766 }
767}
768
769impl From<PeerDiff<'_>> for PeerConfigBuilder {
770 fn from(diff: PeerDiff) -> Self {
773 diff.builder
774 }
775}
776
777#[derive(Debug, Clone, Deserialize, Serialize)]
780pub struct State {
781 pub peers: Vec<Peer>,
785
786 pub cidrs: Vec<Cidr>,
789}
790
791#[derive(Debug, Default, Clone, Deserialize, Serialize)]
793pub struct ServerCapabilities {
794 #[serde(default)]
795 pub unspecified_ip_in_override_endpoint: bool,
796}
797
798#[derive(Clone, Debug, PartialEq, Eq)]
799pub struct Timestring {
800 timestring: String,
801 seconds: u64,
802}
803
804impl Display for Timestring {
805 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
806 f.write_str(&self.timestring)
807 }
808}
809
810impl FromStr for Timestring {
811 type Err = &'static str;
812
813 fn from_str(timestring: &str) -> Result<Self, Self::Err> {
814 if timestring.len() < 2 {
815 Err("timestring isn't long enough!")
816 } else {
817 let (n, suffix) = timestring.split_at(timestring.len() - 1);
818 let n: u64 = n.parse().map_err(|_| {
819 "invalid timestring (a number followed by a time unit character, eg. '15m')"
820 })?;
821 let multiplier = match suffix {
822 "s" => Ok(1),
823 "m" => Ok(60),
824 "h" => Ok(60 * 60),
825 "d" => Ok(60 * 60 * 24),
826 "w" => Ok(60 * 60 * 24 * 7),
827 _ => Err("invalid timestring suffix (must be one of 's', 'm', 'h', 'd', or 'w')"),
828 }?;
829
830 Ok(Self {
831 timestring: timestring.to_string(),
832 seconds: n * multiplier,
833 })
834 }
835 }
836}
837
838impl From<Timestring> for Duration {
839 fn from(timestring: Timestring) -> Self {
840 Duration::from_secs(timestring.seconds)
841 }
842}
843
844#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
845pub struct Hostname(String);
846
847static HOSTNAME_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"^([a-z0-9]-?)*[a-z0-9]$").unwrap());
850
851impl Hostname {
852 pub fn is_valid(name: &str) -> bool {
853 name.len() < 64 && HOSTNAME_REGEX.is_match(name)
854 }
855}
856
857impl FromStr for Hostname {
858 type Err = &'static str;
859
860 fn from_str(name: &str) -> Result<Self, Self::Err> {
861 if Self::is_valid(name) {
862 Ok(Self(name.to_string()))
863 } else {
864 Err("invalid hostname string (only alphanumeric with dashes)")
865 }
866 }
867}
868
869impl Deref for Hostname {
870 type Target = str;
871
872 fn deref(&self) -> &Self::Target {
873 &self.0
874 }
875}
876
877impl Display for Hostname {
878 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
879 f.write_str(&self.0)
880 }
881}
882
883pub trait IoErrorContext<T> {
884 fn with_path<P: AsRef<Path>>(self, path: P) -> Result<T, WrappedIoError>;
885 fn with_str<S: Into<String>>(self, context: S) -> Result<T, WrappedIoError>;
886}
887
888impl<T> IoErrorContext<T> for Result<T, std::io::Error> {
889 fn with_path<P: AsRef<Path>>(self, path: P) -> Result<T, WrappedIoError> {
890 self.with_str(path.as_ref().to_string_lossy())
891 }
892
893 fn with_str<S: Into<String>>(self, context: S) -> Result<T, WrappedIoError> {
894 self.map_err(|e| WrappedIoError {
895 io_error: e,
896 context: context.into(),
897 })
898 }
899}
900
901#[derive(Debug)]
902pub struct WrappedIoError {
903 io_error: std::io::Error,
904 context: String,
905}
906
907impl Display for WrappedIoError {
908 fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
909 write!(f, "{} - {}", self.context, self.io_error)
910 }
911}
912
913impl Deref for WrappedIoError {
914 type Target = std::io::Error;
915
916 fn deref(&self) -> &Self::Target {
917 &self.io_error
918 }
919}
920
921impl std::error::Error for WrappedIoError {}
922
923#[cfg(test)]
924mod tests {
925 use super::*;
926 use std::net::IpAddr;
927 use wireguard_control::{Key, PeerConfigBuilder, PeerStats};
928
929 #[test]
930 fn test_peer_no_diff() {
931 const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE=";
932 let ip: IpAddr = "10.0.0.1".parse().unwrap();
933 let peer = Peer {
934 id: 1,
935 contents: PeerContents {
936 name: "peer1".parse().unwrap(),
937 ip,
938 cidr_id: 1,
939 public_key: PUBKEY.to_owned(),
940 endpoint: None,
941 persistent_keepalive_interval: None,
942 is_admin: false,
943 is_disabled: false,
944 is_redeemed: true,
945 invite_expires: None,
946 candidates: vec![],
947 },
948 };
949 let builder =
950 PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32);
951
952 let config = builder.into_peer_config();
953 let info = PeerInfo {
954 config,
955 stats: Default::default(),
956 };
957
958 let diff = PeerDiff::new(Some(&info), Some(&peer)).unwrap();
959
960 println!("{diff:?}");
961 assert_eq!(diff, None);
962 }
963
964 #[test]
965 fn test_peer_diff() {
966 const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE=";
967 let ip: IpAddr = "10.0.0.1".parse().unwrap();
968 let peer = Peer {
969 id: 1,
970 contents: PeerContents {
971 name: "peer1".parse().unwrap(),
972 ip,
973 cidr_id: 1,
974 public_key: PUBKEY.to_owned(),
975 endpoint: None,
976 persistent_keepalive_interval: Some(15),
977 is_admin: false,
978 is_disabled: false,
979 is_redeemed: true,
980 invite_expires: None,
981 candidates: vec![],
982 },
983 };
984 let builder =
985 PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32);
986
987 let config = builder.into_peer_config();
988 let info = PeerInfo {
989 config,
990 stats: Default::default(),
991 };
992 let diff = PeerDiff::new(Some(&info), Some(&peer)).unwrap();
993
994 println!("{peer:?}");
995 println!("{:?}", info.config);
996 assert!(diff.is_some());
997 }
998
999 #[test]
1000 fn test_peer_diff_handshake_time() {
1001 const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE=";
1002 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1003 let peer = Peer {
1004 id: 1,
1005 contents: PeerContents {
1006 name: "peer1".parse().unwrap(),
1007 ip,
1008 cidr_id: 1,
1009 public_key: PUBKEY.to_owned(),
1010 endpoint: Some("1.1.1.1:1111".parse().unwrap()),
1011 persistent_keepalive_interval: None,
1012 is_admin: false,
1013 is_disabled: false,
1014 is_redeemed: true,
1015 invite_expires: None,
1016 candidates: vec![],
1017 },
1018 };
1019 let builder =
1020 PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32);
1021
1022 let config = builder.into_peer_config();
1023 let mut info = PeerInfo {
1024 config,
1025 stats: PeerStats {
1026 last_handshake_time: Some(SystemTime::now() - Duration::from_secs(200)),
1027 ..Default::default()
1028 },
1029 };
1030
1031 assert!(matches!(
1033 PeerDiff::new(Some(&info), Some(&peer)),
1034 Ok(Some(_))
1035 ));
1036
1037 info.stats.last_handshake_time = Some(SystemTime::now());
1039 assert!(matches!(PeerDiff::new(Some(&info), Some(&peer)), Ok(None)));
1040 }
1041}