polysig_relay_server/
server.rs1use futures::StreamExt;
2use std::{
3 collections::HashMap, net::SocketAddr, sync::Arc, time::Duration,
4};
5use tokio::sync::RwLock;
6use tokio_stream::wrappers::IntervalStream;
7
8use axum::{
9 extract::Extension,
10 http::{HeaderValue, Method, StatusCode},
11 response::{IntoResponse, Response},
12 routing::get,
13 Router,
14};
15use axum_server::{tls_rustls::RustlsConfig, Handle};
16use tower_http::{cors::CorsLayer, trace::TraceLayer};
17use uuid::Uuid;
18
19use polysig_protocol::{hex, uuid, Keypair, SessionManager};
20
21use crate::{
22 config::{ServerConfig, TlsConfig},
23 Result,
24};
25
26use crate::{service::RelayService, websocket::Connection};
27
28pub type State = Arc<RwLock<ServerState>>;
29pub(crate) type Service = Arc<RelayService>;
30
31async fn purge_expired(state: State, interval_secs: u64) {
32 let interval =
33 tokio::time::interval(Duration::from_secs(interval_secs));
34 let mut stream = IntervalStream::new(interval);
35 while stream.next().await.is_some() {
36 let mut writer = state.write().await;
37
38 let expired_sessions = writer
39 .sessions
40 .expired_keys(writer.config.session.timeout);
41 tracing::debug!(
42 expired_sessions = %expired_sessions.len());
43 for key in expired_sessions {
44 writer.sessions.remove_session(&key);
45 }
46 }
47}
48
49pub struct ServerState {
50 pub(crate) keypair: Keypair,
52
53 pub(crate) config: ServerConfig,
55
56 pub(crate) pending: HashMap<Uuid, Connection>,
58
59 pub(crate) active: HashMap<Vec<u8>, Connection>,
63
64 pub(crate) sessions: SessionManager,
66}
67
68pub struct RelayServer {
70 state: State,
71}
72
73impl RelayServer {
74 pub fn new(config: ServerConfig, keypair: Keypair) -> Self {
76 Self {
77 state: Arc::new(RwLock::new(ServerState {
78 keypair,
79 config,
80 pending: Default::default(),
81 active: Default::default(),
82 sessions: Default::default(),
83 })),
84 }
85 }
86
87 pub async fn start(
89 &self,
90 addr: SocketAddr,
91 handle: Handle,
92 ) -> Result<()> {
93 let reader = self.state.read().await;
94 let interval = reader.config.session.interval;
95 let tls = reader.config.tls.as_ref().cloned();
96 drop(reader);
97
98 tokio::task::spawn(purge_expired(
100 Arc::clone(&self.state),
101 interval,
102 ));
103
104 if let Some(tls) = tls {
105 self.run_tls(addr, handle, tls).await
106 } else {
107 self.run(addr, handle).await
108 }
109 }
110
111 async fn run_tls(
113 &self,
114 addr: SocketAddr,
115 handle: Handle,
116 tls: TlsConfig,
117 ) -> Result<()> {
118 let tls =
119 RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
120 let app = self.router(Arc::clone(&self.state)).await?;
121 let public_key = {
122 let reader = self.state.read().await;
123 reader.keypair.public_key().to_vec()
124 };
125 tracing::info!("listening on {}", addr);
126 tracing::info!("public key {}", hex::encode(&public_key));
127 axum_server::bind_rustls(addr, tls)
128 .handle(handle)
129 .serve(app.into_make_service())
130 .await?;
131 Ok(())
132 }
133
134 async fn run(
136 &self,
137 addr: SocketAddr,
138 handle: Handle,
139 ) -> Result<()> {
140 let app = self.router(Arc::clone(&self.state)).await?;
141 let public_key = {
142 let reader = self.state.read().await;
143 reader.keypair.public_key().to_vec()
144 };
145 tracing::info!("listening on {}", addr);
146 tracing::info!("public key {}", hex::encode(&public_key));
147 axum_server::bind(addr)
148 .handle(handle)
149 .serve(app.into_make_service())
150 .await?;
151 Ok(())
152 }
153
154 async fn router(&self, state: State) -> Result<Router> {
155 let origins = {
156 let reader = state.read().await;
157 let mut origins = Vec::new();
158 for url in reader.config.cors.origins.iter() {
159 tracing::info!(url = %url, "cors");
160 origins.push(HeaderValue::from_str(
161 url.as_str().trim_end_matches('/'),
162 )?);
163 }
164 origins
165 };
166
167 let cors = CorsLayer::new()
168 .allow_methods(vec![Method::GET])
169 .allow_origin(origins);
172
173 let service = Arc::new(RelayService::new(Arc::clone(&state)));
174 let mut app = Router::new()
175 .route("/", get(crate::websocket::upgrade))
176 .route("/public-key", get(public_key));
177 app = app
178 .layer(cors)
179 .layer(TraceLayer::new_for_http())
180 .layer(Extension(service))
181 .layer(Extension(state));
182 Ok(app)
183 }
184}
185
186async fn public_key(
187 Extension(state): Extension<State>,
188) -> std::result::Result<Response, StatusCode> {
189 let reader = state.read().await;
190 let public_key = hex::encode(reader.keypair.public_key());
191 Ok((StatusCode::OK, public_key).into_response())
192}