hightower_client/
connection.rs

1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{PeerInfo, RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::connection::Connection;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15/// Main connection to Hightower gateway with integrated WireGuard transport
16pub struct HightowerConnection {
17    transport: TransportServer,
18    node_id: String,
19    assigned_ip: String,
20    token: String,
21    endpoint: String,
22    gateway_url: String,
23    gateway_endpoint: SocketAddr,
24    gateway_public_key: PublicKey25519,
25    storage: Option<ConnectionStorage>,
26}
27
28impl HightowerConnection {
29    /// Connect to a Hightower gateway
30    ///
31    /// This method handles everything:
32    /// - Checks for existing stored connection and restores if available
33    /// - Otherwise: Generates WireGuard keypair, registers with gateway
34    /// - Creates transport server on 0.0.0.0:0
35    /// - Discovers network info via STUN using actual bound port
36    /// - Adds gateway as peer
37    /// - Persists connection info to storage (default: ~/.hightower-client/data)
38    ///
39    /// Returns a ready-to-use connection with working transport
40    pub async fn connect(
41        gateway_url: impl Into<String>,
42        auth_token: impl Into<String>,
43    ) -> Result<Self, ClientError> {
44        Self::connect_internal(gateway_url, auth_token, None, false).await
45    }
46
47    /// Connect without using persistent storage
48    pub async fn connect_ephemeral(
49        gateway_url: impl Into<String>,
50        auth_token: impl Into<String>,
51    ) -> Result<Self, ClientError> {
52        Self::connect_internal(gateway_url, auth_token, None, true).await
53    }
54
55    /// Connect with custom storage directory
56    pub async fn connect_with_storage(
57        gateway_url: impl Into<String>,
58        auth_token: impl Into<String>,
59        storage_dir: impl Into<PathBuf>,
60    ) -> Result<Self, ClientError> {
61        Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
62    }
63
64    /// Force a fresh registration even if stored connection exists
65    pub async fn connect_fresh(
66        gateway_url: impl Into<String>,
67        auth_token: impl Into<String>,
68    ) -> Result<Self, ClientError> {
69        let gateway_url = gateway_url.into();
70        let auth_token = auth_token.into();
71
72        // Delete any stored connection first
73        if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
74            let _ = storage.delete_connection();
75        }
76
77        Self::connect_internal(gateway_url, auth_token, None, false).await
78    }
79
80    async fn connect_internal(
81        gateway_url: impl Into<String>,
82        auth_token: impl Into<String>,
83        storage_dir: Option<PathBuf>,
84        ephemeral: bool,
85    ) -> Result<Self, ClientError> {
86        let gateway_url = gateway_url.into();
87        let auth_token = auth_token.into();
88
89        if gateway_url.is_empty() {
90            return Err(ClientError::Configuration(
91                "gateway_url cannot be empty".into(),
92            ));
93        }
94
95        if auth_token.is_empty() {
96            return Err(ClientError::Configuration(
97                "auth_token cannot be empty".into(),
98            ));
99        }
100
101        let endpoint = build_endpoint(&gateway_url)?;
102
103        // Initialize storage if not ephemeral
104        let storage = if ephemeral {
105            None
106        } else if let Some(dir) = storage_dir {
107            // Use custom storage directory
108            match ConnectionStorage::new(dir) {
109                Ok(s) => Some(s),
110                Err(e) => {
111                    warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
112                    None
113                }
114            }
115        } else {
116            // Use default gateway-specific storage
117            match ConnectionStorage::for_gateway(&gateway_url) {
118                Ok(s) => Some(s),
119                Err(e) => {
120                    warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
121                    None
122                }
123            }
124        };
125
126        // Check for existing stored connection
127        if let Some(ref storage) = storage {
128            if let Ok(Some(stored)) = storage.get_connection() {
129                info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
130
131                match Self::restore_from_stored(stored, storage.clone()).await {
132                    Ok(conn) => {
133                        info!(node_id = %conn.node_id, "Successfully restored connection from storage");
134                        return Ok(conn);
135                    }
136                    Err(e) => {
137                        warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
138                        // Continue to create fresh connection
139                    }
140                }
141            }
142        }
143
144        // No stored connection or restore failed - create fresh connection
145        info!("Creating fresh connection to gateway");
146
147        // 1. Generate WireGuard keypair
148        let (private_key, public_key) = dh_generate();
149        let public_key_hex = hex::encode(public_key);
150        let private_key_hex = hex::encode(private_key);
151
152        debug!("Generated WireGuard keypair");
153
154        // 2. Create WireGuard transport server on 0.0.0.0:0
155        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
156            ClientError::Configuration(format!("invalid bind address: {}", e))
157        })?;
158
159        let connection = Connection::new(bind_addr, private_key)
160            .await
161            .map_err(|e| ClientError::Transport(format!("failed to create transport connection: {}", e)))?;
162
163        debug!("Created transport connection");
164
165        // 3. Discover network info using actual bound port
166        let local_addr = connection.local_addr();
167
168        let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
169            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
170
171        debug!(
172            public_ip = %network_info.public_ip,
173            public_port = network_info.public_port,
174            local_ip = %network_info.local_ip,
175            local_port = network_info.local_port,
176            "Discovered network information"
177        );
178
179        // 4. Register with gateway
180        let registration = register_with_gateway(
181            &endpoint,
182            &auth_token,
183            &public_key_hex,
184            &network_info,
185        )
186        .await?;
187
188        debug!(
189            node_id = %registration.node_id,
190            assigned_ip = %registration.assigned_ip,
191            "Registered with gateway"
192        );
193
194        // 5. Add gateway as peer
195        let gateway_public_key_bytes = hex::decode(&registration.gateway_public_key_hex)
196            .map_err(|e| {
197                ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
198            })?;
199
200        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
201            .as_slice()
202            .try_into()
203            .map_err(|e| {
204                ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
205            })?;
206
207        connection
208            .add_peer(gateway_public_key, None)
209            .await
210            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
211
212        debug!("Added gateway as peer");
213
214        // 6. Store connection info
215        if let Some(ref storage) = storage {
216            let now = current_timestamp();
217            let stored = StoredConnection {
218                node_id: registration.node_id.clone(),
219                token: registration.token.clone(),
220                gateway_url: gateway_url.clone(),
221                assigned_ip: registration.assigned_ip.clone(),
222                private_key_hex,
223                public_key_hex,
224                gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
225                created_at: now,
226                last_connected_at: now,
227            };
228
229            if let Err(e) = storage.store_connection(&stored) {
230                warn!(error = ?e, "Failed to persist connection to storage");
231            } else {
232                debug!("Persisted connection to storage");
233            }
234        }
235
236        // Gateway WireGuard endpoint (hardcoded for now - could be returned by registration)
237        let gateway_endpoint: SocketAddr = "127.0.0.1:51820".parse()
238            .map_err(|e| ClientError::Configuration(format!("invalid gateway endpoint: {}", e)))?;
239
240        Ok(Self {
241            transport: TransportServer::new(connection),
242            node_id: registration.node_id,
243            assigned_ip: registration.assigned_ip,
244            token: registration.token,
245            endpoint,
246            gateway_url,
247            gateway_endpoint,
248            gateway_public_key,
249            storage,
250        })
251    }
252
253    /// Restore a connection from stored credentials
254    async fn restore_from_stored(
255        stored: StoredConnection,
256        storage: ConnectionStorage,
257    ) -> Result<Self, ClientError> {
258        let endpoint = build_endpoint(&stored.gateway_url)?;
259
260        // Parse stored keys
261        let private_key_bytes = hex::decode(&stored.private_key_hex)
262            .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
263        let private_key: PrivateKey = private_key_bytes
264            .as_slice()
265            .try_into()
266            .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
267
268        let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
269            .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
270        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
271            .as_slice()
272            .try_into()
273            .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
274
275        // Create transport connection with stored private key
276        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
277            ClientError::Configuration(format!("invalid bind address: {}", e))
278        })?;
279
280        let connection = Connection::new(bind_addr, private_key)
281            .await
282            .map_err(|e| ClientError::Transport(format!("failed to create transport connection: {}", e)))?;
283
284        debug!("Created transport connection with stored keys");
285
286        // Discover current network info (may have changed)
287        let local_addr = connection.local_addr();
288
289        let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
290            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
291
292        debug!("Rediscovered network information");
293
294        // Add gateway as peer
295        connection
296            .add_peer(gateway_public_key, None)
297            .await
298            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
299
300        debug!("Added gateway as peer");
301
302        // Update last_connected_at timestamp
303        if let Err(e) = storage.update_last_connected() {
304            warn!(error = ?e, "Failed to update last_connected timestamp");
305        }
306
307        // Gateway WireGuard endpoint (hardcoded for now - could be stored)
308        let gateway_endpoint: SocketAddr = "127.0.0.1:51820".parse()
309            .map_err(|e| ClientError::Configuration(format!("invalid gateway endpoint: {}", e)))?;
310
311        Ok(Self {
312            transport: TransportServer::new(connection),
313            node_id: stored.node_id,
314            assigned_ip: stored.assigned_ip,
315            token: stored.token,
316            endpoint,
317            gateway_url: stored.gateway_url,
318            gateway_endpoint,
319            gateway_public_key,
320            storage: Some(storage),
321        })
322    }
323
324    /// Connect using default gateway (http://127.0.0.1:8008)
325    pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
326        Self::connect(DEFAULT_GATEWAY, auth_token).await
327    }
328
329    /// Get the node ID assigned by the gateway
330    pub fn node_id(&self) -> &str {
331        &self.node_id
332    }
333
334    /// Get the IP address assigned by the gateway
335    pub fn assigned_ip(&self) -> &str {
336        &self.assigned_ip
337    }
338
339    /// Get the transport for sending/receiving data
340    pub fn transport(&self) -> &TransportServer {
341        &self.transport
342    }
343
344    /// Ping the gateway over WireGuard to verify connectivity
345    pub async fn ping_gateway(&self) -> Result<(), ClientError> {
346        debug!("Pinging gateway over WireGuard");
347
348        // Connect to the gateway's WireGuard endpoint
349        let mut stream = self
350            .transport
351            .connection()
352            .connect(self.gateway_endpoint, self.gateway_public_key)
353            .await
354            .map_err(|e| ClientError::Transport(format!("failed to connect to gateway: {}", e)))?;
355
356        debug!("WireGuard connection established to gateway");
357
358        // Send HTTP GET request to /ping
359        let request = b"GET /ping HTTP/1.1\r\nHost: gateway\r\nConnection: close\r\n\r\n";
360        stream.send(request)
361            .await
362            .map_err(|e| ClientError::Transport(format!("failed to send ping request: {}", e)))?;
363
364        // Receive response
365        let response_bytes = stream
366            .recv()
367            .await
368            .map_err(|e| ClientError::Transport(format!("failed to receive ping response: {}", e)))?;
369
370        let response = String::from_utf8_lossy(&response_bytes);
371
372        if response.contains("200 OK") && response.contains("Pong") {
373            debug!("Successfully pinged gateway");
374            Ok(())
375        } else {
376            Err(ClientError::GatewayError {
377                status: 500,
378                message: format!("Unexpected ping response: {}", response),
379            })
380        }
381    }
382
383    /// Get peer information from the gateway
384    ///
385    /// Accepts either a node_id (e.g., "ht-festive-penguin-abc123") or
386    /// an assigned IP (e.g., "100.64.0.5")
387    pub async fn get_peer_info(&self, node_id_or_ip: &str) -> Result<PeerInfo, ClientError> {
388        debug!(peer = %node_id_or_ip, "Fetching peer info from gateway");
389
390        // Query gateway API: GET /api/peers/{node_id_or_ip}
391        let url = format!("{}/api/peers/{}", self.gateway_url.trim_end_matches('/'), node_id_or_ip);
392
393        let client = reqwest::Client::new();
394        let response = client
395            .get(&url)
396            .send()
397            .await?;
398
399        let status = response.status();
400
401        if status.is_success() {
402            let peer_info: PeerInfo = response.json().await.map_err(|e| {
403                ClientError::InvalidResponse(format!("failed to parse peer info: {}", e))
404            })?;
405
406            debug!(
407                node_id = %peer_info.node_id,
408                assigned_ip = %peer_info.assigned_ip,
409                "Retrieved peer info from gateway"
410            );
411
412            Ok(peer_info)
413        } else {
414            let message = response
415                .text()
416                .await
417                .unwrap_or_else(|_| "unknown error".to_string());
418            Err(ClientError::GatewayError {
419                status: status.as_u16(),
420                message: format!("Failed to get peer info: {}", message),
421            })
422        }
423    }
424
425    /// Dial a peer by node ID or assigned IP
426    ///
427    /// This method:
428    /// 1. Fetches peer info from gateway (public key, endpoint, etc.)
429    /// 2. Adds peer to WireGuard if not already present
430    /// 3. Dials the peer over the WireGuard network
431    ///
432    /// # Arguments
433    /// * `peer` - Node ID (e.g., "ht-festive-penguin") or assigned IP (e.g., "100.64.0.5")
434    /// * `port` - Port to connect to on the peer
435    ///
436    /// # Example
437    /// ```no_run
438    /// # async fn example(conn: &hightower_client::HightowerConnection) -> Result<(), Box<dyn std::error::Error>> {
439    /// let connection = conn.dial("ht-festive-penguin-abc123", 8080).await?;
440    /// connection.send(b"Hello, peer!").await?;
441    /// # Ok(())
442    /// # }
443    /// ```
444    pub async fn dial(&self, peer: &str, port: u16) -> Result<hightower_wireguard::connection::Stream, ClientError> {
445        // 1. Get peer info from gateway
446        let peer_info = self.get_peer_info(peer).await?;
447
448        // 2. Parse peer's public key
449        let peer_public_key_bytes = hex::decode(&peer_info.public_key_hex)
450            .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key hex: {}", e)))?;
451
452        let peer_public_key: PublicKey25519 = peer_public_key_bytes
453            .as_slice()
454            .try_into()
455            .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key format: {:?}", e)))?;
456
457        // 3. Add peer to WireGuard (idempotent - safe to call multiple times)
458        self.transport
459            .connection()
460            .add_peer(peer_public_key, peer_info.endpoint)
461            .await
462            .map_err(|e| ClientError::Transport(format!("failed to add peer: {}", e)))?;
463
464        debug!(
465            peer_id = %peer_info.node_id,
466            peer_ip = %peer_info.assigned_ip,
467            port = port,
468            "Added peer and connecting"
469        );
470
471        // 4. Connect using the peer's assigned IP on the WireGuard network
472        let peer_addr: SocketAddr = format!("{}:{}", peer_info.assigned_ip, port)
473            .parse()
474            .map_err(|e| ClientError::Transport(format!("invalid peer address: {}", e)))?;
475
476        let stream = self
477            .transport
478            .connection()
479            .connect(peer_addr, peer_public_key)
480            .await
481            .map_err(|e| ClientError::Transport(format!("failed to connect to peer: {}", e)))?;
482
483        debug!(
484            peer_id = %peer_info.node_id,
485            addr = %peer_addr,
486            "Successfully connected to peer"
487        );
488
489        Ok(stream)
490    }
491
492    /// Disconnect from the gateway and deregister
493    pub async fn disconnect(self) -> Result<(), ClientError> {
494        let url = format!("{}/{}", self.endpoint, self.token);
495
496        let client = reqwest::Client::new();
497        let response = client.delete(&url).send().await?;
498
499        let status = response.status();
500
501        if status.is_success() || status == StatusCode::NO_CONTENT {
502            debug!("Successfully deregistered from gateway");
503
504            // Remove stored connection
505            if let Some(storage) = self.storage {
506                if let Err(e) = storage.delete_connection() {
507                    warn!(error = ?e, "Failed to delete stored connection");
508                } else {
509                    debug!("Deleted stored connection");
510                }
511            }
512
513            Ok(())
514        } else {
515            let message = response
516                .text()
517                .await
518                .unwrap_or_else(|_| "unknown error".to_string());
519            Err(ClientError::GatewayError {
520                status: status.as_u16(),
521                message,
522            })
523        }
524    }
525}
526
527fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
528    let gateway_url = gateway_url.trim();
529
530    if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
531        return Err(ClientError::Configuration(
532            "gateway_url must start with http:// or https://".into(),
533        ));
534    }
535
536    Ok(format!(
537        "{}{}",
538        gateway_url.trim_end_matches('/'),
539        API_PATH
540    ))
541}
542
543async fn register_with_gateway(
544    endpoint: &str,
545    auth_token: &str,
546    public_key_hex: &str,
547    network_info: &crate::types::NetworkInfo,
548) -> Result<RegistrationResponse, ClientError> {
549    let payload = RegistrationRequest {
550        public_key_hex,
551        public_ip: Some(network_info.public_ip.as_str()),
552        public_port: Some(network_info.public_port),
553        local_ip: Some(network_info.local_ip.as_str()),
554        local_port: Some(network_info.local_port),
555    };
556
557    let client = reqwest::Client::new();
558    let response = client
559        .post(endpoint)
560        .header("X-HT-Auth", auth_token)
561        .json(&payload)
562        .send()
563        .await?;
564
565    let status = response.status();
566
567    if status.is_success() {
568        let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
569            ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
570        })?;
571
572        Ok(registration_response)
573    } else {
574        let message = response
575            .text()
576            .await
577            .unwrap_or_else(|_| "unknown error".to_string());
578        Err(ClientError::GatewayError {
579            status: status.as_u16(),
580            message,
581        })
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn build_endpoint_requires_scheme() {
591        let result = build_endpoint("gateway.example.com:8008");
592        assert!(matches!(result, Err(ClientError::Configuration(_))));
593    }
594
595    #[test]
596    fn build_endpoint_accepts_http() {
597        let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
598        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
599    }
600
601    #[test]
602    fn build_endpoint_accepts_https() {
603        let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
604        assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
605    }
606
607    #[test]
608    fn build_endpoint_strips_trailing_slash() {
609        let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
610        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
611    }
612}