1use crate::error::ClientError;
2use crate::storage::{ConnectionStorage, StoredConnection, current_timestamp};
3use crate::transport::TransportServer;
4use crate::types::{PeerInfo, RegistrationRequest, RegistrationResponse};
5use hightower_wireguard::crypto::{dh_generate, PrivateKey, PublicKey25519};
6use hightower_wireguard::connection::Connection;
7use reqwest::StatusCode;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use tracing::{debug, info, warn};
11
12const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
13const API_PATH: &str = "/api/nodes";
14
15pub struct HightowerConnection {
17 transport: TransportServer,
18 node_id: String,
19 assigned_ip: String,
20 token: String,
21 endpoint: String,
22 gateway_url: String,
23 gateway_endpoint: SocketAddr,
24 gateway_public_key: PublicKey25519,
25 storage: Option<ConnectionStorage>,
26}
27
28impl HightowerConnection {
29 pub async fn connect(
41 gateway_url: impl Into<String>,
42 auth_token: impl Into<String>,
43 ) -> Result<Self, ClientError> {
44 Self::connect_internal(gateway_url, auth_token, None, false).await
45 }
46
47 pub async fn connect_ephemeral(
49 gateway_url: impl Into<String>,
50 auth_token: impl Into<String>,
51 ) -> Result<Self, ClientError> {
52 Self::connect_internal(gateway_url, auth_token, None, true).await
53 }
54
55 pub async fn connect_with_storage(
57 gateway_url: impl Into<String>,
58 auth_token: impl Into<String>,
59 storage_dir: impl Into<PathBuf>,
60 ) -> Result<Self, ClientError> {
61 Self::connect_internal(gateway_url, auth_token, Some(storage_dir.into()), false).await
62 }
63
64 pub async fn connect_fresh(
66 gateway_url: impl Into<String>,
67 auth_token: impl Into<String>,
68 ) -> Result<Self, ClientError> {
69 let gateway_url = gateway_url.into();
70 let auth_token = auth_token.into();
71
72 if let Ok(storage) = ConnectionStorage::for_gateway(&gateway_url) {
74 let _ = storage.delete_connection();
75 }
76
77 Self::connect_internal(gateway_url, auth_token, None, false).await
78 }
79
80 async fn connect_internal(
81 gateway_url: impl Into<String>,
82 auth_token: impl Into<String>,
83 storage_dir: Option<PathBuf>,
84 ephemeral: bool,
85 ) -> Result<Self, ClientError> {
86 let gateway_url = gateway_url.into();
87 let auth_token = auth_token.into();
88
89 if gateway_url.is_empty() {
90 return Err(ClientError::Configuration(
91 "gateway_url cannot be empty".into(),
92 ));
93 }
94
95 if auth_token.is_empty() {
96 return Err(ClientError::Configuration(
97 "auth_token cannot be empty".into(),
98 ));
99 }
100
101 let endpoint = build_endpoint(&gateway_url)?;
102
103 let storage = if ephemeral {
105 None
106 } else if let Some(dir) = storage_dir {
107 match ConnectionStorage::new(dir) {
109 Ok(s) => Some(s),
110 Err(e) => {
111 warn!(error = ?e, "Failed to initialize custom storage, continuing without persistence");
112 None
113 }
114 }
115 } else {
116 match ConnectionStorage::for_gateway(&gateway_url) {
118 Ok(s) => Some(s),
119 Err(e) => {
120 warn!(error = ?e, "Failed to initialize storage, continuing without persistence");
121 None
122 }
123 }
124 };
125
126 if let Some(ref storage) = storage {
128 if let Ok(Some(stored)) = storage.get_connection() {
129 info!(node_id = %stored.node_id, "Found stored connection, attempting to restore");
130
131 match Self::restore_from_stored(stored, storage.clone()).await {
132 Ok(conn) => {
133 info!(node_id = %conn.node_id, "Successfully restored connection from storage");
134 return Ok(conn);
135 }
136 Err(e) => {
137 warn!(error = ?e, "Failed to restore stored connection, will create fresh connection");
138 }
140 }
141 }
142 }
143
144 info!("Creating fresh connection to gateway");
146
147 let (private_key, public_key) = dh_generate();
149 let public_key_hex = hex::encode(public_key);
150 let private_key_hex = hex::encode(private_key);
151
152 debug!("Generated WireGuard keypair");
153
154 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
156 ClientError::Configuration(format!("invalid bind address: {}", e))
157 })?;
158
159 let connection = Connection::new(bind_addr, private_key)
160 .await
161 .map_err(|e| ClientError::Transport(format!("failed to create transport connection: {}", e)))?;
162
163 debug!("Created transport connection");
164
165 let local_addr = connection.local_addr();
167
168 let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
169 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
170
171 debug!(
172 public_ip = %network_info.public_ip,
173 public_port = network_info.public_port,
174 local_ip = %network_info.local_ip,
175 local_port = network_info.local_port,
176 "Discovered network information"
177 );
178
179 let registration = register_with_gateway(
181 &endpoint,
182 &auth_token,
183 &public_key_hex,
184 &network_info,
185 )
186 .await?;
187
188 debug!(
189 node_id = %registration.node_id,
190 assigned_ip = %registration.assigned_ip,
191 "Registered with gateway"
192 );
193
194 let gateway_public_key_bytes = hex::decode(®istration.gateway_public_key_hex)
196 .map_err(|e| {
197 ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
198 })?;
199
200 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
201 .as_slice()
202 .try_into()
203 .map_err(|e| {
204 ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
205 })?;
206
207 connection
208 .add_peer(gateway_public_key, None)
209 .await
210 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
211
212 debug!("Added gateway as peer");
213
214 if let Some(ref storage) = storage {
216 let now = current_timestamp();
217 let stored = StoredConnection {
218 node_id: registration.node_id.clone(),
219 token: registration.token.clone(),
220 gateway_url: gateway_url.clone(),
221 assigned_ip: registration.assigned_ip.clone(),
222 private_key_hex,
223 public_key_hex,
224 gateway_public_key_hex: registration.gateway_public_key_hex.clone(),
225 created_at: now,
226 last_connected_at: now,
227 };
228
229 if let Err(e) = storage.store_connection(&stored) {
230 warn!(error = ?e, "Failed to persist connection to storage");
231 } else {
232 debug!("Persisted connection to storage");
233 }
234 }
235
236 let gateway_endpoint = parse_gateway_wireguard_endpoint(&gateway_url).await?;
238
239 Ok(Self {
240 transport: TransportServer::new(connection),
241 node_id: registration.node_id,
242 assigned_ip: registration.assigned_ip,
243 token: registration.token,
244 endpoint,
245 gateway_url,
246 gateway_endpoint,
247 gateway_public_key,
248 storage,
249 })
250 }
251
252 async fn restore_from_stored(
254 stored: StoredConnection,
255 storage: ConnectionStorage,
256 ) -> Result<Self, ClientError> {
257 let endpoint = build_endpoint(&stored.gateway_url)?;
258
259 let private_key_bytes = hex::decode(&stored.private_key_hex)
261 .map_err(|e| ClientError::Storage(format!("invalid private key hex: {}", e)))?;
262 let private_key: PrivateKey = private_key_bytes
263 .as_slice()
264 .try_into()
265 .map_err(|e| ClientError::Storage(format!("invalid private key format: {:?}", e)))?;
266
267 let gateway_public_key_bytes = hex::decode(&stored.gateway_public_key_hex)
268 .map_err(|e| ClientError::Storage(format!("invalid gateway public key hex: {}", e)))?;
269 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
270 .as_slice()
271 .try_into()
272 .map_err(|e| ClientError::Storage(format!("invalid gateway public key format: {:?}", e)))?;
273
274 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
276 ClientError::Configuration(format!("invalid bind address: {}", e))
277 })?;
278
279 let connection = Connection::new(bind_addr, private_key)
280 .await
281 .map_err(|e| ClientError::Transport(format!("failed to create transport connection: {}", e)))?;
282
283 debug!("Created transport connection with stored keys");
284
285 let local_addr = connection.local_addr();
287
288 let _network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
289 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
290
291 debug!("Rediscovered network information");
292
293 connection
295 .add_peer(gateway_public_key, None)
296 .await
297 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
298
299 debug!("Added gateway as peer");
300
301 if let Err(e) = storage.update_last_connected() {
303 warn!(error = ?e, "Failed to update last_connected timestamp");
304 }
305
306 let gateway_endpoint = parse_gateway_wireguard_endpoint(&stored.gateway_url).await?;
308
309 Ok(Self {
310 transport: TransportServer::new(connection),
311 node_id: stored.node_id,
312 assigned_ip: stored.assigned_ip,
313 token: stored.token,
314 endpoint,
315 gateway_url: stored.gateway_url,
316 gateway_endpoint,
317 gateway_public_key,
318 storage: Some(storage),
319 })
320 }
321
322 pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
324 Self::connect(DEFAULT_GATEWAY, auth_token).await
325 }
326
327 pub fn node_id(&self) -> &str {
329 &self.node_id
330 }
331
332 pub fn assigned_ip(&self) -> &str {
334 &self.assigned_ip
335 }
336
337 pub fn transport(&self) -> &TransportServer {
339 &self.transport
340 }
341
342 pub async fn ping_gateway(&self) -> Result<(), ClientError> {
344 debug!("Pinging gateway over WireGuard");
345
346 let mut stream = self
348 .transport
349 .connection()
350 .connect(self.gateway_endpoint, self.gateway_public_key)
351 .await
352 .map_err(|e| ClientError::Transport(format!("failed to connect to gateway: {}", e)))?;
353
354 debug!("WireGuard connection established to gateway");
355
356 let request = b"GET /ping HTTP/1.1\r\nHost: gateway\r\nConnection: close\r\n\r\n";
358 stream.send(request)
359 .await
360 .map_err(|e| ClientError::Transport(format!("failed to send ping request: {}", e)))?;
361
362 let response_bytes = stream
364 .recv()
365 .await
366 .map_err(|e| ClientError::Transport(format!("failed to receive ping response: {}", e)))?;
367
368 let response = String::from_utf8_lossy(&response_bytes);
369
370 if response.contains("200 OK") && response.contains("Pong") {
371 debug!("Successfully pinged gateway");
372 Ok(())
373 } else {
374 Err(ClientError::GatewayError {
375 status: 500,
376 message: format!("Unexpected ping response: {}", response),
377 })
378 }
379 }
380
381 pub async fn get_peer_info(&self, node_id_or_ip: &str) -> Result<PeerInfo, ClientError> {
386 debug!(peer = %node_id_or_ip, "Fetching peer info from gateway");
387
388 let url = format!("{}/api/peers/{}", self.gateway_url.trim_end_matches('/'), node_id_or_ip);
390
391 let client = reqwest::Client::new();
392 let response = client
393 .get(&url)
394 .send()
395 .await?;
396
397 let status = response.status();
398
399 if status.is_success() {
400 let peer_info: PeerInfo = response.json().await.map_err(|e| {
401 ClientError::InvalidResponse(format!("failed to parse peer info: {}", e))
402 })?;
403
404 debug!(
405 node_id = %peer_info.node_id,
406 assigned_ip = %peer_info.assigned_ip,
407 "Retrieved peer info from gateway"
408 );
409
410 Ok(peer_info)
411 } else {
412 let message = response
413 .text()
414 .await
415 .unwrap_or_else(|_| "unknown error".to_string());
416 Err(ClientError::GatewayError {
417 status: status.as_u16(),
418 message: format!("Failed to get peer info: {}", message),
419 })
420 }
421 }
422
423 pub async fn dial(&self, peer: &str, port: u16) -> Result<hightower_wireguard::connection::Stream, ClientError> {
443 let peer_info = self.get_peer_info(peer).await?;
445
446 let peer_public_key_bytes = hex::decode(&peer_info.public_key_hex)
448 .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key hex: {}", e)))?;
449
450 let peer_public_key: PublicKey25519 = peer_public_key_bytes
451 .as_slice()
452 .try_into()
453 .map_err(|e| ClientError::InvalidResponse(format!("invalid peer public key format: {:?}", e)))?;
454
455 self.transport
457 .connection()
458 .add_peer(peer_public_key, peer_info.endpoint)
459 .await
460 .map_err(|e| ClientError::Transport(format!("failed to add peer: {}", e)))?;
461
462 debug!(
463 peer_id = %peer_info.node_id,
464 peer_ip = %peer_info.assigned_ip,
465 port = port,
466 "Added peer and connecting"
467 );
468
469 let peer_addr: SocketAddr = format!("{}:{}", peer_info.assigned_ip, port)
471 .parse()
472 .map_err(|e| ClientError::Transport(format!("invalid peer address: {}", e)))?;
473
474 let stream = self
475 .transport
476 .connection()
477 .connect(peer_addr, peer_public_key)
478 .await
479 .map_err(|e| ClientError::Transport(format!("failed to connect to peer: {}", e)))?;
480
481 debug!(
482 peer_id = %peer_info.node_id,
483 addr = %peer_addr,
484 "Successfully connected to peer"
485 );
486
487 Ok(stream)
488 }
489
490 pub async fn disconnect(self) -> Result<(), ClientError> {
492 let url = format!("{}/{}", self.endpoint, self.token);
493
494 let client = reqwest::Client::new();
495 let response = client.delete(&url).send().await?;
496
497 let status = response.status();
498
499 if status.is_success() || status == StatusCode::NO_CONTENT {
500 debug!("Successfully deregistered from gateway");
501
502 if let Some(storage) = self.storage {
504 if let Err(e) = storage.delete_connection() {
505 warn!(error = ?e, "Failed to delete stored connection");
506 } else {
507 debug!("Deleted stored connection");
508 }
509 }
510
511 Ok(())
512 } else {
513 let message = response
514 .text()
515 .await
516 .unwrap_or_else(|_| "unknown error".to_string());
517 Err(ClientError::GatewayError {
518 status: status.as_u16(),
519 message,
520 })
521 }
522 }
523}
524
525fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
526 let gateway_url = gateway_url.trim();
527
528 if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
529 return Err(ClientError::Configuration(
530 "gateway_url must start with http:// or https://".into(),
531 ));
532 }
533
534 Ok(format!(
535 "{}{}",
536 gateway_url.trim_end_matches('/'),
537 API_PATH
538 ))
539}
540
541async fn parse_gateway_wireguard_endpoint(gateway_url: &str) -> Result<SocketAddr, ClientError> {
542 let parsed_url = url::Url::parse(gateway_url)
543 .map_err(|e| ClientError::Configuration(format!("invalid gateway URL: {}", e)))?;
544
545 let host = parsed_url.host_str()
546 .ok_or_else(|| ClientError::Configuration("gateway URL has no host".into()))?;
547
548 let endpoint_str = format!("{}:51820", host);
550
551 let mut addrs = tokio::net::lookup_host(&endpoint_str)
553 .await
554 .map_err(|e| ClientError::Configuration(format!("failed to resolve gateway endpoint {}: {}", endpoint_str, e)))?;
555
556 addrs.next()
557 .ok_or_else(|| ClientError::Configuration(format!("no addresses found for gateway endpoint: {}", endpoint_str)))
558}
559
560async fn register_with_gateway(
561 endpoint: &str,
562 auth_token: &str,
563 public_key_hex: &str,
564 network_info: &crate::types::NetworkInfo,
565) -> Result<RegistrationResponse, ClientError> {
566 let payload = RegistrationRequest {
567 public_key_hex,
568 public_ip: Some(network_info.public_ip.as_str()),
569 public_port: Some(network_info.public_port),
570 local_ip: Some(network_info.local_ip.as_str()),
571 local_port: Some(network_info.local_port),
572 };
573
574 let client = reqwest::Client::new();
575 let response = client
576 .post(endpoint)
577 .header("X-HT-Auth", auth_token)
578 .json(&payload)
579 .send()
580 .await?;
581
582 let status = response.status();
583
584 if status.is_success() {
585 let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
586 ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
587 })?;
588
589 Ok(registration_response)
590 } else {
591 let message = response
592 .text()
593 .await
594 .unwrap_or_else(|_| "unknown error".to_string());
595 Err(ClientError::GatewayError {
596 status: status.as_u16(),
597 message,
598 })
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn build_endpoint_requires_scheme() {
608 let result = build_endpoint("gateway.example.com:8008");
609 assert!(matches!(result, Err(ClientError::Configuration(_))));
610 }
611
612 #[test]
613 fn build_endpoint_accepts_http() {
614 let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
615 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
616 }
617
618 #[test]
619 fn build_endpoint_accepts_https() {
620 let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
621 assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
622 }
623
624 #[test]
625 fn build_endpoint_strips_trailing_slash() {
626 let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
627 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
628 }
629}