polysig_relay_server/
server.rs

1use futures::StreamExt;
2use std::{
3    collections::HashMap, net::SocketAddr, sync::Arc, time::Duration,
4};
5use tokio::sync::RwLock;
6use tokio_stream::wrappers::IntervalStream;
7
8use axum::{
9    extract::Extension,
10    http::{HeaderValue, Method, StatusCode},
11    response::{IntoResponse, Response},
12    routing::get,
13    Router,
14};
15use axum_server::{tls_rustls::RustlsConfig, Handle};
16use tower_http::{cors::CorsLayer, trace::TraceLayer};
17use uuid::Uuid;
18
19use polysig_protocol::{hex, uuid, Keypair, SessionManager};
20
21use crate::{
22    config::{ServerConfig, TlsConfig},
23    Result,
24};
25
26use crate::{service::RelayService, websocket::Connection};
27
28pub type State = Arc<RwLock<ServerState>>;
29pub(crate) type Service = Arc<RelayService>;
30
31async fn purge_expired(state: State, interval_secs: u64) {
32    let interval =
33        tokio::time::interval(Duration::from_secs(interval_secs));
34    let mut stream = IntervalStream::new(interval);
35    while stream.next().await.is_some() {
36        let mut writer = state.write().await;
37
38        let expired_sessions = writer
39            .sessions
40            .expired_keys(writer.config.session.timeout);
41        tracing::debug!(
42            expired_sessions = %expired_sessions.len());
43        for key in expired_sessions {
44            writer.sessions.remove_session(&key);
45        }
46    }
47}
48
49pub struct ServerState {
50    /// Server keypair.
51    pub(crate) keypair: Keypair,
52
53    /// Server config.
54    pub(crate) config: ServerConfig,
55
56    /// Pending socket connections in the handshake state.
57    pub(crate) pending: HashMap<Uuid, Connection>,
58
59    /// Active socket connections in the transport state.
60    ///
61    /// Now the hashmap key is the client's public key.
62    pub(crate) active: HashMap<Vec<u8>, Connection>,
63
64    /// Session manager.
65    pub(crate) sessions: SessionManager,
66}
67
68/// Relay web server.
69pub struct RelayServer {
70    state: State,
71}
72
73impl RelayServer {
74    /// Create a new relay server.
75    pub fn new(config: ServerConfig, keypair: Keypair) -> Self {
76        Self {
77            state: Arc::new(RwLock::new(ServerState {
78                keypair,
79                config,
80                pending: Default::default(),
81                active: Default::default(),
82                sessions: Default::default(),
83            })),
84        }
85    }
86
87    /// Start the server.
88    pub async fn start(
89        &self,
90        addr: SocketAddr,
91        handle: Handle,
92    ) -> Result<()> {
93        let reader = self.state.read().await;
94        let interval = reader.config.session.interval;
95        let tls = reader.config.tls.as_ref().cloned();
96        drop(reader);
97
98        // Spawn task to reap expired sessions
99        tokio::task::spawn(purge_expired(
100            Arc::clone(&self.state),
101            interval,
102        ));
103
104        if let Some(tls) = tls {
105            self.run_tls(addr, handle, tls).await
106        } else {
107            self.run(addr, handle).await
108        }
109    }
110
111    /// Start the server running on HTTPS.
112    async fn run_tls(
113        &self,
114        addr: SocketAddr,
115        handle: Handle,
116        tls: TlsConfig,
117    ) -> Result<()> {
118        let tls =
119            RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
120        let app = self.router(Arc::clone(&self.state)).await?;
121        let public_key = {
122            let reader = self.state.read().await;
123            reader.keypair.public_key().to_vec()
124        };
125        tracing::info!("listening on {}", addr);
126        tracing::info!("public key {}", hex::encode(&public_key));
127        axum_server::bind_rustls(addr, tls)
128            .handle(handle)
129            .serve(app.into_make_service())
130            .await?;
131        Ok(())
132    }
133
134    /// Start the server running on HTTP.
135    async fn run(
136        &self,
137        addr: SocketAddr,
138        handle: Handle,
139    ) -> Result<()> {
140        let app = self.router(Arc::clone(&self.state)).await?;
141        let public_key = {
142            let reader = self.state.read().await;
143            reader.keypair.public_key().to_vec()
144        };
145        tracing::info!("listening on {}", addr);
146        tracing::info!("public key {}", hex::encode(&public_key));
147        axum_server::bind(addr)
148            .handle(handle)
149            .serve(app.into_make_service())
150            .await?;
151        Ok(())
152    }
153
154    async fn router(&self, state: State) -> Result<Router> {
155        let origins = {
156            let reader = state.read().await;
157            let mut origins = Vec::new();
158            for url in reader.config.cors.origins.iter() {
159                tracing::info!(url = %url, "cors");
160                origins.push(HeaderValue::from_str(
161                    url.as_str().trim_end_matches('/'),
162                )?);
163            }
164            origins
165        };
166
167        let cors = CorsLayer::new()
168            .allow_methods(vec![Method::GET])
169            //.allow_headers(vec![])
170            //.expose_headers(vec![])
171            .allow_origin(origins);
172
173        let service = Arc::new(RelayService::new(Arc::clone(&state)));
174        let mut app = Router::new()
175            .route("/", get(crate::websocket::upgrade))
176            .route("/public-key", get(public_key));
177        app = app
178            .layer(cors)
179            .layer(TraceLayer::new_for_http())
180            .layer(Extension(service))
181            .layer(Extension(state));
182        Ok(app)
183    }
184}
185
186async fn public_key(
187    Extension(state): Extension<State>,
188) -> std::result::Result<Response, StatusCode> {
189    let reader = state.read().await;
190    let public_key = hex::encode(reader.keypair.public_key());
191    Ok((StatusCode::OK, public_key).into_response())
192}