hightower_client/
connection.rs

1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{PeerInfo, RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::transport::Server;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, error, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15/// Main connection to Hightower gateway with integrated WireGuard transport
16pub struct HightowerConnection {
17    transport: TransportServer,
18    node_id: String,
19    assigned_ip: String,
20    token: String,
21    endpoint: String,
22    gateway_url: String,
23    gateway_endpoint: SocketAddr,
24    gateway_public_key: PublicKey25519,
25    storage: Option<ConnectionStorage>,
26}
27
28impl HightowerConnection {
29    /// Connect to a Hightower gateway
30    ///
31    /// This method handles everything:
32    /// - Checks for existing stored connection and restores if available
33    /// - Otherwise: Generates WireGuard keypair, registers with gateway
34    /// - Creates transport server on 0.0.0.0:0
35    /// - Discovers network info via STUN using actual bound port
36    /// - Adds gateway as peer
37    /// - Persists connection info to storage (default: ~/.hightower-client/data)
38    ///
39    /// Returns a ready-to-use connection with working transport
40    pub async fn connect(
41        gateway_url: impl Into<String>,
42        auth_token: impl Into<String>,
43    ) -> Result<Self, ClientError> {
44        Self::connect_internal(gateway_url, auth_token, None, false).await
45    }
46
47    /// Connect without using persistent storage
48    pub async fn connect_ephemeral(
49        gateway_url: impl Into<String>,
50        auth_token: impl Into<String>,
51    ) -> Result<Self, ClientError> {
52        Self::connect_internal(gateway_url, auth_token, None, true).await
53    }
54
55    /// Connect with custom storage directory
56    pub async fn connect_with_storage(
57        gateway_url: impl Into<String>,
58        auth_token: impl Into<String>,
59        storage_dir: impl Into<PathBuf>,
60    ) -> Result<Self, ClientError> {
61        Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
62    }
63
64    /// Force a fresh registration even if stored connection exists
65    pub async fn connect_fresh(
66        gateway_url: impl Into<String>,
67        auth_token: impl Into<String>,
68    ) -> Result<Self, ClientError> {
69        let gateway_url = gateway_url.into();
70        let auth_token = auth_token.into();
71
72        // Delete any stored connection first
73        if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
74            let _ = storage.delete_connection();
75        }
76
77        Self::connect_internal(gateway_url, auth_token, None, false).await
78    }
79
80    async fn connect_internal(
81        gateway_url: impl Into<String>,
82        auth_token: impl Into<String>,
83        storage_dir: Option<PathBuf>,
84        ephemeral: bool,
85    ) -> Result<Self, ClientError> {
86        let gateway_url = gateway_url.into();
87        let auth_token = auth_token.into();
88
89        if gateway_url.is_empty() {
90            return Err(ClientError::Configuration(
91                "gateway_url cannot be empty".into(),
92            ));
93        }
94
95        if auth_token.is_empty() {
96            return Err(ClientError::Configuration(
97                "auth_token cannot be empty".into(),
98            ));
99        }
100
101        let endpoint = build_endpoint(&gateway_url)?;
102
103        // Initialize storage if not ephemeral
104        let storage = if ephemeral {
105            None
106        } else if let Some(dir) = storage_dir {
107            // Use custom storage directory
108            match ConnectionStorage::new(dir) {
109                Ok(s) => Some(s),
110                Err(e) => {
111                    warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
112                    None
113                }
114            }
115        } else {
116            // Use default gateway-specific storage
117            match ConnectionStorage::for_gateway(&gateway_url) {
118                Ok(s) => Some(s),
119                Err(e) => {
120                    warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
121                    None
122                }
123            }
124        };
125
126        // Check for existing stored connection
127        if let Some(ref storage) = storage {
128            if let Ok(Some(stored)) = storage.get_connection() {
129                info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
130
131                match Self::restore_from_stored(stored, storage.clone()).await {
132                    Ok(conn) => {
133                        info!(node_id = %conn.node_id, "Successfully restored connection from storage");
134                        return Ok(conn);
135                    }
136                    Err(e) => {
137                        warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
138                        // Continue to create fresh connection
139                    }
140                }
141            }
142        }
143
144        // No stored connection or restore failed - create fresh connection
145        info!("Creating fresh connection to gateway");
146
147        // 1. Generate WireGuard keypair
148        let (private_key, public_key) = dh_generate();
149        let public_key_hex = hex::encode(public_key);
150        let private_key_hex = hex::encode(private_key);
151
152        debug!("Generated WireGuard keypair");
153
154        // 2. Create WireGuard transport server on 0.0.0.0:0
155        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
156            ClientError::Configuration(format!("invalid bind address: {}", e))
157        })?;
158
159        let server = Server::new(bind_addr, private_key)
160            .await
161            .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
162
163        debug!("Created transport server");
164
165        // Spawn background processor
166        let server_clone = server.clone();
167        tokio::spawn(async move {
168            if let Err(e) = server_clone.run().await {
169                error!(error = ?e, "Transport server error");
170            }
171        });
172
173        // Spawn maintenance task
174        let server_maintenance = server.clone();
175        tokio::spawn(async move {
176            server_maintenance.run_maintenance().await
177        });
178
179        // Wait for transport to be ready
180        server
181            .wait_until_ready()
182            .await
183            .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
184
185        debug!("Transport server ready");
186
187        // 3. Discover network info using actual bound port
188        let local_addr = server
189            .local_addr()
190            .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
191
192        let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
193            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
194
195        debug!(
196            public_ip = %network_info.public_ip,
197            public_port = network_info.public_port,
198            local_ip = %network_info.local_ip,
199            local_port = network_info.local_port,
200            "Discovered network information"
201        );
202
203        // 4. Register with gateway
204        let registration = register_with_gateway(
205            &endpoint,
206            &auth_token,
207            &public_key_hex,
208            &network_info,
209        )
210        .await?;
211
212        debug!(
213            node_id = %registration.node_id,
214            assigned_ip = %registration.assigned_ip,
215            "Registered with gateway"
216        );
217
218        // 5. Add gateway as peer
219        let gateway_public_key_bytes = hex::decode(&registration.gateway_public_key_hex)
220            .map_err(|e| {
221                ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
222            })?;
223
224        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
225            .as_slice()
226            .try_into()
227            .map_err(|e| {
228                ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
229            })?;
230
231        server
232            .add_peer(gateway_public_key, None)
233            .await
234            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
235
236        debug!("Added gateway as peer");
237
238        // 6. Store connection info
239        if let Some(ref storage) = storage {
240            let now = current_timestamp();
241            let stored = StoredConnection {
242                node_id: registration.node_id.clone(),
243                token: registration.token.clone(),
244                gateway_url: gateway_url.clone(),
245                assigned_ip: registration.assigned_ip.clone(),
246                private_key_hex,
247                public_key_hex,
248                gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
249                created_at: now,
250                last_connected_at: now,
251            };
252
253            if let Err(e) = storage.store_connection(&stored) {
254                warn!(error = ?e, "Failed to persist connection to storage");
255            } else {
256                debug!("Persisted connection to storage");
257            }
258        }
259
260        // Gateway WireGuard endpoint (hardcoded for now - could be returned by registration)
261        let gateway_endpoint: SocketAddr = "127.0.0.1:51820".parse()
262            .map_err(|e| ClientError::Configuration(format!("invalid gateway endpoint: {}", e)))?;
263
264        Ok(Self {
265            transport: TransportServer::new(server),
266            node_id: registration.node_id,
267            assigned_ip: registration.assigned_ip,
268            token: registration.token,
269            endpoint,
270            gateway_url,
271            gateway_endpoint,
272            gateway_public_key,
273            storage,
274        })
275    }
276
277    /// Restore a connection from stored credentials
278    async fn restore_from_stored(
279        stored: StoredConnection,
280        storage: ConnectionStorage,
281    ) -> Result<Self, ClientError> {
282        let endpoint = build_endpoint(&stored.gateway_url)?;
283
284        // Parse stored keys
285        let private_key_bytes = hex::decode(&stored.private_key_hex)
286            .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
287        let private_key: PrivateKey = private_key_bytes
288            .as_slice()
289            .try_into()
290            .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
291
292        let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
293            .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
294        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
295            .as_slice()
296            .try_into()
297            .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
298
299        // Create transport server with stored private key
300        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
301            ClientError::Configuration(format!("invalid bind address: {}", e))
302        })?;
303
304        let server = Server::new(bind_addr, private_key)
305            .await
306            .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
307
308        debug!("Created transport server with stored keys");
309
310        // Spawn background processor
311        let server_clone = server.clone();
312        tokio::spawn(async move {
313            if let Err(e) = server_clone.run().await {
314                error!(error = ?e, "Transport server error");
315            }
316        });
317
318        // Spawn maintenance task
319        let server_maintenance = server.clone();
320        tokio::spawn(async move {
321            server_maintenance.run_maintenance().await
322        });
323
324        // Wait for transport to be ready
325        server
326            .wait_until_ready()
327            .await
328            .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
329
330        // Discover current network info (may have changed)
331        let local_addr = server
332            .local_addr()
333            .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
334
335        let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
336            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
337
338        debug!("Rediscovered network information");
339
340        // Add gateway as peer
341        server
342            .add_peer(gateway_public_key, None)
343            .await
344            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
345
346        debug!("Added gateway as peer");
347
348        // Update last_connected_at timestamp
349        if let Err(e) = storage.update_last_connected() {
350            warn!(error = ?e, "Failed to update last_connected timestamp");
351        }
352
353        // Gateway WireGuard endpoint (hardcoded for now - could be stored)
354        let gateway_endpoint: SocketAddr = "127.0.0.1:51820".parse()
355            .map_err(|e| ClientError::Configuration(format!("invalid gateway endpoint: {}", e)))?;
356
357        Ok(Self {
358            transport: TransportServer::new(server),
359            node_id: stored.node_id,
360            assigned_ip: stored.assigned_ip,
361            token: stored.token,
362            endpoint,
363            gateway_url: stored.gateway_url,
364            gateway_endpoint,
365            gateway_public_key,
366            storage: Some(storage),
367        })
368    }
369
370    /// Connect using default gateway (http://127.0.0.1:8008)
371    pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
372        Self::connect(DEFAULT_GATEWAY, auth_token).await
373    }
374
375    /// Get the node ID assigned by the gateway
376    pub fn node_id(&self) -> &str {
377        &self.node_id
378    }
379
380    /// Get the IP address assigned by the gateway
381    pub fn assigned_ip(&self) -> &str {
382        &self.assigned_ip
383    }
384
385    /// Get the transport for sending/receiving data
386    pub fn transport(&self) -> &TransportServer {
387        &self.transport
388    }
389
390    /// Ping the gateway over WireGuard to verify connectivity
391    pub async fn ping_gateway(&self) -> Result<(), ClientError> {
392        debug!("Pinging gateway over WireGuard");
393
394        // Dial the gateway's WireGuard endpoint
395        let conn = self
396            .transport
397            .server()
398            .dial("tcp", &self.gateway_endpoint.to_string(), self.gateway_public_key)
399            .await
400            .map_err(|e| ClientError::Transport(format!("failed to dial gateway: {}", e)))?;
401
402        debug!("WireGuard connection established to gateway");
403
404        // Send HTTP GET request to /ping
405        let request = b"GET /ping HTTP/1.1\r\nHost: gateway\r\nConnection: close\r\n\r\n";
406        conn.send(request)
407            .await
408            .map_err(|e| ClientError::Transport(format!("failed to send ping request: {}", e)))?;
409
410        // Receive response
411        let mut buf = vec![0u8; 8192];
412        let n = conn
413            .recv(&mut buf)
414            .await
415            .map_err(|e| ClientError::Transport(format!("failed to receive ping response: {}", e)))?;
416
417        let response = String::from_utf8_lossy(&buf[..n]);
418
419        if response.contains("200 OK") && response.contains("Pong") {
420            debug!("Successfully pinged gateway");
421            Ok(())
422        } else {
423            Err(ClientError::GatewayError {
424                status: 500,
425                message: format!("Unexpected ping response: {}", response),
426            })
427        }
428    }
429
430    /// Get peer information from the gateway
431    ///
432    /// Accepts either a node_id (e.g., "ht-festive-penguin-abc123") or
433    /// an assigned IP (e.g., "100.64.0.5")
434    pub async fn get_peer_info(&self, node_id_or_ip: &str) -> Result<PeerInfo, ClientError> {
435        debug!(peer = %node_id_or_ip, "Fetching peer info from gateway");
436
437        // Query gateway API: GET /api/peers/{node_id_or_ip}
438        let url = format!("{}/api/peers/{}", self.gateway_url.trim_end_matches('/'), node_id_or_ip);
439
440        let client = reqwest::Client::new();
441        let response = client
442            .get(&url)
443            .send()
444            .await?;
445
446        let status = response.status();
447
448        if status.is_success() {
449            let peer_info: PeerInfo = response.json().await.map_err(|e| {
450                ClientError::InvalidResponse(format!("failed to parse peer info: {}", e))
451            })?;
452
453            debug!(
454                node_id = %peer_info.node_id,
455                assigned_ip = %peer_info.assigned_ip,
456                "Retrieved peer info from gateway"
457            );
458
459            Ok(peer_info)
460        } else {
461            let message = response
462                .text()
463                .await
464                .unwrap_or_else(|_| "unknown error".to_string());
465            Err(ClientError::GatewayError {
466                status: status.as_u16(),
467                message: format!("Failed to get peer info: {}", message),
468            })
469        }
470    }
471
472    /// Dial a peer by node ID or assigned IP
473    ///
474    /// This method:
475    /// 1. Fetches peer info from gateway (public key, endpoint, etc.)
476    /// 2. Adds peer to WireGuard if not already present
477    /// 3. Dials the peer over the WireGuard network
478    ///
479    /// # Arguments
480    /// * `peer` - Node ID (e.g., "ht-festive-penguin") or assigned IP (e.g., "100.64.0.5")
481    /// * `port` - Port to connect to on the peer
482    ///
483    /// # Example
484    /// ```no_run
485    /// # async fn example(conn: &hightower_client::HightowerConnection) -> Result<(), Box<dyn std::error::Error>> {
486    /// let connection = conn.dial("ht-festive-penguin-abc123", 8080).await?;
487    /// connection.send(b"Hello, peer!").await?;
488    /// # Ok(())
489    /// # }
490    /// ```
491    pub async fn dial(&self, peer: &str, port: u16) -> Result<hightower_wireguard::transport::Conn, ClientError> {
492        // 1. Get peer info from gateway
493        let peer_info = self.get_peer_info(peer).await?;
494
495        // 2. Parse peer's public key
496        let peer_public_key_bytes = hex::decode(&peer_info.public_key_hex)
497            .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key hex: {}", e)))?;
498
499        let peer_public_key: PublicKey25519 = peer_public_key_bytes
500            .as_slice()
501            .try_into()
502            .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key format: {:?}", e)))?;
503
504        // 3. Add peer to WireGuard (idempotent - safe to call multiple times)
505        self.transport
506            .server()
507            .add_peer(peer_public_key, peer_info.endpoint)
508            .await
509            .map_err(|e| ClientError::Transport(format!("failed to add peer: {}", e)))?;
510
511        debug!(
512            peer_id = %peer_info.node_id,
513            peer_ip = %peer_info.assigned_ip,
514            port = port,
515            "Added peer and dialing"
516        );
517
518        // 4. Dial using the peer's assigned IP on the WireGuard network
519        let addr = format!("{}:{}", peer_info.assigned_ip, port);
520        let conn = self
521            .transport
522            .server()
523            .dial("tcp", &addr, peer_public_key)
524            .await
525            .map_err(|e| ClientError::Transport(format!("failed to dial peer: {}", e)))?;
526
527        debug!(
528            peer_id = %peer_info.node_id,
529            addr = %addr,
530            "Successfully dialed peer"
531        );
532
533        Ok(conn)
534    }
535
536    /// Disconnect from the gateway and deregister
537    pub async fn disconnect(self) -> Result<(), ClientError> {
538        let url = format!("{}/{}", self.endpoint, self.token);
539
540        let client = reqwest::Client::new();
541        let response = client.delete(&url).send().await?;
542
543        let status = response.status();
544
545        if status.is_success() || status == StatusCode::NO_CONTENT {
546            debug!("Successfully deregistered from gateway");
547
548            // Remove stored connection
549            if let Some(storage) = self.storage {
550                if let Err(e) = storage.delete_connection() {
551                    warn!(error = ?e, "Failed to delete stored connection");
552                } else {
553                    debug!("Deleted stored connection");
554                }
555            }
556
557            Ok(())
558        } else {
559            let message = response
560                .text()
561                .await
562                .unwrap_or_else(|_| "unknown error".to_string());
563            Err(ClientError::GatewayError {
564                status: status.as_u16(),
565                message,
566            })
567        }
568    }
569}
570
571fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
572    let gateway_url = gateway_url.trim();
573
574    if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
575        return Err(ClientError::Configuration(
576            "gateway_url must start with http:// or https://".into(),
577        ));
578    }
579
580    Ok(format!(
581        "{}{}",
582        gateway_url.trim_end_matches('/'),
583        API_PATH
584    ))
585}
586
587async fn register_with_gateway(
588    endpoint: &str,
589    auth_token: &str,
590    public_key_hex: &str,
591    network_info: &crate::types::NetworkInfo,
592) -> Result<RegistrationResponse, ClientError> {
593    let payload = RegistrationRequest {
594        public_key_hex,
595        public_ip: Some(network_info.public_ip.as_str()),
596        public_port: Some(network_info.public_port),
597        local_ip: Some(network_info.local_ip.as_str()),
598        local_port: Some(network_info.local_port),
599    };
600
601    let client = reqwest::Client::new();
602    let response = client
603        .post(endpoint)
604        .header("X-HT-Auth", auth_token)
605        .json(&payload)
606        .send()
607        .await?;
608
609    let status = response.status();
610
611    if status.is_success() {
612        let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
613            ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
614        })?;
615
616        Ok(registration_response)
617    } else {
618        let message = response
619            .text()
620            .await
621            .unwrap_or_else(|_| "unknown error".to_string());
622        Err(ClientError::GatewayError {
623            status: status.as_u16(),
624            message,
625        })
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn build_endpoint_requires_scheme() {
635        let result = build_endpoint("gateway.example.com:8008");
636        assert!(matches!(result, Err(ClientError::Configuration(_))));
637    }
638
639    #[test]
640    fn build_endpoint_accepts_http() {
641        let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
642        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
643    }
644
645    #[test]
646    fn build_endpoint_accepts_https() {
647        let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
648        assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
649    }
650
651    #[test]
652    fn build_endpoint_strips_trailing_slash() {
653        let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
654        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
655    }
656}