hightower_client/
connection.rs

1use crate::error::ClientError;
2use crate::transport::TransportServer;
3use crate::types::{RegistrationRequest, RegistrationResponse};
4use hightower_wireguard::crypto::{dh_generate, PublicKey25519};
5use hightower_wireguard::transport::Server;
6use reqwest::StatusCode;
7use std::net::SocketAddr;
8use tracing::{debug, error};
9
10const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
11const API_PATH: &str = "/api/nodes";
12
13/// Main connection to Hightower gateway with integrated WireGuard transport
14pub struct HightowerConnection {
15    transport: TransportServer,
16    node_id: String,
17    assigned_ip: String,
18    token: String,
19    endpoint: String,
20}
21
22impl HightowerConnection {
23    /// Connect to a Hightower gateway
24    ///
25    /// This method handles everything:
26    /// - Generates WireGuard keypair
27    /// - Creates transport server on 0.0.0.0:0
28    /// - Discovers network info via STUN using actual bound port
29    /// - Registers with gateway
30    /// - Adds gateway as peer
31    ///
32    /// Returns a ready-to-use connection with working transport
33    pub async fn connect(
34        gateway_url: impl Into<String>,
35        auth_token: impl Into<String>,
36    ) -> Result<Self, ClientError> {
37        let gateway_url = gateway_url.into();
38        let auth_token = auth_token.into();
39
40        if gateway_url.is_empty() {
41            return Err(ClientError::Configuration(
42                "gateway_url cannot be empty".into(),
43            ));
44        }
45
46        if auth_token.is_empty() {
47            return Err(ClientError::Configuration(
48                "auth_token cannot be empty".into(),
49            ));
50        }
51
52        let endpoint = build_endpoint(&gateway_url)?;
53
54        // 1. Generate WireGuard keypair
55        let (private_key, public_key) = dh_generate();
56        let public_key_hex = hex::encode(public_key);
57
58        debug!("Generated WireGuard keypair");
59
60        // 2. Create WireGuard transport server on 0.0.0.0:0
61        let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
62            ClientError::Configuration(format!("invalid bind address: {}", e))
63        })?;
64
65        let server = Server::new(bind_addr, private_key)
66            .await
67            .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
68
69        debug!("Created transport server");
70
71        // Spawn background processor
72        let server_clone = server.clone();
73        tokio::spawn(async move {
74            if let Err(e) = server_clone.run().await {
75                error!(error = ?e, "Transport server error");
76            }
77        });
78
79        // Spawn maintenance task
80        let server_maintenance = server.clone();
81        tokio::spawn(async move {
82            server_maintenance.run_maintenance().await
83        });
84
85        // Wait for transport to be ready
86        server
87            .wait_until_ready()
88            .await
89            .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
90
91        debug!("Transport server ready");
92
93        // 3. Discover network info using actual bound port
94        let local_addr = server
95            .local_addr()
96            .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
97
98        let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
99            .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
100
101        debug!(
102            public_ip = %network_info.public_ip,
103            public_port = network_info.public_port,
104            local_ip = %network_info.local_ip,
105            local_port = network_info.local_port,
106            "Discovered network information"
107        );
108
109        // 4. Register with gateway
110        let registration = register_with_gateway(
111            &endpoint,
112            &auth_token,
113            &public_key_hex,
114            &network_info,
115        )
116        .await?;
117
118        debug!(
119            node_id = %registration.node_id,
120            assigned_ip = %registration.assigned_ip,
121            "Registered with gateway"
122        );
123
124        // 5. Add gateway as peer
125        let gateway_public_key_bytes = hex::decode(&registration.gateway_public_key_hex)
126            .map_err(|e| {
127                ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
128            })?;
129
130        let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
131            .as_slice()
132            .try_into()
133            .map_err(|e| {
134                ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
135            })?;
136
137        server
138            .add_peer(gateway_public_key, None)
139            .await
140            .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
141
142        debug!("Added gateway as peer");
143
144        Ok(Self {
145            transport: TransportServer::new(server),
146            node_id: registration.node_id,
147            assigned_ip: registration.assigned_ip,
148            token: registration.token,
149            endpoint,
150        })
151    }
152
153    /// Connect using default gateway (http://127.0.0.1:8008)
154    pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
155        Self::connect(DEFAULT_GATEWAY, auth_token).await
156    }
157
158    /// Get the node ID assigned by the gateway
159    pub fn node_id(&self) -> &str {
160        &self.node_id
161    }
162
163    /// Get the IP address assigned by the gateway
164    pub fn assigned_ip(&self) -> &str {
165        &self.assigned_ip
166    }
167
168    /// Get the transport for sending/receiving data
169    pub fn transport(&self) -> &TransportServer {
170        &self.transport
171    }
172
173    /// Disconnect from the gateway and deregister
174    pub async fn disconnect(self) -> Result<(), ClientError> {
175        let url = format!("{}/{}", self.endpoint, self.token);
176
177        let client = reqwest::Client::new();
178        let response = client.delete(&url).send().await?;
179
180        let status = response.status();
181
182        if status.is_success() || status == StatusCode::NO_CONTENT {
183            debug!("Successfully deregistered from gateway");
184            Ok(())
185        } else {
186            let message = response
187                .text()
188                .await
189                .unwrap_or_else(|_| "unknown error".to_string());
190            Err(ClientError::GatewayError {
191                status: status.as_u16(),
192                message,
193            })
194        }
195    }
196}
197
198fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
199    let gateway_url = gateway_url.trim();
200
201    if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
202        return Err(ClientError::Configuration(
203            "gateway_url must start with http:// or https://".into(),
204        ));
205    }
206
207    Ok(format!(
208        "{}{}",
209        gateway_url.trim_end_matches('/'),
210        API_PATH
211    ))
212}
213
214async fn register_with_gateway(
215    endpoint: &str,
216    auth_token: &str,
217    public_key_hex: &str,
218    network_info: &crate::types::NetworkInfo,
219) -> Result<RegistrationResponse, ClientError> {
220    let payload = RegistrationRequest {
221        public_key_hex,
222        public_ip: Some(network_info.public_ip.as_str()),
223        public_port: Some(network_info.public_port),
224        local_ip: Some(network_info.local_ip.as_str()),
225        local_port: Some(network_info.local_port),
226    };
227
228    let client = reqwest::Client::new();
229    let response = client
230        .post(endpoint)
231        .header("X-HT-Auth", auth_token)
232        .json(&payload)
233        .send()
234        .await?;
235
236    let status = response.status();
237
238    if status.is_success() {
239        let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
240            ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
241        })?;
242
243        Ok(registration_response)
244    } else {
245        let message = response
246            .text()
247            .await
248            .unwrap_or_else(|_| "unknown error".to_string());
249        Err(ClientError::GatewayError {
250            status: status.as_u16(),
251            message,
252        })
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn build_endpoint_requires_scheme() {
262        let result = build_endpoint("gateway.example.com:8008");
263        assert!(matches!(result, Err(ClientError::Configuration(_))));
264    }
265
266    #[test]
267    fn build_endpoint_accepts_http() {
268        let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
269        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
270    }
271
272    #[test]
273    fn build_endpoint_accepts_https() {
274        let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
275        assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
276    }
277
278    #[test]
279    fn build_endpoint_strips_trailing_slash() {
280        let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
281        assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
282    }
283}