1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{PeerInfo, RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::transport::Server;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, error, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15pub struct HightowerConnection {
17 transport: TransportServer,
18 node_id: String,
19 assigned_ip: String,
20 token: String,
21 endpoint: String,
22 gateway_url: String,
23 gateway_endpoint: SocketAddr,
24 gateway_public_key: PublicKey25519,
25 storage: Option<ConnectionStorage>,
26}
27
28impl HightowerConnection {
29 pub async fn connect(
41 gateway_url: impl Into<String>,
42 auth_token: impl Into<String>,
43 ) -> Result<Self, ClientError> {
44 Self::connect_internal(gateway_url, auth_token, None, false).await
45 }
46
47 pub async fn connect_ephemeral(
49 gateway_url: impl Into<String>,
50 auth_token: impl Into<String>,
51 ) -> Result<Self, ClientError> {
52 Self::connect_internal(gateway_url, auth_token, None, true).await
53 }
54
55 pub async fn connect_with_storage(
57 gateway_url: impl Into<String>,
58 auth_token: impl Into<String>,
59 storage_dir: impl Into<PathBuf>,
60 ) -> Result<Self, ClientError> {
61 Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
62 }
63
64 pub async fn connect_fresh(
66 gateway_url: impl Into<String>,
67 auth_token: impl Into<String>,
68 ) -> Result<Self, ClientError> {
69 let gateway_url = gateway_url.into();
70 let auth_token = auth_token.into();
71
72 if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
74 let _ = storage.delete_connection();
75 }
76
77 Self::connect_internal(gateway_url, auth_token, None, false).await
78 }
79
80 async fn connect_internal(
81 gateway_url: impl Into<String>,
82 auth_token: impl Into<String>,
83 storage_dir: Option<PathBuf>,
84 ephemeral: bool,
85 ) -> Result<Self, ClientError> {
86 let gateway_url = gateway_url.into();
87 let auth_token = auth_token.into();
88
89 if gateway_url.is_empty() {
90 return Err(ClientError::Configuration(
91 "gateway_url cannot be empty".into(),
92 ));
93 }
94
95 if auth_token.is_empty() {
96 return Err(ClientError::Configuration(
97 "auth_token cannot be empty".into(),
98 ));
99 }
100
101 let endpoint = build_endpoint(&gateway_url)?;
102
103 let storage = if ephemeral {
105 None
106 } else if let Some(dir) = storage_dir {
107 match ConnectionStorage::new(dir) {
109 Ok(s) => Some(s),
110 Err(e) => {
111 warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
112 None
113 }
114 }
115 } else {
116 match ConnectionStorage::for_gateway(&gateway_url) {
118 Ok(s) => Some(s),
119 Err(e) => {
120 warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
121 None
122 }
123 }
124 };
125
126 if let Some(ref storage) = storage {
128 if let Ok(Some(stored)) = storage.get_connection() {
129 info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
130
131 match Self::restore_from_stored(stored, storage.clone()).await {
132 Ok(conn) => {
133 info!(node_id = %conn.node_id, "Successfully restored connection from storage");
134 return Ok(conn);
135 }
136 Err(e) => {
137 warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
138 }
140 }
141 }
142 }
143
144 info!("Creating fresh connection to gateway");
146
147 let (private_key, public_key) = dh_generate();
149 let public_key_hex = hex::encode(public_key);
150 let private_key_hex = hex::encode(private_key);
151
152 debug!("Generated WireGuard keypair");
153
154 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
156 ClientError::Configuration(format!("invalid bind address: {}", e))
157 })?;
158
159 let server = Server::new(bind_addr, private_key)
160 .await
161 .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
162
163 debug!("Created transport server");
164
165 let server_clone = server.clone();
167 tokio::spawn(async move {
168 if let Err(e) = server_clone.run().await {
169 error!(error = ?e, "Transport server error");
170 }
171 });
172
173 let server_maintenance = server.clone();
175 tokio::spawn(async move {
176 server_maintenance.run_maintenance().await
177 });
178
179 server
181 .wait_until_ready()
182 .await
183 .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
184
185 debug!("Transport server ready");
186
187 let local_addr = server
189 .local_addr()
190 .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
191
192 let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
193 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
194
195 debug!(
196 public_ip = %network_info.public_ip,
197 public_port = network_info.public_port,
198 local_ip = %network_info.local_ip,
199 local_port = network_info.local_port,
200 "Discovered network information"
201 );
202
203 let registration = register_with_gateway(
205 &endpoint,
206 &auth_token,
207 &public_key_hex,
208 &network_info,
209 )
210 .await?;
211
212 debug!(
213 node_id = %registration.node_id,
214 assigned_ip = %registration.assigned_ip,
215 "Registered with gateway"
216 );
217
218 let gateway_public_key_bytes = hex::decode(®istration.gateway_public_key_hex)
220 .map_err(|e| {
221 ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
222 })?;
223
224 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
225 .as_slice()
226 .try_into()
227 .map_err(|e| {
228 ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
229 })?;
230
231 server
232 .add_peer(gateway_public_key, None)
233 .await
234 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
235
236 debug!("Added gateway as peer");
237
238 if let Some(ref storage) = storage {
240 let now = current_timestamp();
241 let stored = StoredConnection {
242 node_id: registration.node_id.clone(),
243 token: registration.token.clone(),
244 gateway_url: gateway_url.clone(),
245 assigned_ip: registration.assigned_ip.clone(),
246 private_key_hex,
247 public_key_hex,
248 gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
249 created_at: now,
250 last_connected_at: now,
251 };
252
253 if let Err(e) = storage.store_connection(&stored) {
254 warn!(error = ?e, "Failed to persist connection to storage");
255 } else {
256 debug!("Persisted connection to storage");
257 }
258 }
259
260 let gateway_endpoint: SocketAddr = "127.0.0.1:51820".parse()
262 .map_err(|e| ClientError::Configuration(format!("invalid gateway endpoint: {}", e)))?;
263
264 Ok(Self {
265 transport: TransportServer::new(server),
266 node_id: registration.node_id,
267 assigned_ip: registration.assigned_ip,
268 token: registration.token,
269 endpoint,
270 gateway_url,
271 gateway_endpoint,
272 gateway_public_key,
273 storage,
274 })
275 }
276
277 async fn restore_from_stored(
279 stored: StoredConnection,
280 storage: ConnectionStorage,
281 ) -> Result<Self, ClientError> {
282 let endpoint = build_endpoint(&stored.gateway_url)?;
283
284 let private_key_bytes = hex::decode(&stored.private_key_hex)
286 .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
287 let private_key: PrivateKey = private_key_bytes
288 .as_slice()
289 .try_into()
290 .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
291
292 let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
293 .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
294 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
295 .as_slice()
296 .try_into()
297 .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
298
299 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
301 ClientError::Configuration(format!("invalid bind address: {}", e))
302 })?;
303
304 let server = Server::new(bind_addr, private_key)
305 .await
306 .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
307
308 debug!("Created transport server with stored keys");
309
310 let server_clone = server.clone();
312 tokio::spawn(async move {
313 if let Err(e) = server_clone.run().await {
314 error!(error = ?e, "Transport server error");
315 }
316 });
317
318 let server_maintenance = server.clone();
320 tokio::spawn(async move {
321 server_maintenance.run_maintenance().await
322 });
323
324 server
326 .wait_until_ready()
327 .await
328 .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
329
330 let local_addr = server
332 .local_addr()
333 .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
334
335 let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
336 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
337
338 debug!("Rediscovered network information");
339
340 server
342 .add_peer(gateway_public_key, None)
343 .await
344 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
345
346 debug!("Added gateway as peer");
347
348 if let Err(e) = storage.update_last_connected() {
350 warn!(error = ?e, "Failed to update last_connected timestamp");
351 }
352
353 let gateway_endpoint: SocketAddr = "127.0.0.1:51820".parse()
355 .map_err(|e| ClientError::Configuration(format!("invalid gateway endpoint: {}", e)))?;
356
357 Ok(Self {
358 transport: TransportServer::new(server),
359 node_id: stored.node_id,
360 assigned_ip: stored.assigned_ip,
361 token: stored.token,
362 endpoint,
363 gateway_url: stored.gateway_url,
364 gateway_endpoint,
365 gateway_public_key,
366 storage: Some(storage),
367 })
368 }
369
370 pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
372 Self::connect(DEFAULT_GATEWAY, auth_token).await
373 }
374
375 pub fn node_id(&self) -> &str {
377 &self.node_id
378 }
379
380 pub fn assigned_ip(&self) -> &str {
382 &self.assigned_ip
383 }
384
385 pub fn transport(&self) -> &TransportServer {
387 &self.transport
388 }
389
390 pub async fn ping_gateway(&self) -> Result<(), ClientError> {
392 debug!("Pinging gateway over WireGuard");
393
394 let conn = self
396 .transport
397 .server()
398 .dial("tcp", &self.gateway_endpoint.to_string(), self.gateway_public_key)
399 .await
400 .map_err(|e| ClientError::Transport(format!("failed to dial gateway: {}", e)))?;
401
402 debug!("WireGuard connection established to gateway");
403
404 let request = b"GET /ping HTTP/1.1\r\nHost: gateway\r\nConnection: close\r\n\r\n";
406 conn.send(request)
407 .await
408 .map_err(|e| ClientError::Transport(format!("failed to send ping request: {}", e)))?;
409
410 let mut buf = vec![0u8; 8192];
412 let n = conn
413 .recv(&mut buf)
414 .await
415 .map_err(|e| ClientError::Transport(format!("failed to receive ping response: {}", e)))?;
416
417 let response = String::from_utf8_lossy(&buf[..n]);
418
419 if response.contains("200 OK") && response.contains("Pong") {
420 debug!("Successfully pinged gateway");
421 Ok(())
422 } else {
423 Err(ClientError::GatewayError {
424 status: 500,
425 message: format!("Unexpected ping response: {}", response),
426 })
427 }
428 }
429
430 pub async fn get_peer_info(&self, node_id_or_ip: &str) -> Result<PeerInfo, ClientError> {
435 debug!(peer = %node_id_or_ip, "Fetching peer info from gateway");
436
437 let url = format!("{}/api/peers/{}", self.gateway_url.trim_end_matches('/'), node_id_or_ip);
439
440 let client = reqwest::Client::new();
441 let response = client
442 .get(&url)
443 .send()
444 .await?;
445
446 let status = response.status();
447
448 if status.is_success() {
449 let peer_info: PeerInfo = response.json().await.map_err(|e| {
450 ClientError::InvalidResponse(format!("failed to parse peer info: {}", e))
451 })?;
452
453 debug!(
454 node_id = %peer_info.node_id,
455 assigned_ip = %peer_info.assigned_ip,
456 "Retrieved peer info from gateway"
457 );
458
459 Ok(peer_info)
460 } else {
461 let message = response
462 .text()
463 .await
464 .unwrap_or_else(|_| "unknown error".to_string());
465 Err(ClientError::GatewayError {
466 status: status.as_u16(),
467 message: format!("Failed to get peer info: {}", message),
468 })
469 }
470 }
471
472 pub async fn dial(&self, peer: &str, port: u16) -> Result<hightower_wireguard::transport::Conn, ClientError> {
492 let peer_info = self.get_peer_info(peer).await?;
494
495 let peer_public_key_bytes = hex::decode(&peer_info.public_key_hex)
497 .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key hex: {}", e)))?;
498
499 let peer_public_key: PublicKey25519 = peer_public_key_bytes
500 .as_slice()
501 .try_into()
502 .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key format: {:?}", e)))?;
503
504 self.transport
506 .server()
507 .add_peer(peer_public_key, peer_info.endpoint)
508 .await
509 .map_err(|e| ClientError::Transport(format!("failed to add peer: {}", e)))?;
510
511 debug!(
512 peer_id = %peer_info.node_id,
513 peer_ip = %peer_info.assigned_ip,
514 port = port,
515 "Added peer and dialing"
516 );
517
518 let addr = format!("{}:{}", peer_info.assigned_ip, port);
520 let conn = self
521 .transport
522 .server()
523 .dial("tcp", &addr, peer_public_key)
524 .await
525 .map_err(|e| ClientError::Transport(format!("failed to dial peer: {}", e)))?;
526
527 debug!(
528 peer_id = %peer_info.node_id,
529 addr = %addr,
530 "Successfully dialed peer"
531 );
532
533 Ok(conn)
534 }
535
536 pub async fn disconnect(self) -> Result<(), ClientError> {
538 let url = format!("{}/{}", self.endpoint, self.token);
539
540 let client = reqwest::Client::new();
541 let response = client.delete(&url).send().await?;
542
543 let status = response.status();
544
545 if status.is_success() || status == StatusCode::NO_CONTENT {
546 debug!("Successfully deregistered from gateway");
547
548 if let Some(storage) = self.storage {
550 if let Err(e) = storage.delete_connection() {
551 warn!(error = ?e, "Failed to delete stored connection");
552 } else {
553 debug!("Deleted stored connection");
554 }
555 }
556
557 Ok(())
558 } else {
559 let message = response
560 .text()
561 .await
562 .unwrap_or_else(|_| "unknown error".to_string());
563 Err(ClientError::GatewayError {
564 status: status.as_u16(),
565 message,
566 })
567 }
568 }
569}
570
571fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
572 let gateway_url = gateway_url.trim();
573
574 if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
575 return Err(ClientError::Configuration(
576 "gateway_url must start with http:// or https://".into(),
577 ));
578 }
579
580 Ok(format!(
581 "{}{}",
582 gateway_url.trim_end_matches('/'),
583 API_PATH
584 ))
585}
586
587async fn register_with_gateway(
588 endpoint: &str,
589 auth_token: &str,
590 public_key_hex: &str,
591 network_info: &crate::types::NetworkInfo,
592) -> Result<RegistrationResponse, ClientError> {
593 let payload = RegistrationRequest {
594 public_key_hex,
595 public_ip: Some(network_info.public_ip.as_str()),
596 public_port: Some(network_info.public_port),
597 local_ip: Some(network_info.local_ip.as_str()),
598 local_port: Some(network_info.local_port),
599 };
600
601 let client = reqwest::Client::new();
602 let response = client
603 .post(endpoint)
604 .header("X-HT-Auth", auth_token)
605 .json(&payload)
606 .send()
607 .await?;
608
609 let status = response.status();
610
611 if status.is_success() {
612 let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
613 ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
614 })?;
615
616 Ok(registration_response)
617 } else {
618 let message = response
619 .text()
620 .await
621 .unwrap_or_else(|_| "unknown error".to_string());
622 Err(ClientError::GatewayError {
623 status: status.as_u16(),
624 message,
625 })
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn build_endpoint_requires_scheme() {
635 let result = build_endpoint("gateway.example.com:8008");
636 assert!(matches!(result, Err(ClientError::Configuration(_))));
637 }
638
639 #[test]
640 fn build_endpoint_accepts_http() {
641 let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
642 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
643 }
644
645 #[test]
646 fn build_endpoint_accepts_https() {
647 let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
648 assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
649 }
650
651 #[test]
652 fn build_endpoint_strips_trailing_slash() {
653 let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
654 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
655 }
656}