wireguard_netstack/tunnel.rs
1//! High-level managed WireGuard tunnel.
2//!
3//! This module provides `ManagedTunnel`, a convenient abstraction that handles
4//! all the background tasks required to run a WireGuard tunnel.
5
6use crate::error::{Error, Result};
7use crate::netstack::NetStack;
8use crate::wireguard::{WireGuardConfig, WireGuardTunnel};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::task::JoinSet;
12
13/// A managed WireGuard tunnel that handles all background tasks automatically.
14///
15/// This is the main entry point for library users. It:
16/// - Creates and configures the WireGuard tunnel
17/// - Creates the userspace network stack
18/// - Spawns all required background tasks
19/// - Performs the WireGuard handshake
20/// - Provides access to the `NetStack` for making TCP connections
21///
22/// # Example
23///
24/// ```no_run
25/// use wireguard_netstack::{ManagedTunnel, WgConfigFile};
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29/// // Load config and connect
30/// let config = WgConfigFile::from_file("wg.conf")?
31/// .into_wireguard_config()
32/// .await?;
33///
34/// let tunnel = ManagedTunnel::connect(config).await?;
35///
36/// // Use tunnel.netstack() to create TCP connections
37/// // ...
38///
39/// // Graceful shutdown
40/// tunnel.shutdown().await;
41/// Ok(())
42/// }
43/// ```
44pub struct ManagedTunnel {
45 /// The underlying WireGuard tunnel.
46 wg_tunnel: Arc<WireGuardTunnel>,
47 /// The userspace network stack.
48 netstack: Arc<NetStack>,
49 /// Background task handles.
50 tasks: JoinSet<()>,
51}
52
53impl ManagedTunnel {
54 /// Connect to a WireGuard peer using the provided configuration.
55 ///
56 /// This will:
57 /// 1. Create the WireGuard tunnel
58 /// 2. Create the userspace network stack
59 /// 3. Spawn all background tasks
60 /// 4. Initiate and wait for the WireGuard handshake
61 ///
62 /// # Arguments
63 ///
64 /// * `config` - WireGuard configuration
65 ///
66 /// # Returns
67 ///
68 /// A `ManagedTunnel` ready to use for making TCP connections.
69 pub async fn connect(config: WireGuardConfig) -> Result<Self> {
70 Self::connect_with_timeout(config, Duration::from_secs(10)).await
71 }
72
73 /// Connect with a custom handshake timeout.
74 pub async fn connect_with_timeout(
75 config: WireGuardConfig,
76 handshake_timeout: Duration,
77 ) -> Result<Self> {
78 log::info!("Creating WireGuard tunnel...");
79 let wg_tunnel = WireGuardTunnel::new(config)
80 .await
81 .map_err(|e| Error::TunnelCreation(e.to_string()))?;
82
83 // Take the incoming receiver before starting tasks
84 let incoming_rx = wg_tunnel
85 .take_incoming_receiver()
86 .ok_or_else(|| Error::TunnelCreation("Failed to get incoming receiver".into()))?;
87
88 // Create the network stack
89 log::info!("Creating userspace network stack...");
90 let netstack = NetStack::new(wg_tunnel.clone());
91
92 // Spawn background tasks
93 log::info!("Starting background tasks...");
94 let mut tasks = JoinSet::new();
95
96 // WireGuard receive loop
97 let wg = wg_tunnel.clone();
98 tasks.spawn(async move {
99 if let Err(e) = wg.run_receive_loop().await {
100 log::error!("WireGuard receive loop error: {}", e);
101 }
102 });
103
104 // WireGuard send loop
105 let wg = wg_tunnel.clone();
106 tasks.spawn(async move {
107 if let Err(e) = wg.run_send_loop().await {
108 log::error!("WireGuard send loop error: {}", e);
109 }
110 });
111
112 // WireGuard timer loop
113 let wg = wg_tunnel.clone();
114 tasks.spawn(async move {
115 if let Err(e) = wg.run_timer_loop().await {
116 log::error!("WireGuard timer loop error: {}", e);
117 }
118 });
119
120 // Network stack poll loop
121 let ns = netstack.clone();
122 tasks.spawn(async move {
123 if let Err(e) = ns.run_poll_loop().await {
124 log::error!("Network stack poll loop error: {}", e);
125 }
126 });
127
128 // Network stack RX loop
129 let ns = netstack.clone();
130 tasks.spawn(async move {
131 if let Err(e) = ns.run_rx_loop(incoming_rx).await {
132 log::error!("Network stack RX loop error: {}", e);
133 }
134 });
135
136 // Give tasks time to start
137 tokio::time::sleep(Duration::from_millis(100)).await;
138
139 // Initiate handshake
140 log::info!("Initiating WireGuard handshake...");
141 wg_tunnel
142 .initiate_handshake()
143 .await
144 .map_err(|e| Error::TunnelCreation(e.to_string()))?;
145
146 // Wait for handshake
147 log::info!("Waiting for WireGuard handshake to complete...");
148 wg_tunnel.wait_for_handshake(handshake_timeout).await?;
149
150 log::info!("WireGuard tunnel established!");
151
152 Ok(Self {
153 wg_tunnel,
154 netstack,
155 tasks,
156 })
157 }
158
159 /// Get the network stack for creating TCP connections.
160 pub fn netstack(&self) -> Arc<NetStack> {
161 self.netstack.clone()
162 }
163
164 /// Get the underlying WireGuard tunnel.
165 pub fn wg_tunnel(&self) -> Arc<WireGuardTunnel> {
166 self.wg_tunnel.clone()
167 }
168
169 /// Returns the time elapsed since the last successful WireGuard handshake.
170 ///
171 /// Returns `Some(duration)` if a handshake has completed, or `None` if no
172 /// handshake has occurred yet. This is useful for health-checking the tunnel:
173 /// WireGuard re-handshakes every ~120s on an active session, so a value
174 /// exceeding ~180s typically indicates the tunnel is stale.
175 ///
176 /// # Example
177 ///
178 /// ```no_run
179 /// use std::time::Duration;
180 /// use wireguard_netstack::ManagedTunnel;
181 ///
182 /// fn check_health(tunnel: &ManagedTunnel) -> bool {
183 /// match tunnel.time_since_last_handshake() {
184 /// Some(elapsed) => elapsed < Duration::from_secs(180),
185 /// None => false,
186 /// }
187 /// }
188 /// ```
189 pub fn time_since_last_handshake(&self) -> Option<Duration> {
190 self.wg_tunnel.time_since_last_handshake()
191 }
192
193 /// Gracefully shutdown the tunnel.
194 ///
195 /// This aborts all background tasks and waits for them to complete.
196 pub async fn shutdown(mut self) {
197 log::info!("Shutting down WireGuard tunnel...");
198 self.tasks.abort_all();
199 while self.tasks.join_next().await.is_some() {}
200 log::info!("WireGuard tunnel shutdown complete.");
201 }
202}
203
204impl Drop for ManagedTunnel {
205 fn drop(&mut self) {
206 // Abort all tasks on drop
207 self.tasks.abort_all();
208 }
209}