1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::webrtc::types::{
21 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
22};
23use hex::encode as hex_encode;
24
25#[derive(Debug, Deserialize)]
26#[serde(tag = "type")]
27enum WsClientMessage {
28 #[serde(rename = "req")]
29 Request { id: u32, hash: String },
30 #[serde(rename = "res")]
31 Response { id: u32, hash: String, found: bool },
32}
33
34#[derive(Debug)]
35enum WsTextMessage {
36 Hashtree(WsClientMessage),
37 Nostr(NostrClientMessage),
38}
39
40#[derive(Debug, Deserialize, Serialize)]
41struct WsRequest {
42 #[serde(rename = "type")]
43 kind: String,
44 id: u32,
45 hash: String,
46}
47
48#[derive(Debug, Serialize)]
49struct WsResponse {
50 #[serde(rename = "type")]
51 kind: &'static str,
52 id: u32,
53 hash: String,
54 found: bool,
55}
56
57pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
58 ws_data_with_client_pubkey(state, ws, None)
59}
60
61pub fn ws_data_with_client_pubkey(
62 state: AppState,
63 ws: WebSocketUpgrade,
64 client_pubkey: Option<String>,
65) -> impl IntoResponse {
66 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
67}
68
69async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
70 let client_id = state
73 .nostr_relay
74 .as_ref()
75 .map(|relay| relay.next_client_id())
76 .unwrap_or_else(|| state.ws_relay.next_id());
77 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
78
79 {
80 let mut clients = state.ws_relay.clients.lock().await;
81 clients.insert(client_id, tx);
82 }
83 {
84 let mut protocols = state.ws_relay.client_protocols.lock().await;
85 protocols.insert(client_id, WsProtocol::HashtreeJson);
86 }
87
88 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
89 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
90 relay
91 .register_client(client_id, nostr_tx, client_pubkey.clone())
92 .await;
93 Some(nostr_rx)
94 } else {
95 None
96 };
97
98 let (mut sender, mut receiver) = socket.split();
99 loop {
100 tokio::select! {
101 maybe_msg = rx.recv() => {
102 let Some(msg) = maybe_msg else {
103 break;
104 };
105 if sender.send(msg).await.is_err() {
106 break;
107 }
108 }
109 maybe_text = async {
110 match &mut nostr_rx {
111 Some(rx) => rx.recv().await,
112 None => std::future::pending().await,
113 }
114 } => {
115 let Some(text) = maybe_text else {
116 nostr_rx = None;
117 continue;
118 };
119 if sender.send(Message::Text(text)).await.is_err() {
120 break;
121 }
122 }
123 maybe_incoming = receiver.next() => {
124 match maybe_incoming {
125 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
126 Some(Err(_)) | None => break,
127 }
128 }
129 }
130 }
131
132 close_all_upstream_nostr_subscriptions(&state, client_id).await;
133
134 {
135 let mut clients = state.ws_relay.clients.lock().await;
136 clients.remove(&client_id);
137 }
138 {
139 let mut protocols = state.ws_relay.client_protocols.lock().await;
140 protocols.remove(&client_id);
141 }
142 {
143 let mut pending = state.ws_relay.pending.lock().await;
144 pending.retain(|(peer_id, _), _| *peer_id != client_id);
145 }
146
147 if let Some(relay) = &state.nostr_relay {
148 relay.unregister_client(client_id).await;
149 }
150}
151
152fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
153 let trimmed = text.trim_start();
154 if trimmed.starts_with('[') {
155 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
156 return Some(WsTextMessage::Nostr(msg));
157 }
158 }
159
160 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
161 return Some(WsTextMessage::Hashtree(msg));
162 }
163
164 None
165}
166async fn close_upstream_nostr_subscription(
167 state: &AppState,
168 client_id: u64,
169 subscription_id: &SubscriptionId,
170) {
171 let key = (client_id, subscription_id.to_string());
172 let subscription = {
173 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
174 subscriptions.remove(&key)
175 };
176 if let Some(subscription) = subscription {
177 let _ = subscription.close_tx.send(true);
178 for task in subscription.tasks {
179 task.abort();
180 }
181 }
182 state
183 .ws_relay
184 .upstream_pending_eose
185 .lock()
186 .await
187 .remove(&key);
188 state
189 .ws_relay
190 .upstream_seen_events
191 .lock()
192 .await
193 .remove(&key);
194}
195
196async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
197 let keys = {
198 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
199 subscriptions
200 .keys()
201 .filter(|(id, _)| *id == client_id)
202 .cloned()
203 .collect::<Vec<_>>()
204 };
205 for (_, sub_id) in keys {
206 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
207 }
208}
209
210async fn forward_upstream_nostr_message(
211 state: &AppState,
212 client_id: u64,
213 subscription_id: &SubscriptionId,
214 text: &str,
215) {
216 let Ok(message) = NostrRelayMessage::from_json(text) else {
217 return;
218 };
219
220 match message {
221 NostrRelayMessage::Event {
222 subscription_id: sid,
223 event,
224 } if sid == *subscription_id => {
225 let event = *event;
226 let key = (client_id, subscription_id.to_string());
227 let event_id = event.id.to_hex();
228 let inserted = {
229 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
230 seen_events.entry(key).or_default().insert(event_id)
231 };
232 if !inserted {
233 return;
234 }
235 if let Some(relay) = &state.nostr_relay {
236 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
237 }
238 send_nostr(
239 state,
240 client_id,
241 NostrRelayMessage::event(subscription_id.clone(), event),
242 )
243 .await;
244 }
245 NostrRelayMessage::Closed {
246 subscription_id: sid,
247 message,
248 } if sid == *subscription_id => {
249 send_nostr(
250 state,
251 client_id,
252 NostrRelayMessage::closed(subscription_id.clone(), message),
253 )
254 .await;
255 }
256 _ => {}
257 }
258}
259
260async fn mark_upstream_nostr_relay_complete(
261 state: &AppState,
262 client_id: u64,
263 subscription_id: &SubscriptionId,
264) {
265 let key = (client_id, subscription_id.to_string());
266 let should_send_eose = {
267 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
268 let Some(remaining) = pending.get_mut(&key) else {
269 return;
270 };
271 if *remaining > 0 {
272 *remaining -= 1;
273 }
274 if *remaining == 0 {
275 pending.remove(&key);
276 true
277 } else {
278 false
279 }
280 };
281
282 if should_send_eose {
283 send_nostr(
284 state,
285 client_id,
286 NostrRelayMessage::eose(subscription_id.clone()),
287 )
288 .await;
289 }
290}
291
292async fn run_upstream_nostr_subscription(
293 state: AppState,
294 client_id: u64,
295 relay_url: String,
296 subscription_id: SubscriptionId,
297 filters: Vec<NostrFilter>,
298 mut close_rx: watch::Receiver<bool>,
299) {
300 let mut relay_complete = false;
301 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
302 tracing::warn!(
303 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
304 client_id,
305 subscription_id,
306 relay_url,
307 );
308 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
309 return;
310 };
311 let (mut write, mut read) = socket.split();
312 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
313 state
314 .ws_relay
315 .note_upstream_relay_send(request.as_bytes().len());
316 if write
317 .send(TungsteniteMessage::Text(request.into()))
318 .await
319 .is_err()
320 {
321 tracing::warn!(
322 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
323 client_id,
324 subscription_id,
325 relay_url,
326 );
327 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
328 return;
329 }
330
331 loop {
332 tokio::select! {
333 _ = close_rx.changed() => {
334 if *close_rx.borrow() {
335 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
336 state
337 .ws_relay
338 .note_upstream_relay_send(close.as_bytes().len());
339 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
340 let _ = write.close().await;
341 break;
342 }
343 }
344 message = read.next() => {
345 match message {
346 Some(Ok(TungsteniteMessage::Text(text))) => {
347 state
348 .ws_relay
349 .note_upstream_relay_receive(text.as_bytes().len());
350 if matches!(
351 NostrRelayMessage::from_json(text.as_str()),
352 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
353 ) {
354 if !relay_complete {
355 relay_complete = true;
356 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
357 }
358 continue;
359 }
360 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
361 }
362 Some(Ok(TungsteniteMessage::Ping(payload))) => {
363 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
364 }
365 Some(Ok(TungsteniteMessage::Close(_))) | None => {
366 if !relay_complete {
367 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
368 }
369 break;
370 }
371 Some(Err(_)) => {
372 if !relay_complete {
373 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
374 }
375 break;
376 }
377 _ => {}
378 }
379 }
380 }
381 }
382}
383
384async fn start_upstream_nostr_subscription(
385 state: &AppState,
386 client_id: u64,
387 subscription_id: SubscriptionId,
388 filters: Vec<NostrFilter>,
389) -> usize {
390 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
391 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
392 tracing::info!(
393 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
394 client_id,
395 subscription_id,
396 state.nostr_relay_urls.len(),
397 filters.len(),
398 );
399 return 0;
400 }
401
402 let mut relay_urls = Vec::new();
403 let mut seen = HashSet::new();
404 for relay in &state.nostr_relay_urls {
405 let relay = relay.trim();
406 if relay.is_empty() || !seen.insert(relay.to_string()) {
407 continue;
408 }
409 relay_urls.push(relay.to_string());
410 }
411 if relay_urls.is_empty() {
412 tracing::info!(
413 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
414 client_id,
415 subscription_id,
416 );
417 return 0;
418 }
419
420 tracing::info!(
421 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
422 client_id,
423 subscription_id,
424 relay_urls.len(),
425 );
426
427 let key = (client_id, subscription_id.to_string());
428 state
429 .ws_relay
430 .upstream_seen_events
431 .lock()
432 .await
433 .insert(key.clone(), HashSet::new());
434 state
435 .ws_relay
436 .upstream_pending_eose
437 .lock()
438 .await
439 .insert(key.clone(), relay_urls.len());
440
441 let (close_tx, close_rx) = watch::channel(false);
442 let mut tasks = Vec::new();
443 for relay_url in &relay_urls {
444 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
445 state.clone(),
446 client_id,
447 relay_url.clone(),
448 subscription_id.clone(),
449 filters.clone(),
450 close_rx.clone(),
451 )));
452 }
453
454 state
455 .ws_relay
456 .upstream_nostr_subscriptions
457 .lock()
458 .await
459 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
460 relay_urls.len()
461}
462
463async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
464 match msg {
465 Message::Text(text) => {
466 if let Some(msg) = parse_ws_text_message(&text) {
467 match msg {
468 WsTextMessage::Hashtree(msg) => {
469 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
470 match msg {
471 WsClientMessage::Request { id, hash } => {
472 handle_request(
473 client_id,
474 id,
475 hash,
476 WsProtocol::HashtreeJson,
477 state,
478 )
479 .await;
480 }
481 WsClientMessage::Response { id, hash, found } => {
482 handle_response(client_id, id, hash, found, state).await;
483 }
484 }
485 }
486 WsTextMessage::Nostr(msg) => {
487 if let Some(relay) = &state.nostr_relay {
488 match msg {
489 NostrClientMessage::Req {
490 subscription_id,
491 filters,
492 } => {
493 let local_events = match relay
494 .register_subscription_query(
495 client_id,
496 subscription_id.clone(),
497 filters.clone(),
498 )
499 .await
500 {
501 Ok(events) => events,
502 Err(message) => {
503 send_nostr(
504 state,
505 client_id,
506 NostrRelayMessage::closed(subscription_id, message),
507 )
508 .await;
509 return;
510 }
511 };
512
513 let upstream_relays = start_upstream_nostr_subscription(
514 state,
515 client_id,
516 subscription_id.clone(),
517 filters,
518 )
519 .await;
520 if upstream_relays > 0 {
521 let key = (client_id, subscription_id.to_string());
522 let mut seen_events =
523 state.ws_relay.upstream_seen_events.lock().await;
524 seen_events.entry(key).or_default().extend(
525 local_events.iter().map(|event| event.id.to_hex()),
526 );
527 }
528 for event in local_events {
529 send_nostr(
530 state,
531 client_id,
532 NostrRelayMessage::event(
533 subscription_id.clone(),
534 event,
535 ),
536 )
537 .await;
538 }
539 if upstream_relays == 0 {
540 send_nostr(
541 state,
542 client_id,
543 NostrRelayMessage::eose(subscription_id),
544 )
545 .await;
546 }
547 }
548 NostrClientMessage::Close(subscription_id) => {
549 close_upstream_nostr_subscription(
550 state,
551 client_id,
552 &subscription_id,
553 )
554 .await;
555 relay
556 .handle_client_message(
557 client_id,
558 NostrClientMessage::Close(subscription_id.clone()),
559 )
560 .await;
561 }
562 other => {
563 relay.handle_client_message(client_id, other).await;
564 }
565 }
566 } else {
567 handle_nostr_message(client_id, msg, state).await;
568 }
569 }
570 }
571 }
572 }
573 Message::Binary(data) => {
574 handle_binary(client_id, data, state).await;
575 }
576 Message::Close(_) => {}
577 _ => {}
578 }
579}
580
581async fn handle_request(
582 client_id: u64,
583 request_id: u32,
584 hash: String,
585 origin_protocol: WsProtocol,
586 state: &AppState,
587) {
588 let hash_hex = hash.to_lowercase();
589 let hash_bytes = match from_hex(&hash_hex) {
590 Ok(bytes) => bytes,
591 Err(_) => {
592 if origin_protocol == WsProtocol::HashtreeJson {
593 send_json(
594 state,
595 client_id,
596 WsResponse {
597 kind: "res",
598 id: request_id,
599 hash,
600 found: false,
601 },
602 )
603 .await;
604 }
605 return;
606 }
607 };
608
609 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
610 match origin_protocol {
611 WsProtocol::HashtreeJson => {
612 send_json(
613 state,
614 client_id,
615 WsResponse {
616 kind: "res",
617 id: request_id,
618 hash: hash.clone(),
619 found: true,
620 },
621 )
622 .await;
623 send_binary(state, client_id, request_id, data).await;
624 }
625 WsProtocol::HashtreeMsgpack => {
626 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
627 }
628 WsProtocol::Unknown => {}
629 }
630 return;
631 }
632
633 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
634 let clients = state.ws_relay.clients.lock().await;
635 let protocols = state.ws_relay.client_protocols.lock().await;
636 clients
637 .iter()
638 .filter(|(id, _)| **id != client_id)
639 .filter_map(|(id, tx)| {
640 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
641 match protocol {
642 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
643 Some((*id, tx.clone(), protocol))
644 }
645 WsProtocol::Unknown => None,
646 }
647 })
648 .collect()
649 };
650
651 if peers.is_empty() {
652 if origin_protocol == WsProtocol::HashtreeJson {
653 send_json(
654 state,
655 client_id,
656 WsResponse {
657 kind: "res",
658 id: request_id,
659 hash,
660 found: false,
661 },
662 )
663 .await;
664 }
665 return;
666 }
667
668 {
669 let mut pending = state.ws_relay.pending.lock().await;
670 for (peer_id, _, _) in &peers {
671 pending.insert(
672 (*peer_id, request_id),
673 PendingRequest {
674 origin_id: client_id,
675 hash: hash.clone(),
676 found: false,
677 origin_protocol,
678 },
679 );
680 }
681 }
682
683 let request_text = serde_json::to_string(&WsRequest {
684 kind: "req".to_string(),
685 id: request_id,
686 hash: hash.clone(),
687 })
688 .unwrap_or_else(|_| String::new());
689 for (peer_id, tx, protocol) in peers {
690 match protocol {
691 WsProtocol::HashtreeMsgpack => {
692 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
693 }
694 WsProtocol::HashtreeJson => {
695 let _ = tx.send(Message::Text(request_text.clone()));
696 }
697 WsProtocol::Unknown => {}
698 }
699 }
700
701 let timeout_state = state.clone();
702 let timeout_hash = hash.clone();
703 tokio::spawn(async move {
704 tokio::time::sleep(Duration::from_millis(1500)).await;
705 let mut pending = timeout_state.ws_relay.pending.lock().await;
706 let still_pending = pending
707 .iter()
708 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
709 let already_found = pending
710 .iter()
711 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
712 if !still_pending || already_found {
713 return;
714 }
715 let origin_protocol = pending
716 .iter()
717 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
718 .map(|(_, p)| p.origin_protocol)
719 .unwrap_or(WsProtocol::HashtreeJson);
720 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
721 drop(pending);
722 if origin_protocol == WsProtocol::HashtreeJson {
723 send_json(
724 &timeout_state,
725 client_id,
726 WsResponse {
727 kind: "res",
728 id: request_id,
729 hash: timeout_hash,
730 found: false,
731 },
732 )
733 .await;
734 }
735 });
736}
737
738async fn handle_response(
739 client_id: u64,
740 request_id: u32,
741 _hash: String,
742 found: bool,
743 state: &AppState,
744) {
745 let pending_entry = {
746 let pending = state.ws_relay.pending.lock().await;
747 pending
748 .get(&(client_id, request_id))
749 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
750 };
751
752 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
753 return;
754 };
755
756 if already_found && !found {
757 let mut pending = state.ws_relay.pending.lock().await;
758 pending.remove(&(client_id, request_id));
759 return;
760 }
761
762 if found {
763 let mut pending = state.ws_relay.pending.lock().await;
764 for ((_, id), p) in pending.iter_mut() {
765 if *id == request_id && p.origin_id == origin_id {
766 p.found = true;
767 }
768 }
769 drop(pending);
770 if origin_protocol == WsProtocol::HashtreeJson {
771 send_json(
772 state,
773 origin_id,
774 WsResponse {
775 kind: "res",
776 id: request_id,
777 hash: pending_hash,
778 found: true,
779 },
780 )
781 .await;
782 }
783 return;
784 }
785
786 let mut pending = state.ws_relay.pending.lock().await;
787 pending.remove(&(client_id, request_id));
788 let has_remaining = pending
789 .iter()
790 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
791 drop(pending);
792
793 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
794 send_json(
795 state,
796 origin_id,
797 WsResponse {
798 kind: "res",
799 id: request_id,
800 hash: pending_hash,
801 found: false,
802 },
803 )
804 .await;
805 }
806}
807
808async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
809 if let Some(msg) = parse_msgpack_message(&data) {
810 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
811 match msg {
812 DataMessage::Request(req) => {
813 let hash_hex = hex_encode(&req.h);
814 let request_id = state.ws_relay.next_request_id();
815 handle_request(
816 client_id,
817 request_id,
818 hash_hex,
819 WsProtocol::HashtreeMsgpack,
820 state,
821 )
822 .await;
823 }
824 DataMessage::Response(res) => {
825 handle_msgpack_response(client_id, res, state).await;
826 }
827 DataMessage::QuoteRequest(_)
828 | DataMessage::QuoteResponse(_)
829 | DataMessage::Payment(_)
830 | DataMessage::PaymentAck(_)
831 | DataMessage::Chunk(_) => {}
832 }
833 return;
834 }
835
836 if data.len() < 4 {
838 return;
839 }
840 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
841 let pending_entry = {
842 let pending = state.ws_relay.pending.lock().await;
843 pending
844 .get(&(client_id, request_id))
845 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
846 };
847 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
848 return;
849 };
850
851 match origin_protocol {
852 WsProtocol::HashtreeJson => {
853 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
854 }
855 WsProtocol::HashtreeMsgpack => {
856 let Ok(hash_bytes) = from_hex(&hash_hex) else {
857 return;
858 };
859 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
860 }
861 WsProtocol::Unknown => {}
862 }
863
864 let mut pending = state.ws_relay.pending.lock().await;
865 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
866}
867
868async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
869 let replies = nostr_responses_for(&msg);
870 for reply in replies {
871 send_nostr(state, client_id, reply).await;
872 }
873}
874
875fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
876 match msg {
877 NostrClientMessage::Event(event) => {
878 let ok = event.verify().is_ok();
879 let message = if ok { "" } else { "invalid: signature" };
880 vec![NostrRelayMessage::ok(event.id, ok, message)]
881 }
882 NostrClientMessage::Req {
883 subscription_id, ..
884 } => {
885 vec![NostrRelayMessage::eose(subscription_id.clone())]
886 }
887 NostrClientMessage::Count {
888 subscription_id, ..
889 } => {
890 vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
891 }
892 NostrClientMessage::Close(_) => Vec::new(),
893 NostrClientMessage::Auth(event) => {
894 let ok = event.verify().is_ok();
895 let message = if ok { "" } else { "invalid auth" };
896 vec![NostrRelayMessage::ok(event.id, ok, message)]
897 }
898 NostrClientMessage::NegOpen { .. }
899 | NostrClientMessage::NegMsg { .. }
900 | NostrClientMessage::NegClose { .. } => {
901 vec![NostrRelayMessage::notice("negentropy not supported")]
902 }
903 }
904}
905
906async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
907 let text = response.as_json();
908 send_to_client(state, client_id, Message::Text(text)).await;
909}
910
911fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
912 let msg = parse_message(data).ok()?;
913 match msg {
914 DataMessage::Request(req) => {
915 if req.h.len() == 32 {
916 Some(DataMessage::Request(req))
917 } else {
918 None
919 }
920 }
921 DataMessage::Response(res) => {
922 if res.h.len() == 32 {
923 Some(DataMessage::Response(res))
924 } else {
925 None
926 }
927 }
928 DataMessage::QuoteRequest(req) => {
929 if req.h.len() == 32 {
930 Some(DataMessage::QuoteRequest(req))
931 } else {
932 None
933 }
934 }
935 DataMessage::QuoteResponse(res) => {
936 if res.h.len() == 32 {
937 Some(DataMessage::QuoteResponse(res))
938 } else {
939 None
940 }
941 }
942 DataMessage::Payment(req) => {
943 if req.h.len() == 32 {
944 Some(DataMessage::Payment(req))
945 } else {
946 None
947 }
948 }
949 DataMessage::PaymentAck(res) => {
950 if res.h.len() == 32 {
951 Some(DataMessage::PaymentAck(res))
952 } else {
953 None
954 }
955 }
956 DataMessage::Chunk(chunk) => {
957 if chunk.h.len() == 32 {
958 Some(DataMessage::Chunk(chunk))
959 } else {
960 None
961 }
962 }
963 }
964}
965
966async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
967 let hash_hex = hex_encode(&res.h);
968 let data = res.d.clone();
969 let hash_bytes = res.h.clone();
970
971 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
972 let mut seen = HashSet::new();
973 {
974 let pending = state.ws_relay.pending.lock().await;
975 for ((peer_id, request_id), p) in pending.iter() {
976 if *peer_id != client_id {
977 continue;
978 }
979 if p.hash != hash_hex {
980 continue;
981 }
982 if seen.insert((p.origin_id, *request_id)) {
983 responses.push((p.origin_id, *request_id, p.origin_protocol));
984 }
985 }
986 }
987
988 if responses.is_empty() {
989 return;
990 }
991
992 for (origin_id, request_id, protocol) in &responses {
993 match protocol {
994 WsProtocol::HashtreeJson => {
995 send_json(
996 state,
997 *origin_id,
998 WsResponse {
999 kind: "res",
1000 id: *request_id,
1001 hash: hash_hex.clone(),
1002 found: true,
1003 },
1004 )
1005 .await;
1006 send_binary(state, *origin_id, *request_id, data.clone()).await;
1007 }
1008 WsProtocol::HashtreeMsgpack => {
1009 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1010 }
1011 WsProtocol::Unknown => {}
1012 }
1013 }
1014
1015 let completed: HashSet<(u64, u32)> = responses
1016 .into_iter()
1017 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1018 .collect();
1019 let mut pending = state.ws_relay.pending.lock().await;
1020 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1021}
1022
1023async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1024 if let Ok(text) = serde_json::to_string(&response) {
1025 send_to_client(state, client_id, Message::Text(text)).await;
1026 }
1027}
1028
1029async fn send_msgpack_request(
1030 state: &AppState,
1031 client_id: u64,
1032 hash: &[u8],
1033) -> Result<(), rmp_serde::encode::Error> {
1034 let req = DataRequest {
1035 h: hash.to_vec(),
1036 htl: MAX_HTL,
1037 q: None,
1038 };
1039 let wire = encode_request(&req)?;
1040 send_to_client(state, client_id, Message::Binary(wire)).await;
1041 Ok(())
1042}
1043
1044async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1045 let res = DataResponse {
1046 h: hash.to_vec(),
1047 d: data.to_vec(),
1048 i: None,
1049 n: None,
1050 };
1051 if let Ok(wire) = encode_response(&res) {
1052 send_to_client(state, client_id, Message::Binary(wire)).await;
1053 }
1054}
1055
1056async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1057 let mut packet = Vec::with_capacity(4 + payload.len());
1058 packet.extend_from_slice(&request_id.to_le_bytes());
1059 packet.extend_from_slice(&payload);
1060 send_to_client(state, client_id, Message::Binary(packet)).await;
1061}
1062
1063async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1064 let sender = {
1065 let clients = state.ws_relay.clients.lock().await;
1066 clients.get(&client_id).cloned()
1067 };
1068 if let Some(tx) = sender {
1069 let _ = tx.send(msg);
1070 }
1071}
1072
1073async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1074 let mut protocols = state.ws_relay.client_protocols.lock().await;
1075 protocols.insert(client_id, protocol);
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use super::*;
1081 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1082 use anyhow::Result;
1083 use futures::{SinkExt, StreamExt};
1084 use nostr::secp256k1::schnorr::Signature;
1085 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1086 use std::collections::HashSet;
1087 use std::sync::Arc;
1088 use tempfile::TempDir;
1089 use tokio::net::TcpListener;
1090 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1091
1092 #[test]
1093 fn parse_ws_text_message_detects_nostr_req() {
1094 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1095 match parse_ws_text_message(msg) {
1096 Some(WsTextMessage::Nostr(_)) => {}
1097 other => panic!("expected Nostr message, got {:?}", other),
1098 }
1099 }
1100
1101 #[test]
1102 fn parse_ws_text_message_detects_hashtree_request() {
1103 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1104 match parse_ws_text_message(msg) {
1105 Some(WsTextMessage::Hashtree(_)) => {}
1106 other => panic!("expected Hashtree message, got {:?}", other),
1107 }
1108 }
1109
1110 #[test]
1111 fn nostr_replies_for_req_is_eose() {
1112 let sub = SubscriptionId::new("sub-1");
1113 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1114 let replies = nostr_responses_for(&msg);
1115 assert_eq!(replies.len(), 1);
1116 match &replies[0] {
1117 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1118 other => panic!("expected EOSE, got {:?}", other),
1119 }
1120 }
1121
1122 #[test]
1123 fn nostr_replies_for_event_ok() {
1124 let keys = Keys::generate();
1125 let event = EventBuilder::new(Kind::TextNote, "hello", [])
1126 .to_event(&keys)
1127 .unwrap();
1128 let msg = NostrClientMessage::event(event.clone());
1129 let replies = nostr_responses_for(&msg);
1130 assert_eq!(replies.len(), 1);
1131 match &replies[0] {
1132 NostrRelayMessage::Ok {
1133 event_id, status, ..
1134 } => {
1135 assert_eq!(event_id, &event.id);
1136 assert!(*status);
1137 }
1138 other => panic!("expected OK, got {:?}", other),
1139 }
1140 }
1141
1142 #[test]
1143 fn nostr_replies_for_invalid_event_is_not_ok() {
1144 let keys = Keys::generate();
1145 let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1146 .to_event(&keys)
1147 .unwrap();
1148 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1149 let msg = NostrClientMessage::event(event);
1150 let replies = nostr_responses_for(&msg);
1151 assert_eq!(replies.len(), 1);
1152 match &replies[0] {
1153 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1154 other => panic!("expected OK=false, got {:?}", other),
1155 }
1156 }
1157
1158 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1159 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1160 let addr = listener.local_addr().expect("relay addr");
1161 tokio::spawn(async move {
1162 let (stream, _) = listener.accept().await.expect("accept relay");
1163 let ws = accept_async(stream).await.expect("accept websocket");
1164 let (mut write, mut read) = ws.split();
1165
1166 while let Some(Ok(message)) = read.next().await {
1167 let TungsteniteMessage::Text(text) = message else {
1168 continue;
1169 };
1170 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1171 continue;
1172 };
1173 if let NostrClientMessage::Req {
1174 subscription_id,
1175 filters,
1176 } = parsed
1177 {
1178 for event in events
1179 .iter()
1180 .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1181 {
1182 let _ = write
1183 .send(TungsteniteMessage::Text(
1184 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1185 .as_json()
1186 .into(),
1187 ))
1188 .await;
1189 }
1190 let _ = write
1191 .send(TungsteniteMessage::Text(
1192 NostrRelayMessage::eose(subscription_id).as_json().into(),
1193 ))
1194 .await;
1195 }
1196 }
1197 });
1198 format!("ws://{}", addr)
1199 }
1200
1201 fn test_app_state(
1202 tmp: &TempDir,
1203 relay: Arc<NostrRelay>,
1204 relay_url: String,
1205 ) -> Result<AppState> {
1206 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1207 tmp.path(),
1208 None,
1209 128 * 1024 * 1024,
1210 )?);
1211 Ok(AppState {
1212 store,
1213 auth: None,
1214 webrtc_peers: None,
1215 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1216 max_upload_bytes: 5 * 1024 * 1024,
1217 public_writes: true,
1218 allowed_pubkeys: HashSet::new(),
1219 upstream_blossom: Vec::new(),
1220 social_graph: None,
1221 social_graph_store: None,
1222 social_graph_root: None,
1223 socialgraph_snapshot_public: false,
1224 nostr_relay: Some(relay),
1225 nostr_relay_urls: vec![relay_url],
1226 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1227 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1228 std::collections::HashMap::new(),
1229 )),
1230 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1231 super::super::auth::new_lookup_cache(),
1232 )),
1233 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1234 super::super::auth::new_lookup_cache(),
1235 )),
1236 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1237 super::super::auth::new_lookup_cache(),
1238 )),
1239 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1240 })
1241 }
1242
1243 #[tokio::test]
1244 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1245 let tmp = TempDir::new()?;
1246 let graph_store = {
1247 let _guard = crate::socialgraph::test_lock();
1248 crate::socialgraph::open_social_graph_store_with_mapsize(
1249 tmp.path(),
1250 Some(128 * 1024 * 1024),
1251 )?
1252 };
1253 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1254 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1255 Arc::clone(&backend),
1256 0,
1257 HashSet::new(),
1258 ));
1259
1260 let keys = Keys::generate();
1261 let relay = Arc::new(NostrRelay::new(
1262 Arc::clone(&backend),
1263 tmp.path().to_path_buf(),
1264 HashSet::from([keys.public_key().to_hex()]),
1265 Some(access),
1266 NostrRelayConfig {
1267 spambox_db_max_bytes: 0,
1268 ..Default::default()
1269 },
1270 )?);
1271
1272 let event = EventBuilder::new(
1273 Kind::from(30078_u16),
1274 "",
1275 [
1276 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1277 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1278 ],
1279 )
1280 .to_event(&keys)?;
1281
1282 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1283 let filter = Filter::new()
1284 .authors(vec![event.pubkey])
1285 .kinds(vec![event.kind]);
1286 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1287 let client_id = 7_u64;
1288 let (tx, mut rx) = mpsc::unbounded_channel();
1289 state.ws_relay.clients.lock().await.insert(client_id, tx);
1290 let subscription_id = SubscriptionId::new("sub-1");
1291
1292 start_upstream_nostr_subscription(
1293 &state,
1294 client_id,
1295 subscription_id.clone(),
1296 vec![filter.clone()],
1297 )
1298 .await;
1299
1300 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1301 .await?
1302 .expect("forwarded upstream event");
1303 let Message::Text(text) = forwarded else {
1304 panic!("expected text event");
1305 };
1306 match NostrRelayMessage::from_json(text.as_str())? {
1307 NostrRelayMessage::Event {
1308 subscription_id: sid,
1309 event: forwarded_event,
1310 } => {
1311 assert_eq!(sid, subscription_id);
1312 assert_eq!(forwarded_event.id, event.id);
1313 }
1314 other => panic!("expected forwarded EVENT, got {:?}", other),
1315 }
1316
1317 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1318 .await?
1319 .expect("forwarded upstream eose");
1320 let Message::Text(eose_text) = eose else {
1321 panic!("expected text eose");
1322 };
1323 match NostrRelayMessage::from_json(eose_text.as_str())? {
1324 NostrRelayMessage::EndOfStoredEvents(sid) => {
1325 assert_eq!(sid, subscription_id);
1326 }
1327 other => panic!("expected forwarded EOSE, got {:?}", other),
1328 }
1329
1330 let events = relay.query_events(&filter, 10).await;
1331 assert_eq!(events.len(), 1);
1332 assert_eq!(events[0].id, event.id);
1333
1334 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1335 assert!(state
1336 .ws_relay
1337 .upstream_nostr_subscriptions
1338 .lock()
1339 .await
1340 .is_empty());
1341 Ok(())
1342 }
1343
1344 #[tokio::test]
1345 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1346 let tmp = TempDir::new()?;
1347 let graph_store = {
1348 let _guard = crate::socialgraph::test_lock();
1349 crate::socialgraph::open_social_graph_store_with_mapsize(
1350 tmp.path(),
1351 Some(128 * 1024 * 1024),
1352 )?
1353 };
1354 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1355 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1356 Arc::clone(&backend),
1357 0,
1358 HashSet::new(),
1359 ));
1360
1361 let keys = Keys::generate();
1362 let relay = Arc::new(NostrRelay::new(
1363 Arc::clone(&backend),
1364 tmp.path().to_path_buf(),
1365 HashSet::from([keys.public_key().to_hex()]),
1366 Some(access),
1367 NostrRelayConfig {
1368 spambox_db_max_bytes: 0,
1369 ..Default::default()
1370 },
1371 )?);
1372
1373 let event = EventBuilder::new(
1374 Kind::from(30078_u16),
1375 "",
1376 [
1377 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1378 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1379 ],
1380 )
1381 .to_event(&keys)?;
1382
1383 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1384 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1385 let client_id = 11_u64;
1386 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1387 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1388 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1389 relay.register_client(client_id, relay_tx, None).await;
1390
1391 let request = NostrClientMessage::req(
1392 SubscriptionId::new("feed"),
1393 vec![Filter::new()
1394 .authors(vec![event.pubkey])
1395 .kinds(vec![event.kind])],
1396 )
1397 .as_json();
1398
1399 handle_message(client_id, Message::Text(request.into()), &state).await;
1400
1401 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1402 .await?
1403 .expect("first forwarded message");
1404 let Message::Text(first_text) = first else {
1405 panic!("expected text event");
1406 };
1407 match NostrRelayMessage::from_json(first_text.as_str())? {
1408 NostrRelayMessage::Event {
1409 event: forwarded_event,
1410 ..
1411 } => {
1412 assert_eq!(forwarded_event.id, event.id);
1413 }
1414 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1415 }
1416
1417 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1418 .await?
1419 .expect("second forwarded message");
1420 let Message::Text(second_text) = second else {
1421 panic!("expected text eose");
1422 };
1423 match NostrRelayMessage::from_json(second_text.as_str())? {
1424 NostrRelayMessage::EndOfStoredEvents(sid) => {
1425 assert_eq!(sid, SubscriptionId::new("feed"));
1426 }
1427 other => panic!("expected aggregated EOSE, got {:?}", other),
1428 }
1429
1430 Ok(())
1431 }
1432
1433 #[tokio::test]
1434 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1435 let tmp = TempDir::new()?;
1436 let graph_store = {
1437 let _guard = crate::socialgraph::test_lock();
1438 crate::socialgraph::open_social_graph_store_with_mapsize(
1439 tmp.path(),
1440 Some(128 * 1024 * 1024),
1441 )?
1442 };
1443 let author_keys = Keys::generate();
1444 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1445 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1446 Arc::clone(&backend),
1447 0,
1448 HashSet::from([author_keys.public_key().to_hex()]),
1449 ));
1450 let relay = Arc::new(NostrRelay::new(
1451 Arc::clone(&backend),
1452 tmp.path().to_path_buf(),
1453 HashSet::from([author_keys.public_key().to_hex()]),
1454 Some(access),
1455 NostrRelayConfig {
1456 spambox_db_max_bytes: 0,
1457 ..Default::default()
1458 },
1459 )?);
1460
1461 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1462 let listener = TcpListener::bind("127.0.0.1:0").await?;
1463 let addr = listener.local_addr()?;
1464 let client_pubkey = author_keys.public_key().to_hex();
1465 let app = axum::Router::new().route(
1466 "/ws",
1467 axum::routing::get({
1468 let state = state.clone();
1469 let client_pubkey = client_pubkey.clone();
1470 move |ws: WebSocketUpgrade| {
1471 let state = state.clone();
1472 let client_pubkey = client_pubkey.clone();
1473 async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1474 }
1475 }),
1476 );
1477 tokio::spawn(async move {
1478 let _ = axum::serve(listener, app).await;
1479 });
1480
1481 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1482 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1483 .to_event(&author_keys)?;
1484 socket
1485 .send(TungsteniteMessage::Text(
1486 NostrClientMessage::event(event.clone()).as_json().into(),
1487 ))
1488 .await?;
1489
1490 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1491 .await?
1492 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1493 let TungsteniteMessage::Text(text) = reply else {
1494 anyhow::bail!("expected text publish ack");
1495 };
1496
1497 match NostrRelayMessage::from_json(text.as_str())? {
1498 NostrRelayMessage::Ok {
1499 event_id, status, ..
1500 } => {
1501 assert_eq!(event_id, event.id);
1502 assert!(status);
1503 }
1504 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1505 }
1506
1507 let stored = relay
1508 .query_events(
1509 &Filter::new()
1510 .authors(vec![event.pubkey])
1511 .kinds(vec![event.kind]),
1512 10,
1513 )
1514 .await;
1515 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1516 Ok(())
1517 }
1518}