Skip to main content

ap_proxy/server/
proxy_server.rs

1use ap_proxy_protocol::{IdentityFingerprint, ProxyError};
2
3use crate::connection::AuthenticatedConnection;
4use crate::server::handler::ConnectionHandler;
5use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::SystemTime;
10use tokio::net::TcpListener;
11use tokio::sync::RwLock;
12use tokio_tungstenite::accept_async;
13
14pub struct RendevouzEntry {
15    pub fingerprint: IdentityFingerprint,
16    pub created_at: SystemTime,
17    pub used: bool,
18}
19
20pub struct ServerState {
21    pub connections: Arc<RwLock<HashMap<IdentityFingerprint, Vec<Arc<AuthenticatedConnection>>>>>,
22    pub rendezvous_map: Arc<RwLock<HashMap<String, RendevouzEntry>>>,
23}
24
25impl Default for ServerState {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl ServerState {
32    pub fn new() -> Self {
33        Self {
34            connections: Arc::new(RwLock::new(HashMap::new())),
35            rendezvous_map: Arc::new(RwLock::new(HashMap::new())),
36        }
37    }
38}
39
40/// The proxy server that accepts client connections and relays messages.
41///
42/// This server handles:
43/// - Client authentication using MlDsa65 challenge-response
44/// - Rendezvous code generation and lookup
45/// - Message routing between authenticated clients
46/// - Automatic cleanup of expired rendezvous codes
47///
48/// # Examples
49///
50/// Run a standalone server:
51///
52/// ```no_run
53/// use ap_proxy::server::ProxyServer;
54/// use std::net::SocketAddr;
55///
56/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
57/// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
58/// let server = ProxyServer::new(addr);
59///
60/// println!("Starting proxy server on {}", addr);
61/// server.run().await?;
62/// # Ok(())
63/// # }
64/// ```
65///
66/// Embed in an application with cancellation:
67///
68/// ```no_run
69/// use ap_proxy::server::ProxyServer;
70/// use std::net::SocketAddr;
71/// use tokio::signal;
72///
73/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
74/// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
75/// let server = ProxyServer::new(addr);
76///
77/// tokio::select! {
78///     result = server.run() => {
79///         result?;
80///     }
81///     _ = signal::ctrl_c() => {
82///         println!("Shutting down...");
83///     }
84/// }
85/// # Ok(())
86/// # }
87/// ```
88pub struct ProxyServer {
89    bind_addr: SocketAddr,
90    state: Arc<ServerState>,
91    conn_counter: AtomicU64,
92}
93
94impl ProxyServer {
95    /// Create a new proxy server that will listen on the given address.
96    ///
97    /// This does not start the server - call [`run()`](ProxyServer::run) to begin
98    /// accepting connections.
99    ///
100    /// # Examples
101    ///
102    /// ```
103    /// use ap_proxy::server::ProxyServer;
104    /// use std::net::SocketAddr;
105    ///
106    /// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
107    /// let server = ProxyServer::new(addr);
108    /// ```
109    pub fn new(bind_addr: SocketAddr) -> Self {
110        Self {
111            bind_addr,
112            state: Arc::new(ServerState::new()),
113            conn_counter: AtomicU64::new(0),
114        }
115    }
116
117    /// Run the proxy server, accepting and handling connections.
118    ///
119    /// This method:
120    /// 1. Binds to the configured address
121    /// 2. Spawns a background task to clean up expired rendezvous codes
122    /// 3. Accepts incoming WebSocket connections
123    /// 4. Spawns a handler task for each connection
124    /// 5. Runs indefinitely until an error occurs or cancelled
125    ///
126    /// # Cancellation
127    ///
128    /// Use `tokio::select!` or similar to cancel the server:
129    ///
130    /// ```no_run
131    /// use ap_proxy::server::ProxyServer;
132    /// use std::net::SocketAddr;
133    /// use tokio::signal;
134    ///
135    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
136    /// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
137    /// let server = ProxyServer::new(addr);
138    ///
139    /// tokio::select! {
140    ///     result = server.run() => result?,
141    ///     _ = signal::ctrl_c() => {
142    ///         println!("Shutting down gracefully");
143    ///     }
144    /// }
145    /// # Ok(())
146    /// # }
147    /// ```
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if:
152    /// - The bind address is already in use
153    /// - The address is invalid or cannot be bound
154    /// - A network error occurs
155    ///
156    /// # Examples
157    ///
158    /// ```no_run
159    /// use ap_proxy::server::ProxyServer;
160    /// use std::net::SocketAddr;
161    ///
162    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
163    /// let addr: SocketAddr = "127.0.0.1:8080".parse()?;
164    /// let server = ProxyServer::new(addr);
165    /// server.run().await?;
166    /// # Ok(())
167    /// # }
168    /// ```
169    pub async fn run(&self) -> Result<(), ProxyError> {
170        let listener = TcpListener::bind(self.bind_addr).await?;
171        tracing::info!("Proxy server listening on {}", self.bind_addr);
172        self.run_with_listener(listener).await
173    }
174
175    /// Run the proxy server using an already-bound `TcpListener`.
176    ///
177    /// This is useful in tests to avoid the race condition of binding a port,
178    /// dropping the listener, and re-binding.
179    pub async fn run_with_listener(&self, listener: TcpListener) -> Result<(), ProxyError> {
180        // Spawn background cleanup task for expired rendezvous codes
181        let cleanup_state = Arc::clone(&self.state);
182        tokio::spawn(async move {
183            loop {
184                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
185
186                let mut rendezvous_map = cleanup_state.rendezvous_map.write().await;
187                let now = SystemTime::now();
188                let mut expired_codes = Vec::new();
189
190                for (code, entry) in rendezvous_map.iter() {
191                    let elapsed = now.duration_since(entry.created_at).unwrap_or_default();
192
193                    if elapsed.as_secs() > 300 {
194                        expired_codes.push(code.clone());
195                    }
196                }
197
198                for code in expired_codes {
199                    rendezvous_map.remove(&code);
200                    tracing::debug!("Cleaned up expired rendezvous code: {}", code);
201                }
202            }
203        });
204
205        loop {
206            let (stream, peer_addr) = listener.accept().await?;
207            let conn_id = self.conn_counter.fetch_add(1, Ordering::SeqCst);
208
209            tracing::info!("New connection #{} from {}", conn_id, peer_addr);
210
211            let state = Arc::clone(&self.state);
212
213            tokio::spawn(async move {
214                match accept_async(stream).await {
215                    Ok(ws_stream) => {
216                        let handler = ConnectionHandler::new(conn_id, state, ws_stream);
217                        if let Err(e) = handler.handle().await {
218                            tracing::error!("Connection #{} error: {}", conn_id, e);
219                        }
220                    }
221                    Err(e) => {
222                        tracing::error!(
223                            "Failed to accept WebSocket connection #{}: {}",
224                            conn_id,
225                            e
226                        );
227                    }
228                }
229            });
230        }
231    }
232}