hightower_client/
connection.rs

1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{PeerInfo, RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::connection::Connection;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15/// Main connection to Hightower gateway with integrated WireGuard transport
16pub struct HightowerConnection {
17    transport: TransportServer,
18    node_id: String,
19    assigned_ip: String,
20    token: String,
21    endpoint: String,
22    gateway_url: String,
23    gateway_endpoint: SocketAddr,
24    gateway_public_key: PublicKey25519,
25    storage: Option<ConnectionStorage>,
26}
27
28impl HightowerConnection {
29    /// Connect to a Hightower gateway
30    ///
31    /// This method handles everything:
32    /// - Checks for existing stored connection and restores if available
33    /// - Otherwise: Generates WireGuard keypair, registers with gateway
34    /// - Creates transport server on 0.0.0.0:0
35    /// - Discovers network info via STUN using actual bound port
36    /// - Adds gateway as peer
37    /// - Persists connection info to storage (default: ~/.hightower-client/data)
38    ///
39    /// Returns a ready-to-use connection with working transport
40    pub async fn connect(
41        gateway_url: impl Into<String>,
42        auth_token: impl Into<String>,
43    ) -> Result<Self, ClientError> {
44        Self::connect_internal(gateway_url, auth_token, None, false).await
45    }
46
47    /// Connect without using persistent storage
48    pub async fn connect_ephemeral(
49        gateway_url: impl Into<String>,
50        auth_token: impl Into<String>,
51    ) -> Result<Self, ClientError> {
52        Self::connect_internal(gateway_url, auth_token, None, true).await
53    }
54
55    /// Connect with custom storage directory
56    pub async fn connect_with_storage(
57        gateway_url: impl Into<String>,
58        auth_token: impl Into<String>,
59        storage_dir: impl Into<PathBuf>,
60    ) -> Result<Self, ClientError> {
61        Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
62    }
63
64    /// Force a fresh registration even if stored connection exists
65    pub async fn connect_fresh(
66        gateway_url: impl Into<String>,
67        auth_token: impl Into<String>,
68    ) -> Result<Self, ClientError> {
69        let gateway_url = gateway_url.into();
70        let auth_token = auth_token.into();
71
72        // Delete any stored connection first
73        if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
74            let _ = storage.delete_connection();
75        }
76
77        Self::connect_internal(gateway_url, auth_token, None, false).await
78    }
79
80    async fn connect_internal(
81        gateway_url: impl Into<String>,
82        auth_token: impl Into<String>,
83        storage_dir: Option<PathBuf>,
84        ephemeral: bool,
85    ) -> Result<Self, ClientError> {
86        let gateway_url = gateway_url.into();
87        let auth_token = auth_token.into();
88
89        if gateway_url.is_empty() {
90            return Err(ClientError::Configuration(
91                "gateway_url cannot be empty".into(),
92            ));
93        }
94
95        if auth_token.is_empty() {
96            return Err(ClientError::Configuration(
97                "auth_token cannot be empty".into(),
98            ));
99        }
100
101        let endpoint = build_endpoint(&gateway_url)?;
102
103        // Initialize storage if not ephemeral
104        let storage = if ephemeral {
105            None
106        } else if let Some(dir) = storage_dir {
107            // Use custom storage directory
108            match ConnectionStorage::new(dir) {
109                Ok(s) => Some(s),
110                Err(e) => {
111                    warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
112                    None
113                }
114            }
115        } else {
116            // Use default gateway-specific storage
117            match ConnectionStorage::for_gateway(&gateway_url) {
118                Ok(s) => Some(s),
119                Err(e) => {
120                    warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
121                    None
122                }
123            }
124        };
125
126        // Check for existing stored connection
127        if let Some(ref storage) = storage {
128            if let Ok(Some(stored)) = storage.get_connection() {
129                info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
130
131                match Self::restore_from_stored(stored, storage.clone()).await {
132                    Ok(conn) => {
133                        info!(node_id = %conn.node_id, "Successfully restored connection from storage");
134                        return Ok(conn);
135                    }
136                    Err(e) => {
137                        warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
138                        // Continue to create fresh connection
139                    }
140                }
141            }
142        }
143
144        // No stored connection or restore failed - create fresh connection
145        info!("Creating fresh connection to gateway");
146
147        // 1. Generate WireGuard keypair
148        let (private_key, public_key) = dh_generate();
149        let public_key_hex = hex::encode(public_key);
150        let private_key_hex = hex::encode(private_key);
151
152        debug!("Generated WireGuard keypair");
153
154        // 2. Create WireGuard transport server on 0.0.0.0:0
155        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
156            ClientError::Configuration(format!("invalid bind address: {}", e))
157        })?;
158
159        let connection = Connection::new(bind_addr, private_key)
160            .await
161            .map_err(|e| ClientError::Transport(format!("failed to create transport connection: {}", e)))?;
162
163        debug!("Created transport connection");
164
165        // 3. Discover network info using actual bound port
166        let local_addr = connection.local_addr();
167
168        let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
169            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
170
171        debug!(
172            public_ip = %network_info.public_ip,
173            public_port = network_info.public_port,
174            local_ip = %network_info.local_ip,
175            local_port = network_info.local_port,
176            "Discovered network information"
177        );
178
179        // 4. Register with gateway
180        let registration = register_with_gateway(
181            &endpoint,
182            &auth_token,
183            &public_key_hex,
184            &network_info,
185        )
186        .await?;
187
188        debug!(
189            node_id = %registration.node_id,
190            assigned_ip = %registration.assigned_ip,
191            "Registered with gateway"
192        );
193
194        // 5. Add gateway as peer
195        let gateway_public_key_bytes = hex::decode(&registration.gateway_public_key_hex)
196            .map_err(|e| {
197                ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
198            })?;
199
200        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
201            .as_slice()
202            .try_into()
203            .map_err(|e| {
204                ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
205            })?;
206
207        connection
208            .add_peer(gateway_public_key, None)
209            .await
210            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
211
212        debug!("Added gateway as peer");
213
214        // 6. Store connection info
215        if let Some(ref storage) = storage {
216            let now = current_timestamp();
217            let stored = StoredConnection {
218                node_id: registration.node_id.clone(),
219                token: registration.token.clone(),
220                gateway_url: gateway_url.clone(),
221                assigned_ip: registration.assigned_ip.clone(),
222                private_key_hex,
223                public_key_hex,
224                gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
225                created_at: now,
226                last_connected_at: now,
227            };
228
229            if let Err(e) = storage.store_connection(&stored) {
230                warn!(error = ?e, "Failed to persist connection to storage");
231            } else {
232                debug!("Persisted connection to storage");
233            }
234        }
235
236        // Parse gateway WireGuard endpoint from the gateway URL
237        let gateway_endpoint = parse_gateway_wireguard_endpoint(&gateway_url).await?;
238
239        Ok(Self {
240            transport: TransportServer::new(connection),
241            node_id: registration.node_id,
242            assigned_ip: registration.assigned_ip,
243            token: registration.token,
244            endpoint,
245            gateway_url,
246            gateway_endpoint,
247            gateway_public_key,
248            storage,
249        })
250    }
251
252    /// Restore a connection from stored credentials
253    async fn restore_from_stored(
254        stored: StoredConnection,
255        storage: ConnectionStorage,
256    ) -> Result<Self, ClientError> {
257        let endpoint = build_endpoint(&stored.gateway_url)?;
258
259        // Parse stored keys
260        let private_key_bytes = hex::decode(&stored.private_key_hex)
261            .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
262        let private_key: PrivateKey = private_key_bytes
263            .as_slice()
264            .try_into()
265            .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
266
267        let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
268            .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
269        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
270            .as_slice()
271            .try_into()
272            .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
273
274        // Create transport connection with stored private key
275        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
276            ClientError::Configuration(format!("invalid bind address: {}", e))
277        })?;
278
279        let connection = Connection::new(bind_addr, private_key)
280            .await
281            .map_err(|e| ClientError::Transport(format!("failed to create transport connection: {}", e)))?;
282
283        debug!("Created transport connection with stored keys");
284
285        // Discover current network info (may have changed)
286        let local_addr = connection.local_addr();
287
288        let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
289            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
290
291        debug!("Rediscovered network information");
292
293        // Add gateway as peer
294        connection
295            .add_peer(gateway_public_key, None)
296            .await
297            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
298
299        debug!("Added gateway as peer");
300
301        // Update last_connected_at timestamp
302        if let Err(e) = storage.update_last_connected() {
303            warn!(error = ?e, "Failed to update last_connected timestamp");
304        }
305
306        // Parse gateway WireGuard endpoint from the gateway URL
307        let gateway_endpoint = parse_gateway_wireguard_endpoint(&stored.gateway_url).await?;
308
309        Ok(Self {
310            transport: TransportServer::new(connection),
311            node_id: stored.node_id,
312            assigned_ip: stored.assigned_ip,
313            token: stored.token,
314            endpoint,
315            gateway_url: stored.gateway_url,
316            gateway_endpoint,
317            gateway_public_key,
318            storage: Some(storage),
319        })
320    }
321
322    /// Connect using default gateway (http://127.0.0.1:8008)
323    pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
324        Self::connect(DEFAULT_GATEWAY, auth_token).await
325    }
326
327    /// Get the node ID assigned by the gateway
328    pub fn node_id(&self) -> &str {
329        &self.node_id
330    }
331
332    /// Get the IP address assigned by the gateway
333    pub fn assigned_ip(&self) -> &str {
334        &self.assigned_ip
335    }
336
337    /// Get the transport for sending/receiving data
338    pub fn transport(&self) -> &TransportServer {
339        &self.transport
340    }
341
342    /// Ping the gateway over WireGuard to verify connectivity
343    pub async fn ping_gateway(&self) -> Result<(), ClientError> {
344        debug!("Pinging gateway over WireGuard");
345
346        // Connect to the gateway's WireGuard endpoint
347        let mut stream = self
348            .transport
349            .connection()
350            .connect(self.gateway_endpoint, self.gateway_public_key)
351            .await
352            .map_err(|e| ClientError::Transport(format!("failed to connect to gateway: {}", e)))?;
353
354        debug!("WireGuard connection established to gateway");
355
356        // Send HTTP GET request to /ping
357        let request = b"GET /ping HTTP/1.1\r\nHost: gateway\r\nConnection: close\r\n\r\n";
358        stream.send(request)
359            .await
360            .map_err(|e| ClientError::Transport(format!("failed to send ping request: {}", e)))?;
361
362        // Receive response
363        let response_bytes = stream
364            .recv()
365            .await
366            .map_err(|e| ClientError::Transport(format!("failed to receive ping response: {}", e)))?;
367
368        let response = String::from_utf8_lossy(&response_bytes);
369
370        if response.contains("200 OK") && response.contains("Pong") {
371            debug!("Successfully pinged gateway");
372            Ok(())
373        } else {
374            Err(ClientError::GatewayError {
375                status: 500,
376                message: format!("Unexpected ping response: {}", response),
377            })
378        }
379    }
380
381    /// Get peer information from the gateway
382    ///
383    /// Accepts either a node_id (e.g., "ht-festive-penguin-abc123") or
384    /// an assigned IP (e.g., "100.64.0.5")
385    pub async fn get_peer_info(&self, node_id_or_ip: &str) -> Result<PeerInfo, ClientError> {
386        debug!(peer = %node_id_or_ip, "Fetching peer info from gateway");
387
388        // Query gateway API: GET /api/peers/{node_id_or_ip}
389        let url = format!("{}/api/peers/{}", self.gateway_url.trim_end_matches('/'), node_id_or_ip);
390
391        let client = reqwest::Client::new();
392        let response = client
393            .get(&url)
394            .send()
395            .await?;
396
397        let status = response.status();
398
399        if status.is_success() {
400            let peer_info: PeerInfo = response.json().await.map_err(|e| {
401                ClientError::InvalidResponse(format!("failed to parse peer info: {}", e))
402            })?;
403
404            debug!(
405                node_id = %peer_info.node_id,
406                assigned_ip = %peer_info.assigned_ip,
407                "Retrieved peer info from gateway"
408            );
409
410            Ok(peer_info)
411        } else {
412            let message = response
413                .text()
414                .await
415                .unwrap_or_else(|_| "unknown error".to_string());
416            Err(ClientError::GatewayError {
417                status: status.as_u16(),
418                message: format!("Failed to get peer info: {}", message),
419            })
420        }
421    }
422
423    /// Dial a peer by node ID or assigned IP
424    ///
425    /// This method:
426    /// 1. Fetches peer info from gateway (public key, endpoint, etc.)
427    /// 2. Adds peer to WireGuard if not already present
428    /// 3. Dials the peer over the WireGuard network
429    ///
430    /// # Arguments
431    /// * `peer` - Node ID (e.g., "ht-festive-penguin") or assigned IP (e.g., "100.64.0.5")
432    /// * `port` - Port to connect to on the peer
433    ///
434    /// # Example
435    /// ```no_run
436    /// # async fn example(conn: &hightower_client::HightowerConnection) -> Result<(), Box<dyn std::error::Error>> {
437    /// let connection = conn.dial("ht-festive-penguin-abc123", 8080).await?;
438    /// connection.send(b"Hello, peer!").await?;
439    /// # Ok(())
440    /// # }
441    /// ```
442    pub async fn dial(&self, peer: &str, port: u16) -> Result<hightower_wireguard::connection::Stream, ClientError> {
443        // 1. Get peer info from gateway
444        let peer_info = self.get_peer_info(peer).await?;
445
446        // 2. Parse peer's public key
447        let peer_public_key_bytes = hex::decode(&peer_info.public_key_hex)
448            .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key hex: {}", e)))?;
449
450        let peer_public_key: PublicKey25519 = peer_public_key_bytes
451            .as_slice()
452            .try_into()
453            .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key format: {:?}", e)))?;
454
455        // 3. Add peer to WireGuard (idempotent - safe to call multiple times)
456        self.transport
457            .connection()
458            .add_peer(peer_public_key, peer_info.endpoint)
459            .await
460            .map_err(|e| ClientError::Transport(format!("failed to add peer: {}", e)))?;
461
462        debug!(
463            peer_id = %peer_info.node_id,
464            peer_ip = %peer_info.assigned_ip,
465            port = port,
466            "Added peer and connecting"
467        );
468
469        // 4. Connect using the peer's assigned IP on the WireGuard network
470        let peer_addr: SocketAddr = format!("{}:{}", peer_info.assigned_ip, port)
471            .parse()
472            .map_err(|e| ClientError::Transport(format!("invalid peer address: {}", e)))?;
473
474        let stream = self
475            .transport
476            .connection()
477            .connect(peer_addr, peer_public_key)
478            .await
479            .map_err(|e| ClientError::Transport(format!("failed to connect to peer: {}", e)))?;
480
481        debug!(
482            peer_id = %peer_info.node_id,
483            addr = %peer_addr,
484            "Successfully connected to peer"
485        );
486
487        Ok(stream)
488    }
489
490    /// Disconnect from the gateway and deregister
491    pub async fn disconnect(self) -> Result<(), ClientError> {
492        let url = format!("{}/{}", self.endpoint, self.token);
493
494        let client = reqwest::Client::new();
495        let response = client.delete(&url).send().await?;
496
497        let status = response.status();
498
499        if status.is_success() || status == StatusCode::NO_CONTENT {
500            debug!("Successfully deregistered from gateway");
501
502            // Remove stored connection
503            if let Some(storage) = self.storage {
504                if let Err(e) = storage.delete_connection() {
505                    warn!(error = ?e, "Failed to delete stored connection");
506                } else {
507                    debug!("Deleted stored connection");
508                }
509            }
510
511            Ok(())
512        } else {
513            let message = response
514                .text()
515                .await
516                .unwrap_or_else(|_| "unknown error".to_string());
517            Err(ClientError::GatewayError {
518                status: status.as_u16(),
519                message,
520            })
521        }
522    }
523}
524
525fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
526    let gateway_url = gateway_url.trim();
527
528    if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
529        return Err(ClientError::Configuration(
530            "gateway_url must start with http:// or https://".into(),
531        ));
532    }
533
534    Ok(format!(
535        "{}{}",
536        gateway_url.trim_end_matches('/'),
537        API_PATH
538    ))
539}
540
541async fn parse_gateway_wireguard_endpoint(gateway_url: &str) -> Result<SocketAddr, ClientError> {
542    let parsed_url = url::Url::parse(gateway_url)
543        .map_err(|e| ClientError::Configuration(format!("invalid gateway URL: {}", e)))?;
544
545    let host = parsed_url.host_str()
546        .ok_or_else(|| ClientError::Configuration("gateway URL has no host".into()))?;
547
548    // Construct WireGuard endpoint using the gateway's host and standard WireGuard port
549    let endpoint_str = format!("{}:51820", host);
550
551    // Use tokio's DNS resolution to handle both hostnames and IP addresses
552    let mut addrs = tokio::net::lookup_host(&endpoint_str)
553        .await
554        .map_err(|e| ClientError::Configuration(format!("failed to resolve gateway endpoint {}: {}", endpoint_str, e)))?;
555
556    addrs.next()
557        .ok_or_else(|| ClientError::Configuration(format!("no addresses found for gateway endpoint: {}", endpoint_str)))
558}
559
560async fn register_with_gateway(
561    endpoint: &str,
562    auth_token: &str,
563    public_key_hex: &str,
564    network_info: &crate::types::NetworkInfo,
565) -> Result<RegistrationResponse, ClientError> {
566    let payload = RegistrationRequest {
567        public_key_hex,
568        public_ip: Some(network_info.public_ip.as_str()),
569        public_port: Some(network_info.public_port),
570        local_ip: Some(network_info.local_ip.as_str()),
571        local_port: Some(network_info.local_port),
572    };
573
574    let client = reqwest::Client::new();
575    let response = client
576        .post(endpoint)
577        .header("X-HT-Auth", auth_token)
578        .json(&payload)
579        .send()
580        .await?;
581
582    let status = response.status();
583
584    if status.is_success() {
585        let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
586            ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
587        })?;
588
589        Ok(registration_response)
590    } else {
591        let message = response
592            .text()
593            .await
594            .unwrap_or_else(|_| "unknown error".to_string());
595        Err(ClientError::GatewayError {
596            status: status.as_u16(),
597            message,
598        })
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn build_endpoint_requires_scheme() {
608        let result = build_endpoint("gateway.example.com:8008");
609        assert!(matches!(result, Err(ClientError::Configuration(_))));
610    }
611
612    #[test]
613    fn build_endpoint_accepts_http() {
614        let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
615        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
616    }
617
618    #[test]
619    fn build_endpoint_accepts_https() {
620        let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
621        assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
622    }
623
624    #[test]
625    fn build_endpoint_strips_trailing_slash() {
626        let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
627        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
628    }
629}