1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21 nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31 #[serde(rename = "req")]
32 Request { id: u32, hash: String },
33 #[serde(rename = "res")]
34 Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage<'a> {
39 Hashtree(WsClientMessage),
40 Nostr(NostrClientMessage<'a>),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45 #[serde(rename = "type")]
46 kind: String,
47 id: u32,
48 hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53 #[serde(rename = "type")]
54 kind: &'static str,
55 id: u32,
56 hash: String,
57 found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61 ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65 state: AppState,
66 ws: WebSocketUpgrade,
67 client_pubkey: Option<String>,
68) -> impl IntoResponse {
69 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73 let client_id = state
76 .nostr_relay
77 .as_ref()
78 .map(|relay| relay.next_client_id())
79 .unwrap_or_else(|| state.ws_relay.next_id());
80 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82 {
83 let mut clients = state.ws_relay.clients.lock().await;
84 clients.insert(client_id, tx);
85 }
86 {
87 let mut protocols = state.ws_relay.client_protocols.lock().await;
88 protocols.insert(client_id, WsProtocol::HashtreeJson);
89 }
90
91 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93 relay
94 .register_client(client_id, nostr_tx, client_pubkey.clone())
95 .await;
96 Some(nostr_rx)
97 } else {
98 None
99 };
100
101 let (mut sender, mut receiver) = socket.split();
102 loop {
103 tokio::select! {
104 maybe_msg = rx.recv() => {
105 let Some(msg) = maybe_msg else {
106 break;
107 };
108 if sender.send(msg).await.is_err() {
109 break;
110 }
111 }
112 maybe_text = async {
113 match &mut nostr_rx {
114 Some(rx) => rx.recv().await,
115 None => std::future::pending().await,
116 }
117 } => {
118 let Some(text) = maybe_text else {
119 nostr_rx = None;
120 continue;
121 };
122 if sender.send(Message::Text(text)).await.is_err() {
123 break;
124 }
125 }
126 maybe_incoming = receiver.next() => {
127 match maybe_incoming {
128 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129 Some(Err(_)) | None => break,
130 }
131 }
132 }
133 }
134
135 close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137 {
138 let mut clients = state.ws_relay.clients.lock().await;
139 clients.remove(&client_id);
140 }
141 {
142 let mut protocols = state.ws_relay.client_protocols.lock().await;
143 protocols.remove(&client_id);
144 }
145 {
146 let mut pending = state.ws_relay.pending.lock().await;
147 pending.retain(|(peer_id, _), _| *peer_id != client_id);
148 }
149
150 if let Some(relay) = &state.nostr_relay {
151 relay.unregister_client(client_id).await;
152 }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage<'static>> {
156 let trimmed = text.trim_start();
157 if trimmed.starts_with('[') {
158 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159 return Some(WsTextMessage::Nostr(msg));
160 }
161 }
162
163 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164 return Some(WsTextMessage::Hashtree(msg));
165 }
166
167 None
168}
169async fn close_upstream_nostr_subscription(
170 state: &AppState,
171 client_id: u64,
172 subscription_id: &SubscriptionId,
173) {
174 let key = (client_id, subscription_id.to_string());
175 let subscription = {
176 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177 subscriptions.remove(&key)
178 };
179 if let Some(subscription) = subscription {
180 let _ = subscription.close_tx.send(true);
181 for task in subscription.tasks {
182 task.abort();
183 }
184 }
185 state
186 .ws_relay
187 .upstream_pending_eose
188 .lock()
189 .await
190 .remove(&key);
191 state
192 .ws_relay
193 .upstream_seen_events
194 .lock()
195 .await
196 .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200 let keys = {
201 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202 subscriptions
203 .keys()
204 .filter(|(id, _)| *id == client_id)
205 .cloned()
206 .collect::<Vec<_>>()
207 };
208 for (_, sub_id) in keys {
209 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210 }
211}
212
213async fn forward_upstream_nostr_message(
214 state: &AppState,
215 client_id: u64,
216 subscription_id: &SubscriptionId,
217 text: &str,
218) {
219 let Ok(message) = NostrRelayMessage::from_json(text) else {
220 return;
221 };
222
223 match message {
224 NostrRelayMessage::Event {
225 subscription_id: sid,
226 event,
227 } if sid.as_ref() == subscription_id => {
228 let event = event.into_owned();
229 let key = (client_id, subscription_id.to_string());
230 let event_id = event.id.to_hex();
231 let inserted = {
232 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233 seen_events.entry(key).or_default().insert(event_id)
234 };
235 if !inserted {
236 return;
237 }
238 if let Some(relay) = &state.nostr_relay {
239 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240 }
241 send_nostr(
242 state,
243 client_id,
244 NostrRelayMessage::event(subscription_id.clone(), event),
245 )
246 .await;
247 }
248 NostrRelayMessage::Closed {
249 subscription_id: sid,
250 message,
251 } if sid.as_ref() == subscription_id => {
252 send_nostr(
253 state,
254 client_id,
255 NostrRelayMessage::closed(subscription_id.clone(), message.into_owned()),
256 )
257 .await;
258 }
259 _ => {}
260 }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264 state: &AppState,
265 client_id: u64,
266 subscription_id: &SubscriptionId,
267) {
268 let key = (client_id, subscription_id.to_string());
269 let should_send_eose = {
270 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271 let Some(remaining) = pending.get_mut(&key) else {
272 return;
273 };
274 if *remaining > 0 {
275 *remaining -= 1;
276 }
277 if *remaining == 0 {
278 pending.remove(&key);
279 true
280 } else {
281 false
282 }
283 };
284
285 if should_send_eose {
286 send_nostr(
287 state,
288 client_id,
289 NostrRelayMessage::eose(subscription_id.clone()),
290 )
291 .await;
292 }
293}
294
295async fn run_upstream_nostr_subscription(
296 state: AppState,
297 client_id: u64,
298 relay_url: String,
299 subscription_id: SubscriptionId,
300 filters: Vec<NostrFilter>,
301 mut close_rx: watch::Receiver<bool>,
302) {
303 let mut relay_complete = false;
304 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305 tracing::warn!(
306 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307 client_id,
308 subscription_id,
309 relay_url,
310 );
311 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312 return;
313 };
314 let (mut write, mut read) = socket.split();
315 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316 state
317 .ws_relay
318 .note_upstream_relay_send(request.as_bytes().len());
319 if write
320 .send(TungsteniteMessage::Text(request.into()))
321 .await
322 .is_err()
323 {
324 tracing::warn!(
325 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326 client_id,
327 subscription_id,
328 relay_url,
329 );
330 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331 return;
332 }
333
334 loop {
335 tokio::select! {
336 _ = close_rx.changed() => {
337 if *close_rx.borrow() {
338 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339 state
340 .ws_relay
341 .note_upstream_relay_send(close.as_bytes().len());
342 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343 let _ = write.close().await;
344 break;
345 }
346 }
347 message = read.next() => {
348 match message {
349 Some(Ok(TungsteniteMessage::Text(text))) => {
350 state
351 .ws_relay
352 .note_upstream_relay_receive(text.as_bytes().len());
353 if matches!(
354 NostrRelayMessage::from_json(text.as_str()),
355 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid.as_ref() == &subscription_id
356 ) {
357 if !relay_complete {
358 relay_complete = true;
359 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360 }
361 continue;
362 }
363 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364 }
365 Some(Ok(TungsteniteMessage::Ping(payload))) => {
366 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367 }
368 Some(Ok(TungsteniteMessage::Close(_))) | None => {
369 if !relay_complete {
370 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371 }
372 break;
373 }
374 Some(Err(_)) => {
375 if !relay_complete {
376 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377 }
378 break;
379 }
380 _ => {}
381 }
382 }
383 }
384 }
385}
386
387async fn start_upstream_nostr_subscription(
388 state: &AppState,
389 client_id: u64,
390 subscription_id: SubscriptionId,
391 filters: Vec<NostrFilter>,
392) -> usize {
393 let memory_before = process_memory_snapshot();
394 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396 tracing::info!(
397 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398 client_id,
399 subscription_id,
400 state.nostr_relay_urls.len(),
401 filters.len(),
402 );
403 return 0;
404 }
405
406 let mut relay_urls = Vec::new();
407 let mut seen = HashSet::new();
408 for relay in &state.nostr_relay_urls {
409 let relay = relay.trim();
410 if relay.is_empty() || !seen.insert(relay.to_string()) {
411 continue;
412 }
413 relay_urls.push(relay.to_string());
414 }
415 if relay_urls.is_empty() {
416 tracing::info!(
417 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418 client_id,
419 subscription_id,
420 );
421 return 0;
422 }
423
424 tracing::info!(
425 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426 client_id,
427 subscription_id,
428 relay_urls.len(),
429 );
430
431 let filter_summary = nostr_filters_summary(&filters);
432 let key = (client_id, subscription_id.to_string());
433 state
434 .ws_relay
435 .upstream_seen_events
436 .lock()
437 .await
438 .insert(key.clone(), HashSet::new());
439 state
440 .ws_relay
441 .upstream_pending_eose
442 .lock()
443 .await
444 .insert(key.clone(), relay_urls.len());
445
446 let (close_tx, close_rx) = watch::channel(false);
447 let mut tasks = Vec::new();
448 for relay_url in &relay_urls {
449 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450 state.clone(),
451 client_id,
452 relay_url.clone(),
453 subscription_id.clone(),
454 filters.clone(),
455 close_rx.clone(),
456 )));
457 }
458
459 tracing::info!(
460 target: "hashtree_cli::server::ws_relay::upstream",
461 client_id,
462 subscription_id = %subscription_id,
463 relays = relay_urls.len(),
464 tasks = tasks.len(),
465 filters = filters.len(),
466 filter = %filter_summary,
467 memory_before = ?memory_before,
468 memory_after = ?process_memory_snapshot(),
469 "upstream nostr relay tasks spawned",
470 );
471
472 state
473 .ws_relay
474 .upstream_nostr_subscriptions
475 .lock()
476 .await
477 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478 relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482 match msg {
483 Message::Text(text) => {
484 if let Some(msg) = parse_ws_text_message(&text) {
485 match msg {
486 WsTextMessage::Hashtree(msg) => {
487 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488 match msg {
489 WsClientMessage::Request { id, hash } => {
490 handle_request(
491 client_id,
492 id,
493 hash,
494 WsProtocol::HashtreeJson,
495 state,
496 )
497 .await;
498 }
499 WsClientMessage::Response { id, hash, found } => {
500 handle_response(client_id, id, hash, found, state).await;
501 }
502 }
503 }
504 WsTextMessage::Nostr(msg) => {
505 if let Some(relay) = &state.nostr_relay {
506 match msg {
507 NostrClientMessage::Req {
508 subscription_id,
509 filters,
510 } => {
511 let subscription_id = subscription_id.into_owned();
512 let filters = filters
513 .into_iter()
514 .map(|filter| filter.into_owned())
515 .collect::<Vec<_>>();
516 let local_events = match relay
517 .register_subscription_query(
518 client_id,
519 subscription_id.clone(),
520 filters.clone(),
521 )
522 .await
523 {
524 Ok(events) => events,
525 Err(message) => {
526 send_nostr(
527 state,
528 client_id,
529 NostrRelayMessage::closed(subscription_id, message),
530 )
531 .await;
532 return;
533 }
534 };
535
536 let upstream_relays = start_upstream_nostr_subscription(
537 state,
538 client_id,
539 subscription_id.clone(),
540 filters,
541 )
542 .await;
543 if upstream_relays > 0 {
544 let key = (client_id, subscription_id.to_string());
545 let mut seen_events =
546 state.ws_relay.upstream_seen_events.lock().await;
547 seen_events.entry(key).or_default().extend(
548 local_events.iter().map(|event| event.id.to_hex()),
549 );
550 }
551 for event in local_events {
552 send_nostr(
553 state,
554 client_id,
555 NostrRelayMessage::event(
556 subscription_id.clone(),
557 event,
558 ),
559 )
560 .await;
561 }
562 trim_process_allocations();
563 if upstream_relays == 0 {
564 send_nostr(
565 state,
566 client_id,
567 NostrRelayMessage::eose(subscription_id),
568 )
569 .await;
570 }
571 }
572 NostrClientMessage::Close(subscription_id) => {
573 let subscription_id = subscription_id.into_owned();
574 close_upstream_nostr_subscription(
575 state,
576 client_id,
577 &subscription_id,
578 )
579 .await;
580 relay
581 .handle_client_message(
582 client_id,
583 NostrClientMessage::close(subscription_id.clone()),
584 )
585 .await;
586 }
587 other => {
588 relay.handle_client_message(client_id, other).await;
589 }
590 }
591 } else {
592 handle_nostr_message(client_id, msg, state).await;
593 }
594 }
595 }
596 }
597 }
598 Message::Binary(data) => {
599 handle_binary(client_id, data, state).await;
600 }
601 Message::Close(_) => {}
602 _ => {}
603 }
604}
605
606async fn handle_request(
607 client_id: u64,
608 request_id: u32,
609 hash: String,
610 origin_protocol: WsProtocol,
611 state: &AppState,
612) {
613 let hash_hex = hash.to_lowercase();
614 let hash_bytes = match from_hex(&hash_hex) {
615 Ok(bytes) => bytes,
616 Err(_) => {
617 if origin_protocol == WsProtocol::HashtreeJson {
618 send_json(
619 state,
620 client_id,
621 WsResponse {
622 kind: "res",
623 id: request_id,
624 hash,
625 found: false,
626 },
627 )
628 .await;
629 }
630 return;
631 }
632 };
633
634 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
635 match origin_protocol {
636 WsProtocol::HashtreeJson => {
637 send_json(
638 state,
639 client_id,
640 WsResponse {
641 kind: "res",
642 id: request_id,
643 hash: hash.clone(),
644 found: true,
645 },
646 )
647 .await;
648 send_binary(state, client_id, request_id, data).await;
649 }
650 WsProtocol::HashtreeMsgpack => {
651 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
652 }
653 WsProtocol::Unknown => {}
654 }
655 return;
656 }
657
658 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
659 let clients = state.ws_relay.clients.lock().await;
660 let protocols = state.ws_relay.client_protocols.lock().await;
661 clients
662 .iter()
663 .filter(|(id, _)| **id != client_id)
664 .filter_map(|(id, tx)| {
665 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
666 match protocol {
667 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
668 Some((*id, tx.clone(), protocol))
669 }
670 WsProtocol::Unknown => None,
671 }
672 })
673 .collect()
674 };
675
676 if peers.is_empty() {
677 if origin_protocol == WsProtocol::HashtreeJson {
678 send_json(
679 state,
680 client_id,
681 WsResponse {
682 kind: "res",
683 id: request_id,
684 hash,
685 found: false,
686 },
687 )
688 .await;
689 }
690 return;
691 }
692
693 {
694 let mut pending = state.ws_relay.pending.lock().await;
695 for (peer_id, _, _) in &peers {
696 pending.insert(
697 (*peer_id, request_id),
698 PendingRequest {
699 origin_id: client_id,
700 hash: hash.clone(),
701 found: false,
702 origin_protocol,
703 },
704 );
705 }
706 }
707
708 let request_text = serde_json::to_string(&WsRequest {
709 kind: "req".to_string(),
710 id: request_id,
711 hash: hash.clone(),
712 })
713 .unwrap_or_else(|_| String::new());
714 for (peer_id, tx, protocol) in peers {
715 match protocol {
716 WsProtocol::HashtreeMsgpack => {
717 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
718 }
719 WsProtocol::HashtreeJson => {
720 let _ = tx.send(Message::Text(request_text.clone()));
721 }
722 WsProtocol::Unknown => {}
723 }
724 }
725
726 let timeout_state = state.clone();
727 let timeout_hash = hash.clone();
728 tokio::spawn(async move {
729 tokio::time::sleep(Duration::from_millis(1500)).await;
730 let mut pending = timeout_state.ws_relay.pending.lock().await;
731 let still_pending = pending
732 .iter()
733 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
734 let already_found = pending
735 .iter()
736 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
737 if !still_pending || already_found {
738 return;
739 }
740 let origin_protocol = pending
741 .iter()
742 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
743 .map(|(_, p)| p.origin_protocol)
744 .unwrap_or(WsProtocol::HashtreeJson);
745 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
746 drop(pending);
747 if origin_protocol == WsProtocol::HashtreeJson {
748 send_json(
749 &timeout_state,
750 client_id,
751 WsResponse {
752 kind: "res",
753 id: request_id,
754 hash: timeout_hash,
755 found: false,
756 },
757 )
758 .await;
759 }
760 });
761}
762
763async fn handle_response(
764 client_id: u64,
765 request_id: u32,
766 _hash: String,
767 found: bool,
768 state: &AppState,
769) {
770 let pending_entry = {
771 let pending = state.ws_relay.pending.lock().await;
772 pending
773 .get(&(client_id, request_id))
774 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
775 };
776
777 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
778 return;
779 };
780
781 if already_found && !found {
782 let mut pending = state.ws_relay.pending.lock().await;
783 pending.remove(&(client_id, request_id));
784 return;
785 }
786
787 if found {
788 let mut pending = state.ws_relay.pending.lock().await;
789 for ((_, id), p) in pending.iter_mut() {
790 if *id == request_id && p.origin_id == origin_id {
791 p.found = true;
792 }
793 }
794 drop(pending);
795 if origin_protocol == WsProtocol::HashtreeJson {
796 send_json(
797 state,
798 origin_id,
799 WsResponse {
800 kind: "res",
801 id: request_id,
802 hash: pending_hash,
803 found: true,
804 },
805 )
806 .await;
807 }
808 return;
809 }
810
811 let mut pending = state.ws_relay.pending.lock().await;
812 pending.remove(&(client_id, request_id));
813 let has_remaining = pending
814 .iter()
815 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
816 drop(pending);
817
818 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
819 send_json(
820 state,
821 origin_id,
822 WsResponse {
823 kind: "res",
824 id: request_id,
825 hash: pending_hash,
826 found: false,
827 },
828 )
829 .await;
830 }
831}
832
833async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
834 if let Some(msg) = parse_msgpack_message(&data) {
835 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
836 match msg {
837 DataMessage::Request(req) => {
838 let hash_hex = hex_encode(&req.h);
839 let request_id = state.ws_relay.next_request_id();
840 handle_request(
841 client_id,
842 request_id,
843 hash_hex,
844 WsProtocol::HashtreeMsgpack,
845 state,
846 )
847 .await;
848 }
849 DataMessage::Response(res) => {
850 handle_msgpack_response(client_id, res, state).await;
851 }
852 DataMessage::QuoteRequest(_)
853 | DataMessage::QuoteResponse(_)
854 | DataMessage::Payment(_)
855 | DataMessage::PaymentAck(_)
856 | DataMessage::Chunk(_)
857 | DataMessage::PeerHints(_)
858 | DataMessage::PubsubInterest(_)
859 | DataMessage::PubsubFrame(_)
860 | DataMessage::PubsubInventory(_)
861 | DataMessage::PubsubWant(_) => {}
862 }
863 return;
864 }
865
866 if data.len() < 4 {
868 return;
869 }
870 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
871 let pending_entry = {
872 let pending = state.ws_relay.pending.lock().await;
873 pending
874 .get(&(client_id, request_id))
875 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
876 };
877 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
878 return;
879 };
880
881 match origin_protocol {
882 WsProtocol::HashtreeJson => {
883 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
884 }
885 WsProtocol::HashtreeMsgpack => {
886 let Ok(hash_bytes) = from_hex(&hash_hex) else {
887 return;
888 };
889 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
890 }
891 WsProtocol::Unknown => {}
892 }
893
894 let mut pending = state.ws_relay.pending.lock().await;
895 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
896}
897
898async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage<'_>, state: &AppState) {
899 let replies = nostr_responses_for(&msg);
900 for reply in replies {
901 send_nostr(state, client_id, reply).await;
902 }
903}
904
905fn nostr_responses_for(msg: &NostrClientMessage<'_>) -> Vec<NostrRelayMessage<'static>> {
906 match msg {
907 NostrClientMessage::Event(event) => {
908 let ok = event.verify().is_ok();
909 let message = if ok { "" } else { "invalid: signature" };
910 vec![NostrRelayMessage::ok(event.id, ok, message)]
911 }
912 NostrClientMessage::Req {
913 subscription_id, ..
914 } => {
915 vec![NostrRelayMessage::eose(
916 subscription_id.clone().into_owned(),
917 )]
918 }
919 NostrClientMessage::Count {
920 subscription_id, ..
921 } => {
922 vec![NostrRelayMessage::count(
923 subscription_id.clone().into_owned(),
924 0,
925 )]
926 }
927 NostrClientMessage::Close(_) => Vec::new(),
928 NostrClientMessage::Auth(event) => {
929 let ok = event.verify().is_ok();
930 let message = if ok { "" } else { "invalid auth" };
931 vec![NostrRelayMessage::ok(event.id, ok, message)]
932 }
933 NostrClientMessage::NegOpen { .. }
934 | NostrClientMessage::NegMsg { .. }
935 | NostrClientMessage::NegClose { .. } => {
936 vec![NostrRelayMessage::notice("negentropy not supported")]
937 }
938 }
939}
940
941async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage<'_>) {
942 let text = response.as_json();
943 send_to_client(state, client_id, Message::Text(text)).await;
944}
945
946fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
947 let msg = parse_message(data).ok()?;
948 match msg {
949 DataMessage::Request(req) => {
950 if req.h.len() == 32 {
951 Some(DataMessage::Request(req))
952 } else {
953 None
954 }
955 }
956 DataMessage::Response(res) => {
957 if res.h.len() == 32 {
958 Some(DataMessage::Response(res))
959 } else {
960 None
961 }
962 }
963 DataMessage::QuoteRequest(req) => {
964 if req.h.len() == 32 {
965 Some(DataMessage::QuoteRequest(req))
966 } else {
967 None
968 }
969 }
970 DataMessage::QuoteResponse(res) => {
971 if res.h.len() == 32 {
972 Some(DataMessage::QuoteResponse(res))
973 } else {
974 None
975 }
976 }
977 DataMessage::Payment(req) => {
978 if req.h.len() == 32 {
979 Some(DataMessage::Payment(req))
980 } else {
981 None
982 }
983 }
984 DataMessage::PaymentAck(res) => {
985 if res.h.len() == 32 {
986 Some(DataMessage::PaymentAck(res))
987 } else {
988 None
989 }
990 }
991 DataMessage::Chunk(chunk) => {
992 if chunk.h.len() == 32 {
993 Some(DataMessage::Chunk(chunk))
994 } else {
995 None
996 }
997 }
998 DataMessage::PeerHints(_)
999 | DataMessage::PubsubInterest(_)
1000 | DataMessage::PubsubFrame(_)
1001 | DataMessage::PubsubInventory(_)
1002 | DataMessage::PubsubWant(_) => Some(msg),
1003 }
1004}
1005
1006async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
1007 let hash_hex = hex_encode(&res.h);
1008 let data = res.d.clone();
1009 let hash_bytes = res.h.clone();
1010
1011 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
1012 let mut seen = HashSet::new();
1013 {
1014 let pending = state.ws_relay.pending.lock().await;
1015 for ((peer_id, request_id), p) in pending.iter() {
1016 if *peer_id != client_id {
1017 continue;
1018 }
1019 if p.hash != hash_hex {
1020 continue;
1021 }
1022 if seen.insert((p.origin_id, *request_id)) {
1023 responses.push((p.origin_id, *request_id, p.origin_protocol));
1024 }
1025 }
1026 }
1027
1028 if responses.is_empty() {
1029 return;
1030 }
1031
1032 for (origin_id, request_id, protocol) in &responses {
1033 match protocol {
1034 WsProtocol::HashtreeJson => {
1035 send_json(
1036 state,
1037 *origin_id,
1038 WsResponse {
1039 kind: "res",
1040 id: *request_id,
1041 hash: hash_hex.clone(),
1042 found: true,
1043 },
1044 )
1045 .await;
1046 send_binary(state, *origin_id, *request_id, data.clone()).await;
1047 }
1048 WsProtocol::HashtreeMsgpack => {
1049 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1050 }
1051 WsProtocol::Unknown => {}
1052 }
1053 }
1054
1055 let completed: HashSet<(u64, u32)> = responses
1056 .into_iter()
1057 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1058 .collect();
1059 let mut pending = state.ws_relay.pending.lock().await;
1060 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1061}
1062
1063async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1064 if let Ok(text) = serde_json::to_string(&response) {
1065 send_to_client(state, client_id, Message::Text(text)).await;
1066 }
1067}
1068
1069async fn send_msgpack_request(
1070 state: &AppState,
1071 client_id: u64,
1072 hash: &[u8],
1073) -> Result<(), rmp_serde::encode::Error> {
1074 let req = DataRequest {
1075 h: hash.to_vec(),
1076 htl: MAX_HTL,
1077 q: None,
1078 };
1079 let wire = encode_request(&req)?;
1080 send_to_client(state, client_id, Message::Binary(wire)).await;
1081 Ok(())
1082}
1083
1084async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1085 let res = DataResponse {
1086 h: hash.to_vec(),
1087 d: data.to_vec(),
1088 i: None,
1089 n: None,
1090 };
1091 if let Ok(wire) = encode_response(&res) {
1092 send_to_client(state, client_id, Message::Binary(wire)).await;
1093 }
1094}
1095
1096async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1097 let mut packet = Vec::with_capacity(4 + payload.len());
1098 packet.extend_from_slice(&request_id.to_le_bytes());
1099 packet.extend_from_slice(&payload);
1100 send_to_client(state, client_id, Message::Binary(packet)).await;
1101}
1102
1103async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1104 let sender = {
1105 let clients = state.ws_relay.clients.lock().await;
1106 clients.get(&client_id).cloned()
1107 };
1108 if let Some(tx) = sender {
1109 let _ = tx.send(msg);
1110 }
1111}
1112
1113async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1114 let mut protocols = state.ws_relay.client_protocols.lock().await;
1115 protocols.insert(client_id, protocol);
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120 use super::*;
1121 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1122 use anyhow::Result;
1123 use futures::{SinkExt, StreamExt};
1124 use nostr::secp256k1::schnorr::Signature;
1125 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1126 use std::collections::HashSet;
1127 use std::sync::Arc;
1128 use tempfile::TempDir;
1129 use tokio::net::TcpListener;
1130 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1131
1132 #[test]
1133 fn parse_ws_text_message_detects_nostr_req() {
1134 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1135 match parse_ws_text_message(msg) {
1136 Some(WsTextMessage::Nostr(_)) => {}
1137 other => panic!("expected Nostr message, got {:?}", other),
1138 }
1139 }
1140
1141 #[test]
1142 fn parse_ws_text_message_detects_hashtree_request() {
1143 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1144 match parse_ws_text_message(msg) {
1145 Some(WsTextMessage::Hashtree(_)) => {}
1146 other => panic!("expected Hashtree message, got {:?}", other),
1147 }
1148 }
1149
1150 #[test]
1151 fn nostr_replies_for_req_is_eose() {
1152 let sub = SubscriptionId::new("sub-1");
1153 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1154 let replies = nostr_responses_for(&msg);
1155 assert_eq!(replies.len(), 1);
1156 match &replies[0] {
1157 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id.as_ref(), &sub),
1158 other => panic!("expected EOSE, got {:?}", other),
1159 }
1160 }
1161
1162 #[test]
1163 fn nostr_replies_for_event_ok() {
1164 let keys = Keys::generate();
1165 let event = EventBuilder::new(Kind::TextNote, "hello")
1166 .sign_with_keys(&keys)
1167 .unwrap();
1168 let msg = NostrClientMessage::event(event.clone());
1169 let replies = nostr_responses_for(&msg);
1170 assert_eq!(replies.len(), 1);
1171 match &replies[0] {
1172 NostrRelayMessage::Ok {
1173 event_id, status, ..
1174 } => {
1175 assert_eq!(event_id, &event.id);
1176 assert!(*status);
1177 }
1178 other => panic!("expected OK, got {:?}", other),
1179 }
1180 }
1181
1182 #[test]
1183 fn nostr_replies_for_invalid_event_is_not_ok() {
1184 let keys = Keys::generate();
1185 let mut event = EventBuilder::new(Kind::TextNote, "hello")
1186 .sign_with_keys(&keys)
1187 .unwrap();
1188 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1189 let msg = NostrClientMessage::event(event);
1190 let replies = nostr_responses_for(&msg);
1191 assert_eq!(replies.len(), 1);
1192 match &replies[0] {
1193 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1194 other => panic!("expected OK=false, got {:?}", other),
1195 }
1196 }
1197
1198 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1199 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1200 let addr = listener.local_addr().expect("relay addr");
1201 tokio::spawn(async move {
1202 let (stream, _) = listener.accept().await.expect("accept relay");
1203 let ws = accept_async(stream).await.expect("accept websocket");
1204 let (mut write, mut read) = ws.split();
1205
1206 while let Some(Ok(message)) = read.next().await {
1207 let TungsteniteMessage::Text(text) = message else {
1208 continue;
1209 };
1210 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1211 continue;
1212 };
1213 if let NostrClientMessage::Req {
1214 subscription_id,
1215 filters,
1216 } = parsed
1217 {
1218 let subscription_id = subscription_id.into_owned();
1219 let filters = filters
1220 .into_iter()
1221 .map(|filter| filter.into_owned())
1222 .collect::<Vec<_>>();
1223 for event in events.iter().filter(|event| {
1224 filters
1225 .iter()
1226 .any(|filter| filter.match_event(event, Default::default()))
1227 }) {
1228 let _ = write
1229 .send(TungsteniteMessage::Text(
1230 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1231 .as_json()
1232 .into(),
1233 ))
1234 .await;
1235 }
1236 let _ = write
1237 .send(TungsteniteMessage::Text(
1238 NostrRelayMessage::eose(subscription_id).as_json().into(),
1239 ))
1240 .await;
1241 }
1242 }
1243 });
1244 format!("ws://{}", addr)
1245 }
1246
1247 fn test_app_state(
1248 tmp: &TempDir,
1249 relay: Arc<NostrRelay>,
1250 relay_url: String,
1251 ) -> Result<AppState> {
1252 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1253 tmp.path(),
1254 None,
1255 128 * 1024 * 1024,
1256 )?);
1257 Ok(AppState {
1258 store,
1259 auth: None,
1260 daemon_started_at: 1_700_000_000,
1261 peer_mode: crate::config::ServerMode::Normal,
1262 hash_get_enabled: true,
1263 http_webrtc_fetch: true,
1264 webrtc_peers: None,
1265 fips_transport: None,
1266 fetch_from_fips_peers: true,
1267 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1268 max_upload_bytes: 5 * 1024 * 1024,
1269 public_writes: true,
1270 public_plaintext_reads: true,
1271 require_random_untrusted_ingest: false,
1272 optimistic_blossom_uploads: false,
1273 optimistic_upload_queue_bytes: 512 * 1024 * 1024,
1274 optimistic_upload_queue: Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)),
1275 allowed_pubkeys: HashSet::new(),
1276 upstream_blossom: Vec::new(),
1277 social_graph: None,
1278 social_graph_store: None,
1279 social_graph_root: None,
1280 socialgraph_snapshot_public: false,
1281 nostr_relay: Some(relay),
1282 nostr_relay_urls: vec![relay_url],
1283 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1284 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1285 std::collections::HashMap::new(),
1286 )),
1287 inflight_blob_reads: Arc::new(
1288 tokio::sync::Mutex::new(std::collections::HashMap::new()),
1289 ),
1290 blob_cache: Arc::new(crate::blob_cache::BlobCache::for_tests()),
1291 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1292 super::super::auth::new_lookup_cache(),
1293 )),
1294 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1295 super::super::auth::new_lookup_cache(),
1296 )),
1297 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1298 super::super::auth::new_lookup_cache(),
1299 )),
1300 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1301 })
1302 }
1303
1304 #[tokio::test]
1305 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1306 let tmp = TempDir::new()?;
1307 let graph_store = {
1308 let _guard = crate::socialgraph::test_lock();
1309 crate::socialgraph::open_social_graph_store_with_mapsize(
1310 tmp.path(),
1311 Some(128 * 1024 * 1024),
1312 )?
1313 };
1314 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1315 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1316 Arc::clone(&backend),
1317 0,
1318 HashSet::new(),
1319 ));
1320
1321 let keys = Keys::generate();
1322 let relay = Arc::new(NostrRelay::new(
1323 Arc::clone(&backend),
1324 tmp.path().to_path_buf(),
1325 HashSet::from([keys.public_key().to_hex()]),
1326 Some(access),
1327 NostrRelayConfig {
1328 spambox_db_max_bytes: 0,
1329 ..Default::default()
1330 },
1331 )?);
1332
1333 let event = EventBuilder::new(Kind::from(30078_u16), "")
1334 .tags([
1335 nostr::Tag::parse(["d", "videos/Test"]).expect("d tag"),
1336 nostr::Tag::parse(["l", "hashtree"]).expect("label tag"),
1337 ])
1338 .sign_with_keys(&keys)?;
1339
1340 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1341 let filter = Filter::new()
1342 .authors(vec![event.pubkey])
1343 .kinds(vec![event.kind]);
1344 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1345 let client_id = 7_u64;
1346 let (tx, mut rx) = mpsc::unbounded_channel();
1347 state.ws_relay.clients.lock().await.insert(client_id, tx);
1348 let subscription_id = SubscriptionId::new("sub-1");
1349
1350 start_upstream_nostr_subscription(
1351 &state,
1352 client_id,
1353 subscription_id.clone(),
1354 vec![filter.clone()],
1355 )
1356 .await;
1357
1358 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1359 .await?
1360 .expect("forwarded upstream event");
1361 let Message::Text(text) = forwarded else {
1362 panic!("expected text event");
1363 };
1364 match NostrRelayMessage::from_json(text.as_str())? {
1365 NostrRelayMessage::Event {
1366 subscription_id: sid,
1367 event: forwarded_event,
1368 } => {
1369 assert_eq!(sid.as_ref(), &subscription_id);
1370 assert_eq!(forwarded_event.id, event.id);
1371 }
1372 other => panic!("expected forwarded EVENT, got {:?}", other),
1373 }
1374
1375 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1376 .await?
1377 .expect("forwarded upstream eose");
1378 let Message::Text(eose_text) = eose else {
1379 panic!("expected text eose");
1380 };
1381 match NostrRelayMessage::from_json(eose_text.as_str())? {
1382 NostrRelayMessage::EndOfStoredEvents(sid) => {
1383 assert_eq!(sid.as_ref(), &subscription_id);
1384 }
1385 other => panic!("expected forwarded EOSE, got {:?}", other),
1386 }
1387
1388 let events = relay.query_events(&filter, 10).await;
1389 assert_eq!(events.len(), 1);
1390 assert_eq!(events[0].id, event.id);
1391
1392 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1393 assert!(state
1394 .ws_relay
1395 .upstream_nostr_subscriptions
1396 .lock()
1397 .await
1398 .is_empty());
1399 Ok(())
1400 }
1401
1402 #[tokio::test]
1403 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1404 let tmp = TempDir::new()?;
1405 let graph_store = {
1406 let _guard = crate::socialgraph::test_lock();
1407 crate::socialgraph::open_social_graph_store_with_mapsize(
1408 tmp.path(),
1409 Some(128 * 1024 * 1024),
1410 )?
1411 };
1412 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1413 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1414 Arc::clone(&backend),
1415 0,
1416 HashSet::new(),
1417 ));
1418
1419 let keys = Keys::generate();
1420 let relay = Arc::new(NostrRelay::new(
1421 Arc::clone(&backend),
1422 tmp.path().to_path_buf(),
1423 HashSet::from([keys.public_key().to_hex()]),
1424 Some(access),
1425 NostrRelayConfig {
1426 spambox_db_max_bytes: 0,
1427 ..Default::default()
1428 },
1429 )?);
1430
1431 let event = EventBuilder::new(Kind::from(30078_u16), "")
1432 .tags([
1433 nostr::Tag::parse(["d", "videos/Test"]).expect("d tag"),
1434 nostr::Tag::parse(["l", "hashtree"]).expect("label tag"),
1435 ])
1436 .sign_with_keys(&keys)?;
1437
1438 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1439 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1440 let client_id = 11_u64;
1441 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1442 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1443 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1444 relay.register_client(client_id, relay_tx, None).await;
1445
1446 let request = NostrClientMessage::req(
1447 SubscriptionId::new("feed"),
1448 vec![Filter::new()
1449 .authors(vec![event.pubkey])
1450 .kinds(vec![event.kind])],
1451 )
1452 .as_json();
1453
1454 handle_message(client_id, Message::Text(request.into()), &state).await;
1455
1456 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1457 .await?
1458 .expect("first forwarded message");
1459 let Message::Text(first_text) = first else {
1460 panic!("expected text event");
1461 };
1462 match NostrRelayMessage::from_json(first_text.as_str())? {
1463 NostrRelayMessage::Event {
1464 event: forwarded_event,
1465 ..
1466 } => {
1467 assert_eq!(forwarded_event.id, event.id);
1468 }
1469 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1470 }
1471
1472 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1473 .await?
1474 .expect("second forwarded message");
1475 let Message::Text(second_text) = second else {
1476 panic!("expected text eose");
1477 };
1478 match NostrRelayMessage::from_json(second_text.as_str())? {
1479 NostrRelayMessage::EndOfStoredEvents(sid) => {
1480 assert_eq!(sid.as_ref(), &SubscriptionId::new("feed"));
1481 }
1482 other => panic!("expected aggregated EOSE, got {:?}", other),
1483 }
1484
1485 Ok(())
1486 }
1487
1488 #[tokio::test]
1489 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1490 let tmp = TempDir::new()?;
1491 let graph_store = {
1492 let _guard = crate::socialgraph::test_lock();
1493 crate::socialgraph::open_social_graph_store_with_mapsize(
1494 tmp.path(),
1495 Some(128 * 1024 * 1024),
1496 )?
1497 };
1498 let author_keys = Keys::generate();
1499 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1500 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1501 Arc::clone(&backend),
1502 0,
1503 HashSet::from([author_keys.public_key().to_hex()]),
1504 ));
1505 let relay = Arc::new(NostrRelay::new(
1506 Arc::clone(&backend),
1507 tmp.path().to_path_buf(),
1508 HashSet::from([author_keys.public_key().to_hex()]),
1509 Some(access),
1510 NostrRelayConfig {
1511 spambox_db_max_bytes: 0,
1512 ..Default::default()
1513 },
1514 )?);
1515
1516 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1517 let listener = TcpListener::bind("127.0.0.1:0").await?;
1518 let addr = listener.local_addr()?;
1519 let client_pubkey = author_keys.public_key().to_hex();
1520 let app = axum::Router::new().route(
1521 "/ws",
1522 axum::routing::get({
1523 let state = state.clone();
1524 let client_pubkey = client_pubkey.clone();
1525 move |ws: WebSocketUpgrade| {
1526 let state = state.clone();
1527 let client_pubkey = client_pubkey.clone();
1528 async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1529 }
1530 }),
1531 );
1532 tokio::spawn(async move {
1533 let _ = axum::serve(listener, app).await;
1534 });
1535
1536 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1537 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack")
1538 .sign_with_keys(&author_keys)?;
1539 socket
1540 .send(TungsteniteMessage::Text(
1541 NostrClientMessage::event(event.clone()).as_json().into(),
1542 ))
1543 .await?;
1544
1545 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1546 .await?
1547 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1548 let TungsteniteMessage::Text(text) = reply else {
1549 anyhow::bail!("expected text publish ack");
1550 };
1551
1552 match NostrRelayMessage::from_json(text.as_str())? {
1553 NostrRelayMessage::Ok {
1554 event_id, status, ..
1555 } => {
1556 assert_eq!(event_id, event.id);
1557 assert!(status);
1558 }
1559 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1560 }
1561
1562 let stored = relay
1563 .query_events(
1564 &Filter::new()
1565 .authors(vec![event.pubkey])
1566 .kinds(vec![event.kind]),
1567 10,
1568 )
1569 .await;
1570 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1571 Ok(())
1572 }
1573
1574 #[tokio::test]
1575 async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1576 let tmp = TempDir::new()?;
1577 let graph_store = {
1578 let _guard = crate::socialgraph::test_lock();
1579 crate::socialgraph::open_social_graph_store_with_mapsize(
1580 tmp.path(),
1581 Some(128 * 1024 * 1024),
1582 )?
1583 };
1584 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1585 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1586 Arc::clone(&backend),
1587 0,
1588 HashSet::new(),
1589 ));
1590 let relay = Arc::new(NostrRelay::new(
1591 Arc::clone(&backend),
1592 tmp.path().to_path_buf(),
1593 HashSet::new(),
1594 Some(access),
1595 NostrRelayConfig {
1596 spambox_db_max_bytes: 0,
1597 spambox_max_reqs_per_min: 1,
1598 ..Default::default()
1599 },
1600 )?);
1601
1602 let state = test_app_state(&tmp, relay, String::new())?;
1603 let listener = TcpListener::bind("127.0.0.1:0").await?;
1604 let addr = listener.local_addr()?;
1605 let app = axum::Router::new().route(
1606 "/ws",
1607 axum::routing::get({
1608 let state = state.clone();
1609 move |ws: WebSocketUpgrade| {
1610 let state = state.clone();
1611 async move { ws_data_with_client_pubkey(state, ws, None) }
1612 }
1613 }),
1614 );
1615 tokio::spawn(async move {
1616 let _ = axum::serve(listener, app).await;
1617 });
1618
1619 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1620 socket
1621 .send(TungsteniteMessage::Text(
1622 NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1623 .as_json()
1624 .into(),
1625 ))
1626 .await?;
1627
1628 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1629 .await?
1630 .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1631 let TungsteniteMessage::Text(first_text) = first else {
1632 anyhow::bail!("expected text EOSE reply");
1633 };
1634 match NostrRelayMessage::from_json(first_text.as_str())? {
1635 NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1636 assert_eq!(subscription_id.as_ref(), &SubscriptionId::new("sub-1"));
1637 }
1638 other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1639 }
1640
1641 socket
1642 .send(TungsteniteMessage::Text(
1643 NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1644 .as_json()
1645 .into(),
1646 ))
1647 .await?;
1648
1649 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1650 .await?
1651 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1652 let TungsteniteMessage::Text(second_text) = second else {
1653 anyhow::bail!("expected text CLOSED reply");
1654 };
1655 let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1656 assert_eq!(
1657 second_value,
1658 serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1659 );
1660
1661 Ok(())
1662 }
1663
1664 #[tokio::test]
1665 async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1666 let tmp = TempDir::new()?;
1667 let graph_store = {
1668 let _guard = crate::socialgraph::test_lock();
1669 crate::socialgraph::open_social_graph_store_with_mapsize(
1670 tmp.path(),
1671 Some(128 * 1024 * 1024),
1672 )?
1673 };
1674 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1675 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1676 Arc::clone(&backend),
1677 0,
1678 HashSet::new(),
1679 ));
1680 let relay = Arc::new(NostrRelay::new(
1681 Arc::clone(&backend),
1682 tmp.path().to_path_buf(),
1683 HashSet::new(),
1684 Some(access),
1685 NostrRelayConfig {
1686 spambox_db_max_bytes: 0,
1687 spambox_max_events_per_min: 1,
1688 ..Default::default()
1689 },
1690 )?);
1691
1692 let state = test_app_state(&tmp, relay, String::new())?;
1693 let listener = TcpListener::bind("127.0.0.1:0").await?;
1694 let addr = listener.local_addr()?;
1695 let app = axum::Router::new().route(
1696 "/ws",
1697 axum::routing::get({
1698 let state = state.clone();
1699 move |ws: WebSocketUpgrade| {
1700 let state = state.clone();
1701 async move { ws_data_with_client_pubkey(state, ws, None) }
1702 }
1703 }),
1704 );
1705 tokio::spawn(async move {
1706 let _ = axum::serve(listener, app).await;
1707 });
1708
1709 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1710 let author_keys = Keys::generate();
1711 let event_a =
1712 EventBuilder::new(Kind::TextNote, "spambox-a").sign_with_keys(&author_keys)?;
1713 let event_b =
1714 EventBuilder::new(Kind::TextNote, "spambox-b").sign_with_keys(&author_keys)?;
1715
1716 socket
1717 .send(TungsteniteMessage::Text(
1718 NostrClientMessage::event(event_a.clone()).as_json().into(),
1719 ))
1720 .await?;
1721
1722 let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1723 .await?
1724 .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1725 let TungsteniteMessage::Text(first_text) = first else {
1726 anyhow::bail!("expected text publish ack");
1727 };
1728 match NostrRelayMessage::from_json(first_text.as_str())? {
1729 NostrRelayMessage::Ok {
1730 event_id,
1731 status,
1732 message,
1733 } => {
1734 assert_eq!(event_id, event_a.id);
1735 assert!(status);
1736 assert_eq!(message, "spambox");
1737 }
1738 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1739 }
1740
1741 socket
1742 .send(TungsteniteMessage::Text(
1743 NostrClientMessage::event(event_b.clone()).as_json().into(),
1744 ))
1745 .await?;
1746
1747 let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1748 .await?
1749 .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1750 let TungsteniteMessage::Text(second_text) = second else {
1751 anyhow::bail!("expected text publish ack");
1752 };
1753 match NostrRelayMessage::from_json(second_text.as_str())? {
1754 NostrRelayMessage::Ok {
1755 event_id,
1756 status,
1757 message,
1758 } => {
1759 assert_eq!(event_id, event_b.id);
1760 assert!(!status);
1761 assert_eq!(message, "rate limited");
1762 }
1763 other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1764 }
1765
1766 Ok(())
1767 }
1768}