mpc_relay_server/
server.rs

1use futures::StreamExt;
2use std::{
3    collections::HashMap, net::SocketAddr, sync::Arc, time::Duration,
4};
5use tokio::sync::RwLock;
6use tokio_stream::wrappers::IntervalStream;
7
8use axum::{
9    extract::Extension,
10    http::{HeaderValue, Method, StatusCode},
11    response::{IntoResponse, Response},
12    routing::get,
13    Router,
14};
15use axum_server::{tls_rustls::RustlsConfig, Handle};
16use tower_http::{cors::CorsLayer, trace::TraceLayer};
17use uuid::Uuid;
18
19use mpc_protocol::{
20    hex, uuid, Keypair, MeetingManager, SessionManager,
21};
22
23use crate::{
24    config::{ServerConfig, TlsConfig},
25    Result,
26};
27
28use crate::{service::RelayService, websocket::Connection};
29
30pub type State = Arc<RwLock<ServerState>>;
31pub(crate) type Service = Arc<RelayService>;
32
33async fn purge_expired(state: State, interval_secs: u64) {
34    let interval =
35        tokio::time::interval(Duration::from_secs(interval_secs));
36    let mut stream = IntervalStream::new(interval);
37    while stream.next().await.is_some() {
38        let mut writer = state.write().await;
39        let expired_meetings = writer
40            .meetings
41            .expired_keys(writer.config.session.timeout);
42        tracing::debug!(
43            expired_meetings = %expired_meetings.len());
44        for key in expired_meetings {
45            writer.meetings.remove_meeting(&key);
46        }
47
48        let expired_sessions = writer
49            .sessions
50            .expired_keys(writer.config.session.timeout);
51        tracing::debug!(
52            expired_sessions = %expired_sessions.len());
53        for key in expired_sessions {
54            writer.sessions.remove_session(&key);
55        }
56    }
57}
58
59pub struct ServerState {
60    /// Server keypair.
61    pub(crate) keypair: Keypair,
62
63    /// Server config.
64    pub(crate) config: ServerConfig,
65
66    /// Pending socket connections in the handshake state.
67    pub(crate) pending: HashMap<Uuid, Connection>,
68
69    /// Active socket connections in the transport state.
70    ///
71    /// Now the hashmap key is the client's public key.
72    pub(crate) active: HashMap<Vec<u8>, Connection>,
73
74    /// Meeting point manager.
75    pub(crate) meetings: MeetingManager,
76
77    /// Session manager.
78    pub(crate) sessions: SessionManager,
79}
80
81/// Relay web server.
82pub struct RelayServer {
83    state: State,
84}
85
86impl RelayServer {
87    /// Create a new relay server.
88    pub fn new(config: ServerConfig, keypair: Keypair) -> Self {
89        Self {
90            state: Arc::new(RwLock::new(ServerState {
91                keypair,
92                config,
93                pending: Default::default(),
94                active: Default::default(),
95                meetings: Default::default(),
96                sessions: Default::default(),
97            })),
98        }
99    }
100
101    /// Start the server.
102    pub async fn start(
103        &self,
104        addr: SocketAddr,
105        handle: Handle,
106    ) -> Result<()> {
107        let reader = self.state.read().await;
108        let interval = reader.config.session.interval;
109        let tls = reader.config.tls.as_ref().cloned();
110        drop(reader);
111
112        // Spawn task to reap expired sessions
113        tokio::task::spawn(purge_expired(
114            Arc::clone(&self.state),
115            interval,
116        ));
117
118        if let Some(tls) = tls {
119            self.run_tls(addr, handle, tls).await
120        } else {
121            self.run(addr, handle).await
122        }
123    }
124
125    /// Start the server running on HTTPS.
126    async fn run_tls(
127        &self,
128        addr: SocketAddr,
129        handle: Handle,
130        tls: TlsConfig,
131    ) -> Result<()> {
132        let tls =
133            RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
134        let app = self.router(Arc::clone(&self.state)).await?;
135        let public_key = {
136            let reader = self.state.read().await;
137            reader.keypair.public_key().to_vec()
138        };
139        tracing::info!("listening on {}", addr);
140        tracing::info!("public key {}", hex::encode(&public_key));
141        axum_server::bind_rustls(addr, tls)
142            .handle(handle)
143            .serve(app.into_make_service())
144            .await?;
145        Ok(())
146    }
147
148    /// Start the server running on HTTP.
149    async fn run(
150        &self,
151        addr: SocketAddr,
152        handle: Handle,
153    ) -> Result<()> {
154        let app = self.router(Arc::clone(&self.state)).await?;
155        let public_key = {
156            let reader = self.state.read().await;
157            reader.keypair.public_key().to_vec()
158        };
159        tracing::info!("listening on {}", addr);
160        tracing::info!("public key {}", hex::encode(&public_key));
161        axum_server::bind(addr)
162            .handle(handle)
163            .serve(app.into_make_service())
164            .await?;
165        Ok(())
166    }
167
168    async fn router(&self, state: State) -> Result<Router> {
169        let origins = {
170            let reader = state.read().await;
171            let mut origins = Vec::new();
172            for url in reader.config.cors.origins.iter() {
173                tracing::info!(url = %url, "cors");
174                origins.push(HeaderValue::from_str(
175                    url.as_str().trim_end_matches('/'),
176                )?);
177            }
178            origins
179        };
180
181        let cors = CorsLayer::new()
182            .allow_methods(vec![Method::GET])
183            //.allow_headers(vec![])
184            //.expose_headers(vec![])
185            .allow_origin(origins);
186
187        let service = Arc::new(RelayService::new(Arc::clone(&state)));
188        let mut app = Router::new()
189            .route("/", get(crate::websocket::upgrade))
190            .route("/public-key", get(public_key));
191        app = app
192            .layer(cors)
193            .layer(TraceLayer::new_for_http())
194            .layer(Extension(service))
195            .layer(Extension(state));
196        Ok(app)
197    }
198}
199
200async fn public_key(
201    Extension(state): Extension<State>,
202) -> std::result::Result<Response, StatusCode> {
203    let reader = state.read().await;
204    let public_key = hex::encode(reader.keypair.public_key());
205    Ok((StatusCode::OK, public_key).into_response())
206}