ap_proxy/server/proxy_server.rs
1use ap_proxy_protocol::{IdentityFingerprint, ProxyError};
2
3use crate::connection::AuthenticatedConnection;
4use crate::server::handler::ConnectionHandler;
5use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::SystemTime;
10use tokio::net::TcpListener;
11use tokio::sync::RwLock;
12use tokio_tungstenite::accept_async;
13
14pub struct RendevouzEntry {
15 pub fingerprint: IdentityFingerprint,
16 pub created_at: SystemTime,
17 pub used: bool,
18}
19
20pub struct ServerState {
21 pub connections: Arc<RwLock<HashMap<IdentityFingerprint, Vec<Arc<AuthenticatedConnection>>>>>,
22 pub rendezvous_map: Arc<RwLock<HashMap<String, RendevouzEntry>>>,
23}
24
25impl Default for ServerState {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl ServerState {
32 pub fn new() -> Self {
33 Self {
34 connections: Arc::new(RwLock::new(HashMap::new())),
35 rendezvous_map: Arc::new(RwLock::new(HashMap::new())),
36 }
37 }
38}
39
40/// The proxy server that accepts client connections and relays messages.
41///
42/// This server handles:
43/// - Client authentication using MlDsa65 challenge-response
44/// - Rendezvous code generation and lookup
45/// - Message routing between authenticated clients
46/// - Automatic cleanup of expired rendezvous codes
47///
48/// # Examples
49///
50/// Run a standalone server:
51///
52/// ```no_run
53/// use ap_proxy::server::ProxyServer;
54/// use std::net::SocketAddr;
55///
56/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
57/// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
58/// let server = ProxyServer::new(addr);
59///
60/// println!("Starting proxy server on {}", addr);
61/// server.run().await?;
62/// # Ok(())
63/// # }
64/// ```
65///
66/// Embed in an application with cancellation:
67///
68/// ```no_run
69/// use ap_proxy::server::ProxyServer;
70/// use std::net::SocketAddr;
71/// use tokio::signal;
72///
73/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
74/// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
75/// let server = ProxyServer::new(addr);
76///
77/// tokio::select! {
78/// result = server.run() => {
79/// result?;
80/// }
81/// _ = signal::ctrl_c() => {
82/// println!("Shutting down...");
83/// }
84/// }
85/// # Ok(())
86/// # }
87/// ```
88pub struct ProxyServer {
89 bind_addr: SocketAddr,
90 state: Arc<ServerState>,
91 conn_counter: AtomicU64,
92}
93
94impl ProxyServer {
95 /// Create a new proxy server that will listen on the given address.
96 ///
97 /// This does not start the server - call [`run()`](ProxyServer::run) to begin
98 /// accepting connections.
99 ///
100 /// # Examples
101 ///
102 /// ```
103 /// use ap_proxy::server::ProxyServer;
104 /// use std::net::SocketAddr;
105 ///
106 /// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
107 /// let server = ProxyServer::new(addr);
108 /// ```
109 pub fn new(bind_addr: SocketAddr) -> Self {
110 Self {
111 bind_addr,
112 state: Arc::new(ServerState::new()),
113 conn_counter: AtomicU64::new(0),
114 }
115 }
116
117 /// Run the proxy server, accepting and handling connections.
118 ///
119 /// This method:
120 /// 1. Binds to the configured address
121 /// 2. Spawns a background task to clean up expired rendezvous codes
122 /// 3. Accepts incoming WebSocket connections
123 /// 4. Spawns a handler task for each connection
124 /// 5. Runs indefinitely until an error occurs or cancelled
125 ///
126 /// # Cancellation
127 ///
128 /// Use `tokio::select!` or similar to cancel the server:
129 ///
130 /// ```no_run
131 /// use ap_proxy::server::ProxyServer;
132 /// use std::net::SocketAddr;
133 /// use tokio::signal;
134 ///
135 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
136 /// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
137 /// let server = ProxyServer::new(addr);
138 ///
139 /// tokio::select! {
140 /// result = server.run() => result?,
141 /// _ = signal::ctrl_c() => {
142 /// println!("Shutting down gracefully");
143 /// }
144 /// }
145 /// # Ok(())
146 /// # }
147 /// ```
148 ///
149 /// # Errors
150 ///
151 /// Returns an error if:
152 /// - The bind address is already in use
153 /// - The address is invalid or cannot be bound
154 /// - A network error occurs
155 ///
156 /// # Examples
157 ///
158 /// ```no_run
159 /// use ap_proxy::server::ProxyServer;
160 /// use std::net::SocketAddr;
161 ///
162 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
163 /// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
164 /// let server = ProxyServer::new(addr);
165 /// server.run().await?;
166 /// # Ok(())
167 /// # }
168 /// ```
169 pub async fn run(&self) -> Result<(), ProxyError> {
170 let listener = TcpListener::bind(self.bind_addr).await?;
171 tracing::info!("Proxy server listening on {}", self.bind_addr);
172 self.run_with_listener(listener).await
173 }
174
175 /// Run the proxy server using an already-bound `TcpListener`.
176 ///
177 /// This is useful in tests to avoid the race condition of binding a port,
178 /// dropping the listener, and re-binding.
179 pub async fn run_with_listener(&self, listener: TcpListener) -> Result<(), ProxyError> {
180 // Spawn background cleanup task for expired rendezvous codes
181 let cleanup_state = Arc::clone(&self.state);
182 tokio::spawn(async move {
183 loop {
184 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
185
186 let mut rendezvous_map = cleanup_state.rendezvous_map.write().await;
187 let now = SystemTime::now();
188 let mut expired_codes = Vec::new();
189
190 for (code, entry) in rendezvous_map.iter() {
191 let elapsed = now.duration_since(entry.created_at).unwrap_or_default();
192
193 if elapsed.as_secs() > 300 {
194 expired_codes.push(code.clone());
195 }
196 }
197
198 for code in expired_codes {
199 rendezvous_map.remove(&code);
200 tracing::debug!("Cleaned up expired rendezvous code: {}", code);
201 }
202 }
203 });
204
205 loop {
206 let (stream, peer_addr) = listener.accept().await?;
207 let conn_id = self.conn_counter.fetch_add(1, Ordering::SeqCst);
208
209 tracing::info!("New connection #{} from {}", conn_id, peer_addr);
210
211 let state = Arc::clone(&self.state);
212
213 tokio::spawn(async move {
214 match accept_async(stream).await {
215 Ok(ws_stream) => {
216 let handler = ConnectionHandler::new(conn_id, state, ws_stream);
217 if let Err(e) = handler.handle().await {
218 tracing::error!("Connection #{} error: {}", conn_id, e);
219 }
220 }
221 Err(e) => {
222 tracing::error!(
223 "Failed to accept WebSocket connection #{}: {}",
224 conn_id,
225 e
226 );
227 }
228 }
229 });
230 }
231 }
232}