1use axum::{
2 extract::{
3 ws::{Message, WebSocket, WebSocketUpgrade},
4 State,
5 },
6 response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11 ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12 RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::webrtc::types::{
21 encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
22};
23use hex::encode as hex_encode;
24
25#[derive(Debug, Deserialize)]
26#[serde(tag = "type")]
27enum WsClientMessage {
28 #[serde(rename = "req")]
29 Request { id: u32, hash: String },
30 #[serde(rename = "res")]
31 Response { id: u32, hash: String, found: bool },
32}
33
34#[derive(Debug)]
35enum WsTextMessage {
36 Hashtree(WsClientMessage),
37 Nostr(NostrClientMessage),
38}
39
40#[derive(Debug, Deserialize, Serialize)]
41struct WsRequest {
42 #[serde(rename = "type")]
43 kind: String,
44 id: u32,
45 hash: String,
46}
47
48#[derive(Debug, Serialize)]
49struct WsResponse {
50 #[serde(rename = "type")]
51 kind: &'static str,
52 id: u32,
53 hash: String,
54 found: bool,
55}
56
57pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
58 ws_data_with_client_pubkey(state, ws, None)
59}
60
61pub fn ws_data_with_client_pubkey(
62 state: AppState,
63 ws: WebSocketUpgrade,
64 client_pubkey: Option<String>,
65) -> impl IntoResponse {
66 ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
67}
68
69async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
70 let client_id = state
73 .nostr_relay
74 .as_ref()
75 .map(|relay| relay.next_client_id())
76 .unwrap_or_else(|| state.ws_relay.next_id());
77 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
78
79 {
80 let mut clients = state.ws_relay.clients.lock().await;
81 clients.insert(client_id, tx);
82 }
83 {
84 let mut protocols = state.ws_relay.client_protocols.lock().await;
85 protocols.insert(client_id, WsProtocol::HashtreeJson);
86 }
87
88 let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
89 let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
90 relay
91 .register_client(client_id, nostr_tx, client_pubkey.clone())
92 .await;
93 Some(nostr_rx)
94 } else {
95 None
96 };
97
98 let (mut sender, mut receiver) = socket.split();
99 loop {
100 tokio::select! {
101 maybe_msg = rx.recv() => {
102 let Some(msg) = maybe_msg else {
103 break;
104 };
105 if sender.send(msg).await.is_err() {
106 break;
107 }
108 }
109 maybe_text = async {
110 match &mut nostr_rx {
111 Some(rx) => rx.recv().await,
112 None => std::future::pending().await,
113 }
114 } => {
115 let Some(text) = maybe_text else {
116 nostr_rx = None;
117 continue;
118 };
119 if sender.send(Message::Text(text)).await.is_err() {
120 break;
121 }
122 }
123 maybe_incoming = receiver.next() => {
124 match maybe_incoming {
125 Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
126 Some(Err(_)) | None => break,
127 }
128 }
129 }
130 }
131
132 close_all_upstream_nostr_subscriptions(&state, client_id).await;
133
134 {
135 let mut clients = state.ws_relay.clients.lock().await;
136 clients.remove(&client_id);
137 }
138 {
139 let mut protocols = state.ws_relay.client_protocols.lock().await;
140 protocols.remove(&client_id);
141 }
142 {
143 let mut pending = state.ws_relay.pending.lock().await;
144 pending.retain(|(peer_id, _), _| *peer_id != client_id);
145 }
146
147 if let Some(relay) = &state.nostr_relay {
148 relay.unregister_client(client_id).await;
149 }
150}
151
152fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
153 let trimmed = text.trim_start();
154 if trimmed.starts_with('[') {
155 if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
156 return Some(WsTextMessage::Nostr(msg));
157 }
158 }
159
160 if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
161 return Some(WsTextMessage::Hashtree(msg));
162 }
163
164 None
165}
166async fn close_upstream_nostr_subscription(
167 state: &AppState,
168 client_id: u64,
169 subscription_id: &SubscriptionId,
170) {
171 let key = (client_id, subscription_id.to_string());
172 let subscription = {
173 let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
174 subscriptions.remove(&key)
175 };
176 if let Some(subscription) = subscription {
177 let _ = subscription.close_tx.send(true);
178 for task in subscription.tasks {
179 task.abort();
180 }
181 }
182 state
183 .ws_relay
184 .upstream_pending_eose
185 .lock()
186 .await
187 .remove(&key);
188 state
189 .ws_relay
190 .upstream_seen_events
191 .lock()
192 .await
193 .remove(&key);
194}
195
196async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
197 let keys = {
198 let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
199 subscriptions
200 .keys()
201 .filter(|(id, _)| *id == client_id)
202 .cloned()
203 .collect::<Vec<_>>()
204 };
205 for (_, sub_id) in keys {
206 close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
207 }
208}
209
210async fn forward_upstream_nostr_message(
211 state: &AppState,
212 client_id: u64,
213 subscription_id: &SubscriptionId,
214 text: &str,
215) {
216 let Ok(message) = NostrRelayMessage::from_json(text) else {
217 return;
218 };
219
220 match message {
221 NostrRelayMessage::Event {
222 subscription_id: sid,
223 event,
224 } if sid == *subscription_id => {
225 let event = *event;
226 let key = (client_id, subscription_id.to_string());
227 let event_id = event.id.to_hex();
228 let inserted = {
229 let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
230 seen_events.entry(key).or_default().insert(event_id)
231 };
232 if !inserted {
233 return;
234 }
235 if let Some(relay) = &state.nostr_relay {
236 let _ = relay.ingest_trusted_event_silent(event.clone()).await;
237 }
238 send_nostr(
239 state,
240 client_id,
241 NostrRelayMessage::event(subscription_id.clone(), event),
242 )
243 .await;
244 }
245 NostrRelayMessage::Closed {
246 subscription_id: sid,
247 message,
248 } if sid == *subscription_id => {
249 send_nostr(
250 state,
251 client_id,
252 NostrRelayMessage::closed(subscription_id.clone(), message),
253 )
254 .await;
255 }
256 _ => {}
257 }
258}
259
260async fn mark_upstream_nostr_relay_complete(
261 state: &AppState,
262 client_id: u64,
263 subscription_id: &SubscriptionId,
264) {
265 let key = (client_id, subscription_id.to_string());
266 let should_send_eose = {
267 let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
268 let Some(remaining) = pending.get_mut(&key) else {
269 return;
270 };
271 if *remaining > 0 {
272 *remaining -= 1;
273 }
274 if *remaining == 0 {
275 pending.remove(&key);
276 true
277 } else {
278 false
279 }
280 };
281
282 if should_send_eose {
283 send_nostr(
284 state,
285 client_id,
286 NostrRelayMessage::eose(subscription_id.clone()),
287 )
288 .await;
289 }
290}
291
292async fn run_upstream_nostr_subscription(
293 state: AppState,
294 client_id: u64,
295 relay_url: String,
296 subscription_id: SubscriptionId,
297 filters: Vec<NostrFilter>,
298 mut close_rx: watch::Receiver<bool>,
299) {
300 let mut relay_complete = false;
301 let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
302 tracing::warn!(
303 "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
304 client_id,
305 subscription_id,
306 relay_url,
307 );
308 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
309 return;
310 };
311 let (mut write, mut read) = socket.split();
312 let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
313 state
314 .ws_relay
315 .note_upstream_relay_send(request.as_bytes().len());
316 if write
317 .send(TungsteniteMessage::Text(request.into()))
318 .await
319 .is_err()
320 {
321 tracing::warn!(
322 "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
323 client_id,
324 subscription_id,
325 relay_url,
326 );
327 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
328 return;
329 }
330
331 loop {
332 tokio::select! {
333 _ = close_rx.changed() => {
334 if *close_rx.borrow() {
335 let close = NostrClientMessage::close(subscription_id.clone()).as_json();
336 state
337 .ws_relay
338 .note_upstream_relay_send(close.as_bytes().len());
339 let _ = write.send(TungsteniteMessage::Text(close.into())).await;
340 let _ = write.close().await;
341 break;
342 }
343 }
344 message = read.next() => {
345 match message {
346 Some(Ok(TungsteniteMessage::Text(text))) => {
347 state
348 .ws_relay
349 .note_upstream_relay_receive(text.as_bytes().len());
350 if matches!(
351 NostrRelayMessage::from_json(text.as_str()),
352 Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
353 ) {
354 if !relay_complete {
355 relay_complete = true;
356 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
357 }
358 continue;
359 }
360 forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
361 }
362 Some(Ok(TungsteniteMessage::Ping(payload))) => {
363 let _ = write.send(TungsteniteMessage::Pong(payload)).await;
364 }
365 Some(Ok(TungsteniteMessage::Close(_))) | None => {
366 if !relay_complete {
367 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
368 }
369 break;
370 }
371 Some(Err(_)) => {
372 if !relay_complete {
373 mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
374 }
375 break;
376 }
377 _ => {}
378 }
379 }
380 }
381 }
382}
383
384async fn start_upstream_nostr_subscription(
385 state: &AppState,
386 client_id: u64,
387 subscription_id: SubscriptionId,
388 filters: Vec<NostrFilter>,
389) -> usize {
390 close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
391 if state.nostr_relay_urls.is_empty() || filters.is_empty() {
392 tracing::info!(
393 "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
394 client_id,
395 subscription_id,
396 state.nostr_relay_urls.len(),
397 filters.len(),
398 );
399 return 0;
400 }
401
402 let mut relay_urls = Vec::new();
403 let mut seen = HashSet::new();
404 for relay in &state.nostr_relay_urls {
405 let relay = relay.trim();
406 if relay.is_empty() || !seen.insert(relay.to_string()) {
407 continue;
408 }
409 relay_urls.push(relay.to_string());
410 }
411 if relay_urls.is_empty() {
412 tracing::info!(
413 "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
414 client_id,
415 subscription_id,
416 );
417 return 0;
418 }
419
420 tracing::info!(
421 "upstream nostr relay start: client_id={} subscription_id={} relays={}",
422 client_id,
423 subscription_id,
424 relay_urls.len(),
425 );
426
427 let key = (client_id, subscription_id.to_string());
428 state
429 .ws_relay
430 .upstream_seen_events
431 .lock()
432 .await
433 .insert(key.clone(), HashSet::new());
434 state
435 .ws_relay
436 .upstream_pending_eose
437 .lock()
438 .await
439 .insert(key.clone(), relay_urls.len());
440
441 let (close_tx, close_rx) = watch::channel(false);
442 let mut tasks = Vec::new();
443 for relay_url in &relay_urls {
444 tasks.push(tokio::spawn(run_upstream_nostr_subscription(
445 state.clone(),
446 client_id,
447 relay_url.clone(),
448 subscription_id.clone(),
449 filters.clone(),
450 close_rx.clone(),
451 )));
452 }
453
454 state
455 .ws_relay
456 .upstream_nostr_subscriptions
457 .lock()
458 .await
459 .insert(key, UpstreamNostrSubscription { close_tx, tasks });
460 relay_urls.len()
461}
462
463async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
464 match msg {
465 Message::Text(text) => {
466 if let Some(msg) = parse_ws_text_message(&text) {
467 match msg {
468 WsTextMessage::Hashtree(msg) => {
469 set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
470 match msg {
471 WsClientMessage::Request { id, hash } => {
472 handle_request(
473 client_id,
474 id,
475 hash,
476 WsProtocol::HashtreeJson,
477 state,
478 )
479 .await;
480 }
481 WsClientMessage::Response { id, hash, found } => {
482 handle_response(client_id, id, hash, found, state).await;
483 }
484 }
485 }
486 WsTextMessage::Nostr(msg) => {
487 if let Some(relay) = &state.nostr_relay {
488 match msg {
489 NostrClientMessage::Req {
490 subscription_id,
491 filters,
492 } => {
493 let local_events = match relay
494 .register_subscription_query(
495 client_id,
496 subscription_id.clone(),
497 filters.clone(),
498 )
499 .await
500 {
501 Ok(events) => events,
502 Err(message) => {
503 send_nostr(
504 state,
505 client_id,
506 NostrRelayMessage::closed(subscription_id, message),
507 )
508 .await;
509 return;
510 }
511 };
512
513 let upstream_relays = start_upstream_nostr_subscription(
514 state,
515 client_id,
516 subscription_id.clone(),
517 filters,
518 )
519 .await;
520 if upstream_relays > 0 {
521 let key = (client_id, subscription_id.to_string());
522 let mut seen_events =
523 state.ws_relay.upstream_seen_events.lock().await;
524 seen_events.entry(key).or_default().extend(
525 local_events.iter().map(|event| event.id.to_hex()),
526 );
527 }
528 for event in local_events {
529 send_nostr(
530 state,
531 client_id,
532 NostrRelayMessage::event(
533 subscription_id.clone(),
534 event,
535 ),
536 )
537 .await;
538 }
539 if upstream_relays == 0 {
540 send_nostr(
541 state,
542 client_id,
543 NostrRelayMessage::eose(subscription_id),
544 )
545 .await;
546 }
547 }
548 NostrClientMessage::Close(subscription_id) => {
549 close_upstream_nostr_subscription(
550 state,
551 client_id,
552 &subscription_id,
553 )
554 .await;
555 relay
556 .handle_client_message(
557 client_id,
558 NostrClientMessage::Close(subscription_id.clone()),
559 )
560 .await;
561 }
562 other => {
563 relay.handle_client_message(client_id, other).await;
564 }
565 }
566 } else {
567 handle_nostr_message(client_id, msg, state).await;
568 }
569 }
570 }
571 }
572 }
573 Message::Binary(data) => {
574 handle_binary(client_id, data, state).await;
575 }
576 Message::Close(_) => {}
577 _ => {}
578 }
579}
580
581async fn handle_request(
582 client_id: u64,
583 request_id: u32,
584 hash: String,
585 origin_protocol: WsProtocol,
586 state: &AppState,
587) {
588 let hash_hex = hash.to_lowercase();
589 let hash_bytes = match from_hex(&hash_hex) {
590 Ok(bytes) => bytes,
591 Err(_) => {
592 if origin_protocol == WsProtocol::HashtreeJson {
593 send_json(
594 state,
595 client_id,
596 WsResponse {
597 kind: "res",
598 id: request_id,
599 hash,
600 found: false,
601 },
602 )
603 .await;
604 }
605 return;
606 }
607 };
608
609 if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
610 match origin_protocol {
611 WsProtocol::HashtreeJson => {
612 send_json(
613 state,
614 client_id,
615 WsResponse {
616 kind: "res",
617 id: request_id,
618 hash: hash.clone(),
619 found: true,
620 },
621 )
622 .await;
623 send_binary(state, client_id, request_id, data).await;
624 }
625 WsProtocol::HashtreeMsgpack => {
626 send_msgpack_response(state, client_id, &hash_bytes, &data).await;
627 }
628 WsProtocol::Unknown => {}
629 }
630 return;
631 }
632
633 let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
634 let clients = state.ws_relay.clients.lock().await;
635 let protocols = state.ws_relay.client_protocols.lock().await;
636 clients
637 .iter()
638 .filter(|(id, _)| **id != client_id)
639 .filter_map(|(id, tx)| {
640 let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
641 match protocol {
642 WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
643 Some((*id, tx.clone(), protocol))
644 }
645 WsProtocol::Unknown => None,
646 }
647 })
648 .collect()
649 };
650
651 if peers.is_empty() {
652 if origin_protocol == WsProtocol::HashtreeJson {
653 send_json(
654 state,
655 client_id,
656 WsResponse {
657 kind: "res",
658 id: request_id,
659 hash,
660 found: false,
661 },
662 )
663 .await;
664 }
665 return;
666 }
667
668 {
669 let mut pending = state.ws_relay.pending.lock().await;
670 for (peer_id, _, _) in &peers {
671 pending.insert(
672 (*peer_id, request_id),
673 PendingRequest {
674 origin_id: client_id,
675 hash: hash.clone(),
676 found: false,
677 origin_protocol,
678 },
679 );
680 }
681 }
682
683 let request_text = serde_json::to_string(&WsRequest {
684 kind: "req".to_string(),
685 id: request_id,
686 hash: hash.clone(),
687 })
688 .unwrap_or_else(|_| String::new());
689 for (peer_id, tx, protocol) in peers {
690 match protocol {
691 WsProtocol::HashtreeMsgpack => {
692 let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
693 }
694 WsProtocol::HashtreeJson => {
695 let _ = tx.send(Message::Text(request_text.clone()));
696 }
697 WsProtocol::Unknown => {}
698 }
699 }
700
701 let timeout_state = state.clone();
702 let timeout_hash = hash.clone();
703 tokio::spawn(async move {
704 tokio::time::sleep(Duration::from_millis(1500)).await;
705 let mut pending = timeout_state.ws_relay.pending.lock().await;
706 let still_pending = pending
707 .iter()
708 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
709 let already_found = pending
710 .iter()
711 .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
712 if !still_pending || already_found {
713 return;
714 }
715 let origin_protocol = pending
716 .iter()
717 .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
718 .map(|(_, p)| p.origin_protocol)
719 .unwrap_or(WsProtocol::HashtreeJson);
720 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
721 drop(pending);
722 if origin_protocol == WsProtocol::HashtreeJson {
723 send_json(
724 &timeout_state,
725 client_id,
726 WsResponse {
727 kind: "res",
728 id: request_id,
729 hash: timeout_hash,
730 found: false,
731 },
732 )
733 .await;
734 }
735 });
736}
737
738async fn handle_response(
739 client_id: u64,
740 request_id: u32,
741 _hash: String,
742 found: bool,
743 state: &AppState,
744) {
745 let pending_entry = {
746 let pending = state.ws_relay.pending.lock().await;
747 pending
748 .get(&(client_id, request_id))
749 .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
750 };
751
752 let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
753 return;
754 };
755
756 if already_found && !found {
757 let mut pending = state.ws_relay.pending.lock().await;
758 pending.remove(&(client_id, request_id));
759 return;
760 }
761
762 if found {
763 let mut pending = state.ws_relay.pending.lock().await;
764 for ((_, id), p) in pending.iter_mut() {
765 if *id == request_id && p.origin_id == origin_id {
766 p.found = true;
767 }
768 }
769 drop(pending);
770 if origin_protocol == WsProtocol::HashtreeJson {
771 send_json(
772 state,
773 origin_id,
774 WsResponse {
775 kind: "res",
776 id: request_id,
777 hash: pending_hash,
778 found: true,
779 },
780 )
781 .await;
782 }
783 return;
784 }
785
786 let mut pending = state.ws_relay.pending.lock().await;
787 pending.remove(&(client_id, request_id));
788 let has_remaining = pending
789 .iter()
790 .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
791 drop(pending);
792
793 if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
794 send_json(
795 state,
796 origin_id,
797 WsResponse {
798 kind: "res",
799 id: request_id,
800 hash: pending_hash,
801 found: false,
802 },
803 )
804 .await;
805 }
806}
807
808async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
809 if let Some(msg) = parse_msgpack_message(&data) {
810 set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
811 match msg {
812 DataMessage::Request(req) => {
813 let hash_hex = hex_encode(&req.h);
814 let request_id = state.ws_relay.next_request_id();
815 handle_request(
816 client_id,
817 request_id,
818 hash_hex,
819 WsProtocol::HashtreeMsgpack,
820 state,
821 )
822 .await;
823 }
824 DataMessage::Response(res) => {
825 handle_msgpack_response(client_id, res, state).await;
826 }
827 DataMessage::QuoteRequest(_)
828 | DataMessage::QuoteResponse(_)
829 | DataMessage::Payment(_)
830 | DataMessage::PaymentAck(_)
831 | DataMessage::Chunk(_) => {}
832 }
833 return;
834 }
835
836 if data.len() < 4 {
838 return;
839 }
840 let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
841 let pending_entry = {
842 let pending = state.ws_relay.pending.lock().await;
843 pending
844 .get(&(client_id, request_id))
845 .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
846 };
847 let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
848 return;
849 };
850
851 match origin_protocol {
852 WsProtocol::HashtreeJson => {
853 send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
854 }
855 WsProtocol::HashtreeMsgpack => {
856 let Ok(hash_bytes) = from_hex(&hash_hex) else {
857 return;
858 };
859 send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
860 }
861 WsProtocol::Unknown => {}
862 }
863
864 let mut pending = state.ws_relay.pending.lock().await;
865 pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
866}
867
868async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
869 let replies = nostr_responses_for(&msg);
870 for reply in replies {
871 send_nostr(state, client_id, reply).await;
872 }
873}
874
875fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
876 match msg {
877 NostrClientMessage::Event(event) => {
878 let ok = event.verify().is_ok();
879 let message = if ok { "" } else { "invalid: signature" };
880 vec![NostrRelayMessage::ok(event.id, ok, message)]
881 }
882 NostrClientMessage::Req {
883 subscription_id, ..
884 } => {
885 vec![NostrRelayMessage::eose(subscription_id.clone())]
886 }
887 NostrClientMessage::Count {
888 subscription_id, ..
889 } => {
890 vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
891 }
892 NostrClientMessage::Close(_) => Vec::new(),
893 NostrClientMessage::Auth(event) => {
894 let ok = event.verify().is_ok();
895 let message = if ok { "" } else { "invalid auth" };
896 vec![NostrRelayMessage::ok(event.id, ok, message)]
897 }
898 NostrClientMessage::NegOpen { .. }
899 | NostrClientMessage::NegMsg { .. }
900 | NostrClientMessage::NegClose { .. } => {
901 vec![NostrRelayMessage::notice("negentropy not supported")]
902 }
903 }
904}
905
906async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
907 let text = response.as_json();
908 send_to_client(state, client_id, Message::Text(text)).await;
909}
910
911fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
912 let msg = parse_message(data).ok()?;
913 match msg {
914 DataMessage::Request(req) => {
915 if req.h.len() == 32 {
916 Some(DataMessage::Request(req))
917 } else {
918 None
919 }
920 }
921 DataMessage::Response(res) => {
922 if res.h.len() == 32 {
923 Some(DataMessage::Response(res))
924 } else {
925 None
926 }
927 }
928 DataMessage::QuoteRequest(req) => {
929 if req.h.len() == 32 {
930 Some(DataMessage::QuoteRequest(req))
931 } else {
932 None
933 }
934 }
935 DataMessage::QuoteResponse(res) => {
936 if res.h.len() == 32 {
937 Some(DataMessage::QuoteResponse(res))
938 } else {
939 None
940 }
941 }
942 DataMessage::Payment(req) => {
943 if req.h.len() == 32 {
944 Some(DataMessage::Payment(req))
945 } else {
946 None
947 }
948 }
949 DataMessage::PaymentAck(res) => {
950 if res.h.len() == 32 {
951 Some(DataMessage::PaymentAck(res))
952 } else {
953 None
954 }
955 }
956 DataMessage::Chunk(chunk) => {
957 if chunk.h.len() == 32 {
958 Some(DataMessage::Chunk(chunk))
959 } else {
960 None
961 }
962 }
963 }
964}
965
966async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
967 let hash_hex = hex_encode(&res.h);
968 let data = res.d.clone();
969 let hash_bytes = res.h.clone();
970
971 let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
972 let mut seen = HashSet::new();
973 {
974 let pending = state.ws_relay.pending.lock().await;
975 for ((peer_id, request_id), p) in pending.iter() {
976 if *peer_id != client_id {
977 continue;
978 }
979 if p.hash != hash_hex {
980 continue;
981 }
982 if seen.insert((p.origin_id, *request_id)) {
983 responses.push((p.origin_id, *request_id, p.origin_protocol));
984 }
985 }
986 }
987
988 if responses.is_empty() {
989 return;
990 }
991
992 for (origin_id, request_id, protocol) in &responses {
993 match protocol {
994 WsProtocol::HashtreeJson => {
995 send_json(
996 state,
997 *origin_id,
998 WsResponse {
999 kind: "res",
1000 id: *request_id,
1001 hash: hash_hex.clone(),
1002 found: true,
1003 },
1004 )
1005 .await;
1006 send_binary(state, *origin_id, *request_id, data.clone()).await;
1007 }
1008 WsProtocol::HashtreeMsgpack => {
1009 send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1010 }
1011 WsProtocol::Unknown => {}
1012 }
1013 }
1014
1015 let completed: HashSet<(u64, u32)> = responses
1016 .into_iter()
1017 .map(|(origin_id, request_id, _)| (origin_id, request_id))
1018 .collect();
1019 let mut pending = state.ws_relay.pending.lock().await;
1020 pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1021}
1022
1023async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1024 if let Ok(text) = serde_json::to_string(&response) {
1025 send_to_client(state, client_id, Message::Text(text)).await;
1026 }
1027}
1028
1029async fn send_msgpack_request(
1030 state: &AppState,
1031 client_id: u64,
1032 hash: &[u8],
1033) -> Result<(), rmp_serde::encode::Error> {
1034 let req = DataRequest {
1035 h: hash.to_vec(),
1036 htl: MAX_HTL,
1037 q: None,
1038 };
1039 let wire = encode_request(&req)?;
1040 send_to_client(state, client_id, Message::Binary(wire)).await;
1041 Ok(())
1042}
1043
1044async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1045 let res = DataResponse {
1046 h: hash.to_vec(),
1047 d: data.to_vec(),
1048 };
1049 if let Ok(wire) = encode_response(&res) {
1050 send_to_client(state, client_id, Message::Binary(wire)).await;
1051 }
1052}
1053
1054async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1055 let mut packet = Vec::with_capacity(4 + payload.len());
1056 packet.extend_from_slice(&request_id.to_le_bytes());
1057 packet.extend_from_slice(&payload);
1058 send_to_client(state, client_id, Message::Binary(packet)).await;
1059}
1060
1061async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1062 let sender = {
1063 let clients = state.ws_relay.clients.lock().await;
1064 clients.get(&client_id).cloned()
1065 };
1066 if let Some(tx) = sender {
1067 let _ = tx.send(msg);
1068 }
1069}
1070
1071async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1072 let mut protocols = state.ws_relay.client_protocols.lock().await;
1073 protocols.insert(client_id, protocol);
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079 use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1080 use anyhow::Result;
1081 use futures::{SinkExt, StreamExt};
1082 use nostr::secp256k1::schnorr::Signature;
1083 use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1084 use std::collections::HashSet;
1085 use std::sync::Arc;
1086 use tempfile::TempDir;
1087 use tokio::net::TcpListener;
1088 use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1089
1090 #[test]
1091 fn parse_ws_text_message_detects_nostr_req() {
1092 let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1093 match parse_ws_text_message(msg) {
1094 Some(WsTextMessage::Nostr(_)) => {}
1095 other => panic!("expected Nostr message, got {:?}", other),
1096 }
1097 }
1098
1099 #[test]
1100 fn parse_ws_text_message_detects_hashtree_request() {
1101 let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1102 match parse_ws_text_message(msg) {
1103 Some(WsTextMessage::Hashtree(_)) => {}
1104 other => panic!("expected Hashtree message, got {:?}", other),
1105 }
1106 }
1107
1108 #[test]
1109 fn nostr_replies_for_req_is_eose() {
1110 let sub = SubscriptionId::new("sub-1");
1111 let msg = NostrClientMessage::req(sub.clone(), vec![]);
1112 let replies = nostr_responses_for(&msg);
1113 assert_eq!(replies.len(), 1);
1114 match &replies[0] {
1115 NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1116 other => panic!("expected EOSE, got {:?}", other),
1117 }
1118 }
1119
1120 #[test]
1121 fn nostr_replies_for_event_ok() {
1122 let keys = Keys::generate();
1123 let event = EventBuilder::new(Kind::TextNote, "hello", [])
1124 .to_event(&keys)
1125 .unwrap();
1126 let msg = NostrClientMessage::event(event.clone());
1127 let replies = nostr_responses_for(&msg);
1128 assert_eq!(replies.len(), 1);
1129 match &replies[0] {
1130 NostrRelayMessage::Ok {
1131 event_id, status, ..
1132 } => {
1133 assert_eq!(event_id, &event.id);
1134 assert!(*status);
1135 }
1136 other => panic!("expected OK, got {:?}", other),
1137 }
1138 }
1139
1140 #[test]
1141 fn nostr_replies_for_invalid_event_is_not_ok() {
1142 let keys = Keys::generate();
1143 let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1144 .to_event(&keys)
1145 .unwrap();
1146 event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1147 let msg = NostrClientMessage::event(event);
1148 let replies = nostr_responses_for(&msg);
1149 assert_eq!(replies.len(), 1);
1150 match &replies[0] {
1151 NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1152 other => panic!("expected OK=false, got {:?}", other),
1153 }
1154 }
1155
1156 async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1157 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1158 let addr = listener.local_addr().expect("relay addr");
1159 tokio::spawn(async move {
1160 let (stream, _) = listener.accept().await.expect("accept relay");
1161 let ws = accept_async(stream).await.expect("accept websocket");
1162 let (mut write, mut read) = ws.split();
1163
1164 while let Some(Ok(message)) = read.next().await {
1165 let TungsteniteMessage::Text(text) = message else {
1166 continue;
1167 };
1168 let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1169 continue;
1170 };
1171 if let NostrClientMessage::Req {
1172 subscription_id,
1173 filters,
1174 } = parsed
1175 {
1176 for event in events
1177 .iter()
1178 .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1179 {
1180 let _ = write
1181 .send(TungsteniteMessage::Text(
1182 NostrRelayMessage::event(subscription_id.clone(), event.clone())
1183 .as_json()
1184 .into(),
1185 ))
1186 .await;
1187 }
1188 let _ = write
1189 .send(TungsteniteMessage::Text(
1190 NostrRelayMessage::eose(subscription_id).as_json().into(),
1191 ))
1192 .await;
1193 }
1194 }
1195 });
1196 format!("ws://{}", addr)
1197 }
1198
1199 fn test_app_state(
1200 tmp: &TempDir,
1201 relay: Arc<NostrRelay>,
1202 relay_url: String,
1203 ) -> Result<AppState> {
1204 let store = Arc::new(crate::storage::HashtreeStore::with_options(
1205 tmp.path(),
1206 None,
1207 128 * 1024 * 1024,
1208 )?);
1209 Ok(AppState {
1210 store,
1211 auth: None,
1212 webrtc_peers: None,
1213 ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1214 max_upload_bytes: 5 * 1024 * 1024,
1215 public_writes: true,
1216 allowed_pubkeys: HashSet::new(),
1217 upstream_blossom: Vec::new(),
1218 social_graph: None,
1219 social_graph_store: None,
1220 social_graph_root: None,
1221 socialgraph_snapshot_public: false,
1222 nostr_relay: Some(relay),
1223 nostr_relay_urls: vec![relay_url],
1224 tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1225 inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1226 std::collections::HashMap::new(),
1227 )),
1228 directory_listing_cache: Arc::new(std::sync::Mutex::new(
1229 super::super::auth::new_lookup_cache(),
1230 )),
1231 resolved_path_cache: Arc::new(std::sync::Mutex::new(
1232 super::super::auth::new_lookup_cache(),
1233 )),
1234 thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1235 super::super::auth::new_lookup_cache(),
1236 )),
1237 cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1238 })
1239 }
1240
1241 #[tokio::test]
1242 async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1243 let tmp = TempDir::new()?;
1244 let graph_store = {
1245 let _guard = crate::socialgraph::test_lock();
1246 crate::socialgraph::open_social_graph_store_with_mapsize(
1247 tmp.path(),
1248 Some(128 * 1024 * 1024),
1249 )?
1250 };
1251 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1252 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1253 Arc::clone(&backend),
1254 0,
1255 HashSet::new(),
1256 ));
1257
1258 let keys = Keys::generate();
1259 let relay = Arc::new(NostrRelay::new(
1260 Arc::clone(&backend),
1261 tmp.path().to_path_buf(),
1262 HashSet::from([keys.public_key().to_hex()]),
1263 Some(access),
1264 NostrRelayConfig {
1265 spambox_db_max_bytes: 0,
1266 ..Default::default()
1267 },
1268 )?);
1269
1270 let event = EventBuilder::new(
1271 Kind::from(30078_u16),
1272 "",
1273 [
1274 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1275 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1276 ],
1277 )
1278 .to_event(&keys)?;
1279
1280 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1281 let filter = Filter::new()
1282 .authors(vec![event.pubkey])
1283 .kinds(vec![event.kind]);
1284 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1285 let client_id = 7_u64;
1286 let (tx, mut rx) = mpsc::unbounded_channel();
1287 state.ws_relay.clients.lock().await.insert(client_id, tx);
1288 let subscription_id = SubscriptionId::new("sub-1");
1289
1290 start_upstream_nostr_subscription(
1291 &state,
1292 client_id,
1293 subscription_id.clone(),
1294 vec![filter.clone()],
1295 )
1296 .await;
1297
1298 let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1299 .await?
1300 .expect("forwarded upstream event");
1301 let Message::Text(text) = forwarded else {
1302 panic!("expected text event");
1303 };
1304 match NostrRelayMessage::from_json(text.as_str())? {
1305 NostrRelayMessage::Event {
1306 subscription_id: sid,
1307 event: forwarded_event,
1308 } => {
1309 assert_eq!(sid, subscription_id);
1310 assert_eq!(forwarded_event.id, event.id);
1311 }
1312 other => panic!("expected forwarded EVENT, got {:?}", other),
1313 }
1314
1315 let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1316 .await?
1317 .expect("forwarded upstream eose");
1318 let Message::Text(eose_text) = eose else {
1319 panic!("expected text eose");
1320 };
1321 match NostrRelayMessage::from_json(eose_text.as_str())? {
1322 NostrRelayMessage::EndOfStoredEvents(sid) => {
1323 assert_eq!(sid, subscription_id);
1324 }
1325 other => panic!("expected forwarded EOSE, got {:?}", other),
1326 }
1327
1328 let events = relay.query_events(&filter, 10).await;
1329 assert_eq!(events.len(), 1);
1330 assert_eq!(events[0].id, event.id);
1331
1332 close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1333 assert!(state
1334 .ws_relay
1335 .upstream_nostr_subscriptions
1336 .lock()
1337 .await
1338 .is_empty());
1339 Ok(())
1340 }
1341
1342 #[tokio::test]
1343 async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1344 let tmp = TempDir::new()?;
1345 let graph_store = {
1346 let _guard = crate::socialgraph::test_lock();
1347 crate::socialgraph::open_social_graph_store_with_mapsize(
1348 tmp.path(),
1349 Some(128 * 1024 * 1024),
1350 )?
1351 };
1352 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1353 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1354 Arc::clone(&backend),
1355 0,
1356 HashSet::new(),
1357 ));
1358
1359 let keys = Keys::generate();
1360 let relay = Arc::new(NostrRelay::new(
1361 Arc::clone(&backend),
1362 tmp.path().to_path_buf(),
1363 HashSet::from([keys.public_key().to_hex()]),
1364 Some(access),
1365 NostrRelayConfig {
1366 spambox_db_max_bytes: 0,
1367 ..Default::default()
1368 },
1369 )?);
1370
1371 let event = EventBuilder::new(
1372 Kind::from(30078_u16),
1373 "",
1374 [
1375 nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1376 nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1377 ],
1378 )
1379 .to_event(&keys)?;
1380
1381 let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1382 let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1383 let client_id = 11_u64;
1384 let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1385 let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1386 state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1387 relay.register_client(client_id, relay_tx, None).await;
1388
1389 let request = NostrClientMessage::req(
1390 SubscriptionId::new("feed"),
1391 vec![Filter::new()
1392 .authors(vec![event.pubkey])
1393 .kinds(vec![event.kind])],
1394 )
1395 .as_json();
1396
1397 handle_message(client_id, Message::Text(request.into()), &state).await;
1398
1399 let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1400 .await?
1401 .expect("first forwarded message");
1402 let Message::Text(first_text) = first else {
1403 panic!("expected text event");
1404 };
1405 match NostrRelayMessage::from_json(first_text.as_str())? {
1406 NostrRelayMessage::Event {
1407 event: forwarded_event,
1408 ..
1409 } => {
1410 assert_eq!(forwarded_event.id, event.id);
1411 }
1412 other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1413 }
1414
1415 let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1416 .await?
1417 .expect("second forwarded message");
1418 let Message::Text(second_text) = second else {
1419 panic!("expected text eose");
1420 };
1421 match NostrRelayMessage::from_json(second_text.as_str())? {
1422 NostrRelayMessage::EndOfStoredEvents(sid) => {
1423 assert_eq!(sid, SubscriptionId::new("feed"));
1424 }
1425 other => panic!("expected aggregated EOSE, got {:?}", other),
1426 }
1427
1428 Ok(())
1429 }
1430
1431 #[tokio::test]
1432 async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1433 let tmp = TempDir::new()?;
1434 let graph_store = {
1435 let _guard = crate::socialgraph::test_lock();
1436 crate::socialgraph::open_social_graph_store_with_mapsize(
1437 tmp.path(),
1438 Some(128 * 1024 * 1024),
1439 )?
1440 };
1441 let author_keys = Keys::generate();
1442 let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1443 let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1444 Arc::clone(&backend),
1445 0,
1446 HashSet::from([author_keys.public_key().to_hex()]),
1447 ));
1448 let relay = Arc::new(NostrRelay::new(
1449 Arc::clone(&backend),
1450 tmp.path().to_path_buf(),
1451 HashSet::from([author_keys.public_key().to_hex()]),
1452 Some(access),
1453 NostrRelayConfig {
1454 spambox_db_max_bytes: 0,
1455 ..Default::default()
1456 },
1457 )?);
1458
1459 let state = test_app_state(&tmp, relay.clone(), String::new())?;
1460 let listener = TcpListener::bind("127.0.0.1:0").await?;
1461 let addr = listener.local_addr()?;
1462 let client_pubkey = author_keys.public_key().to_hex();
1463 let app = axum::Router::new().route(
1464 "/ws",
1465 axum::routing::get({
1466 let state = state.clone();
1467 let client_pubkey = client_pubkey.clone();
1468 move |ws: WebSocketUpgrade| {
1469 let state = state.clone();
1470 let client_pubkey = client_pubkey.clone();
1471 async move {
1472 ws_data_with_client_pubkey(state, ws, Some(client_pubkey))
1473 }
1474 }
1475 }),
1476 );
1477 tokio::spawn(async move {
1478 let _ = axum::serve(listener, app).await;
1479 });
1480
1481 let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1482 let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1483 .to_event(&author_keys)?;
1484 socket
1485 .send(TungsteniteMessage::Text(
1486 NostrClientMessage::event(event.clone()).as_json().into(),
1487 ))
1488 .await?;
1489
1490 let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1491 .await?
1492 .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1493 let TungsteniteMessage::Text(text) = reply else {
1494 anyhow::bail!("expected text publish ack");
1495 };
1496
1497 match NostrRelayMessage::from_json(text.as_str())? {
1498 NostrRelayMessage::Ok {
1499 event_id, status, ..
1500 } => {
1501 assert_eq!(event_id, event.id);
1502 assert!(status);
1503 }
1504 other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1505 }
1506
1507 let stored = relay
1508 .query_events(
1509 &Filter::new()
1510 .authors(vec![event.pubkey])
1511 .kinds(vec![event.kind]),
1512 10,
1513 )
1514 .await;
1515 assert!(stored.iter().any(|candidate| candidate.id == event.id));
1516 Ok(())
1517 }
1518}