hightower_client/
connection.rs1use crate::error::ClientError;
2use crate::transport::TransportServer;
3use crate::types::{RegistrationRequest, RegistrationResponse};
4use hightower_wireguard::crypto::{dh_generate, PublicKey25519};
5use hightower_wireguard::transport::Server;
6use reqwest::StatusCode;
7use std::net::SocketAddr;
8use tracing::{debug, error};
9
10const DEFAULT_GATEWAY: &str = "http://127.0.0.1:8008";
11const API_PATH: &str = "/api/nodes";
12
13pub struct HightowerConnection {
15 transport: TransportServer,
16 node_id: String,
17 assigned_ip: String,
18 token: String,
19 endpoint: String,
20}
21
22impl HightowerConnection {
23 pub async fn connect(
34 gateway_url: impl Into<String>,
35 auth_token: impl Into<String>,
36 ) -> Result<Self, ClientError> {
37 let gateway_url = gateway_url.into();
38 let auth_token = auth_token.into();
39
40 if gateway_url.is_empty() {
41 return Err(ClientError::Configuration(
42 "gateway_url cannot be empty".into(),
43 ));
44 }
45
46 if auth_token.is_empty() {
47 return Err(ClientError::Configuration(
48 "auth_token cannot be empty".into(),
49 ));
50 }
51
52 let endpoint = build_endpoint(&gateway_url)?;
53
54 let (private_key, public_key) = dh_generate();
56 let public_key_hex = hex::encode(public_key);
57
58 debug!("Generated WireGuard keypair");
59
60 let bind_addr: SocketAddr = "0.0.0.0:0".parse().map_err(|e| {
62 ClientError::Configuration(format!("invalid bind address: {}", e))
63 })?;
64
65 let server = Server::new(bind_addr, private_key)
66 .await
67 .map_err(|e| ClientError::Transport(format!("failed to create transport server: {}", e)))?;
68
69 debug!("Created transport server");
70
71 let server_clone = server.clone();
73 tokio::spawn(async move {
74 if let Err(e) = server_clone.run().await {
75 error!(error = ?e, "Transport server error");
76 }
77 });
78
79 let server_maintenance = server.clone();
81 tokio::spawn(async move {
82 server_maintenance.run_maintenance().await
83 });
84
85 server
87 .wait_until_ready()
88 .await
89 .map_err(|e| ClientError::Transport(format!("transport not ready: {}", e)))?;
90
91 debug!("Transport server ready");
92
93 let local_addr = server
95 .local_addr()
96 .map_err(|e| ClientError::Transport(format!("failed to get local address: {}", e)))?;
97
98 let network_info = crate::ip_discovery::discover_with_bound_address(local_addr, None)
99 .map_err(|e| ClientError::NetworkDiscovery(e.to_string()))?;
100
101 debug!(
102 public_ip = %network_info.public_ip,
103 public_port = network_info.public_port,
104 local_ip = %network_info.local_ip,
105 local_port = network_info.local_port,
106 "Discovered network information"
107 );
108
109 let registration = register_with_gateway(
111 &endpoint,
112 &auth_token,
113 &public_key_hex,
114 &network_info,
115 )
116 .await?;
117
118 debug!(
119 node_id = %registration.node_id,
120 assigned_ip = %registration.assigned_ip,
121 "Registered with gateway"
122 );
123
124 let gateway_public_key_bytes = hex::decode(®istration.gateway_public_key_hex)
126 .map_err(|e| {
127 ClientError::InvalidResponse(format!("invalid gateway public key hex: {}", e))
128 })?;
129
130 let gateway_public_key: PublicKey25519 = gateway_public_key_bytes
131 .as_slice()
132 .try_into()
133 .map_err(|e| {
134 ClientError::InvalidResponse(format!("invalid gateway public key format: {:?}", e))
135 })?;
136
137 server
138 .add_peer(gateway_public_key, None)
139 .await
140 .map_err(|e| ClientError::Transport(format!("failed to add gateway as peer: {}", e)))?;
141
142 debug!("Added gateway as peer");
143
144 Ok(Self {
145 transport: TransportServer::new(server),
146 node_id: registration.node_id,
147 assigned_ip: registration.assigned_ip,
148 token: registration.token,
149 endpoint,
150 })
151 }
152
153 pub async fn connect_with_auth_token(auth_token: impl Into<String>) -> Result<Self, ClientError> {
155 Self::connect(DEFAULT_GATEWAY, auth_token).await
156 }
157
158 pub fn node_id(&self) -> &str {
160 &self.node_id
161 }
162
163 pub fn assigned_ip(&self) -> &str {
165 &self.assigned_ip
166 }
167
168 pub fn transport(&self) -> &TransportServer {
170 &self.transport
171 }
172
173 pub async fn disconnect(self) -> Result<(), ClientError> {
175 let url = format!("{}/{}", self.endpoint, self.token);
176
177 let client = reqwest::Client::new();
178 let response = client.delete(&url).send().await?;
179
180 let status = response.status();
181
182 if status.is_success() || status == StatusCode::NO_CONTENT {
183 debug!("Successfully deregistered from gateway");
184 Ok(())
185 } else {
186 let message = response
187 .text()
188 .await
189 .unwrap_or_else(|_| "unknown error".to_string());
190 Err(ClientError::GatewayError {
191 status: status.as_u16(),
192 message,
193 })
194 }
195 }
196}
197
198fn build_endpoint(gateway_url: &str) -> Result<String, ClientError> {
199 let gateway_url = gateway_url.trim();
200
201 if !gateway_url.starts_with("http://") && !gateway_url.starts_with("https://") {
202 return Err(ClientError::Configuration(
203 "gateway_url must start with http:// or https://".into(),
204 ));
205 }
206
207 Ok(format!(
208 "{}{}",
209 gateway_url.trim_end_matches('/'),
210 API_PATH
211 ))
212}
213
214async fn register_with_gateway(
215 endpoint: &str,
216 auth_token: &str,
217 public_key_hex: &str,
218 network_info: &crate::types::NetworkInfo,
219) -> Result<RegistrationResponse, ClientError> {
220 let payload = RegistrationRequest {
221 public_key_hex,
222 public_ip: Some(network_info.public_ip.as_str()),
223 public_port: Some(network_info.public_port),
224 local_ip: Some(network_info.local_ip.as_str()),
225 local_port: Some(network_info.local_port),
226 };
227
228 let client = reqwest::Client::new();
229 let response = client
230 .post(endpoint)
231 .header("X-HT-Auth", auth_token)
232 .json(&payload)
233 .send()
234 .await?;
235
236 let status = response.status();
237
238 if status.is_success() {
239 let registration_response: RegistrationResponse = response.json().await.map_err(|e| {
240 ClientError::InvalidResponse(format!("failed to parse registration response: {}", e))
241 })?;
242
243 Ok(registration_response)
244 } else {
245 let message = response
246 .text()
247 .await
248 .unwrap_or_else(|_| "unknown error".to_string());
249 Err(ClientError::GatewayError {
250 status: status.as_u16(),
251 message,
252 })
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn build_endpoint_requires_scheme() {
262 let result = build_endpoint("gateway.example.com:8008");
263 assert!(matches!(result, Err(ClientError::Configuration(_))));
264 }
265
266 #[test]
267 fn build_endpoint_accepts_http() {
268 let endpoint = build_endpoint("http://gateway.example.com:8008").unwrap();
269 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
270 }
271
272 #[test]
273 fn build_endpoint_accepts_https() {
274 let endpoint = build_endpoint("https://gateway.example.com:8443").unwrap();
275 assert_eq!(endpoint, "https://gateway.example.com:8443/api/nodes");
276 }
277
278 #[test]
279 fn build_endpoint_strips_trailing_slash() {
280 let endpoint = build_endpoint("http://gateway.example.com:8008/").unwrap();
281 assert_eq!(endpoint, "http://gateway.example.com:8008/api/nodes");
282 }
283}