1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21 nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31 #[serde(rename = "req")]
32 Request { id: u32, hash: String },
33 #[serde(rename = "res")]
34 Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage {
39 Hashtree(WsClientMessage),
40 Nostr(NostrClientMessage),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45 #[serde(rename = "type")]
46 kind: String,
47 id: u32,
48 hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53 #[serde(rename = "type")]
54 kind: &'static str,
55 id: u32,
56 hash: String,
57 found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61 ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65 state: AppState,
66 ws: WebSocketUpgrade,
67 client_pubkey: Option<String>,
68) -> impl IntoResponse {
69 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73 let client_id = state
76 .nostr_relay
77 .as_ref()
78 .map(|relay| relay.next_client_id())
79 .unwrap_or_else(|| state.ws_relay.next_id());
80 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82 {
83 let mut clients = state.ws_relay.clients.lock().await;
84 clients.insert(client_id, tx);
85 }
86 {
87 let mut protocols = state.ws_relay.client_protocols.lock().await;
88 protocols.insert(client_id, WsProtocol::HashtreeJson);
89 }
90
91 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93 relay
94 .register_client(client_id, nostr_tx, client_pubkey.clone())
95 .await;
96 Some(nostr_rx)
97 } else {
98 None
99 };
100
101 let (mut sender, mut receiver) = socket.split();
102 loop {
103 tokio::select! {
104 maybe_msg = rx.recv() => {
105 let Some(msg) = maybe_msg else {
106 break;
107 };
108 if sender.send(msg).await.is_err() {
109 break;
110 }
111 }
112 maybe_text = async {
113 match &mut nostr_rx {
114 Some(rx) => rx.recv().await,
115 None => std::future::pending().await,
116 }
117 } => {
118 let Some(text) = maybe_text else {
119 nostr_rx = None;
120 continue;
121 };
122 if sender.send(Message::Text(text)).await.is_err() {
123 break;
124 }
125 }
126 maybe_incoming = receiver.next() => {
127 match maybe_incoming {
128 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129 Some(Err(_)) | None => break,
130 }
131 }
132 }
133 }
134
135 close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137 {
138 let mut clients = state.ws_relay.clients.lock().await;
139 clients.remove(&client_id);
140 }
141 {
142 let mut protocols = state.ws_relay.client_protocols.lock().await;
143 protocols.remove(&client_id);
144 }
145 {
146 let mut pending = state.ws_relay.pending.lock().await;
147 pending.retain(|(peer_id, _), _| *peer_id != client_id);
148 }
149
150 if let Some(relay) = &state.nostr_relay {
151 relay.unregister_client(client_id).await;
152 }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
156 let trimmed = text.trim_start();
157 if trimmed.starts_with('[') {
158 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159 return Some(WsTextMessage::Nostr(msg));
160 }
161 }
162
163 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164 return Some(WsTextMessage::Hashtree(msg));
165 }
166
167 None
168}
169async fn close_upstream_nostr_subscription(
170 state: &AppState,
171 client_id: u64,
172 subscription_id: &SubscriptionId,
173) {
174 let key = (client_id, subscription_id.to_string());
175 let subscription = {
176 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177 subscriptions.remove(&key)
178 };
179 if let Some(subscription) = subscription {
180 let _ = subscription.close_tx.send(true);
181 for task in subscription.tasks {
182 task.abort();
183 }
184 }
185 state
186 .ws_relay
187 .upstream_pending_eose
188 .lock()
189 .await
190 .remove(&key);
191 state
192 .ws_relay
193 .upstream_seen_events
194 .lock()
195 .await
196 .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200 let keys = {
201 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202 subscriptions
203 .keys()
204 .filter(|(id, _)| *id == client_id)
205 .cloned()
206 .collect::<Vec<_>>()
207 };
208 for (_, sub_id) in keys {
209 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210 }
211}
212
213async fn forward_upstream_nostr_message(
214 state: &AppState,
215 client_id: u64,
216 subscription_id: &SubscriptionId,
217 text: &str,
218) {
219 let Ok(message) = NostrRelayMessage::from_json(text) else {
220 return;
221 };
222
223 match message {
224 NostrRelayMessage::Event {
225 subscription_id: sid,
226 event,
227 } if sid == *subscription_id => {
228 let event = *event;
229 let key = (client_id, subscription_id.to_string());
230 let event_id = event.id.to_hex();
231 let inserted = {
232 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233 seen_events.entry(key).or_default().insert(event_id)
234 };
235 if !inserted {
236 return;
237 }
238 if let Some(relay) = &state.nostr_relay {
239 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240 }
241 send_nostr(
242 state,
243 client_id,
244 NostrRelayMessage::event(subscription_id.clone(), event),
245 )
246 .await;
247 }
248 NostrRelayMessage::Closed {
249 subscription_id: sid,
250 message,
251 } if sid == *subscription_id => {
252 send_nostr(
253 state,
254 client_id,
255 NostrRelayMessage::closed(subscription_id.clone(), message),
256 )
257 .await;
258 }
259 _ => {}
260 }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264 state: &AppState,
265 client_id: u64,
266 subscription_id: &SubscriptionId,
267) {
268 let key = (client_id, subscription_id.to_string());
269 let should_send_eose = {
270 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271 let Some(remaining) = pending.get_mut(&key) else {
272 return;
273 };
274 if *remaining > 0 {
275 *remaining -= 1;
276 }
277 if *remaining == 0 {
278 pending.remove(&key);
279 true
280 } else {
281 false
282 }
283 };
284
285 if should_send_eose {
286 send_nostr(
287 state,
288 client_id,
289 NostrRelayMessage::eose(subscription_id.clone()),
290 )
291 .await;
292 }
293}
294
295async fn run_upstream_nostr_subscription(
296 state: AppState,
297 client_id: u64,
298 relay_url: String,
299 subscription_id: SubscriptionId,
300 filters: Vec<NostrFilter>,
301 mut close_rx: watch::Receiver<bool>,
302) {
303 let mut relay_complete = false;
304 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305 tracing::warn!(
306 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307 client_id,
308 subscription_id,
309 relay_url,
310 );
311 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312 return;
313 };
314 let (mut write, mut read) = socket.split();
315 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316 state
317 .ws_relay
318 .note_upstream_relay_send(request.as_bytes().len());
319 if write
320 .send(TungsteniteMessage::Text(request.into()))
321 .await
322 .is_err()
323 {
324 tracing::warn!(
325 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326 client_id,
327 subscription_id,
328 relay_url,
329 );
330 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331 return;
332 }
333
334 loop {
335 tokio::select! {
336 _ = close_rx.changed() => {
337 if *close_rx.borrow() {
338 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339 state
340 .ws_relay
341 .note_upstream_relay_send(close.as_bytes().len());
342 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343 let _ = write.close().await;
344 break;
345 }
346 }
347 message = read.next() => {
348 match message {
349 Some(Ok(TungsteniteMessage::Text(text))) => {
350 state
351 .ws_relay
352 .note_upstream_relay_receive(text.as_bytes().len());
353 if matches!(
354 NostrRelayMessage::from_json(text.as_str()),
355 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
356 ) {
357 if !relay_complete {
358 relay_complete = true;
359 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360 }
361 continue;
362 }
363 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364 }
365 Some(Ok(TungsteniteMessage::Ping(payload))) => {
366 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367 }
368 Some(Ok(TungsteniteMessage::Close(_))) | None => {
369 if !relay_complete {
370 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371 }
372 break;
373 }
374 Some(Err(_)) => {
375 if !relay_complete {
376 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377 }
378 break;
379 }
380 _ => {}
381 }
382 }
383 }
384 }
385}
386
387async fn start_upstream_nostr_subscription(
388 state: &AppState,
389 client_id: u64,
390 subscription_id: SubscriptionId,
391 filters: Vec<NostrFilter>,
392) -> usize {
393 let memory_before = process_memory_snapshot();
394 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396 tracing::info!(
397 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398 client_id,
399 subscription_id,
400 state.nostr_relay_urls.len(),
401 filters.len(),
402 );
403 return 0;
404 }
405
406 let mut relay_urls = Vec::new();
407 let mut seen = HashSet::new();
408 for relay in &state.nostr_relay_urls {
409 let relay = relay.trim();
410 if relay.is_empty() || !seen.insert(relay.to_string()) {
411 continue;
412 }
413 relay_urls.push(relay.to_string());
414 }
415 if relay_urls.is_empty() {
416 tracing::info!(
417 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418 client_id,
419 subscription_id,
420 );
421 return 0;
422 }
423
424 tracing::info!(
425 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426 client_id,
427 subscription_id,
428 relay_urls.len(),
429 );
430
431 let filter_summary = nostr_filters_summary(&filters);
432 let key = (client_id, subscription_id.to_string());
433 state
434 .ws_relay
435 .upstream_seen_events
436 .lock()
437 .await
438 .insert(key.clone(), HashSet::new());
439 state
440 .ws_relay
441 .upstream_pending_eose
442 .lock()
443 .await
444 .insert(key.clone(), relay_urls.len());
445
446 let (close_tx, close_rx) = watch::channel(false);
447 let mut tasks = Vec::new();
448 for relay_url in &relay_urls {
449 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450 state.clone(),
451 client_id,
452 relay_url.clone(),
453 subscription_id.clone(),
454 filters.clone(),
455 close_rx.clone(),
456 )));
457 }
458
459 tracing::info!(
460 target: "hashtree_cli::server::ws_relay::upstream",
461 client_id,
462 subscription_id = %subscription_id,
463 relays = relay_urls.len(),
464 tasks = tasks.len(),
465 filters = filters.len(),
466 filter = %filter_summary,
467 memory_before = ?memory_before,
468 memory_after = ?process_memory_snapshot(),
469 "upstream nostr relay tasks spawned",
470 );
471
472 state
473 .ws_relay
474 .upstream_nostr_subscriptions
475 .lock()
476 .await
477 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478 relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482 match msg {
483 Message::Text(text) => {
484 if let Some(msg) = parse_ws_text_message(&text) {
485 match msg {
486 WsTextMessage::Hashtree(msg) => {
487 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488 match msg {
489 WsClientMessage::Request { id, hash } => {
490 handle_request(
491 client_id,
492 id,
493 hash,
494 WsProtocol::HashtreeJson,
495 state,
496 )
497 .await;
498 }
499 WsClientMessage::Response { id, hash, found } => {
500 handle_response(client_id, id, hash, found, state).await;
501 }
502 }
503 }
504 WsTextMessage::Nostr(msg) => {
505 if let Some(relay) = &state.nostr_relay {
506 match msg {
507 NostrClientMessage::Req {
508 subscription_id,
509 filters,
510 } => {
511 let local_events = match relay
512 .register_subscription_query(
513 client_id,
514 subscription_id.clone(),
515 filters.clone(),
516 )
517 .await
518 {
519 Ok(events) => events,
520 Err(message) => {
521 send_nostr(
522 state,
523 client_id,
524 NostrRelayMessage::closed(subscription_id, message),
525 )
526 .await;
527 return;
528 }
529 };
530
531 let upstream_relays = start_upstream_nostr_subscription(
532 state,
533 client_id,
534 subscription_id.clone(),
535 filters,
536 )
537 .await;
538 if upstream_relays > 0 {
539 let key = (client_id, subscription_id.to_string());
540 let mut seen_events =
541 state.ws_relay.upstream_seen_events.lock().await;
542 seen_events.entry(key).or_default().extend(
543 local_events.iter().map(|event| event.id.to_hex()),
544 );
545 }
546 for event in local_events {
547 send_nostr(
548 state,
549 client_id,
550 NostrRelayMessage::event(
551 subscription_id.clone(),
552 event,
553 ),
554 )
555 .await;
556 }
557 trim_process_allocations();
558 if upstream_relays == 0 {
559 send_nostr(
560 state,
561 client_id,
562 NostrRelayMessage::eose(subscription_id),
563 )
564 .await;
565 }
566 }
567 NostrClientMessage::Close(subscription_id) => {
568 close_upstream_nostr_subscription(
569 state,
570 client_id,
571 &subscription_id,
572 )
573 .await;
574 relay
575 .handle_client_message(
576 client_id,
577 NostrClientMessage::Close(subscription_id.clone()),
578 )
579 .await;
580 }
581 other => {
582 relay.handle_client_message(client_id, other).await;
583 }
584 }
585 } else {
586 handle_nostr_message(client_id, msg, state).await;
587 }
588 }
589 }
590 }
591 }
592 Message::Binary(data) => {
593 handle_binary(client_id, data, state).await;
594 }
595 Message::Close(_) => {}
596 _ => {}
597 }
598}
599
600async fn handle_request(
601 client_id: u64,
602 request_id: u32,
603 hash: String,
604 origin_protocol: WsProtocol,
605 state: &AppState,
606) {
607 let hash_hex = hash.to_lowercase();
608 let hash_bytes = match from_hex(&hash_hex) {
609 Ok(bytes) => bytes,
610 Err(_) => {
611 if origin_protocol == WsProtocol::HashtreeJson {
612 send_json(
613 state,
614 client_id,
615 WsResponse {
616 kind: "res",
617 id: request_id,
618 hash,
619 found: false,
620 },
621 )
622 .await;
623 }
624 return;
625 }
626 };
627
628 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
629 match origin_protocol {
630 WsProtocol::HashtreeJson => {
631 send_json(
632 state,
633 client_id,
634 WsResponse {
635 kind: "res",
636 id: request_id,
637 hash: hash.clone(),
638 found: true,
639 },
640 )
641 .await;
642 send_binary(state, client_id, request_id, data).await;
643 }
644 WsProtocol::HashtreeMsgpack => {
645 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
646 }
647 WsProtocol::Unknown => {}
648 }
649 return;
650 }
651
652 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
653 let clients = state.ws_relay.clients.lock().await;
654 let protocols = state.ws_relay.client_protocols.lock().await;
655 clients
656 .iter()
657 .filter(|(id, _)| **id != client_id)
658 .filter_map(|(id, tx)| {
659 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
660 match protocol {
661 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
662 Some((*id, tx.clone(), protocol))
663 }
664 WsProtocol::Unknown => None,
665 }
666 })
667 .collect()
668 };
669
670 if peers.is_empty() {
671 if origin_protocol == WsProtocol::HashtreeJson {
672 send_json(
673 state,
674 client_id,
675 WsResponse {
676 kind: "res",
677 id: request_id,
678 hash,
679 found: false,
680 },
681 )
682 .await;
683 }
684 return;
685 }
686
687 {
688 let mut pending = state.ws_relay.pending.lock().await;
689 for (peer_id, _, _) in &peers {
690 pending.insert(
691 (*peer_id, request_id),
692 PendingRequest {
693 origin_id: client_id,
694 hash: hash.clone(),
695 found: false,
696 origin_protocol,
697 },
698 );
699 }
700 }
701
702 let request_text = serde_json::to_string(&WsRequest {
703 kind: "req".to_string(),
704 id: request_id,
705 hash: hash.clone(),
706 })
707 .unwrap_or_else(|_| String::new());
708 for (peer_id, tx, protocol) in peers {
709 match protocol {
710 WsProtocol::HashtreeMsgpack => {
711 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
712 }
713 WsProtocol::HashtreeJson => {
714 let _ = tx.send(Message::Text(request_text.clone()));
715 }
716 WsProtocol::Unknown => {}
717 }
718 }
719
720 let timeout_state = state.clone();
721 let timeout_hash = hash.clone();
722 tokio::spawn(async move {
723 tokio::time::sleep(Duration::from_millis(1500)).await;
724 let mut pending = timeout_state.ws_relay.pending.lock().await;
725 let still_pending = pending
726 .iter()
727 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
728 let already_found = pending
729 .iter()
730 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
731 if !still_pending || already_found {
732 return;
733 }
734 let origin_protocol = pending
735 .iter()
736 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
737 .map(|(_, p)| p.origin_protocol)
738 .unwrap_or(WsProtocol::HashtreeJson);
739 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
740 drop(pending);
741 if origin_protocol == WsProtocol::HashtreeJson {
742 send_json(
743 &timeout_state,
744 client_id,
745 WsResponse {
746 kind: "res",
747 id: request_id,
748 hash: timeout_hash,
749 found: false,
750 },
751 )
752 .await;
753 }
754 });
755}
756
757async fn handle_response(
758 client_id: u64,
759 request_id: u32,
760 _hash: String,
761 found: bool,
762 state: &AppState,
763) {
764 let pending_entry = {
765 let pending = state.ws_relay.pending.lock().await;
766 pending
767 .get(&(client_id, request_id))
768 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
769 };
770
771 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
772 return;
773 };
774
775 if already_found && !found {
776 let mut pending = state.ws_relay.pending.lock().await;
777 pending.remove(&(client_id, request_id));
778 return;
779 }
780
781 if found {
782 let mut pending = state.ws_relay.pending.lock().await;
783 for ((_, id), p) in pending.iter_mut() {
784 if *id == request_id && p.origin_id == origin_id {
785 p.found = true;
786 }
787 }
788 drop(pending);
789 if origin_protocol == WsProtocol::HashtreeJson {
790 send_json(
791 state,
792 origin_id,
793 WsResponse {
794 kind: "res",
795 id: request_id,
796 hash: pending_hash,
797 found: true,
798 },
799 )
800 .await;
801 }
802 return;
803 }
804
805 let mut pending = state.ws_relay.pending.lock().await;
806 pending.remove(&(client_id, request_id));
807 let has_remaining = pending
808 .iter()
809 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
810 drop(pending);
811
812 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
813 send_json(
814 state,
815 origin_id,
816 WsResponse {
817 kind: "res",
818 id: request_id,
819 hash: pending_hash,
820 found: false,
821 },
822 )
823 .await;
824 }
825}
826
827async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
828 if let Some(msg) = parse_msgpack_message(&data) {
829 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
830 match msg {
831 DataMessage::Request(req) => {
832 let hash_hex = hex_encode(&req.h);
833 let request_id = state.ws_relay.next_request_id();
834 handle_request(
835 client_id,
836 request_id,
837 hash_hex,
838 WsProtocol::HashtreeMsgpack,
839 state,
840 )
841 .await;
842 }
843 DataMessage::Response(res) => {
844 handle_msgpack_response(client_id, res, state).await;
845 }
846 DataMessage::QuoteRequest(_)
847 | DataMessage::QuoteResponse(_)
848 | DataMessage::Payment(_)
849 | DataMessage::PaymentAck(_)
850 | DataMessage::Chunk(_)
851 | DataMessage::PeerHints(_) => {}
852 }
853 return;
854 }
855
856 if data.len() < 4 {
858 return;
859 }
860 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
861 let pending_entry = {
862 let pending = state.ws_relay.pending.lock().await;
863 pending
864 .get(&(client_id, request_id))
865 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
866 };
867 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
868 return;
869 };
870
871 match origin_protocol {
872 WsProtocol::HashtreeJson => {
873 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
874 }
875 WsProtocol::HashtreeMsgpack => {
876 let Ok(hash_bytes) = from_hex(&hash_hex) else {
877 return;
878 };
879 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
880 }
881 WsProtocol::Unknown => {}
882 }
883
884 let mut pending = state.ws_relay.pending.lock().await;
885 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
886}
887
888async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
889 let replies = nostr_responses_for(&msg);
890 for reply in replies {
891 send_nostr(state, client_id, reply).await;
892 }
893}
894
895fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
896 match msg {
897 NostrClientMessage::Event(event) => {
898 let ok = event.verify().is_ok();
899 let message = if ok { "" } else { "invalid: signature" };
900 vec![NostrRelayMessage::ok(event.id, ok, message)]
901 }
902 NostrClientMessage::Req {
903 subscription_id, ..
904 } => {
905 vec![NostrRelayMessage::eose(subscription_id.clone())]
906 }
907 NostrClientMessage::Count {
908 subscription_id, ..
909 } => {
910 vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
911 }
912 NostrClientMessage::Close(_) => Vec::new(),
913 NostrClientMessage::Auth(event) => {
914 let ok = event.verify().is_ok();
915 let message = if ok { "" } else { "invalid auth" };
916 vec![NostrRelayMessage::ok(event.id, ok, message)]
917 }
918 NostrClientMessage::NegOpen { .. }
919 | NostrClientMessage::NegMsg { .. }
920 | NostrClientMessage::NegClose { .. } => {
921 vec![NostrRelayMessage::notice("negentropy not supported")]
922 }
923 }
924}
925
926async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
927 let text = response.as_json();
928 send_to_client(state, client_id, Message::Text(text)).await;
929}
930
931fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
932 let msg = parse_message(data).ok()?;
933 match msg {
934 DataMessage::Request(req) => {
935 if req.h.len() == 32 {
936 Some(DataMessage::Request(req))
937 } else {
938 None
939 }
940 }
941 DataMessage::Response(res) => {
942 if res.h.len() == 32 {
943 Some(DataMessage::Response(res))
944 } else {
945 None
946 }
947 }
948 DataMessage::QuoteRequest(req) => {
949 if req.h.len() == 32 {
950 Some(DataMessage::QuoteRequest(req))
951 } else {
952 None
953 }
954 }
955 DataMessage::QuoteResponse(res) => {
956 if res.h.len() == 32 {
957 Some(DataMessage::QuoteResponse(res))
958 } else {
959 None
960 }
961 }
962 DataMessage::Payment(req) => {
963 if req.h.len() == 32 {
964 Some(DataMessage::Payment(req))
965 } else {
966 None
967 }
968 }
969 DataMessage::PaymentAck(res) => {
970 if res.h.len() == 32 {
971 Some(DataMessage::PaymentAck(res))
972 } else {
973 None
974 }
975 }
976 DataMessage::Chunk(chunk) => {
977 if chunk.h.len() == 32 {
978 Some(DataMessage::Chunk(chunk))
979 } else {
980 None
981 }
982 }
983 DataMessage::PeerHints(_) => None,
984 }
985}
986
987async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
988 let hash_hex = hex_encode(&res.h);
989 let data = res.d.clone();
990 let hash_bytes = res.h.clone();
991
992 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
993 let mut seen = HashSet::new();
994 {
995 let pending = state.ws_relay.pending.lock().await;
996 for ((peer_id, request_id), p) in pending.iter() {
997 if *peer_id != client_id {
998 continue;
999 }
1000 if p.hash != hash_hex {
1001 continue;
1002 }
1003 if seen.insert((p.origin_id, *request_id)) {
1004 responses.push((p.origin_id, *request_id, p.origin_protocol));
1005 }
1006 }
1007 }
1008
1009 if responses.is_empty() {
1010 return;
1011 }
1012
1013 for (origin_id, request_id, protocol) in &responses {
1014 match protocol {
1015 WsProtocol::HashtreeJson => {
1016 send_json(
1017 state,
1018 *origin_id,
1019 WsResponse {
1020 kind: "res",
1021 id: *request_id,
1022 hash: hash_hex.clone(),
1023 found: true,
1024 },
1025 )
1026 .await;
1027 send_binary(state, *origin_id, *request_id, data.clone()).await;
1028 }
1029 WsProtocol::HashtreeMsgpack => {
1030 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1031 }
1032 WsProtocol::Unknown => {}
1033 }
1034 }
1035
1036 let completed: HashSet<(u64, u32)> = responses
1037 .into_iter()
1038 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1039 .collect();
1040 let mut pending = state.ws_relay.pending.lock().await;
1041 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1042}
1043
1044async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1045 if let Ok(text) = serde_json::to_string(&response) {
1046 send_to_client(state, client_id, Message::Text(text)).await;
1047 }
1048}
1049
1050async fn send_msgpack_request(
1051 state: &AppState,
1052 client_id: u64,
1053 hash: &[u8],
1054) -> Result<(), rmp_serde::encode::Error> {
1055 let req = DataRequest {
1056 h: hash.to_vec(),
1057 htl: MAX_HTL,
1058 q: None,
1059 };
1060 let wire = encode_request(&req)?;
1061 send_to_client(state, client_id, Message::Binary(wire)).await;
1062 Ok(())
1063}
1064
1065async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1066 let res = DataResponse {
1067 h: hash.to_vec(),
1068 d: data.to_vec(),
1069 i: None,
1070 n: None,
1071 };
1072 if let Ok(wire) = encode_response(&res) {
1073 send_to_client(state, client_id, Message::Binary(wire)).await;
1074 }
1075}
1076
1077async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1078 let mut packet = Vec::with_capacity(4 + payload.len());
1079 packet.extend_from_slice(&request_id.to_le_bytes());
1080 packet.extend_from_slice(&payload);
1081 send_to_client(state, client_id, Message::Binary(packet)).await;
1082}
1083
1084async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1085 let sender = {
1086 let clients = state.ws_relay.clients.lock().await;
1087 clients.get(&client_id).cloned()
1088 };
1089 if let Some(tx) = sender {
1090 let _ = tx.send(msg);
1091 }
1092}
1093
1094async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1095 let mut protocols = state.ws_relay.client_protocols.lock().await;
1096 protocols.insert(client_id, protocol);
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101 use super::*;
1102 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1103 use anyhow::Result;
1104 use futures::{SinkExt, StreamExt};
1105 use nostr::secp256k1::schnorr::Signature;
1106 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1107 use std::collections::HashSet;
1108 use std::sync::Arc;
1109 use tempfile::TempDir;
1110 use tokio::net::TcpListener;
1111 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1112
1113 #[test]
1114 fn parse_ws_text_message_detects_nostr_req() {
1115 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1116 match parse_ws_text_message(msg) {
1117 Some(WsTextMessage::Nostr(_)) => {}
1118 other => panic!("expected Nostr message, got {:?}", other),
1119 }
1120 }
1121
1122 #[test]
1123 fn parse_ws_text_message_detects_hashtree_request() {
1124 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1125 match parse_ws_text_message(msg) {
1126 Some(WsTextMessage::Hashtree(_)) => {}
1127 other => panic!("expected Hashtree message, got {:?}", other),
1128 }
1129 }
1130
1131 #[test]
1132 fn nostr_replies_for_req_is_eose() {
1133 let sub = SubscriptionId::new("sub-1");
1134 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1135 let replies = nostr_responses_for(&msg);
1136 assert_eq!(replies.len(), 1);
1137 match &replies[0] {
1138 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1139 other => panic!("expected EOSE, got {:?}", other),
1140 }
1141 }
1142
1143 #[test]
1144 fn nostr_replies_for_event_ok() {
1145 let keys = Keys::generate();
1146 let event = EventBuilder::new(Kind::TextNote, "hello", [])
1147 .to_event(&keys)
1148 .unwrap();
1149 let msg = NostrClientMessage::event(event.clone());
1150 let replies = nostr_responses_for(&msg);
1151 assert_eq!(replies.len(), 1);
1152 match &replies[0] {
1153 NostrRelayMessage::Ok {
1154 event_id, status, ..
1155 } => {
1156 assert_eq!(event_id, &event.id);
1157 assert!(*status);
1158 }
1159 other => panic!("expected OK, got {:?}", other),
1160 }
1161 }
1162
1163 #[test]
1164 fn nostr_replies_for_invalid_event_is_not_ok() {
1165 let keys = Keys::generate();
1166 let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1167 .to_event(&keys)
1168 .unwrap();
1169 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1170 let msg = NostrClientMessage::event(event);
1171 let replies = nostr_responses_for(&msg);
1172 assert_eq!(replies.len(), 1);
1173 match &replies[0] {
1174 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1175 other => panic!("expected OK=false, got {:?}", other),
1176 }
1177 }
1178
1179 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1180 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1181 let addr = listener.local_addr().expect("relay addr");
1182 tokio::spawn(async move {
1183 let (stream, _) = listener.accept().await.expect("accept relay");
1184 let ws = accept_async(stream).await.expect("accept websocket");
1185 let (mut write, mut read) = ws.split();
1186
1187 while let Some(Ok(message)) = read.next().await {
1188 let TungsteniteMessage::Text(text) = message else {
1189 continue;
1190 };
1191 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1192 continue;
1193 };
1194 if let NostrClientMessage::Req {
1195 subscription_id,
1196 filters,
1197 } = parsed
1198 {
1199 for event in events
1200 .iter()
1201 .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1202 {
1203 let _ = write
1204 .send(TungsteniteMessage::Text(
1205 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1206 .as_json()
1207 .into(),
1208 ))
1209 .await;
1210 }
1211 let _ = write
1212 .send(TungsteniteMessage::Text(
1213 NostrRelayMessage::eose(subscription_id).as_json().into(),
1214 ))
1215 .await;
1216 }
1217 }
1218 });
1219 format!("ws://{}", addr)
1220 }
1221
1222 fn test_app_state(
1223 tmp: &TempDir,
1224 relay: Arc<NostrRelay>,
1225 relay_url: String,
1226 ) -> Result<AppState> {
1227 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1228 tmp.path(),
1229 None,
1230 128 * 1024 * 1024,
1231 )?);
1232 Ok(AppState {
1233 store,
1234 auth: None,
1235 peer_mode: crate::config::ServerMode::Normal,
1236 hash_get_enabled: true,
1237 webrtc_peers: None,
1238 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1239 max_upload_bytes: 5 * 1024 * 1024,
1240 public_writes: true,
1241 allowed_pubkeys: HashSet::new(),
1242 upstream_blossom: Vec::new(),
1243 social_graph: None,
1244 social_graph_store: None,
1245 social_graph_root: None,
1246 socialgraph_snapshot_public: false,
1247 nostr_relay: Some(relay),
1248 nostr_relay_urls: vec![relay_url],
1249 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1250 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1251 std::collections::HashMap::new(),
1252 )),
1253 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1254 super::super::auth::new_lookup_cache(),
1255 )),
1256 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1257 super::super::auth::new_lookup_cache(),
1258 )),
1259 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1260 super::super::auth::new_lookup_cache(),
1261 )),
1262 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1263 })
1264 }
1265
1266 #[tokio::test]
1267 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1268 let tmp = TempDir::new()?;
1269 let graph_store = {
1270 let _guard = crate::socialgraph::test_lock();
1271 crate::socialgraph::open_social_graph_store_with_mapsize(
1272 tmp.path(),
1273 Some(128 * 1024 * 1024),
1274 )?
1275 };
1276 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1277 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1278 Arc::clone(&backend),
1279 0,
1280 HashSet::new(),
1281 ));
1282
1283 let keys = Keys::generate();
1284 let relay = Arc::new(NostrRelay::new(
1285 Arc::clone(&backend),
1286 tmp.path().to_path_buf(),
1287 HashSet::from([keys.public_key().to_hex()]),
1288 Some(access),
1289 NostrRelayConfig {
1290 spambox_db_max_bytes: 0,
1291 ..Default::default()
1292 },
1293 )?);
1294
1295 let event = EventBuilder::new(
1296 Kind::from(30078_u16),
1297 "",
1298 [
1299 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1300 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1301 ],
1302 )
1303 .to_event(&keys)?;
1304
1305 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1306 let filter = Filter::new()
1307 .authors(vec![event.pubkey])
1308 .kinds(vec![event.kind]);
1309 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1310 let client_id = 7_u64;
1311 let (tx, mut rx) = mpsc::unbounded_channel();
1312 state.ws_relay.clients.lock().await.insert(client_id, tx);
1313 let subscription_id = SubscriptionId::new("sub-1");
1314
1315 start_upstream_nostr_subscription(
1316 &state,
1317 client_id,
1318 subscription_id.clone(),
1319 vec![filter.clone()],
1320 )
1321 .await;
1322
1323 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1324 .await?
1325 .expect("forwarded upstream event");
1326 let Message::Text(text) = forwarded else {
1327 panic!("expected text event");
1328 };
1329 match NostrRelayMessage::from_json(text.as_str())? {
1330 NostrRelayMessage::Event {
1331 subscription_id: sid,
1332 event: forwarded_event,
1333 } => {
1334 assert_eq!(sid, subscription_id);
1335 assert_eq!(forwarded_event.id, event.id);
1336 }
1337 other => panic!("expected forwarded EVENT, got {:?}", other),
1338 }
1339
1340 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1341 .await?
1342 .expect("forwarded upstream eose");
1343 let Message::Text(eose_text) = eose else {
1344 panic!("expected text eose");
1345 };
1346 match NostrRelayMessage::from_json(eose_text.as_str())? {
1347 NostrRelayMessage::EndOfStoredEvents(sid) => {
1348 assert_eq!(sid, subscription_id);
1349 }
1350 other => panic!("expected forwarded EOSE, got {:?}", other),
1351 }
1352
1353 let events = relay.query_events(&filter, 10).await;
1354 assert_eq!(events.len(), 1);
1355 assert_eq!(events[0].id, event.id);
1356
1357 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1358 assert!(state
1359 .ws_relay
1360 .upstream_nostr_subscriptions
1361 .lock()
1362 .await
1363 .is_empty());
1364 Ok(())
1365 }
1366
1367 #[tokio::test]
1368 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1369 let tmp = TempDir::new()?;
1370 let graph_store = {
1371 let _guard = crate::socialgraph::test_lock();
1372 crate::socialgraph::open_social_graph_store_with_mapsize(
1373 tmp.path(),
1374 Some(128 * 1024 * 1024),
1375 )?
1376 };
1377 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1378 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1379 Arc::clone(&backend),
1380 0,
1381 HashSet::new(),
1382 ));
1383
1384 let keys = Keys::generate();
1385 let relay = Arc::new(NostrRelay::new(
1386 Arc::clone(&backend),
1387 tmp.path().to_path_buf(),
1388 HashSet::from([keys.public_key().to_hex()]),
1389 Some(access),
1390 NostrRelayConfig {
1391 spambox_db_max_bytes: 0,
1392 ..Default::default()
1393 },
1394 )?);
1395
1396 let event = EventBuilder::new(
1397 Kind::from(30078_u16),
1398 "",
1399 [
1400 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1401 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1402 ],
1403 )
1404 .to_event(&keys)?;
1405
1406 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1407 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1408 let client_id = 11_u64;
1409 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1410 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1411 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1412 relay.register_client(client_id, relay_tx, None).await;
1413
1414 let request = NostrClientMessage::req(
1415 SubscriptionId::new("feed"),
1416 vec![Filter::new()
1417 .authors(vec![event.pubkey])
1418 .kinds(vec![event.kind])],
1419 )
1420 .as_json();
1421
1422 handle_message(client_id, Message::Text(request.into()), &state).await;
1423
1424 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1425 .await?
1426 .expect("first forwarded message");
1427 let Message::Text(first_text) = first else {
1428 panic!("expected text event");
1429 };
1430 match NostrRelayMessage::from_json(first_text.as_str())? {
1431 NostrRelayMessage::Event {
1432 event: forwarded_event,
1433 ..
1434 } => {
1435 assert_eq!(forwarded_event.id, event.id);
1436 }
1437 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1438 }
1439
1440 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1441 .await?
1442 .expect("second forwarded message");
1443 let Message::Text(second_text) = second else {
1444 panic!("expected text eose");
1445 };
1446 match NostrRelayMessage::from_json(second_text.as_str())? {
1447 NostrRelayMessage::EndOfStoredEvents(sid) => {
1448 assert_eq!(sid, SubscriptionId::new("feed"));
1449 }
1450 other => panic!("expected aggregated EOSE, got {:?}", other),
1451 }
1452
1453 Ok(())
1454 }
1455
1456 #[tokio::test]
1457 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1458 let tmp = TempDir::new()?;
1459 let graph_store = {
1460 let _guard = crate::socialgraph::test_lock();
1461 crate::socialgraph::open_social_graph_store_with_mapsize(
1462 tmp.path(),
1463 Some(128 * 1024 * 1024),
1464 )?
1465 };
1466 let author_keys = Keys::generate();
1467 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1468 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1469 Arc::clone(&backend),
1470 0,
1471 HashSet::from([author_keys.public_key().to_hex()]),
1472 ));
1473 let relay = Arc::new(NostrRelay::new(
1474 Arc::clone(&backend),
1475 tmp.path().to_path_buf(),
1476 HashSet::from([author_keys.public_key().to_hex()]),
1477 Some(access),
1478 NostrRelayConfig {
1479 spambox_db_max_bytes: 0,
1480 ..Default::default()
1481 },
1482 )?);
1483
1484 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1485 let listener = TcpListener::bind("127.0.0.1:0").await?;
1486 let addr = listener.local_addr()?;
1487 let client_pubkey = author_keys.public_key().to_hex();
1488 let app = axum::Router::new().route(
1489 "/ws",
1490 axum::routing::get({
1491 let state = state.clone();
1492 let client_pubkey = client_pubkey.clone();
1493 move |ws: WebSocketUpgrade| {
1494 let state = state.clone();
1495 let client_pubkey = client_pubkey.clone();
1496 async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1497 }
1498 }),
1499 );
1500 tokio::spawn(async move {
1501 let _ = axum::serve(listener, app).await;
1502 });
1503
1504 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1505 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1506 .to_event(&author_keys)?;
1507 socket
1508 .send(TungsteniteMessage::Text(
1509 NostrClientMessage::event(event.clone()).as_json().into(),
1510 ))
1511 .await?;
1512
1513 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1514 .await?
1515 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1516 let TungsteniteMessage::Text(text) = reply else {
1517 anyhow::bail!("expected text publish ack");
1518 };
1519
1520 match NostrRelayMessage::from_json(text.as_str())? {
1521 NostrRelayMessage::Ok {
1522 event_id, status, ..
1523 } => {
1524 assert_eq!(event_id, event.id);
1525 assert!(status);
1526 }
1527 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1528 }
1529
1530 let stored = relay
1531 .query_events(
1532 &Filter::new()
1533 .authors(vec![event.pubkey])
1534 .kinds(vec![event.kind]),
1535 10,
1536 )
1537 .await;
1538 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1539 Ok(())
1540 }
1541
1542 #[tokio::test]
1543 async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1544 let tmp = TempDir::new()?;
1545 let graph_store = {
1546 let _guard = crate::socialgraph::test_lock();
1547 crate::socialgraph::open_social_graph_store_with_mapsize(
1548 tmp.path(),
1549 Some(128 * 1024 * 1024),
1550 )?
1551 };
1552 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1553 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1554 Arc::clone(&backend),
1555 0,
1556 HashSet::new(),
1557 ));
1558 let relay = Arc::new(NostrRelay::new(
1559 Arc::clone(&backend),
1560 tmp.path().to_path_buf(),
1561 HashSet::new(),
1562 Some(access),
1563 NostrRelayConfig {
1564 spambox_db_max_bytes: 0,
1565 spambox_max_reqs_per_min: 1,
1566 ..Default::default()
1567 },
1568 )?);
1569
1570 let state = test_app_state(&tmp, relay, String::new())?;
1571 let listener = TcpListener::bind("127.0.0.1:0").await?;
1572 let addr = listener.local_addr()?;
1573 let app = axum::Router::new().route(
1574 "/ws",
1575 axum::routing::get({
1576 let state = state.clone();
1577 move |ws: WebSocketUpgrade| {
1578 let state = state.clone();
1579 async move { ws_data_with_client_pubkey(state, ws, None) }
1580 }
1581 }),
1582 );
1583 tokio::spawn(async move {
1584 let _ = axum::serve(listener, app).await;
1585 });
1586
1587 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1588 socket
1589 .send(TungsteniteMessage::Text(
1590 NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1591 .as_json()
1592 .into(),
1593 ))
1594 .await?;
1595
1596 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1597 .await?
1598 .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1599 let TungsteniteMessage::Text(first_text) = first else {
1600 anyhow::bail!("expected text EOSE reply");
1601 };
1602 match NostrRelayMessage::from_json(first_text.as_str())? {
1603 NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1604 assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1605 }
1606 other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1607 }
1608
1609 socket
1610 .send(TungsteniteMessage::Text(
1611 NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1612 .as_json()
1613 .into(),
1614 ))
1615 .await?;
1616
1617 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1618 .await?
1619 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1620 let TungsteniteMessage::Text(second_text) = second else {
1621 anyhow::bail!("expected text CLOSED reply");
1622 };
1623 let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1624 assert_eq!(
1625 second_value,
1626 serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1627 );
1628
1629 Ok(())
1630 }
1631
1632 #[tokio::test]
1633 async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1634 let tmp = TempDir::new()?;
1635 let graph_store = {
1636 let _guard = crate::socialgraph::test_lock();
1637 crate::socialgraph::open_social_graph_store_with_mapsize(
1638 tmp.path(),
1639 Some(128 * 1024 * 1024),
1640 )?
1641 };
1642 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1643 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1644 Arc::clone(&backend),
1645 0,
1646 HashSet::new(),
1647 ));
1648 let relay = Arc::new(NostrRelay::new(
1649 Arc::clone(&backend),
1650 tmp.path().to_path_buf(),
1651 HashSet::new(),
1652 Some(access),
1653 NostrRelayConfig {
1654 spambox_db_max_bytes: 0,
1655 spambox_max_events_per_min: 1,
1656 ..Default::default()
1657 },
1658 )?);
1659
1660 let state = test_app_state(&tmp, relay, String::new())?;
1661 let listener = TcpListener::bind("127.0.0.1:0").await?;
1662 let addr = listener.local_addr()?;
1663 let app = axum::Router::new().route(
1664 "/ws",
1665 axum::routing::get({
1666 let state = state.clone();
1667 move |ws: WebSocketUpgrade| {
1668 let state = state.clone();
1669 async move { ws_data_with_client_pubkey(state, ws, None) }
1670 }
1671 }),
1672 );
1673 tokio::spawn(async move {
1674 let _ = axum::serve(listener, app).await;
1675 });
1676
1677 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1678 let author_keys = Keys::generate();
1679 let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1680 let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1681
1682 socket
1683 .send(TungsteniteMessage::Text(
1684 NostrClientMessage::event(event_a.clone()).as_json().into(),
1685 ))
1686 .await?;
1687
1688 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1689 .await?
1690 .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1691 let TungsteniteMessage::Text(first_text) = first else {
1692 anyhow::bail!("expected text publish ack");
1693 };
1694 match NostrRelayMessage::from_json(first_text.as_str())? {
1695 NostrRelayMessage::Ok {
1696 event_id,
1697 status,
1698 message,
1699 } => {
1700 assert_eq!(event_id, event_a.id);
1701 assert!(status);
1702 assert_eq!(message, "spambox");
1703 }
1704 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1705 }
1706
1707 socket
1708 .send(TungsteniteMessage::Text(
1709 NostrClientMessage::event(event_b.clone()).as_json().into(),
1710 ))
1711 .await?;
1712
1713 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1714 .await?
1715 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1716 let TungsteniteMessage::Text(second_text) = second else {
1717 anyhow::bail!("expected text publish ack");
1718 };
1719 match NostrRelayMessage::from_json(second_text.as_str())? {
1720 NostrRelayMessage::Ok {
1721 event_id,
1722 status,
1723 message,
1724 } => {
1725 assert_eq!(event_id, event_b.id);
1726 assert!(!status);
1727 assert_eq!(message, "rate limited");
1728 }
1729 other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1730 }
1731
1732 Ok(())
1733 }
1734}