mpc_relay_server/
server.rs1use futures::StreamExt;
2use std::{
3 collections::HashMap, net::SocketAddr, sync::Arc, time::Duration,
4};
5use tokio::sync::RwLock;
6use tokio_stream::wrappers::IntervalStream;
7
8use axum::{
9 extract::Extension,
10 http::{HeaderValue, Method, StatusCode},
11 response::{IntoResponse, Response},
12 routing::get,
13 Router,
14};
15use axum_server::{tls_rustls::RustlsConfig, Handle};
16use tower_http::{cors::CorsLayer, trace::TraceLayer};
17use uuid::Uuid;
18
19use mpc_protocol::{
20 hex, uuid, Keypair, MeetingManager, SessionManager,
21};
22
23use crate::{
24 config::{ServerConfig, TlsConfig},
25 Result,
26};
27
28use crate::{service::RelayService, websocket::Connection};
29
30pub type State = Arc<RwLock<ServerState>>;
31pub(crate) type Service = Arc<RelayService>;
32
33async fn purge_expired(state: State, interval_secs: u64) {
34 let interval =
35 tokio::time::interval(Duration::from_secs(interval_secs));
36 let mut stream = IntervalStream::new(interval);
37 while stream.next().await.is_some() {
38 let mut writer = state.write().await;
39 let expired_meetings = writer
40 .meetings
41 .expired_keys(writer.config.session.timeout);
42 tracing::debug!(
43 expired_meetings = %expired_meetings.len());
44 for key in expired_meetings {
45 writer.meetings.remove_meeting(&key);
46 }
47
48 let expired_sessions = writer
49 .sessions
50 .expired_keys(writer.config.session.timeout);
51 tracing::debug!(
52 expired_sessions = %expired_sessions.len());
53 for key in expired_sessions {
54 writer.sessions.remove_session(&key);
55 }
56 }
57}
58
59pub struct ServerState {
60 pub(crate) keypair: Keypair,
62
63 pub(crate) config: ServerConfig,
65
66 pub(crate) pending: HashMap<Uuid, Connection>,
68
69 pub(crate) active: HashMap<Vec<u8>, Connection>,
73
74 pub(crate) meetings: MeetingManager,
76
77 pub(crate) sessions: SessionManager,
79}
80
81pub struct RelayServer {
83 state: State,
84}
85
86impl RelayServer {
87 pub fn new(config: ServerConfig, keypair: Keypair) -> Self {
89 Self {
90 state: Arc::new(RwLock::new(ServerState {
91 keypair,
92 config,
93 pending: Default::default(),
94 active: Default::default(),
95 meetings: Default::default(),
96 sessions: Default::default(),
97 })),
98 }
99 }
100
101 pub async fn start(
103 &self,
104 addr: SocketAddr,
105 handle: Handle,
106 ) -> Result<()> {
107 let reader = self.state.read().await;
108 let interval = reader.config.session.interval;
109 let tls = reader.config.tls.as_ref().cloned();
110 drop(reader);
111
112 tokio::task::spawn(purge_expired(
114 Arc::clone(&self.state),
115 interval,
116 ));
117
118 if let Some(tls) = tls {
119 self.run_tls(addr, handle, tls).await
120 } else {
121 self.run(addr, handle).await
122 }
123 }
124
125 async fn run_tls(
127 &self,
128 addr: SocketAddr,
129 handle: Handle,
130 tls: TlsConfig,
131 ) -> Result<()> {
132 let tls =
133 RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
134 let app = self.router(Arc::clone(&self.state)).await?;
135 let public_key = {
136 let reader = self.state.read().await;
137 reader.keypair.public_key().to_vec()
138 };
139 tracing::info!("listening on {}", addr);
140 tracing::info!("public key {}", hex::encode(&public_key));
141 axum_server::bind_rustls(addr, tls)
142 .handle(handle)
143 .serve(app.into_make_service())
144 .await?;
145 Ok(())
146 }
147
148 async fn run(
150 &self,
151 addr: SocketAddr,
152 handle: Handle,
153 ) -> Result<()> {
154 let app = self.router(Arc::clone(&self.state)).await?;
155 let public_key = {
156 let reader = self.state.read().await;
157 reader.keypair.public_key().to_vec()
158 };
159 tracing::info!("listening on {}", addr);
160 tracing::info!("public key {}", hex::encode(&public_key));
161 axum_server::bind(addr)
162 .handle(handle)
163 .serve(app.into_make_service())
164 .await?;
165 Ok(())
166 }
167
168 async fn router(&self, state: State) -> Result<Router> {
169 let origins = {
170 let reader = state.read().await;
171 let mut origins = Vec::new();
172 for url in reader.config.cors.origins.iter() {
173 tracing::info!(url = %url, "cors");
174 origins.push(HeaderValue::from_str(
175 url.as_str().trim_end_matches('/'),
176 )?);
177 }
178 origins
179 };
180
181 let cors = CorsLayer::new()
182 .allow_methods(vec![Method::GET])
183 .allow_origin(origins);
186
187 let service = Arc::new(RelayService::new(Arc::clone(&state)));
188 let mut app = Router::new()
189 .route("/", get(crate::websocket::upgrade))
190 .route("/public-key", get(public_key));
191 app = app
192 .layer(cors)
193 .layer(TraceLayer::new_for_http())
194 .layer(Extension(service))
195 .layer(Extension(state));
196 Ok(app)
197 }
198}
199
200async fn public_key(
201 Extension(state): Extension<State>,
202) -> std::result::Result<Response, StatusCode> {
203 let reader = state.read().await;
204 let public_key = hex::encode(reader.keypair.public_key());
205 Ok((StatusCode::OK, public_key).into_response())
206}