1#![warn(missing_docs)]
5#![warn(rustdoc::missing_doc_code_examples)]
6
7use std::{
8 collections::{HashMap, HashSet},
9 error,
10 fmt::{self, Debug},
11 future::Future,
12 net::{IpAddr, SocketAddr},
13 sync::{
14 atomic::{AtomicU64, Ordering},
15 Arc,
16 },
17};
18
19use async_trait::async_trait;
20use axum::{
21 body::{Body, Bytes},
22 extract::{DefaultBodyLimit, Query, State, WebSocketUpgrade},
23 http::{
24 header::{AUTHORIZATION, CONTENT_TYPE},
25 response::Builder,
26 HeaderValue, Method, StatusCode,
27 },
28 middleware::Next,
29 response::{IntoResponse as _, Response as HttpResponse},
30 routing::{any, post},
31 Router,
32};
33use axum_extra::{
34 headers::{authorization::Basic, Authorization},
35 TypedHeader,
36};
37use blake2::{digest::consts::U32, Blake2b, Digest};
38use futures::{
39 pin_mut,
40 sink::SinkExt,
41 stream::{FuturesUnordered, StreamExt},
42 Stream,
43};
44use serde::{de::Deserialize, ser::Serialize};
45use serde_json::Value;
46use subtle::ConstantTimeEq;
47use thiserror::Error;
48use tokio::{
49 net::TcpListener,
50 sync::{mpsc, RwLock, RwLockReadGuard, RwLockWriteGuard},
51};
52
53use nimiq_jsonrpc_core::{
54 FrameType, Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId,
55 SubscriptionMessage,
56};
57
58pub use axum::extract::ws::Message;
59pub use tokio::sync::Notify;
60use tower_http::cors::{Any, CorsLayer};
61
62pub type ResponseAndSubScriptionNotifier = (Response, Option<Arc<Notify>>);
64
65#[derive(Debug, Error)]
67pub enum Error {
68 #[error("HTTP error: {0}")]
70 Axum(#[from] axum::Error),
71
72 #[error("Queue error: {0}")]
74 Mpsc(#[from] tokio::sync::mpsc::error::SendError<Message>),
75
76 #[error("JSON error: {0}")]
78 Json(#[from] serde_json::Error),
79
80 #[error("JSON RPC error: {0}")]
82 JsonRpc(#[from] nimiq_jsonrpc_core::Error),
83}
84
85#[derive(Clone, Debug)]
92pub struct Config {
93 pub bind_to: SocketAddr,
95
96 pub enable_websocket: bool,
98
99 pub ip_whitelist: Option<HashSet<IpAddr>>,
101
102 pub basic_auth: Option<Credentials>,
104
105 pub cors: Option<Cors>,
107}
108
109impl Default for Config {
110 fn default() -> Self {
111 Self {
112 bind_to: ([127, 0, 0, 1], 8000).into(),
113 enable_websocket: true,
114 ip_whitelist: None,
115 basic_auth: None,
116 cors: None,
117 }
118 }
119}
120
121fn blake2b(bytes: &[u8]) -> [u8; 32] {
122 *Blake2b::<U32>::digest(bytes).as_ref()
123}
124
125async fn basic_auth_middleware<D: Dispatcher>(
126 State(state): State<Arc<Inner<D>>>,
127 basic_auth_header: Option<TypedHeader<Authorization<Basic>>>,
128 request: axum::extract::Request,
129 next: Next,
130) -> HttpResponse {
131 let auth_config = if let Some(auth_config) = &state.config.basic_auth {
132 auth_config
133 } else {
134 return next.run(request).await;
136 };
137
138 let auth_header = if let Some(auth_header) = basic_auth_header {
139 auth_header
140 } else {
141 return StatusCode::UNAUTHORIZED.into_response();
143 };
144
145 if auth_config
146 .verify(auth_header.username(), auth_header.password())
147 .is_ok()
148 {
149 next.run(request).await
151 } else {
152 StatusCode::UNAUTHORIZED.into_response()
154 }
155}
156
157#[derive(Clone, Debug)]
158pub struct Cors(CorsLayer);
160
161impl Cors {
162 pub fn new() -> Self {
164 Self(
165 CorsLayer::new()
166 .allow_headers([AUTHORIZATION, CONTENT_TYPE])
167 .allow_methods([Method::POST]),
168 )
169 }
170
171 pub fn with_origins(mut self, origins: Vec<String>) -> Self {
174 self.0 = self.0.allow_origin::<Vec<HeaderValue>>(
175 origins
176 .iter()
177 .map(|o| o.parse::<HeaderValue>().unwrap())
178 .collect(),
179 );
180 self
181 }
182
183 pub fn with_any_origin(mut self) -> Self {
186 self.0 = self.0.allow_origin(Any);
187 self
188 }
189
190 pub(crate) fn into_layer(self) -> CorsLayer {
191 self.0
192 }
193}
194
195impl Default for Cors {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[derive(Clone, Debug)]
203pub struct Credentials {
204 username: String,
205 password_blake2b: Sensitive<[u8; 32]>,
206}
207
208#[derive(Clone, Debug)]
210pub struct CredentialsVerificationError(());
211
212impl error::Error for CredentialsVerificationError {}
213impl fmt::Display for CredentialsVerificationError {
214 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
215 fmt::Display::fmt("invalid username or password", f)
216 }
217}
218
219impl Credentials {
220 pub fn new<T: Into<String>, U: AsRef<str>>(username: T, password: U) -> Credentials {
222 Credentials::new_from_blake2b(username, blake2b(password.as_ref().as_bytes()))
223 }
224 pub fn new_from_blake2b<T: Into<String>>(
226 username: T,
227 password_blake2b: [u8; 32],
228 ) -> Credentials {
229 Credentials {
230 username: username.into(),
231 password_blake2b: Sensitive(password_blake2b),
232 }
233 }
234 pub fn verify<T: AsRef<str>, U: AsRef<str>>(
236 &self,
237 username: T,
238 password: U,
239 ) -> Result<(), CredentialsVerificationError> {
240 if (self.username.as_bytes().ct_eq(username.as_ref().as_bytes())
241 & self
242 .password_blake2b
243 .ct_eq(&blake2b(password.as_ref().as_bytes())))
244 .into()
245 {
246 Ok(())
247 } else {
248 Err(CredentialsVerificationError(()))
249 }
250 }
251}
252
253struct Inner<D: Dispatcher> {
254 config: Config,
255 dispatcher: RwLock<D>,
256 next_id: AtomicU64,
257 subscription_notifiers: RwLock<HashMap<SubscriptionId, Arc<Notify>>>,
258}
259
260pub struct Server<D: Dispatcher> {
262 inner: Arc<Inner<D>>,
263}
264
265impl<D: Dispatcher> Server<D> {
266 pub fn new(config: Config, dispatcher: D) -> Self {
274 Self {
275 inner: Arc::new(Inner {
276 config,
277 dispatcher: RwLock::new(dispatcher),
278 next_id: AtomicU64::new(1),
279 subscription_notifiers: RwLock::new(HashMap::new()),
280 }),
281 }
282 }
283
284 pub async fn config(&self) -> &Config {
286 &self.inner.config
287 }
288
289 pub async fn dispatcher(&self) -> RwLockReadGuard<'_, D> {
291 self.inner.dispatcher.read().await
292 }
293
294 pub async fn dispatcher_mut(&self) -> RwLockWriteGuard<'_, D> {
296 self.inner.dispatcher.write().await
297 }
298
299 pub async fn run(&self) {
301 let inner = Arc::clone(&self.inner);
302 let http_router = Router::new().route(
303 "/",
304 post(|body: Bytes| async move {
305 let data = Self::handle_raw_request(inner, &Message::binary(body), None, None)
306 .await
307 .unwrap_or(Message::Binary(Bytes::new()));
308
309 Builder::new()
310 .status(StatusCode::OK)
311 .header(CONTENT_TYPE, "application/json")
312 .body(Body::from(data.into_data().to_owned()))
313 .unwrap() }),
315 );
316
317 let inner = Arc::clone(&self.inner);
318 let ws_router = Router::new().route(
319 "/ws",
320 any(
321 |Query(params): Query<HashMap<String, String>>, ws: WebSocketUpgrade| async move {
322 Self::upgrade_to_ws(inner, ws, params)
323 },
324 ),
325 );
326
327 let app = Router::new()
328 .merge(http_router)
329 .merge(ws_router)
330 .route_layer(axum::middleware::from_fn_with_state(
331 Arc::clone(&self.inner),
332 basic_auth_middleware,
333 ))
334 .layer(DefaultBodyLimit::max(1024 * 1024 ))
335 .layer(
336 self.inner
337 .config
338 .cors
339 .clone()
340 .unwrap_or_default()
341 .into_layer(),
342 )
343 .with_state(Arc::clone(&self.inner));
344
345 let listener = TcpListener::bind(self.inner.config.bind_to).await.unwrap();
346 axum::serve(listener, app).await.unwrap();
347 }
348
349 fn upgrade_to_ws(
359 inner: Arc<Inner<D>>,
360 ws: WebSocketUpgrade,
361 query_params: HashMap<String, String>,
362 ) -> HttpResponse<Body> {
363 let frame_type: Option<FrameType> = query_params
364 .get("frame")
365 .map(|frame_type| Some(frame_type.into()))
366 .unwrap_or_default();
367
368 ws.on_upgrade(move |websocket| {
369 let (mut tx, mut rx) = websocket.split();
370
371 let (multiplex_tx, mut multiplex_rx) = mpsc::channel::<Message>(16); let forward_fut = async move {
375 while let Some(data) = multiplex_rx.recv().await {
376 if matches!(data, Message::Close(_)) {
378 tx.close().await?;
379 } else {
380 tx.send(data).await?;
381 }
382 }
383 Ok::<(), Error>(())
384 };
385
386 let handle_fut = {
388 async move {
389 while let Some(message) = rx.next().await.transpose()? {
390 if matches!(message, Message::Ping(_))
391 || matches!(message, Message::Pong(_))
392 {
393 } else if matches!(message, Message::Close(_)) {
395 multiplex_tx.send(Message::Close(None)).await?;
397 break;
399 } else if let Some(response) = Self::handle_raw_request(
400 Arc::clone(&inner),
401 &message,
402 Some(&multiplex_tx),
403 frame_type,
404 )
405 .await
406 {
407 multiplex_tx.send(response).await?;
408 }
409 }
410 Ok::<(), Error>(())
411 }
412 };
413
414 async {
415 if let Err(e) = futures::future::try_join(forward_fut, handle_fut).await {
416 log::error!("Websocket error: {}", e);
417 }
418 }
419 })
420 }
421
422 async fn handle_raw_request(
433 inner: Arc<Inner<D>>,
434 request: &Message,
435 tx: Option<&mpsc::Sender<Message>>,
436 frame_type: Option<FrameType>,
437 ) -> Option<Message> {
438 match serde_json::from_slice(request.clone().into_data().as_ref()) {
439 Ok(request) => Self::handle_request(inner, request, tx, frame_type).await,
440 Err(_e) => {
441 log::error!("Received invalid JSON from client");
442 Some(SingleOrBatch::Single(Response::new_error(
443 Value::Null,
444 RpcError::invalid_request(Some("Received invalid JSON".to_owned())),
445 )))
446 }
447 }
448 .map(|response| {
449 if matches!(&request, Message::Text(_)) {
450 Message::text(
451 serde_json::to_string(&response)
452 .expect("Failed to serialize JSON RPC response"),
453 )
454 } else {
455 Message::binary(
456 serde_json::to_vec(&response).expect("Failed to serialize JSON RPC response"),
457 )
458 }
459 })
460 }
461
462 async fn handle_request(
473 inner: Arc<Inner<D>>,
474 request: SingleOrBatch<Request>,
475 tx: Option<&mpsc::Sender<Message>>,
476 frame_type: Option<FrameType>,
477 ) -> Option<SingleOrBatch<Response>> {
478 match request {
479 SingleOrBatch::Single(request) => {
480 Self::handle_single_request(inner, request, tx, frame_type)
481 .await
482 .map(|(response, _)| SingleOrBatch::Single(response))
483 }
484
485 SingleOrBatch::Batch(requests) => {
486 let futures = requests
487 .into_iter()
488 .map(|request| {
489 Self::handle_single_request(Arc::clone(&inner), request, tx, frame_type)
490 })
491 .collect::<FuturesUnordered<_>>();
492
493 let responses = futures
494 .filter_map(|response_opt| async { response_opt.map(|(response, _)| response) })
495 .collect::<Vec<Response>>()
496 .await;
497
498 Some(SingleOrBatch::Batch(responses))
499 }
500 }
501 }
502
503 async fn handle_single_request(
505 inner: Arc<Inner<D>>,
506 request: Request,
507 tx: Option<&mpsc::Sender<Message>>,
508 frame_type: Option<FrameType>,
509 ) -> Option<ResponseAndSubScriptionNotifier> {
510 if request.method == "unsubscribe" {
511 return Self::handle_unsubscribe_stream(request, inner).await;
512 }
513
514 let mut dispatcher = inner.dispatcher.write().await;
515 let id = inner.next_id.fetch_add(1, Ordering::SeqCst);
517
518 log::debug!("request: {:#?}", request);
519
520 let response = dispatcher.dispatch(request, tx, id, frame_type).await;
521
522 log::debug!("response: {:#?}", response);
523
524 if let Some((_, Some(ref handler))) = response {
525 inner
526 .subscription_notifiers
527 .write()
528 .await
529 .insert(SubscriptionId::Number(id), handler.clone());
530 }
531
532 response
533 }
534
535 async fn handle_unsubscribe_stream(
536 request: Request,
537 inner: Arc<Inner<D>>,
538 ) -> Option<ResponseAndSubScriptionNotifier> {
539 let params = if let Some(params) = request.params {
540 params
541 } else {
542 return error_response(request.id, || {
543 RpcError::invalid_request(Some(
544 "Missing request parameter containing a list of subscription ids".to_owned(),
545 ))
546 });
547 };
548
549 let subscription_ids =
550 if let Ok(ids) = serde_json::from_value::<Vec<SubscriptionId>>(params) {
551 ids
552 } else {
553 return error_response(request.id, || {
554 RpcError::invalid_params(Some(
555 "A list of subscription ids is not provided".to_owned(),
556 ))
557 });
558 };
559
560 if subscription_ids.is_empty() {
561 return error_response(request.id, || {
562 RpcError::invalid_params(Some("Empty list of subscription ids provided".to_owned()))
563 });
564 }
565
566 let mut terminated_streams = vec![];
567 let mut subscription_notifiers = inner.subscription_notifiers.write().await;
568 for id in subscription_ids.iter() {
569 if let Some(notifier) = subscription_notifiers.remove(id) {
570 notifier.notify_one();
571 terminated_streams.push(id);
572 }
573 }
574
575 Some((
576 Response::new_success(
577 serde_json::to_value(request.id.unwrap_or_default()).unwrap(),
578 serde_json::to_value(terminated_streams).unwrap(),
579 ),
580 None,
581 ))
582 }
583}
584
585#[async_trait]
588pub trait Dispatcher: Send + Sync + 'static {
589 async fn dispatch(
591 &mut self,
592 request: Request,
593 tx: Option<&mpsc::Sender<Message>>,
594 id: u64,
595 frame_type: Option<FrameType>,
596 ) -> Option<ResponseAndSubScriptionNotifier>;
597
598 fn match_method(&self, _name: &str) -> bool {
609 true
610 }
611
612 fn method_names(&self) -> Vec<&str>;
614}
615
616#[derive(Default)]
618pub struct ModularDispatcher {
619 dispatchers: Vec<Box<dyn Dispatcher>>,
620}
621
622impl ModularDispatcher {
623 pub fn add<D: Dispatcher>(&mut self, dispatcher: D) {
625 self.dispatchers.push(Box::new(dispatcher));
626 }
627}
628
629#[async_trait]
630impl Dispatcher for ModularDispatcher {
631 async fn dispatch(
632 &mut self,
633 request: Request,
634 tx: Option<&mpsc::Sender<Message>>,
635 id: u64,
636 frame_type: Option<FrameType>,
637 ) -> Option<ResponseAndSubScriptionNotifier> {
638 for dispatcher in &mut self.dispatchers {
639 let m = dispatcher.match_method(&request.method);
640 log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
641 log::debug!("Methods: {:?}", dispatcher.method_names());
642 if m {
643 return dispatcher.dispatch(request, tx, id, frame_type).await;
644 }
645 }
646
647 method_not_found(request)
648 }
649
650 fn method_names(&self) -> Vec<&str> {
651 self.dispatchers
652 .iter()
653 .flat_map(|dispatcher| dispatcher.method_names())
654 .collect()
655 }
656}
657
658pub struct AllowListDispatcher<D>
660where
661 D: Dispatcher,
662{
663 pub inner: D,
665
666 pub method_allowlist: Option<HashSet<String>>,
668}
669
670impl<D> AllowListDispatcher<D>
671where
672 D: Dispatcher,
673{
674 pub fn new(inner: D, method_allowlist: Option<HashSet<String>>) -> Self {
682 Self {
683 inner,
684 method_allowlist,
685 }
686 }
687
688 fn is_allowed(&self, method: &str) -> bool {
689 self.method_allowlist
690 .as_ref()
691 .map(|method_allowlist| method_allowlist.contains(method))
692 .unwrap_or(true)
693 }
694}
695
696#[async_trait]
697impl<D> Dispatcher for AllowListDispatcher<D>
698where
699 D: Dispatcher,
700{
701 async fn dispatch(
702 &mut self,
703 request: Request,
704 tx: Option<&mpsc::Sender<Message>>,
705 id: u64,
706 frame_type: Option<FrameType>,
707 ) -> Option<ResponseAndSubScriptionNotifier> {
708 if self.is_allowed(&request.method) {
709 log::debug!("Dispatching method: {}", request.method);
710 self.inner.dispatch(request, tx, id, frame_type).await
711 } else {
712 log::debug!("Method not allowed: {}", request.method);
713 method_not_found(request)
715 }
716 }
717
718 fn match_method(&self, name: &str) -> bool {
719 if !self.is_allowed(name) {
720 log::debug!("Method not allowed: {}", name);
721 false
722 } else {
723 true
724 }
725 }
726
727 fn method_names(&self) -> Vec<&str> {
728 self.inner
729 .method_names()
730 .into_iter()
731 .filter(|method_name| self.is_allowed(method_name))
732 .collect()
733 }
734}
735
736pub async fn dispatch_method_with_args<P, R, E, F, Fut>(
746 request: Request,
747 f: F,
748) -> Option<ResponseAndSubScriptionNotifier>
749where
750 P: for<'de> Deserialize<'de> + Send,
751 R: Serialize,
752 RpcError: From<E>,
753 F: FnOnce(P) -> Fut + Send,
754 Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
755{
756 let params = match request.params {
757 Some(params) => params,
758 None => Value::Array(Vec::new()),
759 };
760
761 let params = match serde_json::from_value(params) {
762 Ok(params) => params,
763 Err(e) => {
764 log::error!("{}", e);
765 return error_response(request.id, || RpcError::invalid_params(Some(e.to_string())));
766 }
767 };
768
769 let result = f(params).await;
770
771 response(request.id, result)
772}
773
774pub async fn dispatch_method_without_args<R, E, F, Fut>(
779 request: Request,
780 f: F,
781) -> Option<ResponseAndSubScriptionNotifier>
782where
783 R: Serialize,
784 RpcError: From<E>,
785 F: FnOnce() -> Fut + Send,
786 Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
787{
788 let result = f().await;
789
790 match request.params {
791 Some(Value::Null) | None => {}
792 Some(Value::Array(a)) if a.is_empty() => {}
793 Some(Value::Object(o)) if o.is_empty() => {}
794 _ => {
795 return error_response(request.id, || {
796 RpcError::invalid_params(Some("Didn't expect any request parameters".to_owned()))
797 })
798 }
799 }
800
801 response(request.id, result)
802}
803
804fn response<R, E>(
806 id_opt: Option<Value>,
807 result: Result<(R, Option<Arc<Notify>>), E>,
808) -> Option<ResponseAndSubScriptionNotifier>
809where
810 R: Serialize,
811 RpcError: From<E>,
812{
813 let response = match (id_opt, result) {
814 (Some(id), Ok((value, subscription))) => {
815 let retval = serde_json::to_value(value).expect("Failed to serialize return value");
816 Some((Response::new_success(id, retval), subscription))
817 }
818 (Some(id), Err(e)) => Some((Response::new_error(id, RpcError::from(e)), None)),
819 (None, _) => None,
820 };
821
822 log::debug!("Sending response: {:?}", response);
823
824 response
825}
826
827pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<ResponseAndSubScriptionNotifier>
835where
836 E: FnOnce() -> RpcError,
837{
838 if let Some(id) = id_opt {
839 let e = e();
840 log::error!("Error response: {:?}", e);
841 Some((Response::new_error(id, e), None))
842 } else {
843 None
844 }
845}
846
847pub fn method_not_found(request: Request) -> Option<ResponseAndSubScriptionNotifier> {
850 let ::nimiq_jsonrpc_core::Request { id, method, .. } = request;
851
852 error_response(id, || {
853 RpcError::method_not_found(Some(format!("Method does not exist: {}", method)))
854 })
855}
856
857async fn forward_notification<T>(
858 item: T,
859 tx: &mut mpsc::Sender<Message>,
860 id: &SubscriptionId,
861 method: &str,
862 frame_type: Option<FrameType>,
863) -> Result<(), Error>
864where
865 T: Serialize + Debug + Send + Sync,
866{
867 let message = SubscriptionMessage {
868 subscription: id.clone(),
869 result: item,
870 };
871
872 let notification = Request::build::<_, ()>(method.to_owned(), Some(&message), None)?;
873
874 log::debug!("Sending notification: {:?}", notification);
875
876 let message = match frame_type {
877 Some(FrameType::Text) => Message::text(serde_json::to_string(¬ification)?),
878 Some(FrameType::Binary) | None => Message::binary(serde_json::to_vec(¬ification)?),
879 };
880
881 tx.send(message).await?;
882
883 Ok(())
884}
885
886pub fn connect_stream<T, S>(
900 stream: S,
901 tx: &mpsc::Sender<Message>,
902 stream_id: u64,
903 method: String,
904 notify_handler: Arc<Notify>,
905 frame_type: Option<FrameType>,
906) -> SubscriptionId
907where
908 T: Serialize + Debug + Send + Sync,
909 S: Stream<Item = T> + Send + 'static,
910{
911 let mut tx = tx.clone();
912 let id: SubscriptionId = stream_id.into();
913
914 {
915 let id = id.clone();
916 tokio::spawn(async move {
917 pin_mut!(stream);
918
919 let notify_future = notify_handler.notified();
920 pin_mut!(notify_future);
921
922 loop {
923 tokio::select! {
924 item = stream.next() => {
925 match item {
926 Some(notification) => {
927 if let Err(error) = forward_notification(notification, &mut tx, &id, &method, frame_type).await {
928 if let Error::Mpsc(_) = error {
930 break;
931 }
932
933 log::error!("{}", error);
934 }
935 },
936 None => break,
937 }
938 }
939 _ = &mut notify_future => {
940 break;
942 }
943 }
944 }
945 });
946 }
947
948 id
949}