1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21 nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31 #[serde(rename = "req")]
32 Request { id: u32, hash: String },
33 #[serde(rename = "res")]
34 Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage {
39 Hashtree(WsClientMessage),
40 Nostr(NostrClientMessage),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45 #[serde(rename = "type")]
46 kind: String,
47 id: u32,
48 hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53 #[serde(rename = "type")]
54 kind: &'static str,
55 id: u32,
56 hash: String,
57 found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61 ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65 state: AppState,
66 ws: WebSocketUpgrade,
67 client_pubkey: Option<String>,
68) -> impl IntoResponse {
69 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73 let client_id = state
76 .nostr_relay
77 .as_ref()
78 .map(|relay| relay.next_client_id())
79 .unwrap_or_else(|| state.ws_relay.next_id());
80 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82 {
83 let mut clients = state.ws_relay.clients.lock().await;
84 clients.insert(client_id, tx);
85 }
86 {
87 let mut protocols = state.ws_relay.client_protocols.lock().await;
88 protocols.insert(client_id, WsProtocol::HashtreeJson);
89 }
90
91 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93 relay
94 .register_client(client_id, nostr_tx, client_pubkey.clone())
95 .await;
96 Some(nostr_rx)
97 } else {
98 None
99 };
100
101 let (mut sender, mut receiver) = socket.split();
102 loop {
103 tokio::select! {
104 maybe_msg = rx.recv() => {
105 let Some(msg) = maybe_msg else {
106 break;
107 };
108 if sender.send(msg).await.is_err() {
109 break;
110 }
111 }
112 maybe_text = async {
113 match &mut nostr_rx {
114 Some(rx) => rx.recv().await,
115 None => std::future::pending().await,
116 }
117 } => {
118 let Some(text) = maybe_text else {
119 nostr_rx = None;
120 continue;
121 };
122 if sender.send(Message::Text(text)).await.is_err() {
123 break;
124 }
125 }
126 maybe_incoming = receiver.next() => {
127 match maybe_incoming {
128 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129 Some(Err(_)) | None => break,
130 }
131 }
132 }
133 }
134
135 close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137 {
138 let mut clients = state.ws_relay.clients.lock().await;
139 clients.remove(&client_id);
140 }
141 {
142 let mut protocols = state.ws_relay.client_protocols.lock().await;
143 protocols.remove(&client_id);
144 }
145 {
146 let mut pending = state.ws_relay.pending.lock().await;
147 pending.retain(|(peer_id, _), _| *peer_id != client_id);
148 }
149
150 if let Some(relay) = &state.nostr_relay {
151 relay.unregister_client(client_id).await;
152 }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
156 let trimmed = text.trim_start();
157 if trimmed.starts_with('[') {
158 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159 return Some(WsTextMessage::Nostr(msg));
160 }
161 }
162
163 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164 return Some(WsTextMessage::Hashtree(msg));
165 }
166
167 None
168}
169async fn close_upstream_nostr_subscription(
170 state: &AppState,
171 client_id: u64,
172 subscription_id: &SubscriptionId,
173) {
174 let key = (client_id, subscription_id.to_string());
175 let subscription = {
176 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177 subscriptions.remove(&key)
178 };
179 if let Some(subscription) = subscription {
180 let _ = subscription.close_tx.send(true);
181 for task in subscription.tasks {
182 task.abort();
183 }
184 }
185 state
186 .ws_relay
187 .upstream_pending_eose
188 .lock()
189 .await
190 .remove(&key);
191 state
192 .ws_relay
193 .upstream_seen_events
194 .lock()
195 .await
196 .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200 let keys = {
201 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202 subscriptions
203 .keys()
204 .filter(|(id, _)| *id == client_id)
205 .cloned()
206 .collect::<Vec<_>>()
207 };
208 for (_, sub_id) in keys {
209 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210 }
211}
212
213async fn forward_upstream_nostr_message(
214 state: &AppState,
215 client_id: u64,
216 subscription_id: &SubscriptionId,
217 text: &str,
218) {
219 let Ok(message) = NostrRelayMessage::from_json(text) else {
220 return;
221 };
222
223 match message {
224 NostrRelayMessage::Event {
225 subscription_id: sid,
226 event,
227 } if sid == *subscription_id => {
228 let event = *event;
229 let key = (client_id, subscription_id.to_string());
230 let event_id = event.id.to_hex();
231 let inserted = {
232 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233 seen_events.entry(key).or_default().insert(event_id)
234 };
235 if !inserted {
236 return;
237 }
238 if let Some(relay) = &state.nostr_relay {
239 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240 }
241 send_nostr(
242 state,
243 client_id,
244 NostrRelayMessage::event(subscription_id.clone(), event),
245 )
246 .await;
247 }
248 NostrRelayMessage::Closed {
249 subscription_id: sid,
250 message,
251 } if sid == *subscription_id => {
252 send_nostr(
253 state,
254 client_id,
255 NostrRelayMessage::closed(subscription_id.clone(), message),
256 )
257 .await;
258 }
259 _ => {}
260 }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264 state: &AppState,
265 client_id: u64,
266 subscription_id: &SubscriptionId,
267) {
268 let key = (client_id, subscription_id.to_string());
269 let should_send_eose = {
270 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271 let Some(remaining) = pending.get_mut(&key) else {
272 return;
273 };
274 if *remaining > 0 {
275 *remaining -= 1;
276 }
277 if *remaining == 0 {
278 pending.remove(&key);
279 true
280 } else {
281 false
282 }
283 };
284
285 if should_send_eose {
286 send_nostr(
287 state,
288 client_id,
289 NostrRelayMessage::eose(subscription_id.clone()),
290 )
291 .await;
292 }
293}
294
295async fn run_upstream_nostr_subscription(
296 state: AppState,
297 client_id: u64,
298 relay_url: String,
299 subscription_id: SubscriptionId,
300 filters: Vec<NostrFilter>,
301 mut close_rx: watch::Receiver<bool>,
302) {
303 let mut relay_complete = false;
304 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305 tracing::warn!(
306 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307 client_id,
308 subscription_id,
309 relay_url,
310 );
311 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312 return;
313 };
314 let (mut write, mut read) = socket.split();
315 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316 state
317 .ws_relay
318 .note_upstream_relay_send(request.as_bytes().len());
319 if write
320 .send(TungsteniteMessage::Text(request.into()))
321 .await
322 .is_err()
323 {
324 tracing::warn!(
325 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326 client_id,
327 subscription_id,
328 relay_url,
329 );
330 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331 return;
332 }
333
334 loop {
335 tokio::select! {
336 _ = close_rx.changed() => {
337 if *close_rx.borrow() {
338 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339 state
340 .ws_relay
341 .note_upstream_relay_send(close.as_bytes().len());
342 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343 let _ = write.close().await;
344 break;
345 }
346 }
347 message = read.next() => {
348 match message {
349 Some(Ok(TungsteniteMessage::Text(text))) => {
350 state
351 .ws_relay
352 .note_upstream_relay_receive(text.as_bytes().len());
353 if matches!(
354 NostrRelayMessage::from_json(text.as_str()),
355 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
356 ) {
357 if !relay_complete {
358 relay_complete = true;
359 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360 }
361 continue;
362 }
363 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364 }
365 Some(Ok(TungsteniteMessage::Ping(payload))) => {
366 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367 }
368 Some(Ok(TungsteniteMessage::Close(_))) | None => {
369 if !relay_complete {
370 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371 }
372 break;
373 }
374 Some(Err(_)) => {
375 if !relay_complete {
376 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377 }
378 break;
379 }
380 _ => {}
381 }
382 }
383 }
384 }
385}
386
387async fn start_upstream_nostr_subscription(
388 state: &AppState,
389 client_id: u64,
390 subscription_id: SubscriptionId,
391 filters: Vec<NostrFilter>,
392) -> usize {
393 let memory_before = process_memory_snapshot();
394 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396 tracing::info!(
397 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398 client_id,
399 subscription_id,
400 state.nostr_relay_urls.len(),
401 filters.len(),
402 );
403 return 0;
404 }
405
406 let mut relay_urls = Vec::new();
407 let mut seen = HashSet::new();
408 for relay in &state.nostr_relay_urls {
409 let relay = relay.trim();
410 if relay.is_empty() || !seen.insert(relay.to_string()) {
411 continue;
412 }
413 relay_urls.push(relay.to_string());
414 }
415 if relay_urls.is_empty() {
416 tracing::info!(
417 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418 client_id,
419 subscription_id,
420 );
421 return 0;
422 }
423
424 tracing::info!(
425 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426 client_id,
427 subscription_id,
428 relay_urls.len(),
429 );
430
431 let filter_summary = nostr_filters_summary(&filters);
432 let key = (client_id, subscription_id.to_string());
433 state
434 .ws_relay
435 .upstream_seen_events
436 .lock()
437 .await
438 .insert(key.clone(), HashSet::new());
439 state
440 .ws_relay
441 .upstream_pending_eose
442 .lock()
443 .await
444 .insert(key.clone(), relay_urls.len());
445
446 let (close_tx, close_rx) = watch::channel(false);
447 let mut tasks = Vec::new();
448 for relay_url in &relay_urls {
449 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450 state.clone(),
451 client_id,
452 relay_url.clone(),
453 subscription_id.clone(),
454 filters.clone(),
455 close_rx.clone(),
456 )));
457 }
458
459 tracing::info!(
460 target: "hashtree_cli::server::ws_relay::upstream",
461 client_id,
462 subscription_id = %subscription_id,
463 relays = relay_urls.len(),
464 tasks = tasks.len(),
465 filters = filters.len(),
466 filter = %filter_summary,
467 memory_before = ?memory_before,
468 memory_after = ?process_memory_snapshot(),
469 "upstream nostr relay tasks spawned",
470 );
471
472 state
473 .ws_relay
474 .upstream_nostr_subscriptions
475 .lock()
476 .await
477 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478 relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482 match msg {
483 Message::Text(text) => {
484 if let Some(msg) = parse_ws_text_message(&text) {
485 match msg {
486 WsTextMessage::Hashtree(msg) => {
487 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488 match msg {
489 WsClientMessage::Request { id, hash } => {
490 handle_request(
491 client_id,
492 id,
493 hash,
494 WsProtocol::HashtreeJson,
495 state,
496 )
497 .await;
498 }
499 WsClientMessage::Response { id, hash, found } => {
500 handle_response(client_id, id, hash, found, state).await;
501 }
502 }
503 }
504 WsTextMessage::Nostr(msg) => {
505 if let Some(relay) = &state.nostr_relay {
506 match msg {
507 NostrClientMessage::Req {
508 subscription_id,
509 filters,
510 } => {
511 let local_events = match relay
512 .register_subscription_query(
513 client_id,
514 subscription_id.clone(),
515 filters.clone(),
516 )
517 .await
518 {
519 Ok(events) => events,
520 Err(message) => {
521 send_nostr(
522 state,
523 client_id,
524 NostrRelayMessage::closed(subscription_id, message),
525 )
526 .await;
527 return;
528 }
529 };
530
531 let upstream_relays = start_upstream_nostr_subscription(
532 state,
533 client_id,
534 subscription_id.clone(),
535 filters,
536 )
537 .await;
538 if upstream_relays > 0 {
539 let key = (client_id, subscription_id.to_string());
540 let mut seen_events =
541 state.ws_relay.upstream_seen_events.lock().await;
542 seen_events.entry(key).or_default().extend(
543 local_events.iter().map(|event| event.id.to_hex()),
544 );
545 }
546 for event in local_events {
547 send_nostr(
548 state,
549 client_id,
550 NostrRelayMessage::event(
551 subscription_id.clone(),
552 event,
553 ),
554 )
555 .await;
556 }
557 trim_process_allocations();
558 if upstream_relays == 0 {
559 send_nostr(
560 state,
561 client_id,
562 NostrRelayMessage::eose(subscription_id),
563 )
564 .await;
565 }
566 }
567 NostrClientMessage::Close(subscription_id) => {
568 close_upstream_nostr_subscription(
569 state,
570 client_id,
571 &subscription_id,
572 )
573 .await;
574 relay
575 .handle_client_message(
576 client_id,
577 NostrClientMessage::Close(subscription_id.clone()),
578 )
579 .await;
580 }
581 other => {
582 relay.handle_client_message(client_id, other).await;
583 }
584 }
585 } else {
586 handle_nostr_message(client_id, msg, state).await;
587 }
588 }
589 }
590 }
591 }
592 Message::Binary(data) => {
593 handle_binary(client_id, data, state).await;
594 }
595 Message::Close(_) => {}
596 _ => {}
597 }
598}
599
600async fn handle_request(
601 client_id: u64,
602 request_id: u32,
603 hash: String,
604 origin_protocol: WsProtocol,
605 state: &AppState,
606) {
607 let hash_hex = hash.to_lowercase();
608 let hash_bytes = match from_hex(&hash_hex) {
609 Ok(bytes) => bytes,
610 Err(_) => {
611 if origin_protocol == WsProtocol::HashtreeJson {
612 send_json(
613 state,
614 client_id,
615 WsResponse {
616 kind: "res",
617 id: request_id,
618 hash,
619 found: false,
620 },
621 )
622 .await;
623 }
624 return;
625 }
626 };
627
628 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
629 match origin_protocol {
630 WsProtocol::HashtreeJson => {
631 send_json(
632 state,
633 client_id,
634 WsResponse {
635 kind: "res",
636 id: request_id,
637 hash: hash.clone(),
638 found: true,
639 },
640 )
641 .await;
642 send_binary(state, client_id, request_id, data).await;
643 }
644 WsProtocol::HashtreeMsgpack => {
645 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
646 }
647 WsProtocol::Unknown => {}
648 }
649 return;
650 }
651
652 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
653 let clients = state.ws_relay.clients.lock().await;
654 let protocols = state.ws_relay.client_protocols.lock().await;
655 clients
656 .iter()
657 .filter(|(id, _)| **id != client_id)
658 .filter_map(|(id, tx)| {
659 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
660 match protocol {
661 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
662 Some((*id, tx.clone(), protocol))
663 }
664 WsProtocol::Unknown => None,
665 }
666 })
667 .collect()
668 };
669
670 if peers.is_empty() {
671 if origin_protocol == WsProtocol::HashtreeJson {
672 send_json(
673 state,
674 client_id,
675 WsResponse {
676 kind: "res",
677 id: request_id,
678 hash,
679 found: false,
680 },
681 )
682 .await;
683 }
684 return;
685 }
686
687 {
688 let mut pending = state.ws_relay.pending.lock().await;
689 for (peer_id, _, _) in &peers {
690 pending.insert(
691 (*peer_id, request_id),
692 PendingRequest {
693 origin_id: client_id,
694 hash: hash.clone(),
695 found: false,
696 origin_protocol,
697 },
698 );
699 }
700 }
701
702 let request_text = serde_json::to_string(&WsRequest {
703 kind: "req".to_string(),
704 id: request_id,
705 hash: hash.clone(),
706 })
707 .unwrap_or_else(|_| String::new());
708 for (peer_id, tx, protocol) in peers {
709 match protocol {
710 WsProtocol::HashtreeMsgpack => {
711 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
712 }
713 WsProtocol::HashtreeJson => {
714 let _ = tx.send(Message::Text(request_text.clone()));
715 }
716 WsProtocol::Unknown => {}
717 }
718 }
719
720 let timeout_state = state.clone();
721 let timeout_hash = hash.clone();
722 tokio::spawn(async move {
723 tokio::time::sleep(Duration::from_millis(1500)).await;
724 let mut pending = timeout_state.ws_relay.pending.lock().await;
725 let still_pending = pending
726 .iter()
727 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
728 let already_found = pending
729 .iter()
730 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
731 if !still_pending || already_found {
732 return;
733 }
734 let origin_protocol = pending
735 .iter()
736 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
737 .map(|(_, p)| p.origin_protocol)
738 .unwrap_or(WsProtocol::HashtreeJson);
739 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
740 drop(pending);
741 if origin_protocol == WsProtocol::HashtreeJson {
742 send_json(
743 &timeout_state,
744 client_id,
745 WsResponse {
746 kind: "res",
747 id: request_id,
748 hash: timeout_hash,
749 found: false,
750 },
751 )
752 .await;
753 }
754 });
755}
756
757async fn handle_response(
758 client_id: u64,
759 request_id: u32,
760 _hash: String,
761 found: bool,
762 state: &AppState,
763) {
764 let pending_entry = {
765 let pending = state.ws_relay.pending.lock().await;
766 pending
767 .get(&(client_id, request_id))
768 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
769 };
770
771 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
772 return;
773 };
774
775 if already_found && !found {
776 let mut pending = state.ws_relay.pending.lock().await;
777 pending.remove(&(client_id, request_id));
778 return;
779 }
780
781 if found {
782 let mut pending = state.ws_relay.pending.lock().await;
783 for ((_, id), p) in pending.iter_mut() {
784 if *id == request_id && p.origin_id == origin_id {
785 p.found = true;
786 }
787 }
788 drop(pending);
789 if origin_protocol == WsProtocol::HashtreeJson {
790 send_json(
791 state,
792 origin_id,
793 WsResponse {
794 kind: "res",
795 id: request_id,
796 hash: pending_hash,
797 found: true,
798 },
799 )
800 .await;
801 }
802 return;
803 }
804
805 let mut pending = state.ws_relay.pending.lock().await;
806 pending.remove(&(client_id, request_id));
807 let has_remaining = pending
808 .iter()
809 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
810 drop(pending);
811
812 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
813 send_json(
814 state,
815 origin_id,
816 WsResponse {
817 kind: "res",
818 id: request_id,
819 hash: pending_hash,
820 found: false,
821 },
822 )
823 .await;
824 }
825}
826
827async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
828 if let Some(msg) = parse_msgpack_message(&data) {
829 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
830 match msg {
831 DataMessage::Request(req) => {
832 let hash_hex = hex_encode(&req.h);
833 let request_id = state.ws_relay.next_request_id();
834 handle_request(
835 client_id,
836 request_id,
837 hash_hex,
838 WsProtocol::HashtreeMsgpack,
839 state,
840 )
841 .await;
842 }
843 DataMessage::Response(res) => {
844 handle_msgpack_response(client_id, res, state).await;
845 }
846 DataMessage::QuoteRequest(_)
847 | DataMessage::QuoteResponse(_)
848 | DataMessage::Payment(_)
849 | DataMessage::PaymentAck(_)
850 | DataMessage::Chunk(_)
851 | DataMessage::PeerHints(_)
852 | DataMessage::PubsubInterest(_)
853 | DataMessage::PubsubFrame(_)
854 | DataMessage::PubsubInventory(_)
855 | DataMessage::PubsubWant(_) => {}
856 }
857 return;
858 }
859
860 if data.len() < 4 {
862 return;
863 }
864 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
865 let pending_entry = {
866 let pending = state.ws_relay.pending.lock().await;
867 pending
868 .get(&(client_id, request_id))
869 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
870 };
871 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
872 return;
873 };
874
875 match origin_protocol {
876 WsProtocol::HashtreeJson => {
877 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
878 }
879 WsProtocol::HashtreeMsgpack => {
880 let Ok(hash_bytes) = from_hex(&hash_hex) else {
881 return;
882 };
883 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
884 }
885 WsProtocol::Unknown => {}
886 }
887
888 let mut pending = state.ws_relay.pending.lock().await;
889 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
890}
891
892async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
893 let replies = nostr_responses_for(&msg);
894 for reply in replies {
895 send_nostr(state, client_id, reply).await;
896 }
897}
898
899fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
900 match msg {
901 NostrClientMessage::Event(event) => {
902 let ok = event.verify().is_ok();
903 let message = if ok { "" } else { "invalid: signature" };
904 vec![NostrRelayMessage::ok(event.id, ok, message)]
905 }
906 NostrClientMessage::Req {
907 subscription_id, ..
908 } => {
909 vec![NostrRelayMessage::eose(subscription_id.clone())]
910 }
911 NostrClientMessage::Count {
912 subscription_id, ..
913 } => {
914 vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
915 }
916 NostrClientMessage::Close(_) => Vec::new(),
917 NostrClientMessage::Auth(event) => {
918 let ok = event.verify().is_ok();
919 let message = if ok { "" } else { "invalid auth" };
920 vec![NostrRelayMessage::ok(event.id, ok, message)]
921 }
922 NostrClientMessage::NegOpen { .. }
923 | NostrClientMessage::NegMsg { .. }
924 | NostrClientMessage::NegClose { .. } => {
925 vec![NostrRelayMessage::notice("negentropy not supported")]
926 }
927 }
928}
929
930async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
931 let text = response.as_json();
932 send_to_client(state, client_id, Message::Text(text)).await;
933}
934
935fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
936 let msg = parse_message(data).ok()?;
937 match msg {
938 DataMessage::Request(req) => {
939 if req.h.len() == 32 {
940 Some(DataMessage::Request(req))
941 } else {
942 None
943 }
944 }
945 DataMessage::Response(res) => {
946 if res.h.len() == 32 {
947 Some(DataMessage::Response(res))
948 } else {
949 None
950 }
951 }
952 DataMessage::QuoteRequest(req) => {
953 if req.h.len() == 32 {
954 Some(DataMessage::QuoteRequest(req))
955 } else {
956 None
957 }
958 }
959 DataMessage::QuoteResponse(res) => {
960 if res.h.len() == 32 {
961 Some(DataMessage::QuoteResponse(res))
962 } else {
963 None
964 }
965 }
966 DataMessage::Payment(req) => {
967 if req.h.len() == 32 {
968 Some(DataMessage::Payment(req))
969 } else {
970 None
971 }
972 }
973 DataMessage::PaymentAck(res) => {
974 if res.h.len() == 32 {
975 Some(DataMessage::PaymentAck(res))
976 } else {
977 None
978 }
979 }
980 DataMessage::Chunk(chunk) => {
981 if chunk.h.len() == 32 {
982 Some(DataMessage::Chunk(chunk))
983 } else {
984 None
985 }
986 }
987 DataMessage::PeerHints(_)
988 | DataMessage::PubsubInterest(_)
989 | DataMessage::PubsubFrame(_)
990 | DataMessage::PubsubInventory(_)
991 | DataMessage::PubsubWant(_) => Some(msg),
992 }
993}
994
995async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
996 let hash_hex = hex_encode(&res.h);
997 let data = res.d.clone();
998 let hash_bytes = res.h.clone();
999
1000 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
1001 let mut seen = HashSet::new();
1002 {
1003 let pending = state.ws_relay.pending.lock().await;
1004 for ((peer_id, request_id), p) in pending.iter() {
1005 if *peer_id != client_id {
1006 continue;
1007 }
1008 if p.hash != hash_hex {
1009 continue;
1010 }
1011 if seen.insert((p.origin_id, *request_id)) {
1012 responses.push((p.origin_id, *request_id, p.origin_protocol));
1013 }
1014 }
1015 }
1016
1017 if responses.is_empty() {
1018 return;
1019 }
1020
1021 for (origin_id, request_id, protocol) in &responses {
1022 match protocol {
1023 WsProtocol::HashtreeJson => {
1024 send_json(
1025 state,
1026 *origin_id,
1027 WsResponse {
1028 kind: "res",
1029 id: *request_id,
1030 hash: hash_hex.clone(),
1031 found: true,
1032 },
1033 )
1034 .await;
1035 send_binary(state, *origin_id, *request_id, data.clone()).await;
1036 }
1037 WsProtocol::HashtreeMsgpack => {
1038 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1039 }
1040 WsProtocol::Unknown => {}
1041 }
1042 }
1043
1044 let completed: HashSet<(u64, u32)> = responses
1045 .into_iter()
1046 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1047 .collect();
1048 let mut pending = state.ws_relay.pending.lock().await;
1049 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1050}
1051
1052async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1053 if let Ok(text) = serde_json::to_string(&response) {
1054 send_to_client(state, client_id, Message::Text(text)).await;
1055 }
1056}
1057
1058async fn send_msgpack_request(
1059 state: &AppState,
1060 client_id: u64,
1061 hash: &[u8],
1062) -> Result<(), rmp_serde::encode::Error> {
1063 let req = DataRequest {
1064 h: hash.to_vec(),
1065 htl: MAX_HTL,
1066 q: None,
1067 };
1068 let wire = encode_request(&req)?;
1069 send_to_client(state, client_id, Message::Binary(wire)).await;
1070 Ok(())
1071}
1072
1073async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1074 let res = DataResponse {
1075 h: hash.to_vec(),
1076 d: data.to_vec(),
1077 i: None,
1078 n: None,
1079 };
1080 if let Ok(wire) = encode_response(&res) {
1081 send_to_client(state, client_id, Message::Binary(wire)).await;
1082 }
1083}
1084
1085async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1086 let mut packet = Vec::with_capacity(4 + payload.len());
1087 packet.extend_from_slice(&request_id.to_le_bytes());
1088 packet.extend_from_slice(&payload);
1089 send_to_client(state, client_id, Message::Binary(packet)).await;
1090}
1091
1092async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1093 let sender = {
1094 let clients = state.ws_relay.clients.lock().await;
1095 clients.get(&client_id).cloned()
1096 };
1097 if let Some(tx) = sender {
1098 let _ = tx.send(msg);
1099 }
1100}
1101
1102async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1103 let mut protocols = state.ws_relay.client_protocols.lock().await;
1104 protocols.insert(client_id, protocol);
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109 use super::*;
1110 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1111 use anyhow::Result;
1112 use futures::{SinkExt, StreamExt};
1113 use nostr::secp256k1::schnorr::Signature;
1114 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1115 use std::collections::HashSet;
1116 use std::sync::Arc;
1117 use tempfile::TempDir;
1118 use tokio::net::TcpListener;
1119 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1120
1121 #[test]
1122 fn parse_ws_text_message_detects_nostr_req() {
1123 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1124 match parse_ws_text_message(msg) {
1125 Some(WsTextMessage::Nostr(_)) => {}
1126 other => panic!("expected Nostr message, got {:?}", other),
1127 }
1128 }
1129
1130 #[test]
1131 fn parse_ws_text_message_detects_hashtree_request() {
1132 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1133 match parse_ws_text_message(msg) {
1134 Some(WsTextMessage::Hashtree(_)) => {}
1135 other => panic!("expected Hashtree message, got {:?}", other),
1136 }
1137 }
1138
1139 #[test]
1140 fn nostr_replies_for_req_is_eose() {
1141 let sub = SubscriptionId::new("sub-1");
1142 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1143 let replies = nostr_responses_for(&msg);
1144 assert_eq!(replies.len(), 1);
1145 match &replies[0] {
1146 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1147 other => panic!("expected EOSE, got {:?}", other),
1148 }
1149 }
1150
1151 #[test]
1152 fn nostr_replies_for_event_ok() {
1153 let keys = Keys::generate();
1154 let event = EventBuilder::new(Kind::TextNote, "hello", [])
1155 .to_event(&keys)
1156 .unwrap();
1157 let msg = NostrClientMessage::event(event.clone());
1158 let replies = nostr_responses_for(&msg);
1159 assert_eq!(replies.len(), 1);
1160 match &replies[0] {
1161 NostrRelayMessage::Ok {
1162 event_id, status, ..
1163 } => {
1164 assert_eq!(event_id, &event.id);
1165 assert!(*status);
1166 }
1167 other => panic!("expected OK, got {:?}", other),
1168 }
1169 }
1170
1171 #[test]
1172 fn nostr_replies_for_invalid_event_is_not_ok() {
1173 let keys = Keys::generate();
1174 let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1175 .to_event(&keys)
1176 .unwrap();
1177 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1178 let msg = NostrClientMessage::event(event);
1179 let replies = nostr_responses_for(&msg);
1180 assert_eq!(replies.len(), 1);
1181 match &replies[0] {
1182 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1183 other => panic!("expected OK=false, got {:?}", other),
1184 }
1185 }
1186
1187 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1188 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1189 let addr = listener.local_addr().expect("relay addr");
1190 tokio::spawn(async move {
1191 let (stream, _) = listener.accept().await.expect("accept relay");
1192 let ws = accept_async(stream).await.expect("accept websocket");
1193 let (mut write, mut read) = ws.split();
1194
1195 while let Some(Ok(message)) = read.next().await {
1196 let TungsteniteMessage::Text(text) = message else {
1197 continue;
1198 };
1199 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1200 continue;
1201 };
1202 if let NostrClientMessage::Req {
1203 subscription_id,
1204 filters,
1205 } = parsed
1206 {
1207 for event in events
1208 .iter()
1209 .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1210 {
1211 let _ = write
1212 .send(TungsteniteMessage::Text(
1213 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1214 .as_json()
1215 .into(),
1216 ))
1217 .await;
1218 }
1219 let _ = write
1220 .send(TungsteniteMessage::Text(
1221 NostrRelayMessage::eose(subscription_id).as_json().into(),
1222 ))
1223 .await;
1224 }
1225 }
1226 });
1227 format!("ws://{}", addr)
1228 }
1229
1230 fn test_app_state(
1231 tmp: &TempDir,
1232 relay: Arc<NostrRelay>,
1233 relay_url: String,
1234 ) -> Result<AppState> {
1235 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1236 tmp.path(),
1237 None,
1238 128 * 1024 * 1024,
1239 )?);
1240 Ok(AppState {
1241 store,
1242 auth: None,
1243 peer_mode: crate::config::ServerMode::Normal,
1244 hash_get_enabled: true,
1245 webrtc_peers: None,
1246 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1247 max_upload_bytes: 5 * 1024 * 1024,
1248 public_writes: true,
1249 allowed_pubkeys: HashSet::new(),
1250 upstream_blossom: Vec::new(),
1251 social_graph: None,
1252 social_graph_store: None,
1253 social_graph_root: None,
1254 socialgraph_snapshot_public: false,
1255 nostr_relay: Some(relay),
1256 nostr_relay_urls: vec![relay_url],
1257 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1258 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1259 std::collections::HashMap::new(),
1260 )),
1261 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1262 super::super::auth::new_lookup_cache(),
1263 )),
1264 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1265 super::super::auth::new_lookup_cache(),
1266 )),
1267 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1268 super::super::auth::new_lookup_cache(),
1269 )),
1270 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1271 })
1272 }
1273
1274 #[tokio::test]
1275 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1276 let tmp = TempDir::new()?;
1277 let graph_store = {
1278 let _guard = crate::socialgraph::test_lock();
1279 crate::socialgraph::open_social_graph_store_with_mapsize(
1280 tmp.path(),
1281 Some(128 * 1024 * 1024),
1282 )?
1283 };
1284 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1285 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1286 Arc::clone(&backend),
1287 0,
1288 HashSet::new(),
1289 ));
1290
1291 let keys = Keys::generate();
1292 let relay = Arc::new(NostrRelay::new(
1293 Arc::clone(&backend),
1294 tmp.path().to_path_buf(),
1295 HashSet::from([keys.public_key().to_hex()]),
1296 Some(access),
1297 NostrRelayConfig {
1298 spambox_db_max_bytes: 0,
1299 ..Default::default()
1300 },
1301 )?);
1302
1303 let event = EventBuilder::new(
1304 Kind::from(30078_u16),
1305 "",
1306 [
1307 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1308 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1309 ],
1310 )
1311 .to_event(&keys)?;
1312
1313 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1314 let filter = Filter::new()
1315 .authors(vec![event.pubkey])
1316 .kinds(vec![event.kind]);
1317 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1318 let client_id = 7_u64;
1319 let (tx, mut rx) = mpsc::unbounded_channel();
1320 state.ws_relay.clients.lock().await.insert(client_id, tx);
1321 let subscription_id = SubscriptionId::new("sub-1");
1322
1323 start_upstream_nostr_subscription(
1324 &state,
1325 client_id,
1326 subscription_id.clone(),
1327 vec![filter.clone()],
1328 )
1329 .await;
1330
1331 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1332 .await?
1333 .expect("forwarded upstream event");
1334 let Message::Text(text) = forwarded else {
1335 panic!("expected text event");
1336 };
1337 match NostrRelayMessage::from_json(text.as_str())? {
1338 NostrRelayMessage::Event {
1339 subscription_id: sid,
1340 event: forwarded_event,
1341 } => {
1342 assert_eq!(sid, subscription_id);
1343 assert_eq!(forwarded_event.id, event.id);
1344 }
1345 other => panic!("expected forwarded EVENT, got {:?}", other),
1346 }
1347
1348 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1349 .await?
1350 .expect("forwarded upstream eose");
1351 let Message::Text(eose_text) = eose else {
1352 panic!("expected text eose");
1353 };
1354 match NostrRelayMessage::from_json(eose_text.as_str())? {
1355 NostrRelayMessage::EndOfStoredEvents(sid) => {
1356 assert_eq!(sid, subscription_id);
1357 }
1358 other => panic!("expected forwarded EOSE, got {:?}", other),
1359 }
1360
1361 let events = relay.query_events(&filter, 10).await;
1362 assert_eq!(events.len(), 1);
1363 assert_eq!(events[0].id, event.id);
1364
1365 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1366 assert!(state
1367 .ws_relay
1368 .upstream_nostr_subscriptions
1369 .lock()
1370 .await
1371 .is_empty());
1372 Ok(())
1373 }
1374
1375 #[tokio::test]
1376 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1377 let tmp = TempDir::new()?;
1378 let graph_store = {
1379 let _guard = crate::socialgraph::test_lock();
1380 crate::socialgraph::open_social_graph_store_with_mapsize(
1381 tmp.path(),
1382 Some(128 * 1024 * 1024),
1383 )?
1384 };
1385 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1386 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1387 Arc::clone(&backend),
1388 0,
1389 HashSet::new(),
1390 ));
1391
1392 let keys = Keys::generate();
1393 let relay = Arc::new(NostrRelay::new(
1394 Arc::clone(&backend),
1395 tmp.path().to_path_buf(),
1396 HashSet::from([keys.public_key().to_hex()]),
1397 Some(access),
1398 NostrRelayConfig {
1399 spambox_db_max_bytes: 0,
1400 ..Default::default()
1401 },
1402 )?);
1403
1404 let event = EventBuilder::new(
1405 Kind::from(30078_u16),
1406 "",
1407 [
1408 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1409 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1410 ],
1411 )
1412 .to_event(&keys)?;
1413
1414 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1415 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1416 let client_id = 11_u64;
1417 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1418 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1419 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1420 relay.register_client(client_id, relay_tx, None).await;
1421
1422 let request = NostrClientMessage::req(
1423 SubscriptionId::new("feed"),
1424 vec![Filter::new()
1425 .authors(vec![event.pubkey])
1426 .kinds(vec![event.kind])],
1427 )
1428 .as_json();
1429
1430 handle_message(client_id, Message::Text(request.into()), &state).await;
1431
1432 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1433 .await?
1434 .expect("first forwarded message");
1435 let Message::Text(first_text) = first else {
1436 panic!("expected text event");
1437 };
1438 match NostrRelayMessage::from_json(first_text.as_str())? {
1439 NostrRelayMessage::Event {
1440 event: forwarded_event,
1441 ..
1442 } => {
1443 assert_eq!(forwarded_event.id, event.id);
1444 }
1445 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1446 }
1447
1448 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1449 .await?
1450 .expect("second forwarded message");
1451 let Message::Text(second_text) = second else {
1452 panic!("expected text eose");
1453 };
1454 match NostrRelayMessage::from_json(second_text.as_str())? {
1455 NostrRelayMessage::EndOfStoredEvents(sid) => {
1456 assert_eq!(sid, SubscriptionId::new("feed"));
1457 }
1458 other => panic!("expected aggregated EOSE, got {:?}", other),
1459 }
1460
1461 Ok(())
1462 }
1463
1464 #[tokio::test]
1465 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1466 let tmp = TempDir::new()?;
1467 let graph_store = {
1468 let _guard = crate::socialgraph::test_lock();
1469 crate::socialgraph::open_social_graph_store_with_mapsize(
1470 tmp.path(),
1471 Some(128 * 1024 * 1024),
1472 )?
1473 };
1474 let author_keys = Keys::generate();
1475 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1476 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1477 Arc::clone(&backend),
1478 0,
1479 HashSet::from([author_keys.public_key().to_hex()]),
1480 ));
1481 let relay = Arc::new(NostrRelay::new(
1482 Arc::clone(&backend),
1483 tmp.path().to_path_buf(),
1484 HashSet::from([author_keys.public_key().to_hex()]),
1485 Some(access),
1486 NostrRelayConfig {
1487 spambox_db_max_bytes: 0,
1488 ..Default::default()
1489 },
1490 )?);
1491
1492 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1493 let listener = TcpListener::bind("127.0.0.1:0").await?;
1494 let addr = listener.local_addr()?;
1495 let client_pubkey = author_keys.public_key().to_hex();
1496 let app = axum::Router::new().route(
1497 "/ws",
1498 axum::routing::get({
1499 let state = state.clone();
1500 let client_pubkey = client_pubkey.clone();
1501 move |ws: WebSocketUpgrade| {
1502 let state = state.clone();
1503 let client_pubkey = client_pubkey.clone();
1504 async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1505 }
1506 }),
1507 );
1508 tokio::spawn(async move {
1509 let _ = axum::serve(listener, app).await;
1510 });
1511
1512 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1513 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1514 .to_event(&author_keys)?;
1515 socket
1516 .send(TungsteniteMessage::Text(
1517 NostrClientMessage::event(event.clone()).as_json().into(),
1518 ))
1519 .await?;
1520
1521 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1522 .await?
1523 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1524 let TungsteniteMessage::Text(text) = reply else {
1525 anyhow::bail!("expected text publish ack");
1526 };
1527
1528 match NostrRelayMessage::from_json(text.as_str())? {
1529 NostrRelayMessage::Ok {
1530 event_id, status, ..
1531 } => {
1532 assert_eq!(event_id, event.id);
1533 assert!(status);
1534 }
1535 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1536 }
1537
1538 let stored = relay
1539 .query_events(
1540 &Filter::new()
1541 .authors(vec![event.pubkey])
1542 .kinds(vec![event.kind]),
1543 10,
1544 )
1545 .await;
1546 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1547 Ok(())
1548 }
1549
1550 #[tokio::test]
1551 async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1552 let tmp = TempDir::new()?;
1553 let graph_store = {
1554 let _guard = crate::socialgraph::test_lock();
1555 crate::socialgraph::open_social_graph_store_with_mapsize(
1556 tmp.path(),
1557 Some(128 * 1024 * 1024),
1558 )?
1559 };
1560 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1561 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1562 Arc::clone(&backend),
1563 0,
1564 HashSet::new(),
1565 ));
1566 let relay = Arc::new(NostrRelay::new(
1567 Arc::clone(&backend),
1568 tmp.path().to_path_buf(),
1569 HashSet::new(),
1570 Some(access),
1571 NostrRelayConfig {
1572 spambox_db_max_bytes: 0,
1573 spambox_max_reqs_per_min: 1,
1574 ..Default::default()
1575 },
1576 )?);
1577
1578 let state = test_app_state(&tmp, relay, String::new())?;
1579 let listener = TcpListener::bind("127.0.0.1:0").await?;
1580 let addr = listener.local_addr()?;
1581 let app = axum::Router::new().route(
1582 "/ws",
1583 axum::routing::get({
1584 let state = state.clone();
1585 move |ws: WebSocketUpgrade| {
1586 let state = state.clone();
1587 async move { ws_data_with_client_pubkey(state, ws, None) }
1588 }
1589 }),
1590 );
1591 tokio::spawn(async move {
1592 let _ = axum::serve(listener, app).await;
1593 });
1594
1595 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1596 socket
1597 .send(TungsteniteMessage::Text(
1598 NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1599 .as_json()
1600 .into(),
1601 ))
1602 .await?;
1603
1604 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1605 .await?
1606 .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1607 let TungsteniteMessage::Text(first_text) = first else {
1608 anyhow::bail!("expected text EOSE reply");
1609 };
1610 match NostrRelayMessage::from_json(first_text.as_str())? {
1611 NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1612 assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1613 }
1614 other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1615 }
1616
1617 socket
1618 .send(TungsteniteMessage::Text(
1619 NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1620 .as_json()
1621 .into(),
1622 ))
1623 .await?;
1624
1625 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1626 .await?
1627 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1628 let TungsteniteMessage::Text(second_text) = second else {
1629 anyhow::bail!("expected text CLOSED reply");
1630 };
1631 let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1632 assert_eq!(
1633 second_value,
1634 serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1635 );
1636
1637 Ok(())
1638 }
1639
1640 #[tokio::test]
1641 async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1642 let tmp = TempDir::new()?;
1643 let graph_store = {
1644 let _guard = crate::socialgraph::test_lock();
1645 crate::socialgraph::open_social_graph_store_with_mapsize(
1646 tmp.path(),
1647 Some(128 * 1024 * 1024),
1648 )?
1649 };
1650 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1651 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1652 Arc::clone(&backend),
1653 0,
1654 HashSet::new(),
1655 ));
1656 let relay = Arc::new(NostrRelay::new(
1657 Arc::clone(&backend),
1658 tmp.path().to_path_buf(),
1659 HashSet::new(),
1660 Some(access),
1661 NostrRelayConfig {
1662 spambox_db_max_bytes: 0,
1663 spambox_max_events_per_min: 1,
1664 ..Default::default()
1665 },
1666 )?);
1667
1668 let state = test_app_state(&tmp, relay, String::new())?;
1669 let listener = TcpListener::bind("127.0.0.1:0").await?;
1670 let addr = listener.local_addr()?;
1671 let app = axum::Router::new().route(
1672 "/ws",
1673 axum::routing::get({
1674 let state = state.clone();
1675 move |ws: WebSocketUpgrade| {
1676 let state = state.clone();
1677 async move { ws_data_with_client_pubkey(state, ws, None) }
1678 }
1679 }),
1680 );
1681 tokio::spawn(async move {
1682 let _ = axum::serve(listener, app).await;
1683 });
1684
1685 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1686 let author_keys = Keys::generate();
1687 let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1688 let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1689
1690 socket
1691 .send(TungsteniteMessage::Text(
1692 NostrClientMessage::event(event_a.clone()).as_json().into(),
1693 ))
1694 .await?;
1695
1696 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1697 .await?
1698 .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1699 let TungsteniteMessage::Text(first_text) = first else {
1700 anyhow::bail!("expected text publish ack");
1701 };
1702 match NostrRelayMessage::from_json(first_text.as_str())? {
1703 NostrRelayMessage::Ok {
1704 event_id,
1705 status,
1706 message,
1707 } => {
1708 assert_eq!(event_id, event_a.id);
1709 assert!(status);
1710 assert_eq!(message, "spambox");
1711 }
1712 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1713 }
1714
1715 socket
1716 .send(TungsteniteMessage::Text(
1717 NostrClientMessage::event(event_b.clone()).as_json().into(),
1718 ))
1719 .await?;
1720
1721 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1722 .await?
1723 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1724 let TungsteniteMessage::Text(second_text) = second else {
1725 anyhow::bail!("expected text publish ack");
1726 };
1727 match NostrRelayMessage::from_json(second_text.as_str())? {
1728 NostrRelayMessage::Ok {
1729 event_id,
1730 status,
1731 message,
1732 } => {
1733 assert_eq!(event_id, event_b.id);
1734 assert!(!status);
1735 assert_eq!(message, "rate limited");
1736 }
1737 other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1738 }
1739
1740 Ok(())
1741 }
1742}