1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::webrtc::types::{
21 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
22};
23use hex::encode as hex_encode;
24
25#[derive(Debug, Deserialize)]
26#[serde(tag = "type")]
27enum WsClientMessage {
28 #[serde(rename = "req")]
29 Request { id: u32, hash: String },
30 #[serde(rename = "res")]
31 Response { id: u32, hash: String, found: bool },
32}
33
34#[derive(Debug)]
35enum WsTextMessage {
36 Hashtree(WsClientMessage),
37 Nostr(NostrClientMessage),
38}
39
40#[derive(Debug, Deserialize, Serialize)]
41struct WsRequest {
42 #[serde(rename = "type")]
43 kind: String,
44 id: u32,
45 hash: String,
46}
47
48#[derive(Debug, Serialize)]
49struct WsResponse {
50 #[serde(rename = "type")]
51 kind: &'static str,
52 id: u32,
53 hash: String,
54 found: bool,
55}
56
57pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
58 ws_data_with_client_pubkey(state, ws, None)
59}
60
61pub fn ws_data_with_client_pubkey(
62 state: AppState,
63 ws: WebSocketUpgrade,
64 client_pubkey: Option<String>,
65) -> impl IntoResponse {
66 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
67}
68
69async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
70 let client_id = state
73 .nostr_relay
74 .as_ref()
75 .map(|relay| relay.next_client_id())
76 .unwrap_or_else(|| state.ws_relay.next_id());
77 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
78
79 {
80 let mut clients = state.ws_relay.clients.lock().await;
81 clients.insert(client_id, tx);
82 }
83 {
84 let mut protocols = state.ws_relay.client_protocols.lock().await;
85 protocols.insert(client_id, WsProtocol::HashtreeJson);
86 }
87
88 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
89 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
90 relay
91 .register_client(client_id, nostr_tx, client_pubkey.clone())
92 .await;
93 Some(nostr_rx)
94 } else {
95 None
96 };
97
98 let (mut sender, mut receiver) = socket.split();
99 loop {
100 tokio::select! {
101 maybe_msg = rx.recv() => {
102 let Some(msg) = maybe_msg else {
103 break;
104 };
105 if sender.send(msg).await.is_err() {
106 break;
107 }
108 }
109 maybe_text = async {
110 match &mut nostr_rx {
111 Some(rx) => rx.recv().await,
112 None => std::future::pending().await,
113 }
114 } => {
115 let Some(text) = maybe_text else {
116 nostr_rx = None;
117 continue;
118 };
119 if sender.send(Message::Text(text)).await.is_err() {
120 break;
121 }
122 }
123 maybe_incoming = receiver.next() => {
124 match maybe_incoming {
125 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
126 Some(Err(_)) | None => break,
127 }
128 }
129 }
130 }
131
132 close_all_upstream_nostr_subscriptions(&state, client_id).await;
133
134 {
135 let mut clients = state.ws_relay.clients.lock().await;
136 clients.remove(&client_id);
137 }
138 {
139 let mut protocols = state.ws_relay.client_protocols.lock().await;
140 protocols.remove(&client_id);
141 }
142 {
143 let mut pending = state.ws_relay.pending.lock().await;
144 pending.retain(|(peer_id, _), _| *peer_id != client_id);
145 }
146
147 if let Some(relay) = &state.nostr_relay {
148 relay.unregister_client(client_id).await;
149 }
150}
151
152fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
153 let trimmed = text.trim_start();
154 if trimmed.starts_with('[') {
155 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
156 return Some(WsTextMessage::Nostr(msg));
157 }
158 }
159
160 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
161 return Some(WsTextMessage::Hashtree(msg));
162 }
163
164 None
165}
166async fn close_upstream_nostr_subscription(
167 state: &AppState,
168 client_id: u64,
169 subscription_id: &SubscriptionId,
170) {
171 let key = (client_id, subscription_id.to_string());
172 let subscription = {
173 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
174 subscriptions.remove(&key)
175 };
176 if let Some(subscription) = subscription {
177 let _ = subscription.close_tx.send(true);
178 for task in subscription.tasks {
179 task.abort();
180 }
181 }
182 state
183 .ws_relay
184 .upstream_pending_eose
185 .lock()
186 .await
187 .remove(&key);
188 state
189 .ws_relay
190 .upstream_seen_events
191 .lock()
192 .await
193 .remove(&key);
194}
195
196async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
197 let keys = {
198 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
199 subscriptions
200 .keys()
201 .filter(|(id, _)| *id == client_id)
202 .cloned()
203 .collect::<Vec<_>>()
204 };
205 for (_, sub_id) in keys {
206 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
207 }
208}
209
210async fn forward_upstream_nostr_message(
211 state: &AppState,
212 client_id: u64,
213 subscription_id: &SubscriptionId,
214 text: &str,
215) {
216 let Ok(message) = NostrRelayMessage::from_json(text) else {
217 return;
218 };
219
220 match message {
221 NostrRelayMessage::Event {
222 subscription_id: sid,
223 event,
224 } if sid == *subscription_id => {
225 let event = *event;
226 let key = (client_id, subscription_id.to_string());
227 let event_id = event.id.to_hex();
228 let inserted = {
229 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
230 seen_events.entry(key).or_default().insert(event_id)
231 };
232 if !inserted {
233 return;
234 }
235 if let Some(relay) = &state.nostr_relay {
236 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
237 }
238 send_nostr(
239 state,
240 client_id,
241 NostrRelayMessage::event(subscription_id.clone(), event),
242 )
243 .await;
244 }
245 NostrRelayMessage::Closed {
246 subscription_id: sid,
247 message,
248 } if sid == *subscription_id => {
249 send_nostr(
250 state,
251 client_id,
252 NostrRelayMessage::closed(subscription_id.clone(), message),
253 )
254 .await;
255 }
256 _ => {}
257 }
258}
259
260async fn mark_upstream_nostr_relay_complete(
261 state: &AppState,
262 client_id: u64,
263 subscription_id: &SubscriptionId,
264) {
265 let key = (client_id, subscription_id.to_string());
266 let should_send_eose = {
267 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
268 let Some(remaining) = pending.get_mut(&key) else {
269 return;
270 };
271 if *remaining > 0 {
272 *remaining -= 1;
273 }
274 if *remaining == 0 {
275 pending.remove(&key);
276 true
277 } else {
278 false
279 }
280 };
281
282 if should_send_eose {
283 send_nostr(
284 state,
285 client_id,
286 NostrRelayMessage::eose(subscription_id.clone()),
287 )
288 .await;
289 }
290}
291
292async fn run_upstream_nostr_subscription(
293 state: AppState,
294 client_id: u64,
295 relay_url: String,
296 subscription_id: SubscriptionId,
297 filters: Vec<NostrFilter>,
298 mut close_rx: watch::Receiver<bool>,
299) {
300 let mut relay_complete = false;
301 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
302 tracing::warn!(
303 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
304 client_id,
305 subscription_id,
306 relay_url,
307 );
308 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
309 return;
310 };
311 let (mut write, mut read) = socket.split();
312 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
313 state
314 .ws_relay
315 .note_upstream_relay_send(request.as_bytes().len());
316 if write
317 .send(TungsteniteMessage::Text(request.into()))
318 .await
319 .is_err()
320 {
321 tracing::warn!(
322 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
323 client_id,
324 subscription_id,
325 relay_url,
326 );
327 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
328 return;
329 }
330
331 loop {
332 tokio::select! {
333 _ = close_rx.changed() => {
334 if *close_rx.borrow() {
335 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
336 state
337 .ws_relay
338 .note_upstream_relay_send(close.as_bytes().len());
339 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
340 let _ = write.close().await;
341 break;
342 }
343 }
344 message = read.next() => {
345 match message {
346 Some(Ok(TungsteniteMessage::Text(text))) => {
347 state
348 .ws_relay
349 .note_upstream_relay_receive(text.as_bytes().len());
350 if matches!(
351 NostrRelayMessage::from_json(text.as_str()),
352 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
353 ) {
354 if !relay_complete {
355 relay_complete = true;
356 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
357 }
358 continue;
359 }
360 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
361 }
362 Some(Ok(TungsteniteMessage::Ping(payload))) => {
363 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
364 }
365 Some(Ok(TungsteniteMessage::Close(_))) | None => {
366 if !relay_complete {
367 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
368 }
369 break;
370 }
371 Some(Err(_)) => {
372 if !relay_complete {
373 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
374 }
375 break;
376 }
377 _ => {}
378 }
379 }
380 }
381 }
382}
383
384async fn start_upstream_nostr_subscription(
385 state: &AppState,
386 client_id: u64,
387 subscription_id: SubscriptionId,
388 filters: Vec<NostrFilter>,
389) -> usize {
390 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
391 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
392 tracing::info!(
393 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
394 client_id,
395 subscription_id,
396 state.nostr_relay_urls.len(),
397 filters.len(),
398 );
399 return 0;
400 }
401
402 let mut relay_urls = Vec::new();
403 let mut seen = HashSet::new();
404 for relay in &state.nostr_relay_urls {
405 let relay = relay.trim();
406 if relay.is_empty() || !seen.insert(relay.to_string()) {
407 continue;
408 }
409 relay_urls.push(relay.to_string());
410 }
411 if relay_urls.is_empty() {
412 tracing::info!(
413 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
414 client_id,
415 subscription_id,
416 );
417 return 0;
418 }
419
420 tracing::info!(
421 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
422 client_id,
423 subscription_id,
424 relay_urls.len(),
425 );
426
427 let key = (client_id, subscription_id.to_string());
428 state
429 .ws_relay
430 .upstream_seen_events
431 .lock()
432 .await
433 .insert(key.clone(), HashSet::new());
434 state
435 .ws_relay
436 .upstream_pending_eose
437 .lock()
438 .await
439 .insert(key.clone(), relay_urls.len());
440
441 let (close_tx, close_rx) = watch::channel(false);
442 let mut tasks = Vec::new();
443 for relay_url in &relay_urls {
444 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
445 state.clone(),
446 client_id,
447 relay_url.clone(),
448 subscription_id.clone(),
449 filters.clone(),
450 close_rx.clone(),
451 )));
452 }
453
454 state
455 .ws_relay
456 .upstream_nostr_subscriptions
457 .lock()
458 .await
459 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
460 relay_urls.len()
461}
462
463async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
464 match msg {
465 Message::Text(text) => {
466 if let Some(msg) = parse_ws_text_message(&text) {
467 match msg {
468 WsTextMessage::Hashtree(msg) => {
469 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
470 match msg {
471 WsClientMessage::Request { id, hash } => {
472 handle_request(
473 client_id,
474 id,
475 hash,
476 WsProtocol::HashtreeJson,
477 state,
478 )
479 .await;
480 }
481 WsClientMessage::Response { id, hash, found } => {
482 handle_response(client_id, id, hash, found, state).await;
483 }
484 }
485 }
486 WsTextMessage::Nostr(msg) => {
487 if let Some(relay) = &state.nostr_relay {
488 match msg {
489 NostrClientMessage::Req {
490 subscription_id,
491 filters,
492 } => {
493 let local_events = match relay
494 .register_subscription_query(
495 client_id,
496 subscription_id.clone(),
497 filters.clone(),
498 )
499 .await
500 {
501 Ok(events) => events,
502 Err(message) => {
503 send_nostr(
504 state,
505 client_id,
506 NostrRelayMessage::closed(subscription_id, message),
507 )
508 .await;
509 return;
510 }
511 };
512
513 let upstream_relays = start_upstream_nostr_subscription(
514 state,
515 client_id,
516 subscription_id.clone(),
517 filters,
518 )
519 .await;
520 if upstream_relays > 0 {
521 let key = (client_id, subscription_id.to_string());
522 let mut seen_events =
523 state.ws_relay.upstream_seen_events.lock().await;
524 seen_events.entry(key).or_default().extend(
525 local_events.iter().map(|event| event.id.to_hex()),
526 );
527 }
528 for event in local_events {
529 send_nostr(
530 state,
531 client_id,
532 NostrRelayMessage::event(
533 subscription_id.clone(),
534 event,
535 ),
536 )
537 .await;
538 }
539 if upstream_relays == 0 {
540 send_nostr(
541 state,
542 client_id,
543 NostrRelayMessage::eose(subscription_id),
544 )
545 .await;
546 }
547 }
548 NostrClientMessage::Close(subscription_id) => {
549 close_upstream_nostr_subscription(
550 state,
551 client_id,
552 &subscription_id,
553 )
554 .await;
555 relay
556 .handle_client_message(
557 client_id,
558 NostrClientMessage::Close(subscription_id.clone()),
559 )
560 .await;
561 }
562 other => {
563 relay.handle_client_message(client_id, other).await;
564 }
565 }
566 } else {
567 handle_nostr_message(client_id, msg, state).await;
568 }
569 }
570 }
571 }
572 }
573 Message::Binary(data) => {
574 handle_binary(client_id, data, state).await;
575 }
576 Message::Close(_) => {}
577 _ => {}
578 }
579}
580
581async fn handle_request(
582 client_id: u64,
583 request_id: u32,
584 hash: String,
585 origin_protocol: WsProtocol,
586 state: &AppState,
587) {
588 let hash_hex = hash.to_lowercase();
589 let hash_bytes = match from_hex(&hash_hex) {
590 Ok(bytes) => bytes,
591 Err(_) => {
592 if origin_protocol == WsProtocol::HashtreeJson {
593 send_json(
594 state,
595 client_id,
596 WsResponse {
597 kind: "res",
598 id: request_id,
599 hash,
600 found: false,
601 },
602 )
603 .await;
604 }
605 return;
606 }
607 };
608
609 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
610 match origin_protocol {
611 WsProtocol::HashtreeJson => {
612 send_json(
613 state,
614 client_id,
615 WsResponse {
616 kind: "res",
617 id: request_id,
618 hash: hash.clone(),
619 found: true,
620 },
621 )
622 .await;
623 send_binary(state, client_id, request_id, data).await;
624 }
625 WsProtocol::HashtreeMsgpack => {
626 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
627 }
628 WsProtocol::Unknown => {}
629 }
630 return;
631 }
632
633 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
634 let clients = state.ws_relay.clients.lock().await;
635 let protocols = state.ws_relay.client_protocols.lock().await;
636 clients
637 .iter()
638 .filter(|(id, _)| **id != client_id)
639 .filter_map(|(id, tx)| {
640 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
641 match protocol {
642 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
643 Some((*id, tx.clone(), protocol))
644 }
645 WsProtocol::Unknown => None,
646 }
647 })
648 .collect()
649 };
650
651 if peers.is_empty() {
652 if origin_protocol == WsProtocol::HashtreeJson {
653 send_json(
654 state,
655 client_id,
656 WsResponse {
657 kind: "res",
658 id: request_id,
659 hash,
660 found: false,
661 },
662 )
663 .await;
664 }
665 return;
666 }
667
668 {
669 let mut pending = state.ws_relay.pending.lock().await;
670 for (peer_id, _, _) in &peers {
671 pending.insert(
672 (*peer_id, request_id),
673 PendingRequest {
674 origin_id: client_id,
675 hash: hash.clone(),
676 found: false,
677 origin_protocol,
678 },
679 );
680 }
681 }
682
683 let request_text = serde_json::to_string(&WsRequest {
684 kind: "req".to_string(),
685 id: request_id,
686 hash: hash.clone(),
687 })
688 .unwrap_or_else(|_| String::new());
689 for (peer_id, tx, protocol) in peers {
690 match protocol {
691 WsProtocol::HashtreeMsgpack => {
692 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
693 }
694 WsProtocol::HashtreeJson => {
695 let _ = tx.send(Message::Text(request_text.clone()));
696 }
697 WsProtocol::Unknown => {}
698 }
699 }
700
701 let timeout_state = state.clone();
702 let timeout_hash = hash.clone();
703 tokio::spawn(async move {
704 tokio::time::sleep(Duration::from_millis(1500)).await;
705 let mut pending = timeout_state.ws_relay.pending.lock().await;
706 let still_pending = pending
707 .iter()
708 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
709 let already_found = pending
710 .iter()
711 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
712 if !still_pending || already_found {
713 return;
714 }
715 let origin_protocol = pending
716 .iter()
717 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
718 .map(|(_, p)| p.origin_protocol)
719 .unwrap_or(WsProtocol::HashtreeJson);
720 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
721 drop(pending);
722 if origin_protocol == WsProtocol::HashtreeJson {
723 send_json(
724 &timeout_state,
725 client_id,
726 WsResponse {
727 kind: "res",
728 id: request_id,
729 hash: timeout_hash,
730 found: false,
731 },
732 )
733 .await;
734 }
735 });
736}
737
738async fn handle_response(
739 client_id: u64,
740 request_id: u32,
741 _hash: String,
742 found: bool,
743 state: &AppState,
744) {
745 let pending_entry = {
746 let pending = state.ws_relay.pending.lock().await;
747 pending
748 .get(&(client_id, request_id))
749 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
750 };
751
752 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
753 return;
754 };
755
756 if already_found && !found {
757 let mut pending = state.ws_relay.pending.lock().await;
758 pending.remove(&(client_id, request_id));
759 return;
760 }
761
762 if found {
763 let mut pending = state.ws_relay.pending.lock().await;
764 for ((_, id), p) in pending.iter_mut() {
765 if *id == request_id && p.origin_id == origin_id {
766 p.found = true;
767 }
768 }
769 drop(pending);
770 if origin_protocol == WsProtocol::HashtreeJson {
771 send_json(
772 state,
773 origin_id,
774 WsResponse {
775 kind: "res",
776 id: request_id,
777 hash: pending_hash,
778 found: true,
779 },
780 )
781 .await;
782 }
783 return;
784 }
785
786 let mut pending = state.ws_relay.pending.lock().await;
787 pending.remove(&(client_id, request_id));
788 let has_remaining = pending
789 .iter()
790 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
791 drop(pending);
792
793 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
794 send_json(
795 state,
796 origin_id,
797 WsResponse {
798 kind: "res",
799 id: request_id,
800 hash: pending_hash,
801 found: false,
802 },
803 )
804 .await;
805 }
806}
807
808async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
809 if let Some(msg) = parse_msgpack_message(&data) {
810 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
811 match msg {
812 DataMessage::Request(req) => {
813 let hash_hex = hex_encode(&req.h);
814 let request_id = state.ws_relay.next_request_id();
815 handle_request(
816 client_id,
817 request_id,
818 hash_hex,
819 WsProtocol::HashtreeMsgpack,
820 state,
821 )
822 .await;
823 }
824 DataMessage::Response(res) => {
825 handle_msgpack_response(client_id, res, state).await;
826 }
827 DataMessage::QuoteRequest(_)
828 | DataMessage::QuoteResponse(_)
829 | DataMessage::Payment(_)
830 | DataMessage::PaymentAck(_)
831 | DataMessage::Chunk(_) => {}
832 }
833 return;
834 }
835
836 if data.len() < 4 {
838 return;
839 }
840 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
841 let pending_entry = {
842 let pending = state.ws_relay.pending.lock().await;
843 pending
844 .get(&(client_id, request_id))
845 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
846 };
847 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
848 return;
849 };
850
851 match origin_protocol {
852 WsProtocol::HashtreeJson => {
853 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
854 }
855 WsProtocol::HashtreeMsgpack => {
856 let Ok(hash_bytes) = from_hex(&hash_hex) else {
857 return;
858 };
859 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
860 }
861 WsProtocol::Unknown => {}
862 }
863
864 let mut pending = state.ws_relay.pending.lock().await;
865 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
866}
867
868async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
869 let replies = nostr_responses_for(&msg);
870 for reply in replies {
871 send_nostr(state, client_id, reply).await;
872 }
873}
874
875fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
876 match msg {
877 NostrClientMessage::Event(event) => {
878 let ok = event.verify().is_ok();
879 let message = if ok { "" } else { "invalid: signature" };
880 vec![NostrRelayMessage::ok(event.id, ok, message)]
881 }
882 NostrClientMessage::Req {
883 subscription_id, ..
884 } => {
885 vec![NostrRelayMessage::eose(subscription_id.clone())]
886 }
887 NostrClientMessage::Count {
888 subscription_id, ..
889 } => {
890 vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
891 }
892 NostrClientMessage::Close(_) => Vec::new(),
893 NostrClientMessage::Auth(event) => {
894 let ok = event.verify().is_ok();
895 let message = if ok { "" } else { "invalid auth" };
896 vec![NostrRelayMessage::ok(event.id, ok, message)]
897 }
898 NostrClientMessage::NegOpen { .. }
899 | NostrClientMessage::NegMsg { .. }
900 | NostrClientMessage::NegClose { .. } => {
901 vec![NostrRelayMessage::notice("negentropy not supported")]
902 }
903 }
904}
905
906async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
907 let text = response.as_json();
908 send_to_client(state, client_id, Message::Text(text)).await;
909}
910
911fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
912 let msg = parse_message(data).ok()?;
913 match msg {
914 DataMessage::Request(req) => {
915 if req.h.len() == 32 {
916 Some(DataMessage::Request(req))
917 } else {
918 None
919 }
920 }
921 DataMessage::Response(res) => {
922 if res.h.len() == 32 {
923 Some(DataMessage::Response(res))
924 } else {
925 None
926 }
927 }
928 DataMessage::QuoteRequest(req) => {
929 if req.h.len() == 32 {
930 Some(DataMessage::QuoteRequest(req))
931 } else {
932 None
933 }
934 }
935 DataMessage::QuoteResponse(res) => {
936 if res.h.len() == 32 {
937 Some(DataMessage::QuoteResponse(res))
938 } else {
939 None
940 }
941 }
942 DataMessage::Payment(req) => {
943 if req.h.len() == 32 {
944 Some(DataMessage::Payment(req))
945 } else {
946 None
947 }
948 }
949 DataMessage::PaymentAck(res) => {
950 if res.h.len() == 32 {
951 Some(DataMessage::PaymentAck(res))
952 } else {
953 None
954 }
955 }
956 DataMessage::Chunk(chunk) => {
957 if chunk.h.len() == 32 {
958 Some(DataMessage::Chunk(chunk))
959 } else {
960 None
961 }
962 }
963 }
964}
965
966async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
967 let hash_hex = hex_encode(&res.h);
968 let data = res.d.clone();
969 let hash_bytes = res.h.clone();
970
971 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
972 let mut seen = HashSet::new();
973 {
974 let pending = state.ws_relay.pending.lock().await;
975 for ((peer_id, request_id), p) in pending.iter() {
976 if *peer_id != client_id {
977 continue;
978 }
979 if p.hash != hash_hex {
980 continue;
981 }
982 if seen.insert((p.origin_id, *request_id)) {
983 responses.push((p.origin_id, *request_id, p.origin_protocol));
984 }
985 }
986 }
987
988 if responses.is_empty() {
989 return;
990 }
991
992 for (origin_id, request_id, protocol) in &responses {
993 match protocol {
994 WsProtocol::HashtreeJson => {
995 send_json(
996 state,
997 *origin_id,
998 WsResponse {
999 kind: "res",
1000 id: *request_id,
1001 hash: hash_hex.clone(),
1002 found: true,
1003 },
1004 )
1005 .await;
1006 send_binary(state, *origin_id, *request_id, data.clone()).await;
1007 }
1008 WsProtocol::HashtreeMsgpack => {
1009 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1010 }
1011 WsProtocol::Unknown => {}
1012 }
1013 }
1014
1015 let completed: HashSet<(u64, u32)> = responses
1016 .into_iter()
1017 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1018 .collect();
1019 let mut pending = state.ws_relay.pending.lock().await;
1020 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1021}
1022
1023async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1024 if let Ok(text) = serde_json::to_string(&response) {
1025 send_to_client(state, client_id, Message::Text(text)).await;
1026 }
1027}
1028
1029async fn send_msgpack_request(
1030 state: &AppState,
1031 client_id: u64,
1032 hash: &[u8],
1033) -> Result<(), rmp_serde::encode::Error> {
1034 let req = DataRequest {
1035 h: hash.to_vec(),
1036 htl: MAX_HTL,
1037 q: None,
1038 };
1039 let wire = encode_request(&req)?;
1040 send_to_client(state, client_id, Message::Binary(wire)).await;
1041 Ok(())
1042}
1043
1044async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1045 let res = DataResponse {
1046 h: hash.to_vec(),
1047 d: data.to_vec(),
1048 i: None,
1049 n: None,
1050 };
1051 if let Ok(wire) = encode_response(&res) {
1052 send_to_client(state, client_id, Message::Binary(wire)).await;
1053 }
1054}
1055
1056async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1057 let mut packet = Vec::with_capacity(4 + payload.len());
1058 packet.extend_from_slice(&request_id.to_le_bytes());
1059 packet.extend_from_slice(&payload);
1060 send_to_client(state, client_id, Message::Binary(packet)).await;
1061}
1062
1063async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1064 let sender = {
1065 let clients = state.ws_relay.clients.lock().await;
1066 clients.get(&client_id).cloned()
1067 };
1068 if let Some(tx) = sender {
1069 let _ = tx.send(msg);
1070 }
1071}
1072
1073async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1074 let mut protocols = state.ws_relay.client_protocols.lock().await;
1075 protocols.insert(client_id, protocol);
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use super::*;
1081 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1082 use anyhow::Result;
1083 use futures::{SinkExt, StreamExt};
1084 use nostr::secp256k1::schnorr::Signature;
1085 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1086 use std::collections::HashSet;
1087 use std::sync::Arc;
1088 use tempfile::TempDir;
1089 use tokio::net::TcpListener;
1090 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1091
1092 #[test]
1093 fn parse_ws_text_message_detects_nostr_req() {
1094 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1095 match parse_ws_text_message(msg) {
1096 Some(WsTextMessage::Nostr(_)) => {}
1097 other => panic!("expected Nostr message, got {:?}", other),
1098 }
1099 }
1100
1101 #[test]
1102 fn parse_ws_text_message_detects_hashtree_request() {
1103 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1104 match parse_ws_text_message(msg) {
1105 Some(WsTextMessage::Hashtree(_)) => {}
1106 other => panic!("expected Hashtree message, got {:?}", other),
1107 }
1108 }
1109
1110 #[test]
1111 fn nostr_replies_for_req_is_eose() {
1112 let sub = SubscriptionId::new("sub-1");
1113 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1114 let replies = nostr_responses_for(&msg);
1115 assert_eq!(replies.len(), 1);
1116 match &replies[0] {
1117 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1118 other => panic!("expected EOSE, got {:?}", other),
1119 }
1120 }
1121
1122 #[test]
1123 fn nostr_replies_for_event_ok() {
1124 let keys = Keys::generate();
1125 let event = EventBuilder::new(Kind::TextNote, "hello", [])
1126 .to_event(&keys)
1127 .unwrap();
1128 let msg = NostrClientMessage::event(event.clone());
1129 let replies = nostr_responses_for(&msg);
1130 assert_eq!(replies.len(), 1);
1131 match &replies[0] {
1132 NostrRelayMessage::Ok {
1133 event_id, status, ..
1134 } => {
1135 assert_eq!(event_id, &event.id);
1136 assert!(*status);
1137 }
1138 other => panic!("expected OK, got {:?}", other),
1139 }
1140 }
1141
1142 #[test]
1143 fn nostr_replies_for_invalid_event_is_not_ok() {
1144 let keys = Keys::generate();
1145 let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1146 .to_event(&keys)
1147 .unwrap();
1148 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1149 let msg = NostrClientMessage::event(event);
1150 let replies = nostr_responses_for(&msg);
1151 assert_eq!(replies.len(), 1);
1152 match &replies[0] {
1153 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1154 other => panic!("expected OK=false, got {:?}", other),
1155 }
1156 }
1157
1158 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1159 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1160 let addr = listener.local_addr().expect("relay addr");
1161 tokio::spawn(async move {
1162 let (stream, _) = listener.accept().await.expect("accept relay");
1163 let ws = accept_async(stream).await.expect("accept websocket");
1164 let (mut write, mut read) = ws.split();
1165
1166 while let Some(Ok(message)) = read.next().await {
1167 let TungsteniteMessage::Text(text) = message else {
1168 continue;
1169 };
1170 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1171 continue;
1172 };
1173 if let NostrClientMessage::Req {
1174 subscription_id,
1175 filters,
1176 } = parsed
1177 {
1178 for event in events
1179 .iter()
1180 .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1181 {
1182 let _ = write
1183 .send(TungsteniteMessage::Text(
1184 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1185 .as_json()
1186 .into(),
1187 ))
1188 .await;
1189 }
1190 let _ = write
1191 .send(TungsteniteMessage::Text(
1192 NostrRelayMessage::eose(subscription_id).as_json().into(),
1193 ))
1194 .await;
1195 }
1196 }
1197 });
1198 format!("ws://{}", addr)
1199 }
1200
1201 fn test_app_state(
1202 tmp: &TempDir,
1203 relay: Arc<NostrRelay>,
1204 relay_url: String,
1205 ) -> Result<AppState> {
1206 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1207 tmp.path(),
1208 None,
1209 128 * 1024 * 1024,
1210 )?);
1211 Ok(AppState {
1212 store,
1213 auth: None,
1214 peer_mode: crate::config::ServerMode::Normal,
1215 hash_get_enabled: true,
1216 webrtc_peers: None,
1217 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1218 max_upload_bytes: 5 * 1024 * 1024,
1219 public_writes: true,
1220 allowed_pubkeys: HashSet::new(),
1221 upstream_blossom: Vec::new(),
1222 social_graph: None,
1223 social_graph_store: None,
1224 social_graph_root: None,
1225 socialgraph_snapshot_public: false,
1226 nostr_relay: Some(relay),
1227 nostr_relay_urls: vec![relay_url],
1228 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1229 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1230 std::collections::HashMap::new(),
1231 )),
1232 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1233 super::super::auth::new_lookup_cache(),
1234 )),
1235 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1236 super::super::auth::new_lookup_cache(),
1237 )),
1238 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1239 super::super::auth::new_lookup_cache(),
1240 )),
1241 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1242 })
1243 }
1244
1245 #[tokio::test]
1246 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1247 let tmp = TempDir::new()?;
1248 let graph_store = {
1249 let _guard = crate::socialgraph::test_lock();
1250 crate::socialgraph::open_social_graph_store_with_mapsize(
1251 tmp.path(),
1252 Some(128 * 1024 * 1024),
1253 )?
1254 };
1255 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1256 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1257 Arc::clone(&backend),
1258 0,
1259 HashSet::new(),
1260 ));
1261
1262 let keys = Keys::generate();
1263 let relay = Arc::new(NostrRelay::new(
1264 Arc::clone(&backend),
1265 tmp.path().to_path_buf(),
1266 HashSet::from([keys.public_key().to_hex()]),
1267 Some(access),
1268 NostrRelayConfig {
1269 spambox_db_max_bytes: 0,
1270 ..Default::default()
1271 },
1272 )?);
1273
1274 let event = EventBuilder::new(
1275 Kind::from(30078_u16),
1276 "",
1277 [
1278 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1279 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1280 ],
1281 )
1282 .to_event(&keys)?;
1283
1284 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1285 let filter = Filter::new()
1286 .authors(vec![event.pubkey])
1287 .kinds(vec![event.kind]);
1288 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1289 let client_id = 7_u64;
1290 let (tx, mut rx) = mpsc::unbounded_channel();
1291 state.ws_relay.clients.lock().await.insert(client_id, tx);
1292 let subscription_id = SubscriptionId::new("sub-1");
1293
1294 start_upstream_nostr_subscription(
1295 &state,
1296 client_id,
1297 subscription_id.clone(),
1298 vec![filter.clone()],
1299 )
1300 .await;
1301
1302 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1303 .await?
1304 .expect("forwarded upstream event");
1305 let Message::Text(text) = forwarded else {
1306 panic!("expected text event");
1307 };
1308 match NostrRelayMessage::from_json(text.as_str())? {
1309 NostrRelayMessage::Event {
1310 subscription_id: sid,
1311 event: forwarded_event,
1312 } => {
1313 assert_eq!(sid, subscription_id);
1314 assert_eq!(forwarded_event.id, event.id);
1315 }
1316 other => panic!("expected forwarded EVENT, got {:?}", other),
1317 }
1318
1319 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1320 .await?
1321 .expect("forwarded upstream eose");
1322 let Message::Text(eose_text) = eose else {
1323 panic!("expected text eose");
1324 };
1325 match NostrRelayMessage::from_json(eose_text.as_str())? {
1326 NostrRelayMessage::EndOfStoredEvents(sid) => {
1327 assert_eq!(sid, subscription_id);
1328 }
1329 other => panic!("expected forwarded EOSE, got {:?}", other),
1330 }
1331
1332 let events = relay.query_events(&filter, 10).await;
1333 assert_eq!(events.len(), 1);
1334 assert_eq!(events[0].id, event.id);
1335
1336 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1337 assert!(state
1338 .ws_relay
1339 .upstream_nostr_subscriptions
1340 .lock()
1341 .await
1342 .is_empty());
1343 Ok(())
1344 }
1345
1346 #[tokio::test]
1347 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1348 let tmp = TempDir::new()?;
1349 let graph_store = {
1350 let _guard = crate::socialgraph::test_lock();
1351 crate::socialgraph::open_social_graph_store_with_mapsize(
1352 tmp.path(),
1353 Some(128 * 1024 * 1024),
1354 )?
1355 };
1356 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1357 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1358 Arc::clone(&backend),
1359 0,
1360 HashSet::new(),
1361 ));
1362
1363 let keys = Keys::generate();
1364 let relay = Arc::new(NostrRelay::new(
1365 Arc::clone(&backend),
1366 tmp.path().to_path_buf(),
1367 HashSet::from([keys.public_key().to_hex()]),
1368 Some(access),
1369 NostrRelayConfig {
1370 spambox_db_max_bytes: 0,
1371 ..Default::default()
1372 },
1373 )?);
1374
1375 let event = EventBuilder::new(
1376 Kind::from(30078_u16),
1377 "",
1378 [
1379 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1380 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1381 ],
1382 )
1383 .to_event(&keys)?;
1384
1385 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1386 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1387 let client_id = 11_u64;
1388 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1389 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1390 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1391 relay.register_client(client_id, relay_tx, None).await;
1392
1393 let request = NostrClientMessage::req(
1394 SubscriptionId::new("feed"),
1395 vec![Filter::new()
1396 .authors(vec![event.pubkey])
1397 .kinds(vec![event.kind])],
1398 )
1399 .as_json();
1400
1401 handle_message(client_id, Message::Text(request.into()), &state).await;
1402
1403 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1404 .await?
1405 .expect("first forwarded message");
1406 let Message::Text(first_text) = first else {
1407 panic!("expected text event");
1408 };
1409 match NostrRelayMessage::from_json(first_text.as_str())? {
1410 NostrRelayMessage::Event {
1411 event: forwarded_event,
1412 ..
1413 } => {
1414 assert_eq!(forwarded_event.id, event.id);
1415 }
1416 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1417 }
1418
1419 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1420 .await?
1421 .expect("second forwarded message");
1422 let Message::Text(second_text) = second else {
1423 panic!("expected text eose");
1424 };
1425 match NostrRelayMessage::from_json(second_text.as_str())? {
1426 NostrRelayMessage::EndOfStoredEvents(sid) => {
1427 assert_eq!(sid, SubscriptionId::new("feed"));
1428 }
1429 other => panic!("expected aggregated EOSE, got {:?}", other),
1430 }
1431
1432 Ok(())
1433 }
1434
1435 #[tokio::test]
1436 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1437 let tmp = TempDir::new()?;
1438 let graph_store = {
1439 let _guard = crate::socialgraph::test_lock();
1440 crate::socialgraph::open_social_graph_store_with_mapsize(
1441 tmp.path(),
1442 Some(128 * 1024 * 1024),
1443 )?
1444 };
1445 let author_keys = Keys::generate();
1446 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1447 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1448 Arc::clone(&backend),
1449 0,
1450 HashSet::from([author_keys.public_key().to_hex()]),
1451 ));
1452 let relay = Arc::new(NostrRelay::new(
1453 Arc::clone(&backend),
1454 tmp.path().to_path_buf(),
1455 HashSet::from([author_keys.public_key().to_hex()]),
1456 Some(access),
1457 NostrRelayConfig {
1458 spambox_db_max_bytes: 0,
1459 ..Default::default()
1460 },
1461 )?);
1462
1463 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1464 let listener = TcpListener::bind("127.0.0.1:0").await?;
1465 let addr = listener.local_addr()?;
1466 let client_pubkey = author_keys.public_key().to_hex();
1467 let app = axum::Router::new().route(
1468 "/ws",
1469 axum::routing::get({
1470 let state = state.clone();
1471 let client_pubkey = client_pubkey.clone();
1472 move |ws: WebSocketUpgrade| {
1473 let state = state.clone();
1474 let client_pubkey = client_pubkey.clone();
1475 async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1476 }
1477 }),
1478 );
1479 tokio::spawn(async move {
1480 let _ = axum::serve(listener, app).await;
1481 });
1482
1483 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1484 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1485 .to_event(&author_keys)?;
1486 socket
1487 .send(TungsteniteMessage::Text(
1488 NostrClientMessage::event(event.clone()).as_json().into(),
1489 ))
1490 .await?;
1491
1492 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1493 .await?
1494 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1495 let TungsteniteMessage::Text(text) = reply else {
1496 anyhow::bail!("expected text publish ack");
1497 };
1498
1499 match NostrRelayMessage::from_json(text.as_str())? {
1500 NostrRelayMessage::Ok {
1501 event_id, status, ..
1502 } => {
1503 assert_eq!(event_id, event.id);
1504 assert!(status);
1505 }
1506 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1507 }
1508
1509 let stored = relay
1510 .query_events(
1511 &Filter::new()
1512 .authors(vec![event.pubkey])
1513 .kinds(vec![event.kind]),
1514 10,
1515 )
1516 .await;
1517 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1518 Ok(())
1519 }
1520
1521 #[tokio::test]
1522 async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1523 let tmp = TempDir::new()?;
1524 let graph_store = {
1525 let _guard = crate::socialgraph::test_lock();
1526 crate::socialgraph::open_social_graph_store_with_mapsize(
1527 tmp.path(),
1528 Some(128 * 1024 * 1024),
1529 )?
1530 };
1531 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1532 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1533 Arc::clone(&backend),
1534 0,
1535 HashSet::new(),
1536 ));
1537 let relay = Arc::new(NostrRelay::new(
1538 Arc::clone(&backend),
1539 tmp.path().to_path_buf(),
1540 HashSet::new(),
1541 Some(access),
1542 NostrRelayConfig {
1543 spambox_db_max_bytes: 0,
1544 spambox_max_reqs_per_min: 1,
1545 ..Default::default()
1546 },
1547 )?);
1548
1549 let state = test_app_state(&tmp, relay, String::new())?;
1550 let listener = TcpListener::bind("127.0.0.1:0").await?;
1551 let addr = listener.local_addr()?;
1552 let app = axum::Router::new().route(
1553 "/ws",
1554 axum::routing::get({
1555 let state = state.clone();
1556 move |ws: WebSocketUpgrade| {
1557 let state = state.clone();
1558 async move { ws_data_with_client_pubkey(state, ws, None) }
1559 }
1560 }),
1561 );
1562 tokio::spawn(async move {
1563 let _ = axum::serve(listener, app).await;
1564 });
1565
1566 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1567 socket
1568 .send(TungsteniteMessage::Text(
1569 NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1570 .as_json()
1571 .into(),
1572 ))
1573 .await?;
1574
1575 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1576 .await?
1577 .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1578 let TungsteniteMessage::Text(first_text) = first else {
1579 anyhow::bail!("expected text EOSE reply");
1580 };
1581 match NostrRelayMessage::from_json(first_text.as_str())? {
1582 NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1583 assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1584 }
1585 other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1586 }
1587
1588 socket
1589 .send(TungsteniteMessage::Text(
1590 NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1591 .as_json()
1592 .into(),
1593 ))
1594 .await?;
1595
1596 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1597 .await?
1598 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1599 let TungsteniteMessage::Text(second_text) = second else {
1600 anyhow::bail!("expected text CLOSED reply");
1601 };
1602 let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1603 assert_eq!(
1604 second_value,
1605 serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1606 );
1607
1608 Ok(())
1609 }
1610
1611 #[tokio::test]
1612 async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1613 let tmp = TempDir::new()?;
1614 let graph_store = {
1615 let _guard = crate::socialgraph::test_lock();
1616 crate::socialgraph::open_social_graph_store_with_mapsize(
1617 tmp.path(),
1618 Some(128 * 1024 * 1024),
1619 )?
1620 };
1621 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1622 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1623 Arc::clone(&backend),
1624 0,
1625 HashSet::new(),
1626 ));
1627 let relay = Arc::new(NostrRelay::new(
1628 Arc::clone(&backend),
1629 tmp.path().to_path_buf(),
1630 HashSet::new(),
1631 Some(access),
1632 NostrRelayConfig {
1633 spambox_db_max_bytes: 0,
1634 spambox_max_events_per_min: 1,
1635 ..Default::default()
1636 },
1637 )?);
1638
1639 let state = test_app_state(&tmp, relay, String::new())?;
1640 let listener = TcpListener::bind("127.0.0.1:0").await?;
1641 let addr = listener.local_addr()?;
1642 let app = axum::Router::new().route(
1643 "/ws",
1644 axum::routing::get({
1645 let state = state.clone();
1646 move |ws: WebSocketUpgrade| {
1647 let state = state.clone();
1648 async move { ws_data_with_client_pubkey(state, ws, None) }
1649 }
1650 }),
1651 );
1652 tokio::spawn(async move {
1653 let _ = axum::serve(listener, app).await;
1654 });
1655
1656 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1657 let author_keys = Keys::generate();
1658 let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1659 let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1660
1661 socket
1662 .send(TungsteniteMessage::Text(
1663 NostrClientMessage::event(event_a.clone()).as_json().into(),
1664 ))
1665 .await?;
1666
1667 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1668 .await?
1669 .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1670 let TungsteniteMessage::Text(first_text) = first else {
1671 anyhow::bail!("expected text publish ack");
1672 };
1673 match NostrRelayMessage::from_json(first_text.as_str())? {
1674 NostrRelayMessage::Ok {
1675 event_id,
1676 status,
1677 message,
1678 } => {
1679 assert_eq!(event_id, event_a.id);
1680 assert!(status);
1681 assert_eq!(message, "spambox");
1682 }
1683 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1684 }
1685
1686 socket
1687 .send(TungsteniteMessage::Text(
1688 NostrClientMessage::event(event_b.clone()).as_json().into(),
1689 ))
1690 .await?;
1691
1692 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1693 .await?
1694 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1695 let TungsteniteMessage::Text(second_text) = second else {
1696 anyhow::bail!("expected text publish ack");
1697 };
1698 match NostrRelayMessage::from_json(second_text.as_str())? {
1699 NostrRelayMessage::Ok {
1700 event_id,
1701 status,
1702 message,
1703 } => {
1704 assert_eq!(event_id, event_b.id);
1705 assert!(!status);
1706 assert_eq!(message, "rate limited");
1707 }
1708 other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1709 }
1710
1711 Ok(())
1712 }
1713}