1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::webrtc::types::{
21 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
22};
23use hex::encode as hex_encode;
24
25#[derive(Debug, Deserialize)]
26#[serde(tag = "type")]
27enum WsClientMessage {
28 #[serde(rename = "req")]
29 Request { id: u32, hash: String },
30 #[serde(rename = "res")]
31 Response { id: u32, hash: String, found: bool },
32}
33
34#[derive(Debug)]
35enum WsTextMessage {
36 Hashtree(WsClientMessage),
37 Nostr(NostrClientMessage),
38}
39
40#[derive(Debug, Deserialize, Serialize)]
41struct WsRequest {
42 #[serde(rename = "type")]
43 kind: String,
44 id: u32,
45 hash: String,
46}
47
48#[derive(Debug, Serialize)]
49struct WsResponse {
50 #[serde(rename = "type")]
51 kind: &'static str,
52 id: u32,
53 hash: String,
54 found: bool,
55}
56
57pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
58 ws_data_with_client_pubkey(state, ws, None)
59}
60
61pub fn ws_data_with_client_pubkey(
62 state: AppState,
63 ws: WebSocketUpgrade,
64 client_pubkey: Option<String>,
65) -> impl IntoResponse {
66 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
67}
68
69async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
70 let client_id = state
73 .nostr_relay
74 .as_ref()
75 .map(|relay| relay.next_client_id())
76 .unwrap_or_else(|| state.ws_relay.next_id());
77 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
78
79 {
80 let mut clients = state.ws_relay.clients.lock().await;
81 clients.insert(client_id, tx);
82 }
83 {
84 let mut protocols = state.ws_relay.client_protocols.lock().await;
85 protocols.insert(client_id, WsProtocol::HashtreeJson);
86 }
87
88 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
89 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
90 relay
91 .register_client(client_id, nostr_tx, client_pubkey.clone())
92 .await;
93 Some(nostr_rx)
94 } else {
95 None
96 };
97
98 let (mut sender, mut receiver) = socket.split();
99 loop {
100 tokio::select! {
101 maybe_msg = rx.recv() => {
102 let Some(msg) = maybe_msg else {
103 break;
104 };
105 if sender.send(msg).await.is_err() {
106 break;
107 }
108 }
109 maybe_text = async {
110 match &mut nostr_rx {
111 Some(rx) => rx.recv().await,
112 None => std::future::pending().await,
113 }
114 } => {
115 let Some(text) = maybe_text else {
116 nostr_rx = None;
117 continue;
118 };
119 if sender.send(Message::Text(text)).await.is_err() {
120 break;
121 }
122 }
123 maybe_incoming = receiver.next() => {
124 match maybe_incoming {
125 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
126 Some(Err(_)) | None => break,
127 }
128 }
129 }
130 }
131
132 close_all_upstream_nostr_subscriptions(&state, client_id).await;
133
134 {
135 let mut clients = state.ws_relay.clients.lock().await;
136 clients.remove(&client_id);
137 }
138 {
139 let mut protocols = state.ws_relay.client_protocols.lock().await;
140 protocols.remove(&client_id);
141 }
142 {
143 let mut pending = state.ws_relay.pending.lock().await;
144 pending.retain(|(peer_id, _), _| *peer_id != client_id);
145 }
146
147 if let Some(relay) = &state.nostr_relay {
148 relay.unregister_client(client_id).await;
149 }
150}
151
152fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
153 let trimmed = text.trim_start();
154 if trimmed.starts_with('[') {
155 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
156 return Some(WsTextMessage::Nostr(msg));
157 }
158 }
159
160 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
161 return Some(WsTextMessage::Hashtree(msg));
162 }
163
164 None
165}
166async fn close_upstream_nostr_subscription(
167 state: &AppState,
168 client_id: u64,
169 subscription_id: &SubscriptionId,
170) {
171 let key = (client_id, subscription_id.to_string());
172 let subscription = {
173 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
174 subscriptions.remove(&key)
175 };
176 if let Some(subscription) = subscription {
177 let _ = subscription.close_tx.send(true);
178 for task in subscription.tasks {
179 task.abort();
180 }
181 }
182 state
183 .ws_relay
184 .upstream_pending_eose
185 .lock()
186 .await
187 .remove(&key);
188 state
189 .ws_relay
190 .upstream_seen_events
191 .lock()
192 .await
193 .remove(&key);
194}
195
196async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
197 let keys = {
198 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
199 subscriptions
200 .keys()
201 .filter(|(id, _)| *id == client_id)
202 .cloned()
203 .collect::<Vec<_>>()
204 };
205 for (_, sub_id) in keys {
206 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
207 }
208}
209
210async fn forward_upstream_nostr_message(
211 state: &AppState,
212 client_id: u64,
213 subscription_id: &SubscriptionId,
214 text: &str,
215) {
216 let Ok(message) = NostrRelayMessage::from_json(text) else {
217 return;
218 };
219
220 match message {
221 NostrRelayMessage::Event {
222 subscription_id: sid,
223 event,
224 } if sid == *subscription_id => {
225 let event = *event;
226 let key = (client_id, subscription_id.to_string());
227 let event_id = event.id.to_hex();
228 let inserted = {
229 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
230 seen_events.entry(key).or_default().insert(event_id)
231 };
232 if !inserted {
233 return;
234 }
235 if let Some(relay) = &state.nostr_relay {
236 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
237 }
238 send_nostr(
239 state,
240 client_id,
241 NostrRelayMessage::event(subscription_id.clone(), event),
242 )
243 .await;
244 }
245 NostrRelayMessage::Closed {
246 subscription_id: sid,
247 message,
248 } if sid == *subscription_id => {
249 send_nostr(
250 state,
251 client_id,
252 NostrRelayMessage::closed(subscription_id.clone(), message),
253 )
254 .await;
255 }
256 _ => {}
257 }
258}
259
260async fn mark_upstream_nostr_relay_complete(
261 state: &AppState,
262 client_id: u64,
263 subscription_id: &SubscriptionId,
264) {
265 let key = (client_id, subscription_id.to_string());
266 let should_send_eose = {
267 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
268 let Some(remaining) = pending.get_mut(&key) else {
269 return;
270 };
271 if *remaining > 0 {
272 *remaining -= 1;
273 }
274 if *remaining == 0 {
275 pending.remove(&key);
276 true
277 } else {
278 false
279 }
280 };
281
282 if should_send_eose {
283 send_nostr(
284 state,
285 client_id,
286 NostrRelayMessage::eose(subscription_id.clone()),
287 )
288 .await;
289 }
290}
291
292async fn run_upstream_nostr_subscription(
293 state: AppState,
294 client_id: u64,
295 relay_url: String,
296 subscription_id: SubscriptionId,
297 filters: Vec<NostrFilter>,
298 mut close_rx: watch::Receiver<bool>,
299) {
300 let mut relay_complete = false;
301 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
302 tracing::warn!(
303 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
304 client_id,
305 subscription_id,
306 relay_url,
307 );
308 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
309 return;
310 };
311 let (mut write, mut read) = socket.split();
312 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
313 state
314 .ws_relay
315 .note_upstream_relay_send(request.as_bytes().len());
316 if write
317 .send(TungsteniteMessage::Text(request.into()))
318 .await
319 .is_err()
320 {
321 tracing::warn!(
322 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
323 client_id,
324 subscription_id,
325 relay_url,
326 );
327 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
328 return;
329 }
330
331 loop {
332 tokio::select! {
333 _ = close_rx.changed() => {
334 if *close_rx.borrow() {
335 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
336 state
337 .ws_relay
338 .note_upstream_relay_send(close.as_bytes().len());
339 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
340 let _ = write.close().await;
341 break;
342 }
343 }
344 message = read.next() => {
345 match message {
346 Some(Ok(TungsteniteMessage::Text(text))) => {
347 state
348 .ws_relay
349 .note_upstream_relay_receive(text.as_bytes().len());
350 if matches!(
351 NostrRelayMessage::from_json(text.as_str()),
352 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
353 ) {
354 if !relay_complete {
355 relay_complete = true;
356 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
357 }
358 continue;
359 }
360 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
361 }
362 Some(Ok(TungsteniteMessage::Ping(payload))) => {
363 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
364 }
365 Some(Ok(TungsteniteMessage::Close(_))) | None => {
366 if !relay_complete {
367 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
368 }
369 break;
370 }
371 Some(Err(_)) => {
372 if !relay_complete {
373 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
374 }
375 break;
376 }
377 _ => {}
378 }
379 }
380 }
381 }
382}
383
384async fn start_upstream_nostr_subscription(
385 state: &AppState,
386 client_id: u64,
387 subscription_id: SubscriptionId,
388 filters: Vec<NostrFilter>,
389) -> usize {
390 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
391 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
392 tracing::info!(
393 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
394 client_id,
395 subscription_id,
396 state.nostr_relay_urls.len(),
397 filters.len(),
398 );
399 return 0;
400 }
401
402 let mut relay_urls = Vec::new();
403 let mut seen = HashSet::new();
404 for relay in &state.nostr_relay_urls {
405 let relay = relay.trim();
406 if relay.is_empty() || !seen.insert(relay.to_string()) {
407 continue;
408 }
409 relay_urls.push(relay.to_string());
410 }
411 if relay_urls.is_empty() {
412 tracing::info!(
413 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
414 client_id,
415 subscription_id,
416 );
417 return 0;
418 }
419
420 tracing::info!(
421 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
422 client_id,
423 subscription_id,
424 relay_urls.len(),
425 );
426
427 let key = (client_id, subscription_id.to_string());
428 state
429 .ws_relay
430 .upstream_seen_events
431 .lock()
432 .await
433 .insert(key.clone(), HashSet::new());
434 state
435 .ws_relay
436 .upstream_pending_eose
437 .lock()
438 .await
439 .insert(key.clone(), relay_urls.len());
440
441 let (close_tx, close_rx) = watch::channel(false);
442 let mut tasks = Vec::new();
443 for relay_url in &relay_urls {
444 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
445 state.clone(),
446 client_id,
447 relay_url.clone(),
448 subscription_id.clone(),
449 filters.clone(),
450 close_rx.clone(),
451 )));
452 }
453
454 state
455 .ws_relay
456 .upstream_nostr_subscriptions
457 .lock()
458 .await
459 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
460 relay_urls.len()
461}
462
463async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
464 match msg {
465 Message::Text(text) => {
466 if let Some(msg) = parse_ws_text_message(&text) {
467 match msg {
468 WsTextMessage::Hashtree(msg) => {
469 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
470 match msg {
471 WsClientMessage::Request { id, hash } => {
472 handle_request(
473 client_id,
474 id,
475 hash,
476 WsProtocol::HashtreeJson,
477 state,
478 )
479 .await;
480 }
481 WsClientMessage::Response { id, hash, found } => {
482 handle_response(client_id, id, hash, found, state).await;
483 }
484 }
485 }
486 WsTextMessage::Nostr(msg) => {
487 if let Some(relay) = &state.nostr_relay {
488 match msg {
489 NostrClientMessage::Req {
490 subscription_id,
491 filters,
492 } => {
493 let local_events = match relay
494 .register_subscription_query(
495 client_id,
496 subscription_id.clone(),
497 filters.clone(),
498 )
499 .await
500 {
501 Ok(events) => events,
502 Err(message) => {
503 send_nostr(
504 state,
505 client_id,
506 NostrRelayMessage::closed(subscription_id, message),
507 )
508 .await;
509 return;
510 }
511 };
512
513 let upstream_relays = start_upstream_nostr_subscription(
514 state,
515 client_id,
516 subscription_id.clone(),
517 filters,
518 )
519 .await;
520 if upstream_relays > 0 {
521 let key = (client_id, subscription_id.to_string());
522 let mut seen_events =
523 state.ws_relay.upstream_seen_events.lock().await;
524 seen_events.entry(key).or_default().extend(
525 local_events.iter().map(|event| event.id.to_hex()),
526 );
527 }
528 for event in local_events {
529 send_nostr(
530 state,
531 client_id,
532 NostrRelayMessage::event(
533 subscription_id.clone(),
534 event,
535 ),
536 )
537 .await;
538 }
539 if upstream_relays == 0 {
540 send_nostr(
541 state,
542 client_id,
543 NostrRelayMessage::eose(subscription_id),
544 )
545 .await;
546 }
547 }
548 NostrClientMessage::Close(subscription_id) => {
549 close_upstream_nostr_subscription(
550 state,
551 client_id,
552 &subscription_id,
553 )
554 .await;
555 relay
556 .handle_client_message(
557 client_id,
558 NostrClientMessage::Close(subscription_id.clone()),
559 )
560 .await;
561 }
562 other => {
563 relay.handle_client_message(client_id, other).await;
564 }
565 }
566 } else {
567 handle_nostr_message(client_id, msg, state).await;
568 }
569 }
570 }
571 }
572 }
573 Message::Binary(data) => {
574 handle_binary(client_id, data, state).await;
575 }
576 Message::Close(_) => {}
577 _ => {}
578 }
579}
580
581async fn handle_request(
582 client_id: u64,
583 request_id: u32,
584 hash: String,
585 origin_protocol: WsProtocol,
586 state: &AppState,
587) {
588 let hash_hex = hash.to_lowercase();
589 let hash_bytes = match from_hex(&hash_hex) {
590 Ok(bytes) => bytes,
591 Err(_) => {
592 if origin_protocol == WsProtocol::HashtreeJson {
593 send_json(
594 state,
595 client_id,
596 WsResponse {
597 kind: "res",
598 id: request_id,
599 hash,
600 found: false,
601 },
602 )
603 .await;
604 }
605 return;
606 }
607 };
608
609 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
610 match origin_protocol {
611 WsProtocol::HashtreeJson => {
612 send_json(
613 state,
614 client_id,
615 WsResponse {
616 kind: "res",
617 id: request_id,
618 hash: hash.clone(),
619 found: true,
620 },
621 )
622 .await;
623 send_binary(state, client_id, request_id, data).await;
624 }
625 WsProtocol::HashtreeMsgpack => {
626 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
627 }
628 WsProtocol::Unknown => {}
629 }
630 return;
631 }
632
633 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
634 let clients = state.ws_relay.clients.lock().await;
635 let protocols = state.ws_relay.client_protocols.lock().await;
636 clients
637 .iter()
638 .filter(|(id, _)| **id != client_id)
639 .filter_map(|(id, tx)| {
640 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
641 match protocol {
642 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
643 Some((*id, tx.clone(), protocol))
644 }
645 WsProtocol::Unknown => None,
646 }
647 })
648 .collect()
649 };
650
651 if peers.is_empty() {
652 if origin_protocol == WsProtocol::HashtreeJson {
653 send_json(
654 state,
655 client_id,
656 WsResponse {
657 kind: "res",
658 id: request_id,
659 hash,
660 found: false,
661 },
662 )
663 .await;
664 }
665 return;
666 }
667
668 {
669 let mut pending = state.ws_relay.pending.lock().await;
670 for (peer_id, _, _) in &peers {
671 pending.insert(
672 (*peer_id, request_id),
673 PendingRequest {
674 origin_id: client_id,
675 hash: hash.clone(),
676 found: false,
677 origin_protocol,
678 },
679 );
680 }
681 }
682
683 let request_text = serde_json::to_string(&WsRequest {
684 kind: "req".to_string(),
685 id: request_id,
686 hash: hash.clone(),
687 })
688 .unwrap_or_else(|_| String::new());
689 for (peer_id, tx, protocol) in peers {
690 match protocol {
691 WsProtocol::HashtreeMsgpack => {
692 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
693 }
694 WsProtocol::HashtreeJson => {
695 let _ = tx.send(Message::Text(request_text.clone()));
696 }
697 WsProtocol::Unknown => {}
698 }
699 }
700
701 let timeout_state = state.clone();
702 let timeout_hash = hash.clone();
703 tokio::spawn(async move {
704 tokio::time::sleep(Duration::from_millis(1500)).await;
705 let mut pending = timeout_state.ws_relay.pending.lock().await;
706 let still_pending = pending
707 .iter()
708 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
709 let already_found = pending
710 .iter()
711 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
712 if !still_pending || already_found {
713 return;
714 }
715 let origin_protocol = pending
716 .iter()
717 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
718 .map(|(_, p)| p.origin_protocol)
719 .unwrap_or(WsProtocol::HashtreeJson);
720 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
721 drop(pending);
722 if origin_protocol == WsProtocol::HashtreeJson {
723 send_json(
724 &timeout_state,
725 client_id,
726 WsResponse {
727 kind: "res",
728 id: request_id,
729 hash: timeout_hash,
730 found: false,
731 },
732 )
733 .await;
734 }
735 });
736}
737
738async fn handle_response(
739 client_id: u64,
740 request_id: u32,
741 _hash: String,
742 found: bool,
743 state: &AppState,
744) {
745 let pending_entry = {
746 let pending = state.ws_relay.pending.lock().await;
747 pending
748 .get(&(client_id, request_id))
749 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
750 };
751
752 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
753 return;
754 };
755
756 if already_found && !found {
757 let mut pending = state.ws_relay.pending.lock().await;
758 pending.remove(&(client_id, request_id));
759 return;
760 }
761
762 if found {
763 let mut pending = state.ws_relay.pending.lock().await;
764 for ((_, id), p) in pending.iter_mut() {
765 if *id == request_id && p.origin_id == origin_id {
766 p.found = true;
767 }
768 }
769 drop(pending);
770 if origin_protocol == WsProtocol::HashtreeJson {
771 send_json(
772 state,
773 origin_id,
774 WsResponse {
775 kind: "res",
776 id: request_id,
777 hash: pending_hash,
778 found: true,
779 },
780 )
781 .await;
782 }
783 return;
784 }
785
786 let mut pending = state.ws_relay.pending.lock().await;
787 pending.remove(&(client_id, request_id));
788 let has_remaining = pending
789 .iter()
790 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
791 drop(pending);
792
793 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
794 send_json(
795 state,
796 origin_id,
797 WsResponse {
798 kind: "res",
799 id: request_id,
800 hash: pending_hash,
801 found: false,
802 },
803 )
804 .await;
805 }
806}
807
808async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
809 if let Some(msg) = parse_msgpack_message(&data) {
810 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
811 match msg {
812 DataMessage::Request(req) => {
813 let hash_hex = hex_encode(&req.h);
814 let request_id = state.ws_relay.next_request_id();
815 handle_request(
816 client_id,
817 request_id,
818 hash_hex,
819 WsProtocol::HashtreeMsgpack,
820 state,
821 )
822 .await;
823 }
824 DataMessage::Response(res) => {
825 handle_msgpack_response(client_id, res, state).await;
826 }
827 DataMessage::QuoteRequest(_)
828 | DataMessage::QuoteResponse(_)
829 | DataMessage::Payment(_)
830 | DataMessage::PaymentAck(_)
831 | DataMessage::Chunk(_)
832 | DataMessage::PeerHints(_) => {}
833 }
834 return;
835 }
836
837 if data.len() < 4 {
839 return;
840 }
841 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
842 let pending_entry = {
843 let pending = state.ws_relay.pending.lock().await;
844 pending
845 .get(&(client_id, request_id))
846 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
847 };
848 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
849 return;
850 };
851
852 match origin_protocol {
853 WsProtocol::HashtreeJson => {
854 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
855 }
856 WsProtocol::HashtreeMsgpack => {
857 let Ok(hash_bytes) = from_hex(&hash_hex) else {
858 return;
859 };
860 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
861 }
862 WsProtocol::Unknown => {}
863 }
864
865 let mut pending = state.ws_relay.pending.lock().await;
866 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
867}
868
869async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
870 let replies = nostr_responses_for(&msg);
871 for reply in replies {
872 send_nostr(state, client_id, reply).await;
873 }
874}
875
876fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
877 match msg {
878 NostrClientMessage::Event(event) => {
879 let ok = event.verify().is_ok();
880 let message = if ok { "" } else { "invalid: signature" };
881 vec![NostrRelayMessage::ok(event.id, ok, message)]
882 }
883 NostrClientMessage::Req {
884 subscription_id, ..
885 } => {
886 vec![NostrRelayMessage::eose(subscription_id.clone())]
887 }
888 NostrClientMessage::Count {
889 subscription_id, ..
890 } => {
891 vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
892 }
893 NostrClientMessage::Close(_) => Vec::new(),
894 NostrClientMessage::Auth(event) => {
895 let ok = event.verify().is_ok();
896 let message = if ok { "" } else { "invalid auth" };
897 vec![NostrRelayMessage::ok(event.id, ok, message)]
898 }
899 NostrClientMessage::NegOpen { .. }
900 | NostrClientMessage::NegMsg { .. }
901 | NostrClientMessage::NegClose { .. } => {
902 vec![NostrRelayMessage::notice("negentropy not supported")]
903 }
904 }
905}
906
907async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
908 let text = response.as_json();
909 send_to_client(state, client_id, Message::Text(text)).await;
910}
911
912fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
913 let msg = parse_message(data).ok()?;
914 match msg {
915 DataMessage::Request(req) => {
916 if req.h.len() == 32 {
917 Some(DataMessage::Request(req))
918 } else {
919 None
920 }
921 }
922 DataMessage::Response(res) => {
923 if res.h.len() == 32 {
924 Some(DataMessage::Response(res))
925 } else {
926 None
927 }
928 }
929 DataMessage::QuoteRequest(req) => {
930 if req.h.len() == 32 {
931 Some(DataMessage::QuoteRequest(req))
932 } else {
933 None
934 }
935 }
936 DataMessage::QuoteResponse(res) => {
937 if res.h.len() == 32 {
938 Some(DataMessage::QuoteResponse(res))
939 } else {
940 None
941 }
942 }
943 DataMessage::Payment(req) => {
944 if req.h.len() == 32 {
945 Some(DataMessage::Payment(req))
946 } else {
947 None
948 }
949 }
950 DataMessage::PaymentAck(res) => {
951 if res.h.len() == 32 {
952 Some(DataMessage::PaymentAck(res))
953 } else {
954 None
955 }
956 }
957 DataMessage::Chunk(chunk) => {
958 if chunk.h.len() == 32 {
959 Some(DataMessage::Chunk(chunk))
960 } else {
961 None
962 }
963 }
964 DataMessage::PeerHints(_) => None,
965 }
966}
967
968async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
969 let hash_hex = hex_encode(&res.h);
970 let data = res.d.clone();
971 let hash_bytes = res.h.clone();
972
973 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
974 let mut seen = HashSet::new();
975 {
976 let pending = state.ws_relay.pending.lock().await;
977 for ((peer_id, request_id), p) in pending.iter() {
978 if *peer_id != client_id {
979 continue;
980 }
981 if p.hash != hash_hex {
982 continue;
983 }
984 if seen.insert((p.origin_id, *request_id)) {
985 responses.push((p.origin_id, *request_id, p.origin_protocol));
986 }
987 }
988 }
989
990 if responses.is_empty() {
991 return;
992 }
993
994 for (origin_id, request_id, protocol) in &responses {
995 match protocol {
996 WsProtocol::HashtreeJson => {
997 send_json(
998 state,
999 *origin_id,
1000 WsResponse {
1001 kind: "res",
1002 id: *request_id,
1003 hash: hash_hex.clone(),
1004 found: true,
1005 },
1006 )
1007 .await;
1008 send_binary(state, *origin_id, *request_id, data.clone()).await;
1009 }
1010 WsProtocol::HashtreeMsgpack => {
1011 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1012 }
1013 WsProtocol::Unknown => {}
1014 }
1015 }
1016
1017 let completed: HashSet<(u64, u32)> = responses
1018 .into_iter()
1019 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1020 .collect();
1021 let mut pending = state.ws_relay.pending.lock().await;
1022 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1023}
1024
1025async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1026 if let Ok(text) = serde_json::to_string(&response) {
1027 send_to_client(state, client_id, Message::Text(text)).await;
1028 }
1029}
1030
1031async fn send_msgpack_request(
1032 state: &AppState,
1033 client_id: u64,
1034 hash: &[u8],
1035) -> Result<(), rmp_serde::encode::Error> {
1036 let req = DataRequest {
1037 h: hash.to_vec(),
1038 htl: MAX_HTL,
1039 q: None,
1040 };
1041 let wire = encode_request(&req)?;
1042 send_to_client(state, client_id, Message::Binary(wire)).await;
1043 Ok(())
1044}
1045
1046async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1047 let res = DataResponse {
1048 h: hash.to_vec(),
1049 d: data.to_vec(),
1050 i: None,
1051 n: None,
1052 };
1053 if let Ok(wire) = encode_response(&res) {
1054 send_to_client(state, client_id, Message::Binary(wire)).await;
1055 }
1056}
1057
1058async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1059 let mut packet = Vec::with_capacity(4 + payload.len());
1060 packet.extend_from_slice(&request_id.to_le_bytes());
1061 packet.extend_from_slice(&payload);
1062 send_to_client(state, client_id, Message::Binary(packet)).await;
1063}
1064
1065async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1066 let sender = {
1067 let clients = state.ws_relay.clients.lock().await;
1068 clients.get(&client_id).cloned()
1069 };
1070 if let Some(tx) = sender {
1071 let _ = tx.send(msg);
1072 }
1073}
1074
1075async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1076 let mut protocols = state.ws_relay.client_protocols.lock().await;
1077 protocols.insert(client_id, protocol);
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082 use super::*;
1083 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1084 use anyhow::Result;
1085 use futures::{SinkExt, StreamExt};
1086 use nostr::secp256k1::schnorr::Signature;
1087 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1088 use std::collections::HashSet;
1089 use std::sync::Arc;
1090 use tempfile::TempDir;
1091 use tokio::net::TcpListener;
1092 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1093
1094 #[test]
1095 fn parse_ws_text_message_detects_nostr_req() {
1096 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1097 match parse_ws_text_message(msg) {
1098 Some(WsTextMessage::Nostr(_)) => {}
1099 other => panic!("expected Nostr message, got {:?}", other),
1100 }
1101 }
1102
1103 #[test]
1104 fn parse_ws_text_message_detects_hashtree_request() {
1105 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1106 match parse_ws_text_message(msg) {
1107 Some(WsTextMessage::Hashtree(_)) => {}
1108 other => panic!("expected Hashtree message, got {:?}", other),
1109 }
1110 }
1111
1112 #[test]
1113 fn nostr_replies_for_req_is_eose() {
1114 let sub = SubscriptionId::new("sub-1");
1115 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1116 let replies = nostr_responses_for(&msg);
1117 assert_eq!(replies.len(), 1);
1118 match &replies[0] {
1119 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1120 other => panic!("expected EOSE, got {:?}", other),
1121 }
1122 }
1123
1124 #[test]
1125 fn nostr_replies_for_event_ok() {
1126 let keys = Keys::generate();
1127 let event = EventBuilder::new(Kind::TextNote, "hello", [])
1128 .to_event(&keys)
1129 .unwrap();
1130 let msg = NostrClientMessage::event(event.clone());
1131 let replies = nostr_responses_for(&msg);
1132 assert_eq!(replies.len(), 1);
1133 match &replies[0] {
1134 NostrRelayMessage::Ok {
1135 event_id, status, ..
1136 } => {
1137 assert_eq!(event_id, &event.id);
1138 assert!(*status);
1139 }
1140 other => panic!("expected OK, got {:?}", other),
1141 }
1142 }
1143
1144 #[test]
1145 fn nostr_replies_for_invalid_event_is_not_ok() {
1146 let keys = Keys::generate();
1147 let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1148 .to_event(&keys)
1149 .unwrap();
1150 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1151 let msg = NostrClientMessage::event(event);
1152 let replies = nostr_responses_for(&msg);
1153 assert_eq!(replies.len(), 1);
1154 match &replies[0] {
1155 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1156 other => panic!("expected OK=false, got {:?}", other),
1157 }
1158 }
1159
1160 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1161 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1162 let addr = listener.local_addr().expect("relay addr");
1163 tokio::spawn(async move {
1164 let (stream, _) = listener.accept().await.expect("accept relay");
1165 let ws = accept_async(stream).await.expect("accept websocket");
1166 let (mut write, mut read) = ws.split();
1167
1168 while let Some(Ok(message)) = read.next().await {
1169 let TungsteniteMessage::Text(text) = message else {
1170 continue;
1171 };
1172 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1173 continue;
1174 };
1175 if let NostrClientMessage::Req {
1176 subscription_id,
1177 filters,
1178 } = parsed
1179 {
1180 for event in events
1181 .iter()
1182 .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1183 {
1184 let _ = write
1185 .send(TungsteniteMessage::Text(
1186 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1187 .as_json()
1188 .into(),
1189 ))
1190 .await;
1191 }
1192 let _ = write
1193 .send(TungsteniteMessage::Text(
1194 NostrRelayMessage::eose(subscription_id).as_json().into(),
1195 ))
1196 .await;
1197 }
1198 }
1199 });
1200 format!("ws://{}", addr)
1201 }
1202
1203 fn test_app_state(
1204 tmp: &TempDir,
1205 relay: Arc<NostrRelay>,
1206 relay_url: String,
1207 ) -> Result<AppState> {
1208 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1209 tmp.path(),
1210 None,
1211 128 * 1024 * 1024,
1212 )?);
1213 Ok(AppState {
1214 store,
1215 auth: None,
1216 peer_mode: crate::config::ServerMode::Normal,
1217 hash_get_enabled: true,
1218 webrtc_peers: None,
1219 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1220 max_upload_bytes: 5 * 1024 * 1024,
1221 public_writes: true,
1222 allowed_pubkeys: HashSet::new(),
1223 upstream_blossom: Vec::new(),
1224 social_graph: None,
1225 social_graph_store: None,
1226 social_graph_root: None,
1227 socialgraph_snapshot_public: false,
1228 nostr_relay: Some(relay),
1229 nostr_relay_urls: vec![relay_url],
1230 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1231 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1232 std::collections::HashMap::new(),
1233 )),
1234 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1235 super::super::auth::new_lookup_cache(),
1236 )),
1237 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1238 super::super::auth::new_lookup_cache(),
1239 )),
1240 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1241 super::super::auth::new_lookup_cache(),
1242 )),
1243 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1244 })
1245 }
1246
1247 #[tokio::test]
1248 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1249 let tmp = TempDir::new()?;
1250 let graph_store = {
1251 let _guard = crate::socialgraph::test_lock();
1252 crate::socialgraph::open_social_graph_store_with_mapsize(
1253 tmp.path(),
1254 Some(128 * 1024 * 1024),
1255 )?
1256 };
1257 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1258 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1259 Arc::clone(&backend),
1260 0,
1261 HashSet::new(),
1262 ));
1263
1264 let keys = Keys::generate();
1265 let relay = Arc::new(NostrRelay::new(
1266 Arc::clone(&backend),
1267 tmp.path().to_path_buf(),
1268 HashSet::from([keys.public_key().to_hex()]),
1269 Some(access),
1270 NostrRelayConfig {
1271 spambox_db_max_bytes: 0,
1272 ..Default::default()
1273 },
1274 )?);
1275
1276 let event = EventBuilder::new(
1277 Kind::from(30078_u16),
1278 "",
1279 [
1280 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1281 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1282 ],
1283 )
1284 .to_event(&keys)?;
1285
1286 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1287 let filter = Filter::new()
1288 .authors(vec![event.pubkey])
1289 .kinds(vec![event.kind]);
1290 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1291 let client_id = 7_u64;
1292 let (tx, mut rx) = mpsc::unbounded_channel();
1293 state.ws_relay.clients.lock().await.insert(client_id, tx);
1294 let subscription_id = SubscriptionId::new("sub-1");
1295
1296 start_upstream_nostr_subscription(
1297 &state,
1298 client_id,
1299 subscription_id.clone(),
1300 vec![filter.clone()],
1301 )
1302 .await;
1303
1304 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1305 .await?
1306 .expect("forwarded upstream event");
1307 let Message::Text(text) = forwarded else {
1308 panic!("expected text event");
1309 };
1310 match NostrRelayMessage::from_json(text.as_str())? {
1311 NostrRelayMessage::Event {
1312 subscription_id: sid,
1313 event: forwarded_event,
1314 } => {
1315 assert_eq!(sid, subscription_id);
1316 assert_eq!(forwarded_event.id, event.id);
1317 }
1318 other => panic!("expected forwarded EVENT, got {:?}", other),
1319 }
1320
1321 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1322 .await?
1323 .expect("forwarded upstream eose");
1324 let Message::Text(eose_text) = eose else {
1325 panic!("expected text eose");
1326 };
1327 match NostrRelayMessage::from_json(eose_text.as_str())? {
1328 NostrRelayMessage::EndOfStoredEvents(sid) => {
1329 assert_eq!(sid, subscription_id);
1330 }
1331 other => panic!("expected forwarded EOSE, got {:?}", other),
1332 }
1333
1334 let events = relay.query_events(&filter, 10).await;
1335 assert_eq!(events.len(), 1);
1336 assert_eq!(events[0].id, event.id);
1337
1338 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1339 assert!(state
1340 .ws_relay
1341 .upstream_nostr_subscriptions
1342 .lock()
1343 .await
1344 .is_empty());
1345 Ok(())
1346 }
1347
1348 #[tokio::test]
1349 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1350 let tmp = TempDir::new()?;
1351 let graph_store = {
1352 let _guard = crate::socialgraph::test_lock();
1353 crate::socialgraph::open_social_graph_store_with_mapsize(
1354 tmp.path(),
1355 Some(128 * 1024 * 1024),
1356 )?
1357 };
1358 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1359 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1360 Arc::clone(&backend),
1361 0,
1362 HashSet::new(),
1363 ));
1364
1365 let keys = Keys::generate();
1366 let relay = Arc::new(NostrRelay::new(
1367 Arc::clone(&backend),
1368 tmp.path().to_path_buf(),
1369 HashSet::from([keys.public_key().to_hex()]),
1370 Some(access),
1371 NostrRelayConfig {
1372 spambox_db_max_bytes: 0,
1373 ..Default::default()
1374 },
1375 )?);
1376
1377 let event = EventBuilder::new(
1378 Kind::from(30078_u16),
1379 "",
1380 [
1381 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1382 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1383 ],
1384 )
1385 .to_event(&keys)?;
1386
1387 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1388 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1389 let client_id = 11_u64;
1390 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1391 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1392 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1393 relay.register_client(client_id, relay_tx, None).await;
1394
1395 let request = NostrClientMessage::req(
1396 SubscriptionId::new("feed"),
1397 vec![Filter::new()
1398 .authors(vec![event.pubkey])
1399 .kinds(vec![event.kind])],
1400 )
1401 .as_json();
1402
1403 handle_message(client_id, Message::Text(request.into()), &state).await;
1404
1405 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1406 .await?
1407 .expect("first forwarded message");
1408 let Message::Text(first_text) = first else {
1409 panic!("expected text event");
1410 };
1411 match NostrRelayMessage::from_json(first_text.as_str())? {
1412 NostrRelayMessage::Event {
1413 event: forwarded_event,
1414 ..
1415 } => {
1416 assert_eq!(forwarded_event.id, event.id);
1417 }
1418 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1419 }
1420
1421 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1422 .await?
1423 .expect("second forwarded message");
1424 let Message::Text(second_text) = second else {
1425 panic!("expected text eose");
1426 };
1427 match NostrRelayMessage::from_json(second_text.as_str())? {
1428 NostrRelayMessage::EndOfStoredEvents(sid) => {
1429 assert_eq!(sid, SubscriptionId::new("feed"));
1430 }
1431 other => panic!("expected aggregated EOSE, got {:?}", other),
1432 }
1433
1434 Ok(())
1435 }
1436
1437 #[tokio::test]
1438 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1439 let tmp = TempDir::new()?;
1440 let graph_store = {
1441 let _guard = crate::socialgraph::test_lock();
1442 crate::socialgraph::open_social_graph_store_with_mapsize(
1443 tmp.path(),
1444 Some(128 * 1024 * 1024),
1445 )?
1446 };
1447 let author_keys = Keys::generate();
1448 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1449 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1450 Arc::clone(&backend),
1451 0,
1452 HashSet::from([author_keys.public_key().to_hex()]),
1453 ));
1454 let relay = Arc::new(NostrRelay::new(
1455 Arc::clone(&backend),
1456 tmp.path().to_path_buf(),
1457 HashSet::from([author_keys.public_key().to_hex()]),
1458 Some(access),
1459 NostrRelayConfig {
1460 spambox_db_max_bytes: 0,
1461 ..Default::default()
1462 },
1463 )?);
1464
1465 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1466 let listener = TcpListener::bind("127.0.0.1:0").await?;
1467 let addr = listener.local_addr()?;
1468 let client_pubkey = author_keys.public_key().to_hex();
1469 let app = axum::Router::new().route(
1470 "/ws",
1471 axum::routing::get({
1472 let state = state.clone();
1473 let client_pubkey = client_pubkey.clone();
1474 move |ws: WebSocketUpgrade| {
1475 let state = state.clone();
1476 let client_pubkey = client_pubkey.clone();
1477 async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1478 }
1479 }),
1480 );
1481 tokio::spawn(async move {
1482 let _ = axum::serve(listener, app).await;
1483 });
1484
1485 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1486 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1487 .to_event(&author_keys)?;
1488 socket
1489 .send(TungsteniteMessage::Text(
1490 NostrClientMessage::event(event.clone()).as_json().into(),
1491 ))
1492 .await?;
1493
1494 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1495 .await?
1496 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1497 let TungsteniteMessage::Text(text) = reply else {
1498 anyhow::bail!("expected text publish ack");
1499 };
1500
1501 match NostrRelayMessage::from_json(text.as_str())? {
1502 NostrRelayMessage::Ok {
1503 event_id, status, ..
1504 } => {
1505 assert_eq!(event_id, event.id);
1506 assert!(status);
1507 }
1508 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1509 }
1510
1511 let stored = relay
1512 .query_events(
1513 &Filter::new()
1514 .authors(vec![event.pubkey])
1515 .kinds(vec![event.kind]),
1516 10,
1517 )
1518 .await;
1519 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1520 Ok(())
1521 }
1522
1523 #[tokio::test]
1524 async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1525 let tmp = TempDir::new()?;
1526 let graph_store = {
1527 let _guard = crate::socialgraph::test_lock();
1528 crate::socialgraph::open_social_graph_store_with_mapsize(
1529 tmp.path(),
1530 Some(128 * 1024 * 1024),
1531 )?
1532 };
1533 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1534 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1535 Arc::clone(&backend),
1536 0,
1537 HashSet::new(),
1538 ));
1539 let relay = Arc::new(NostrRelay::new(
1540 Arc::clone(&backend),
1541 tmp.path().to_path_buf(),
1542 HashSet::new(),
1543 Some(access),
1544 NostrRelayConfig {
1545 spambox_db_max_bytes: 0,
1546 spambox_max_reqs_per_min: 1,
1547 ..Default::default()
1548 },
1549 )?);
1550
1551 let state = test_app_state(&tmp, relay, String::new())?;
1552 let listener = TcpListener::bind("127.0.0.1:0").await?;
1553 let addr = listener.local_addr()?;
1554 let app = axum::Router::new().route(
1555 "/ws",
1556 axum::routing::get({
1557 let state = state.clone();
1558 move |ws: WebSocketUpgrade| {
1559 let state = state.clone();
1560 async move { ws_data_with_client_pubkey(state, ws, None) }
1561 }
1562 }),
1563 );
1564 tokio::spawn(async move {
1565 let _ = axum::serve(listener, app).await;
1566 });
1567
1568 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1569 socket
1570 .send(TungsteniteMessage::Text(
1571 NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1572 .as_json()
1573 .into(),
1574 ))
1575 .await?;
1576
1577 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1578 .await?
1579 .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1580 let TungsteniteMessage::Text(first_text) = first else {
1581 anyhow::bail!("expected text EOSE reply");
1582 };
1583 match NostrRelayMessage::from_json(first_text.as_str())? {
1584 NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1585 assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1586 }
1587 other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1588 }
1589
1590 socket
1591 .send(TungsteniteMessage::Text(
1592 NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1593 .as_json()
1594 .into(),
1595 ))
1596 .await?;
1597
1598 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1599 .await?
1600 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1601 let TungsteniteMessage::Text(second_text) = second else {
1602 anyhow::bail!("expected text CLOSED reply");
1603 };
1604 let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1605 assert_eq!(
1606 second_value,
1607 serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1608 );
1609
1610 Ok(())
1611 }
1612
1613 #[tokio::test]
1614 async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1615 let tmp = TempDir::new()?;
1616 let graph_store = {
1617 let _guard = crate::socialgraph::test_lock();
1618 crate::socialgraph::open_social_graph_store_with_mapsize(
1619 tmp.path(),
1620 Some(128 * 1024 * 1024),
1621 )?
1622 };
1623 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1624 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1625 Arc::clone(&backend),
1626 0,
1627 HashSet::new(),
1628 ));
1629 let relay = Arc::new(NostrRelay::new(
1630 Arc::clone(&backend),
1631 tmp.path().to_path_buf(),
1632 HashSet::new(),
1633 Some(access),
1634 NostrRelayConfig {
1635 spambox_db_max_bytes: 0,
1636 spambox_max_events_per_min: 1,
1637 ..Default::default()
1638 },
1639 )?);
1640
1641 let state = test_app_state(&tmp, relay, String::new())?;
1642 let listener = TcpListener::bind("127.0.0.1:0").await?;
1643 let addr = listener.local_addr()?;
1644 let app = axum::Router::new().route(
1645 "/ws",
1646 axum::routing::get({
1647 let state = state.clone();
1648 move |ws: WebSocketUpgrade| {
1649 let state = state.clone();
1650 async move { ws_data_with_client_pubkey(state, ws, None) }
1651 }
1652 }),
1653 );
1654 tokio::spawn(async move {
1655 let _ = axum::serve(listener, app).await;
1656 });
1657
1658 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1659 let author_keys = Keys::generate();
1660 let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1661 let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1662
1663 socket
1664 .send(TungsteniteMessage::Text(
1665 NostrClientMessage::event(event_a.clone()).as_json().into(),
1666 ))
1667 .await?;
1668
1669 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1670 .await?
1671 .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1672 let TungsteniteMessage::Text(first_text) = first else {
1673 anyhow::bail!("expected text publish ack");
1674 };
1675 match NostrRelayMessage::from_json(first_text.as_str())? {
1676 NostrRelayMessage::Ok {
1677 event_id,
1678 status,
1679 message,
1680 } => {
1681 assert_eq!(event_id, event_a.id);
1682 assert!(status);
1683 assert_eq!(message, "spambox");
1684 }
1685 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1686 }
1687
1688 socket
1689 .send(TungsteniteMessage::Text(
1690 NostrClientMessage::event(event_b.clone()).as_json().into(),
1691 ))
1692 .await?;
1693
1694 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1695 .await?
1696 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1697 let TungsteniteMessage::Text(second_text) = second else {
1698 anyhow::bail!("expected text publish ack");
1699 };
1700 match NostrRelayMessage::from_json(second_text.as_str())? {
1701 NostrRelayMessage::Ok {
1702 event_id,
1703 status,
1704 message,
1705 } => {
1706 assert_eq!(event_id, event_b.id);
1707 assert!(!status);
1708 assert_eq!(message, "rate limited");
1709 }
1710 other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1711 }
1712
1713 Ok(())
1714 }
1715}