1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::transport::Server;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, error, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15pub struct HightowerConnection {
17 transport: TransportServer,
18 node_id: String,
19 assigned_ip: String,
20 token: String,
21 endpoint: String,
22 #[allow(dead_code)] gateway_url: String,
24 storage: Option<ConnectionStorage>,
25}
26
27impl HightowerConnection {
28 pub async fn connect(
40 gateway_url: impl Into<String>,
41 auth_token: impl Into<String>,
42 ) -> Result<Self, ClientError> {
43 Self::connect_internal(gateway_url, auth_token, None, false).await
44 }
45
46 pub async fn connect_ephemeral(
48 gateway_url: impl Into<String>,
49 auth_token: impl Into<String>,
50 ) -> Result<Self, ClientError> {
51 Self::connect_internal(gateway_url, auth_token, None, true).await
52 }
53
54 pub async fn connect_with_storage(
56 gateway_url: impl Into<String>,
57 auth_token: impl Into<String>,
58 storage_dir: impl Into<PathBuf>,
59 ) -> Result<Self, ClientError> {
60 Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
61 }
62
63 pub async fn connect_fresh(
65 gateway_url: impl Into<String>,
66 auth_token: impl Into<String>,
67 ) -> Result<Self, ClientError> {
68 let gateway_url = gateway_url.into();
69 let auth_token = auth_token.into();
70
71 if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
73 let _ = storage.delete_connection();
74 }
75
76 Self::connect_internal(gateway_url, auth_token, None, false).await
77 }
78
79 async fn connect_internal(
80 gateway_url: impl Into<String>,
81 auth_token: impl Into<String>,
82 storage_dir: Option<PathBuf>,
83 ephemeral: bool,
84 ) -> Result<Self, ClientError> {
85 let gateway_url = gateway_url.into();
86 let auth_token = auth_token.into();
87
88 if gateway_url.is_empty() {
89 return Err(ClientError::Configuration(
90 "gateway_url cannot be empty".into(),
91 ));
92 }
93
94 if auth_token.is_empty() {
95 return Err(ClientError::Configuration(
96 "auth_token cannot be empty".into(),
97 ));
98 }
99
100 let endpoint = build_endpoint(&gateway_url)?;
101
102 let storage = if ephemeral {
104 None
105 } else if let Some(dir) = storage_dir {
106 match ConnectionStorage::new(dir) {
108 Ok(s) => Some(s),
109 Err(e) => {
110 warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
111 None
112 }
113 }
114 } else {
115 match ConnectionStorage::for_gateway(&gateway_url) {
117 Ok(s) => Some(s),
118 Err(e) => {
119 warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
120 None
121 }
122 }
123 };
124
125 if let Some(ref storage) = storage {
127 if let Ok(Some(stored)) = storage.get_connection() {
128 info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
129
130 match Self::restore_from_stored(stored, storage.clone()).await {
131 Ok(conn) => {
132 info!(node_id = %conn.node_id, "Successfully restored connection from storage");
133 return Ok(conn);
134 }
135 Err(e) => {
136 warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
137 }
139 }
140 }
141 }
142
143 info!("Creating fresh connection to gateway");
145
146 let (private_key, public_key) = dh_generate();
148 let public_key_hex = hex::encode(public_key);
149 let private_key_hex = hex::encode(private_key);
150
151 debug!("Generated WireGuard keypair");
152
153 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
155 ClientError::Configuration(format!("invalid bind address: {}", e))
156 })?;
157
158 let server = Server::new(bind_addr, private_key)
159 .await
160 .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
161
162 debug!("Created transport server");
163
164 let server_clone = server.clone();
166 tokio::spawn(async move {
167 if let Err(e) = server_clone.run().await {
168 error!(error = ?e, "Transport server error");
169 }
170 });
171
172 let server_maintenance = server.clone();
174 tokio::spawn(async move {
175 server_maintenance.run_maintenance().await
176 });
177
178 server
180 .wait_until_ready()
181 .await
182 .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
183
184 debug!("Transport server ready");
185
186 let local_addr = server
188 .local_addr()
189 .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
190
191 let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
192 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
193
194 debug!(
195 public_ip = %network_info.public_ip,
196 public_port = network_info.public_port,
197 local_ip = %network_info.local_ip,
198 local_port = network_info.local_port,
199 "Discovered network information"
200 );
201
202 let registration = register_with_gateway(
204 &endpoint,
205 &auth_token,
206 &public_key_hex,
207 &network_info,
208 )
209 .await?;
210
211 debug!(
212 node_id = %registration.node_id,
213 assigned_ip = %registration.assigned_ip,
214 "Registered with gateway"
215 );
216
217 let gateway_public_key_bytes = hex::decode(®istration.gateway_public_key_hex)
219 .map_err(|e| {
220 ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
221 })?;
222
223 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
224 .as_slice()
225 .try_into()
226 .map_err(|e| {
227 ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
228 })?;
229
230 server
231 .add_peer(gateway_public_key, None)
232 .await
233 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
234
235 debug!("Added gateway as peer");
236
237 if let Some(ref storage) = storage {
239 let now = current_timestamp();
240 let stored = StoredConnection {
241 node_id: registration.node_id.clone(),
242 token: registration.token.clone(),
243 gateway_url: gateway_url.clone(),
244 assigned_ip: registration.assigned_ip.clone(),
245 private_key_hex,
246 public_key_hex,
247 gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
248 created_at: now,
249 last_connected_at: now,
250 };
251
252 if let Err(e) = storage.store_connection(&stored) {
253 warn!(error = ?e, "Failed to persist connection to storage");
254 } else {
255 debug!("Persisted connection to storage");
256 }
257 }
258
259 Ok(Self {
260 transport: TransportServer::new(server),
261 node_id: registration.node_id,
262 assigned_ip: registration.assigned_ip,
263 token: registration.token,
264 endpoint,
265 gateway_url,
266 storage,
267 })
268 }
269
270 async fn restore_from_stored(
272 stored: StoredConnection,
273 storage: ConnectionStorage,
274 ) -> Result<Self, ClientError> {
275 let endpoint = build_endpoint(&stored.gateway_url)?;
276
277 let private_key_bytes = hex::decode(&stored.private_key_hex)
279 .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
280 let private_key: PrivateKey = private_key_bytes
281 .as_slice()
282 .try_into()
283 .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
284
285 let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
286 .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
287 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
288 .as_slice()
289 .try_into()
290 .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
291
292 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
294 ClientError::Configuration(format!("invalid bind address: {}", e))
295 })?;
296
297 let server = Server::new(bind_addr, private_key)
298 .await
299 .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
300
301 debug!("Created transport server with stored keys");
302
303 let server_clone = server.clone();
305 tokio::spawn(async move {
306 if let Err(e) = server_clone.run().await {
307 error!(error = ?e, "Transport server error");
308 }
309 });
310
311 let server_maintenance = server.clone();
313 tokio::spawn(async move {
314 server_maintenance.run_maintenance().await
315 });
316
317 server
319 .wait_until_ready()
320 .await
321 .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
322
323 let local_addr = server
325 .local_addr()
326 .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
327
328 let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
329 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
330
331 debug!("Rediscovered network information");
332
333 server
335 .add_peer(gateway_public_key, None)
336 .await
337 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
338
339 debug!("Added gateway as peer");
340
341 if let Err(e) = storage.update_last_connected() {
343 warn!(error = ?e, "Failed to update last_connected timestamp");
344 }
345
346 Ok(Self {
347 transport: TransportServer::new(server),
348 node_id: stored.node_id,
349 assigned_ip: stored.assigned_ip,
350 token: stored.token,
351 endpoint,
352 gateway_url: stored.gateway_url,
353 storage: Some(storage),
354 })
355 }
356
357 pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
359 Self::connect(DEFAULT_GATEWAY, auth_token).await
360 }
361
362 pub fn node_id(&self) -> &str {
364 &self.node_id
365 }
366
367 pub fn assigned_ip(&self) -> &str {
369 &self.assigned_ip
370 }
371
372 pub fn transport(&self) -> &TransportServer {
374 &self.transport
375 }
376
377 pub async fn disconnect(self) -> Result<(), ClientError> {
379 let url = format!("{}/{}", self.endpoint, self.token);
380
381 let client = reqwest::Client::new();
382 let response = client.delete(&url).send().await?;
383
384 let status = response.status();
385
386 if status.is_success() || status == StatusCode::NO_CONTENT {
387 debug!("Successfully deregistered from gateway");
388
389 if let Some(storage) = self.storage {
391 if let Err(e) = storage.delete_connection() {
392 warn!(error = ?e, "Failed to delete stored connection");
393 } else {
394 debug!("Deleted stored connection");
395 }
396 }
397
398 Ok(())
399 } else {
400 let message = response
401 .text()
402 .await
403 .unwrap_or_else(|_| "unknown error".to_string());
404 Err(ClientError::GatewayError {
405 status: status.as_u16(),
406 message,
407 })
408 }
409 }
410}
411
412fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
413 let gateway_url = gateway_url.trim();
414
415 if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
416 return Err(ClientError::Configuration(
417 "gateway_url must start with http:// or https://".into(),
418 ));
419 }
420
421 Ok(format!(
422 "{}{}",
423 gateway_url.trim_end_matches('/'),
424 API_PATH
425 ))
426}
427
428async fn register_with_gateway(
429 endpoint: &str,
430 auth_token: &str,
431 public_key_hex: &str,
432 network_info: &crate::types::NetworkInfo,
433) -> Result<RegistrationResponse, ClientError> {
434 let payload = RegistrationRequest {
435 public_key_hex,
436 public_ip: Some(network_info.public_ip.as_str()),
437 public_port: Some(network_info.public_port),
438 local_ip: Some(network_info.local_ip.as_str()),
439 local_port: Some(network_info.local_port),
440 };
441
442 let client = reqwest::Client::new();
443 let response = client
444 .post(endpoint)
445 .header("X-HT-Auth", auth_token)
446 .json(&payload)
447 .send()
448 .await?;
449
450 let status = response.status();
451
452 if status.is_success() {
453 let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
454 ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
455 })?;
456
457 Ok(registration_response)
458 } else {
459 let message = response
460 .text()
461 .await
462 .unwrap_or_else(|_| "unknown error".to_string());
463 Err(ClientError::GatewayError {
464 status: status.as_u16(),
465 message,
466 })
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn build_endpoint_requires_scheme() {
476 let result = build_endpoint("gateway.example.com:8008");
477 assert!(matches!(result, Err(ClientError::Configuration(_))));
478 }
479
480 #[test]
481 fn build_endpoint_accepts_http() {
482 let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
483 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
484 }
485
486 #[test]
487 fn build_endpoint_accepts_https() {
488 let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
489 assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
490 }
491
492 #[test]
493 fn build_endpoint_strips_trailing_slash() {
494 let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
495 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
496 }
497}