hightower_client/
connection.rs

1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::transport::Server;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, error, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15/// Main connection to Hightower gateway with integrated WireGuard transport
16pub struct HightowerConnection {
17    transport: TransportServer,
18    node_id: String,
19    assigned_ip: String,
20    token: String,
21    endpoint: String,
22    #[allow(dead_code)] // Used when constructing from stored connection
23    gateway_url: String,
24    storage: Option<ConnectionStorage>,
25}
26
27impl HightowerConnection {
28    /// Connect to a Hightower gateway
29    ///
30    /// This method handles everything:
31    /// - Checks for existing stored connection and restores if available
32    /// - Otherwise: Generates WireGuard keypair, registers with gateway
33    /// - Creates transport server on 0.0.0.0:0
34    /// - Discovers network info via STUN using actual bound port
35    /// - Adds gateway as peer
36    /// - Persists connection info to storage (default: ~/.hightower-client/data)
37    ///
38    /// Returns a ready-to-use connection with working transport
39    pub async fn connect(
40        gateway_url: impl Into<String>,
41        auth_token: impl Into<String>,
42    ) -> Result<Self, ClientError> {
43        Self::connect_internal(gateway_url, auth_token, None, false).await
44    }
45
46    /// Connect without using persistent storage
47    pub async fn connect_ephemeral(
48        gateway_url: impl Into<String>,
49        auth_token: impl Into<String>,
50    ) -> Result<Self, ClientError> {
51        Self::connect_internal(gateway_url, auth_token, None, true).await
52    }
53
54    /// Connect with custom storage directory
55    pub async fn connect_with_storage(
56        gateway_url: impl Into<String>,
57        auth_token: impl Into<String>,
58        storage_dir: impl Into<PathBuf>,
59    ) -> Result<Self, ClientError> {
60        Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
61    }
62
63    /// Force a fresh registration even if stored connection exists
64    pub async fn connect_fresh(
65        gateway_url: impl Into<String>,
66        auth_token: impl Into<String>,
67    ) -> Result<Self, ClientError> {
68        let gateway_url = gateway_url.into();
69        let auth_token = auth_token.into();
70
71        // Delete any stored connection first
72        if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
73            let _ = storage.delete_connection();
74        }
75
76        Self::connect_internal(gateway_url, auth_token, None, false).await
77    }
78
79    async fn connect_internal(
80        gateway_url: impl Into<String>,
81        auth_token: impl Into<String>,
82        storage_dir: Option<PathBuf>,
83        ephemeral: bool,
84    ) -> Result<Self, ClientError> {
85        let gateway_url = gateway_url.into();
86        let auth_token = auth_token.into();
87
88        if gateway_url.is_empty() {
89            return Err(ClientError::Configuration(
90                "gateway_url cannot be empty".into(),
91            ));
92        }
93
94        if auth_token.is_empty() {
95            return Err(ClientError::Configuration(
96                "auth_token cannot be empty".into(),
97            ));
98        }
99
100        let endpoint = build_endpoint(&gateway_url)?;
101
102        // Initialize storage if not ephemeral
103        let storage = if ephemeral {
104            None
105        } else if let Some(dir) = storage_dir {
106            // Use custom storage directory
107            match ConnectionStorage::new(dir) {
108                Ok(s) => Some(s),
109                Err(e) => {
110                    warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
111                    None
112                }
113            }
114        } else {
115            // Use default gateway-specific storage
116            match ConnectionStorage::for_gateway(&gateway_url) {
117                Ok(s) => Some(s),
118                Err(e) => {
119                    warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
120                    None
121                }
122            }
123        };
124
125        // Check for existing stored connection
126        if let Some(ref storage) = storage {
127            if let Ok(Some(stored)) = storage.get_connection() {
128                info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
129
130                match Self::restore_from_stored(stored, storage.clone()).await {
131                    Ok(conn) => {
132                        info!(node_id = %conn.node_id, "Successfully restored connection from storage");
133                        return Ok(conn);
134                    }
135                    Err(e) => {
136                        warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
137                        // Continue to create fresh connection
138                    }
139                }
140            }
141        }
142
143        // No stored connection or restore failed - create fresh connection
144        info!("Creating fresh connection to gateway");
145
146        // 1. Generate WireGuard keypair
147        let (private_key, public_key) = dh_generate();
148        let public_key_hex = hex::encode(public_key);
149        let private_key_hex = hex::encode(private_key);
150
151        debug!("Generated WireGuard keypair");
152
153        // 2. Create WireGuard transport server on 0.0.0.0:0
154        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
155            ClientError::Configuration(format!("invalid bind address: {}", e))
156        })?;
157
158        let server = Server::new(bind_addr, private_key)
159            .await
160            .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
161
162        debug!("Created transport server");
163
164        // Spawn background processor
165        let server_clone = server.clone();
166        tokio::spawn(async move {
167            if let Err(e) = server_clone.run().await {
168                error!(error = ?e, "Transport server error");
169            }
170        });
171
172        // Spawn maintenance task
173        let server_maintenance = server.clone();
174        tokio::spawn(async move {
175            server_maintenance.run_maintenance().await
176        });
177
178        // Wait for transport to be ready
179        server
180            .wait_until_ready()
181            .await
182            .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
183
184        debug!("Transport server ready");
185
186        // 3. Discover network info using actual bound port
187        let local_addr = server
188            .local_addr()
189            .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
190
191        let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
192            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
193
194        debug!(
195            public_ip = %network_info.public_ip,
196            public_port = network_info.public_port,
197            local_ip = %network_info.local_ip,
198            local_port = network_info.local_port,
199            "Discovered network information"
200        );
201
202        // 4. Register with gateway
203        let registration = register_with_gateway(
204            &endpoint,
205            &auth_token,
206            &public_key_hex,
207            &network_info,
208        )
209        .await?;
210
211        debug!(
212            node_id = %registration.node_id,
213            assigned_ip = %registration.assigned_ip,
214            "Registered with gateway"
215        );
216
217        // 5. Add gateway as peer
218        let gateway_public_key_bytes = hex::decode(&registration.gateway_public_key_hex)
219            .map_err(|e| {
220                ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
221            })?;
222
223        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
224            .as_slice()
225            .try_into()
226            .map_err(|e| {
227                ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
228            })?;
229
230        server
231            .add_peer(gateway_public_key, None)
232            .await
233            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
234
235        debug!("Added gateway as peer");
236
237        // 6. Store connection info
238        if let Some(ref storage) = storage {
239            let now = current_timestamp();
240            let stored = StoredConnection {
241                node_id: registration.node_id.clone(),
242                token: registration.token.clone(),
243                gateway_url: gateway_url.clone(),
244                assigned_ip: registration.assigned_ip.clone(),
245                private_key_hex,
246                public_key_hex,
247                gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
248                created_at: now,
249                last_connected_at: now,
250            };
251
252            if let Err(e) = storage.store_connection(&stored) {
253                warn!(error = ?e, "Failed to persist connection to storage");
254            } else {
255                debug!("Persisted connection to storage");
256            }
257        }
258
259        Ok(Self {
260            transport: TransportServer::new(server),
261            node_id: registration.node_id,
262            assigned_ip: registration.assigned_ip,
263            token: registration.token,
264            endpoint,
265            gateway_url,
266            storage,
267        })
268    }
269
270    /// Restore a connection from stored credentials
271    async fn restore_from_stored(
272        stored: StoredConnection,
273        storage: ConnectionStorage,
274    ) -> Result<Self, ClientError> {
275        let endpoint = build_endpoint(&stored.gateway_url)?;
276
277        // Parse stored keys
278        let private_key_bytes = hex::decode(&stored.private_key_hex)
279            .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
280        let private_key: PrivateKey = private_key_bytes
281            .as_slice()
282            .try_into()
283            .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
284
285        let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
286            .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
287        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
288            .as_slice()
289            .try_into()
290            .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
291
292        // Create transport server with stored private key
293        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
294            ClientError::Configuration(format!("invalid bind address: {}", e))
295        })?;
296
297        let server = Server::new(bind_addr, private_key)
298            .await
299            .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
300
301        debug!("Created transport server with stored keys");
302
303        // Spawn background processor
304        let server_clone = server.clone();
305        tokio::spawn(async move {
306            if let Err(e) = server_clone.run().await {
307                error!(error = ?e, "Transport server error");
308            }
309        });
310
311        // Spawn maintenance task
312        let server_maintenance = server.clone();
313        tokio::spawn(async move {
314            server_maintenance.run_maintenance().await
315        });
316
317        // Wait for transport to be ready
318        server
319            .wait_until_ready()
320            .await
321            .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
322
323        // Discover current network info (may have changed)
324        let local_addr = server
325            .local_addr()
326            .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
327
328        let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
329            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
330
331        debug!("Rediscovered network information");
332
333        // Add gateway as peer
334        server
335            .add_peer(gateway_public_key, None)
336            .await
337            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
338
339        debug!("Added gateway as peer");
340
341        // Update last_connected_at timestamp
342        if let Err(e) = storage.update_last_connected() {
343            warn!(error = ?e, "Failed to update last_connected timestamp");
344        }
345
346        Ok(Self {
347            transport: TransportServer::new(server),
348            node_id: stored.node_id,
349            assigned_ip: stored.assigned_ip,
350            token: stored.token,
351            endpoint,
352            gateway_url: stored.gateway_url,
353            storage: Some(storage),
354        })
355    }
356
357    /// Connect using default gateway (http://127.0.0.1:8008)
358    pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
359        Self::connect(DEFAULT_GATEWAY, auth_token).await
360    }
361
362    /// Get the node ID assigned by the gateway
363    pub fn node_id(&self) -> &str {
364        &self.node_id
365    }
366
367    /// Get the IP address assigned by the gateway
368    pub fn assigned_ip(&self) -> &str {
369        &self.assigned_ip
370    }
371
372    /// Get the transport for sending/receiving data
373    pub fn transport(&self) -> &TransportServer {
374        &self.transport
375    }
376
377    /// Disconnect from the gateway and deregister
378    pub async fn disconnect(self) -> Result<(), ClientError> {
379        let url = format!("{}/{}", self.endpoint, self.token);
380
381        let client = reqwest::Client::new();
382        let response = client.delete(&url).send().await?;
383
384        let status = response.status();
385
386        if status.is_success() || status == StatusCode::NO_CONTENT {
387            debug!("Successfully deregistered from gateway");
388
389            // Remove stored connection
390            if let Some(storage) = self.storage {
391                if let Err(e) = storage.delete_connection() {
392                    warn!(error = ?e, "Failed to delete stored connection");
393                } else {
394                    debug!("Deleted stored connection");
395                }
396            }
397
398            Ok(())
399        } else {
400            let message = response
401                .text()
402                .await
403                .unwrap_or_else(|_| "unknown error".to_string());
404            Err(ClientError::GatewayError {
405                status: status.as_u16(),
406                message,
407            })
408        }
409    }
410}
411
412fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
413    let gateway_url = gateway_url.trim();
414
415    if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
416        return Err(ClientError::Configuration(
417            "gateway_url must start with http:// or https://".into(),
418        ));
419    }
420
421    Ok(format!(
422        "{}{}",
423        gateway_url.trim_end_matches('/'),
424        API_PATH
425    ))
426}
427
428async fn register_with_gateway(
429    endpoint: &str,
430    auth_token: &str,
431    public_key_hex: &str,
432    network_info: &crate::types::NetworkInfo,
433) -> Result<RegistrationResponse, ClientError> {
434    let payload = RegistrationRequest {
435        public_key_hex,
436        public_ip: Some(network_info.public_ip.as_str()),
437        public_port: Some(network_info.public_port),
438        local_ip: Some(network_info.local_ip.as_str()),
439        local_port: Some(network_info.local_port),
440    };
441
442    let client = reqwest::Client::new();
443    let response = client
444        .post(endpoint)
445        .header("X-HT-Auth", auth_token)
446        .json(&payload)
447        .send()
448        .await?;
449
450    let status = response.status();
451
452    if status.is_success() {
453        let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
454            ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
455        })?;
456
457        Ok(registration_response)
458    } else {
459        let message = response
460            .text()
461            .await
462            .unwrap_or_else(|_| "unknown error".to_string());
463        Err(ClientError::GatewayError {
464            status: status.as_u16(),
465            message,
466        })
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn build_endpoint_requires_scheme() {
476        let result = build_endpoint("gateway.example.com:8008");
477        assert!(matches!(result, Err(ClientError::Configuration(_))));
478    }
479
480    #[test]
481    fn build_endpoint_accepts_http() {
482        let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
483        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
484    }
485
486    #[test]
487    fn build_endpoint_accepts_https() {
488        let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
489        assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
490    }
491
492    #[test]
493    fn build_endpoint_strips_trailing_slash() {
494        let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
495        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
496    }
497}