Skip to main content

hashtree_cli/server/
ws_relay.rs

1use axum::{
2    extract::{
3        ws::{Message, WebSocket, WebSocketUpgrade},
4        State,
5    },
6    response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11    ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12    RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21    nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24    encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31    #[serde(rename = "req")]
32    Request { id: u32, hash: String },
33    #[serde(rename = "res")]
34    Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage<'a> {
39    Hashtree(WsClientMessage),
40    Nostr(NostrClientMessage<'a>),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45    #[serde(rename = "type")]
46    kind: String,
47    id: u32,
48    hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53    #[serde(rename = "type")]
54    kind: &'static str,
55    id: u32,
56    hash: String,
57    found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61    ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65    state: AppState,
66    ws: WebSocketUpgrade,
67    client_pubkey: Option<String>,
68) -> impl IntoResponse {
69    ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73    // Use the Nostr relay's client-id generator when available so `/ws` IDs
74    // can't collide with WebRTC relay clients.
75    let client_id = state
76        .nostr_relay
77        .as_ref()
78        .map(|relay| relay.next_client_id())
79        .unwrap_or_else(|| state.ws_relay.next_id());
80    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82    {
83        let mut clients = state.ws_relay.clients.lock().await;
84        clients.insert(client_id, tx);
85    }
86    {
87        let mut protocols = state.ws_relay.client_protocols.lock().await;
88        protocols.insert(client_id, WsProtocol::HashtreeJson);
89    }
90
91    let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92        let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93        relay
94            .register_client(client_id, nostr_tx, client_pubkey.clone())
95            .await;
96        Some(nostr_rx)
97    } else {
98        None
99    };
100
101    let (mut sender, mut receiver) = socket.split();
102    loop {
103        tokio::select! {
104            maybe_msg = rx.recv() => {
105                let Some(msg) = maybe_msg else {
106                    break;
107                };
108                if sender.send(msg).await.is_err() {
109                    break;
110                }
111            }
112            maybe_text = async {
113                match &mut nostr_rx {
114                    Some(rx) => rx.recv().await,
115                    None => std::future::pending().await,
116                }
117            } => {
118                let Some(text) = maybe_text else {
119                    nostr_rx = None;
120                    continue;
121                };
122                if sender.send(Message::Text(text)).await.is_err() {
123                    break;
124                }
125            }
126            maybe_incoming = receiver.next() => {
127                match maybe_incoming {
128                    Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129                    Some(Err(_)) | None => break,
130                }
131            }
132        }
133    }
134
135    close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137    {
138        let mut clients = state.ws_relay.clients.lock().await;
139        clients.remove(&client_id);
140    }
141    {
142        let mut protocols = state.ws_relay.client_protocols.lock().await;
143        protocols.remove(&client_id);
144    }
145    {
146        let mut pending = state.ws_relay.pending.lock().await;
147        pending.retain(|(peer_id, _), _| *peer_id != client_id);
148    }
149
150    if let Some(relay) = &state.nostr_relay {
151        relay.unregister_client(client_id).await;
152    }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage<'static>> {
156    let trimmed = text.trim_start();
157    if trimmed.starts_with('[') {
158        if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159            return Some(WsTextMessage::Nostr(msg));
160        }
161    }
162
163    if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164        return Some(WsTextMessage::Hashtree(msg));
165    }
166
167    None
168}
169async fn close_upstream_nostr_subscription(
170    state: &AppState,
171    client_id: u64,
172    subscription_id: &SubscriptionId,
173) {
174    let key = (client_id, subscription_id.to_string());
175    let subscription = {
176        let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177        subscriptions.remove(&key)
178    };
179    if let Some(subscription) = subscription {
180        let _ = subscription.close_tx.send(true);
181        for task in subscription.tasks {
182            task.abort();
183        }
184    }
185    state
186        .ws_relay
187        .upstream_pending_eose
188        .lock()
189        .await
190        .remove(&key);
191    state
192        .ws_relay
193        .upstream_seen_events
194        .lock()
195        .await
196        .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200    let keys = {
201        let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202        subscriptions
203            .keys()
204            .filter(|(id, _)| *id == client_id)
205            .cloned()
206            .collect::<Vec<_>>()
207    };
208    for (_, sub_id) in keys {
209        close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210    }
211}
212
213async fn forward_upstream_nostr_message(
214    state: &AppState,
215    client_id: u64,
216    subscription_id: &SubscriptionId,
217    text: &str,
218) {
219    let Ok(message) = NostrRelayMessage::from_json(text) else {
220        return;
221    };
222
223    match message {
224        NostrRelayMessage::Event {
225            subscription_id: sid,
226            event,
227        } if sid.as_ref() == subscription_id => {
228            let event = event.into_owned();
229            let key = (client_id, subscription_id.to_string());
230            let event_id = event.id.to_hex();
231            let inserted = {
232                let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233                seen_events.entry(key).or_default().insert(event_id)
234            };
235            if !inserted {
236                return;
237            }
238            if let Some(relay) = &state.nostr_relay {
239                let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240            }
241            send_nostr(
242                state,
243                client_id,
244                NostrRelayMessage::event(subscription_id.clone(), event),
245            )
246            .await;
247        }
248        NostrRelayMessage::Closed {
249            subscription_id: sid,
250            message,
251        } if sid.as_ref() == subscription_id => {
252            send_nostr(
253                state,
254                client_id,
255                NostrRelayMessage::closed(subscription_id.clone(), message.into_owned()),
256            )
257            .await;
258        }
259        _ => {}
260    }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264    state: &AppState,
265    client_id: u64,
266    subscription_id: &SubscriptionId,
267) {
268    let key = (client_id, subscription_id.to_string());
269    let should_send_eose = {
270        let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271        let Some(remaining) = pending.get_mut(&key) else {
272            return;
273        };
274        if *remaining > 0 {
275            *remaining -= 1;
276        }
277        if *remaining == 0 {
278            pending.remove(&key);
279            true
280        } else {
281            false
282        }
283    };
284
285    if should_send_eose {
286        send_nostr(
287            state,
288            client_id,
289            NostrRelayMessage::eose(subscription_id.clone()),
290        )
291        .await;
292    }
293}
294
295async fn run_upstream_nostr_subscription(
296    state: AppState,
297    client_id: u64,
298    relay_url: String,
299    subscription_id: SubscriptionId,
300    filters: Vec<NostrFilter>,
301    mut close_rx: watch::Receiver<bool>,
302) {
303    let mut relay_complete = false;
304    let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305        tracing::warn!(
306            "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307            client_id,
308            subscription_id,
309            relay_url,
310        );
311        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312        return;
313    };
314    let (mut write, mut read) = socket.split();
315    let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316    state
317        .ws_relay
318        .note_upstream_relay_send(request.as_bytes().len());
319    if write
320        .send(TungsteniteMessage::Text(request.into()))
321        .await
322        .is_err()
323    {
324        tracing::warn!(
325            "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326            client_id,
327            subscription_id,
328            relay_url,
329        );
330        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331        return;
332    }
333
334    loop {
335        tokio::select! {
336            _ = close_rx.changed() => {
337                if *close_rx.borrow() {
338                    let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339                    state
340                        .ws_relay
341                        .note_upstream_relay_send(close.as_bytes().len());
342                    let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343                    let _ = write.close().await;
344                    break;
345                }
346            }
347            message = read.next() => {
348                match message {
349                    Some(Ok(TungsteniteMessage::Text(text))) => {
350                        state
351                            .ws_relay
352                            .note_upstream_relay_receive(text.as_bytes().len());
353                        if matches!(
354                            NostrRelayMessage::from_json(text.as_str()),
355                            Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid.as_ref() == &subscription_id
356                        ) {
357                            if !relay_complete {
358                                relay_complete = true;
359                                mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360                            }
361                            continue;
362                        }
363                        forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364                    }
365                    Some(Ok(TungsteniteMessage::Ping(payload))) => {
366                        let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367                    }
368                    Some(Ok(TungsteniteMessage::Close(_))) | None => {
369                        if !relay_complete {
370                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371                        }
372                        break;
373                    }
374                    Some(Err(_)) => {
375                        if !relay_complete {
376                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377                        }
378                        break;
379                    }
380                    _ => {}
381                }
382            }
383        }
384    }
385}
386
387async fn start_upstream_nostr_subscription(
388    state: &AppState,
389    client_id: u64,
390    subscription_id: SubscriptionId,
391    filters: Vec<NostrFilter>,
392) -> usize {
393    let memory_before = process_memory_snapshot();
394    close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395    if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396        tracing::info!(
397            "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398            client_id,
399            subscription_id,
400            state.nostr_relay_urls.len(),
401            filters.len(),
402        );
403        return 0;
404    }
405
406    let mut relay_urls = Vec::new();
407    let mut seen = HashSet::new();
408    for relay in &state.nostr_relay_urls {
409        let relay = relay.trim();
410        if relay.is_empty() || !seen.insert(relay.to_string()) {
411            continue;
412        }
413        relay_urls.push(relay.to_string());
414    }
415    if relay_urls.is_empty() {
416        tracing::info!(
417            "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418            client_id,
419            subscription_id,
420        );
421        return 0;
422    }
423
424    tracing::info!(
425        "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426        client_id,
427        subscription_id,
428        relay_urls.len(),
429    );
430
431    let filter_summary = nostr_filters_summary(&filters);
432    let key = (client_id, subscription_id.to_string());
433    state
434        .ws_relay
435        .upstream_seen_events
436        .lock()
437        .await
438        .insert(key.clone(), HashSet::new());
439    state
440        .ws_relay
441        .upstream_pending_eose
442        .lock()
443        .await
444        .insert(key.clone(), relay_urls.len());
445
446    let (close_tx, close_rx) = watch::channel(false);
447    let mut tasks = Vec::new();
448    for relay_url in &relay_urls {
449        tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450            state.clone(),
451            client_id,
452            relay_url.clone(),
453            subscription_id.clone(),
454            filters.clone(),
455            close_rx.clone(),
456        )));
457    }
458
459    tracing::info!(
460        target: "hashtree_cli::server::ws_relay::upstream",
461        client_id,
462        subscription_id = %subscription_id,
463        relays = relay_urls.len(),
464        tasks = tasks.len(),
465        filters = filters.len(),
466        filter = %filter_summary,
467        memory_before = ?memory_before,
468        memory_after = ?process_memory_snapshot(),
469        "upstream nostr relay tasks spawned",
470    );
471
472    state
473        .ws_relay
474        .upstream_nostr_subscriptions
475        .lock()
476        .await
477        .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478    relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482    match msg {
483        Message::Text(text) => {
484            if let Some(msg) = parse_ws_text_message(&text) {
485                match msg {
486                    WsTextMessage::Hashtree(msg) => {
487                        set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488                        match msg {
489                            WsClientMessage::Request { id, hash } => {
490                                handle_request(
491                                    client_id,
492                                    id,
493                                    hash,
494                                    WsProtocol::HashtreeJson,
495                                    state,
496                                )
497                                .await;
498                            }
499                            WsClientMessage::Response { id, hash, found } => {
500                                handle_response(client_id, id, hash, found, state).await;
501                            }
502                        }
503                    }
504                    WsTextMessage::Nostr(msg) => {
505                        if let Some(relay) = &state.nostr_relay {
506                            match msg {
507                                NostrClientMessage::Req {
508                                    subscription_id,
509                                    filters,
510                                } => {
511                                    let subscription_id = subscription_id.into_owned();
512                                    let filters = filters
513                                        .into_iter()
514                                        .map(|filter| filter.into_owned())
515                                        .collect::<Vec<_>>();
516                                    let local_events = match relay
517                                        .register_subscription_query(
518                                            client_id,
519                                            subscription_id.clone(),
520                                            filters.clone(),
521                                        )
522                                        .await
523                                    {
524                                        Ok(events) => events,
525                                        Err(message) => {
526                                            send_nostr(
527                                                state,
528                                                client_id,
529                                                NostrRelayMessage::closed(subscription_id, message),
530                                            )
531                                            .await;
532                                            return;
533                                        }
534                                    };
535
536                                    let upstream_relays = start_upstream_nostr_subscription(
537                                        state,
538                                        client_id,
539                                        subscription_id.clone(),
540                                        filters,
541                                    )
542                                    .await;
543                                    if upstream_relays > 0 {
544                                        let key = (client_id, subscription_id.to_string());
545                                        let mut seen_events =
546                                            state.ws_relay.upstream_seen_events.lock().await;
547                                        seen_events.entry(key).or_default().extend(
548                                            local_events.iter().map(|event| event.id.to_hex()),
549                                        );
550                                    }
551                                    for event in local_events {
552                                        send_nostr(
553                                            state,
554                                            client_id,
555                                            NostrRelayMessage::event(
556                                                subscription_id.clone(),
557                                                event,
558                                            ),
559                                        )
560                                        .await;
561                                    }
562                                    trim_process_allocations();
563                                    if upstream_relays == 0 {
564                                        send_nostr(
565                                            state,
566                                            client_id,
567                                            NostrRelayMessage::eose(subscription_id),
568                                        )
569                                        .await;
570                                    }
571                                }
572                                NostrClientMessage::Close(subscription_id) => {
573                                    let subscription_id = subscription_id.into_owned();
574                                    close_upstream_nostr_subscription(
575                                        state,
576                                        client_id,
577                                        &subscription_id,
578                                    )
579                                    .await;
580                                    relay
581                                        .handle_client_message(
582                                            client_id,
583                                            NostrClientMessage::close(subscription_id.clone()),
584                                        )
585                                        .await;
586                                }
587                                other => {
588                                    relay.handle_client_message(client_id, other).await;
589                                }
590                            }
591                        } else {
592                            handle_nostr_message(client_id, msg, state).await;
593                        }
594                    }
595                }
596            }
597        }
598        Message::Binary(data) => {
599            handle_binary(client_id, data, state).await;
600        }
601        Message::Close(_) => {}
602        _ => {}
603    }
604}
605
606async fn handle_request(
607    client_id: u64,
608    request_id: u32,
609    hash: String,
610    origin_protocol: WsProtocol,
611    state: &AppState,
612) {
613    let hash_hex = hash.to_lowercase();
614    let hash_bytes = match from_hex(&hash_hex) {
615        Ok(bytes) => bytes,
616        Err(_) => {
617            if origin_protocol == WsProtocol::HashtreeJson {
618                send_json(
619                    state,
620                    client_id,
621                    WsResponse {
622                        kind: "res",
623                        id: request_id,
624                        hash,
625                        found: false,
626                    },
627                )
628                .await;
629            }
630            return;
631        }
632    };
633
634    if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
635        match origin_protocol {
636            WsProtocol::HashtreeJson => {
637                send_json(
638                    state,
639                    client_id,
640                    WsResponse {
641                        kind: "res",
642                        id: request_id,
643                        hash: hash.clone(),
644                        found: true,
645                    },
646                )
647                .await;
648                send_binary(state, client_id, request_id, data).await;
649            }
650            WsProtocol::HashtreeMsgpack => {
651                send_msgpack_response(state, client_id, &hash_bytes, &data).await;
652            }
653            WsProtocol::Unknown => {}
654        }
655        return;
656    }
657
658    let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
659        let clients = state.ws_relay.clients.lock().await;
660        let protocols = state.ws_relay.client_protocols.lock().await;
661        clients
662            .iter()
663            .filter(|(id, _)| **id != client_id)
664            .filter_map(|(id, tx)| {
665                let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
666                match protocol {
667                    WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
668                        Some((*id, tx.clone(), protocol))
669                    }
670                    WsProtocol::Unknown => None,
671                }
672            })
673            .collect()
674    };
675
676    if peers.is_empty() {
677        if origin_protocol == WsProtocol::HashtreeJson {
678            send_json(
679                state,
680                client_id,
681                WsResponse {
682                    kind: "res",
683                    id: request_id,
684                    hash,
685                    found: false,
686                },
687            )
688            .await;
689        }
690        return;
691    }
692
693    {
694        let mut pending = state.ws_relay.pending.lock().await;
695        for (peer_id, _, _) in &peers {
696            pending.insert(
697                (*peer_id, request_id),
698                PendingRequest {
699                    origin_id: client_id,
700                    hash: hash.clone(),
701                    found: false,
702                    origin_protocol,
703                },
704            );
705        }
706    }
707
708    let request_text = serde_json::to_string(&WsRequest {
709        kind: "req".to_string(),
710        id: request_id,
711        hash: hash.clone(),
712    })
713    .unwrap_or_else(|_| String::new());
714    for (peer_id, tx, protocol) in peers {
715        match protocol {
716            WsProtocol::HashtreeMsgpack => {
717                let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
718            }
719            WsProtocol::HashtreeJson => {
720                let _ = tx.send(Message::Text(request_text.clone()));
721            }
722            WsProtocol::Unknown => {}
723        }
724    }
725
726    let timeout_state = state.clone();
727    let timeout_hash = hash.clone();
728    tokio::spawn(async move {
729        tokio::time::sleep(Duration::from_millis(1500)).await;
730        let mut pending = timeout_state.ws_relay.pending.lock().await;
731        let still_pending = pending
732            .iter()
733            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
734        let already_found = pending
735            .iter()
736            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
737        if !still_pending || already_found {
738            return;
739        }
740        let origin_protocol = pending
741            .iter()
742            .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
743            .map(|(_, p)| p.origin_protocol)
744            .unwrap_or(WsProtocol::HashtreeJson);
745        pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
746        drop(pending);
747        if origin_protocol == WsProtocol::HashtreeJson {
748            send_json(
749                &timeout_state,
750                client_id,
751                WsResponse {
752                    kind: "res",
753                    id: request_id,
754                    hash: timeout_hash,
755                    found: false,
756                },
757            )
758            .await;
759        }
760    });
761}
762
763async fn handle_response(
764    client_id: u64,
765    request_id: u32,
766    _hash: String,
767    found: bool,
768    state: &AppState,
769) {
770    let pending_entry = {
771        let pending = state.ws_relay.pending.lock().await;
772        pending
773            .get(&(client_id, request_id))
774            .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
775    };
776
777    let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
778        return;
779    };
780
781    if already_found && !found {
782        let mut pending = state.ws_relay.pending.lock().await;
783        pending.remove(&(client_id, request_id));
784        return;
785    }
786
787    if found {
788        let mut pending = state.ws_relay.pending.lock().await;
789        for ((_, id), p) in pending.iter_mut() {
790            if *id == request_id && p.origin_id == origin_id {
791                p.found = true;
792            }
793        }
794        drop(pending);
795        if origin_protocol == WsProtocol::HashtreeJson {
796            send_json(
797                state,
798                origin_id,
799                WsResponse {
800                    kind: "res",
801                    id: request_id,
802                    hash: pending_hash,
803                    found: true,
804                },
805            )
806            .await;
807        }
808        return;
809    }
810
811    let mut pending = state.ws_relay.pending.lock().await;
812    pending.remove(&(client_id, request_id));
813    let has_remaining = pending
814        .iter()
815        .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
816    drop(pending);
817
818    if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
819        send_json(
820            state,
821            origin_id,
822            WsResponse {
823                kind: "res",
824                id: request_id,
825                hash: pending_hash,
826                found: false,
827            },
828        )
829        .await;
830    }
831}
832
833async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
834    if let Some(msg) = parse_msgpack_message(&data) {
835        set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
836        match msg {
837            DataMessage::Request(req) => {
838                let hash_hex = hex_encode(&req.h);
839                let request_id = state.ws_relay.next_request_id();
840                handle_request(
841                    client_id,
842                    request_id,
843                    hash_hex,
844                    WsProtocol::HashtreeMsgpack,
845                    state,
846                )
847                .await;
848            }
849            DataMessage::Response(res) => {
850                handle_msgpack_response(client_id, res, state).await;
851            }
852            DataMessage::QuoteRequest(_)
853            | DataMessage::QuoteResponse(_)
854            | DataMessage::Payment(_)
855            | DataMessage::PaymentAck(_)
856            | DataMessage::Chunk(_)
857            | DataMessage::PeerHints(_)
858            | DataMessage::PubsubInterest(_)
859            | DataMessage::PubsubFrame(_)
860            | DataMessage::PubsubInventory(_)
861            | DataMessage::PubsubWant(_) => {}
862        }
863        return;
864    }
865
866    // Legacy binary: [4-byte LE request_id][data]
867    if data.len() < 4 {
868        return;
869    }
870    let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
871    let pending_entry = {
872        let pending = state.ws_relay.pending.lock().await;
873        pending
874            .get(&(client_id, request_id))
875            .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
876    };
877    let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
878        return;
879    };
880
881    match origin_protocol {
882        WsProtocol::HashtreeJson => {
883            send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
884        }
885        WsProtocol::HashtreeMsgpack => {
886            let Ok(hash_bytes) = from_hex(&hash_hex) else {
887                return;
888            };
889            send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
890        }
891        WsProtocol::Unknown => {}
892    }
893
894    let mut pending = state.ws_relay.pending.lock().await;
895    pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
896}
897
898async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage<'_>, state: &AppState) {
899    let replies = nostr_responses_for(&msg);
900    for reply in replies {
901        send_nostr(state, client_id, reply).await;
902    }
903}
904
905fn nostr_responses_for(msg: &NostrClientMessage<'_>) -> Vec<NostrRelayMessage<'static>> {
906    match msg {
907        NostrClientMessage::Event(event) => {
908            let ok = event.verify().is_ok();
909            let message = if ok { "" } else { "invalid: signature" };
910            vec![NostrRelayMessage::ok(event.id, ok, message)]
911        }
912        NostrClientMessage::Req {
913            subscription_id, ..
914        } => {
915            vec![NostrRelayMessage::eose(
916                subscription_id.clone().into_owned(),
917            )]
918        }
919        NostrClientMessage::Count {
920            subscription_id, ..
921        } => {
922            vec![NostrRelayMessage::count(
923                subscription_id.clone().into_owned(),
924                0,
925            )]
926        }
927        NostrClientMessage::Close(_) => Vec::new(),
928        NostrClientMessage::Auth(event) => {
929            let ok = event.verify().is_ok();
930            let message = if ok { "" } else { "invalid auth" };
931            vec![NostrRelayMessage::ok(event.id, ok, message)]
932        }
933        NostrClientMessage::NegOpen { .. }
934        | NostrClientMessage::NegMsg { .. }
935        | NostrClientMessage::NegClose { .. } => {
936            vec![NostrRelayMessage::notice("negentropy not supported")]
937        }
938    }
939}
940
941async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage<'_>) {
942    let text = response.as_json();
943    send_to_client(state, client_id, Message::Text(text)).await;
944}
945
946fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
947    let msg = parse_message(data).ok()?;
948    match msg {
949        DataMessage::Request(req) => {
950            if req.h.len() == 32 {
951                Some(DataMessage::Request(req))
952            } else {
953                None
954            }
955        }
956        DataMessage::Response(res) => {
957            if res.h.len() == 32 {
958                Some(DataMessage::Response(res))
959            } else {
960                None
961            }
962        }
963        DataMessage::QuoteRequest(req) => {
964            if req.h.len() == 32 {
965                Some(DataMessage::QuoteRequest(req))
966            } else {
967                None
968            }
969        }
970        DataMessage::QuoteResponse(res) => {
971            if res.h.len() == 32 {
972                Some(DataMessage::QuoteResponse(res))
973            } else {
974                None
975            }
976        }
977        DataMessage::Payment(req) => {
978            if req.h.len() == 32 {
979                Some(DataMessage::Payment(req))
980            } else {
981                None
982            }
983        }
984        DataMessage::PaymentAck(res) => {
985            if res.h.len() == 32 {
986                Some(DataMessage::PaymentAck(res))
987            } else {
988                None
989            }
990        }
991        DataMessage::Chunk(chunk) => {
992            if chunk.h.len() == 32 {
993                Some(DataMessage::Chunk(chunk))
994            } else {
995                None
996            }
997        }
998        DataMessage::PeerHints(_)
999        | DataMessage::PubsubInterest(_)
1000        | DataMessage::PubsubFrame(_)
1001        | DataMessage::PubsubInventory(_)
1002        | DataMessage::PubsubWant(_) => Some(msg),
1003    }
1004}
1005
1006async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
1007    let hash_hex = hex_encode(&res.h);
1008    let data = res.d.clone();
1009    let hash_bytes = res.h.clone();
1010
1011    let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
1012    let mut seen = HashSet::new();
1013    {
1014        let pending = state.ws_relay.pending.lock().await;
1015        for ((peer_id, request_id), p) in pending.iter() {
1016            if *peer_id != client_id {
1017                continue;
1018            }
1019            if p.hash != hash_hex {
1020                continue;
1021            }
1022            if seen.insert((p.origin_id, *request_id)) {
1023                responses.push((p.origin_id, *request_id, p.origin_protocol));
1024            }
1025        }
1026    }
1027
1028    if responses.is_empty() {
1029        return;
1030    }
1031
1032    for (origin_id, request_id, protocol) in &responses {
1033        match protocol {
1034            WsProtocol::HashtreeJson => {
1035                send_json(
1036                    state,
1037                    *origin_id,
1038                    WsResponse {
1039                        kind: "res",
1040                        id: *request_id,
1041                        hash: hash_hex.clone(),
1042                        found: true,
1043                    },
1044                )
1045                .await;
1046                send_binary(state, *origin_id, *request_id, data.clone()).await;
1047            }
1048            WsProtocol::HashtreeMsgpack => {
1049                send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1050            }
1051            WsProtocol::Unknown => {}
1052        }
1053    }
1054
1055    let completed: HashSet<(u64, u32)> = responses
1056        .into_iter()
1057        .map(|(origin_id, request_id, _)| (origin_id, request_id))
1058        .collect();
1059    let mut pending = state.ws_relay.pending.lock().await;
1060    pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1061}
1062
1063async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1064    if let Ok(text) = serde_json::to_string(&response) {
1065        send_to_client(state, client_id, Message::Text(text)).await;
1066    }
1067}
1068
1069async fn send_msgpack_request(
1070    state: &AppState,
1071    client_id: u64,
1072    hash: &[u8],
1073) -> Result<(), rmp_serde::encode::Error> {
1074    let req = DataRequest {
1075        h: hash.to_vec(),
1076        htl: MAX_HTL,
1077        q: None,
1078    };
1079    let wire = encode_request(&req)?;
1080    send_to_client(state, client_id, Message::Binary(wire)).await;
1081    Ok(())
1082}
1083
1084async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1085    let res = DataResponse {
1086        h: hash.to_vec(),
1087        d: data.to_vec(),
1088        i: None,
1089        n: None,
1090    };
1091    if let Ok(wire) = encode_response(&res) {
1092        send_to_client(state, client_id, Message::Binary(wire)).await;
1093    }
1094}
1095
1096async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1097    let mut packet = Vec::with_capacity(4 + payload.len());
1098    packet.extend_from_slice(&request_id.to_le_bytes());
1099    packet.extend_from_slice(&payload);
1100    send_to_client(state, client_id, Message::Binary(packet)).await;
1101}
1102
1103async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1104    let sender = {
1105        let clients = state.ws_relay.clients.lock().await;
1106        clients.get(&client_id).cloned()
1107    };
1108    if let Some(tx) = sender {
1109        let _ = tx.send(msg);
1110    }
1111}
1112
1113async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1114    let mut protocols = state.ws_relay.client_protocols.lock().await;
1115    protocols.insert(client_id, protocol);
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120    use super::*;
1121    use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1122    use anyhow::Result;
1123    use futures::{SinkExt, StreamExt};
1124    use nostr::secp256k1::schnorr::Signature;
1125    use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1126    use std::collections::HashSet;
1127    use std::sync::Arc;
1128    use tempfile::TempDir;
1129    use tokio::net::TcpListener;
1130    use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1131
1132    #[test]
1133    fn parse_ws_text_message_detects_nostr_req() {
1134        let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1135        match parse_ws_text_message(msg) {
1136            Some(WsTextMessage::Nostr(_)) => {}
1137            other => panic!("expected Nostr message, got {:?}", other),
1138        }
1139    }
1140
1141    #[test]
1142    fn parse_ws_text_message_detects_hashtree_request() {
1143        let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1144        match parse_ws_text_message(msg) {
1145            Some(WsTextMessage::Hashtree(_)) => {}
1146            other => panic!("expected Hashtree message, got {:?}", other),
1147        }
1148    }
1149
1150    #[test]
1151    fn nostr_replies_for_req_is_eose() {
1152        let sub = SubscriptionId::new("sub-1");
1153        let msg = NostrClientMessage::req(sub.clone(), vec![]);
1154        let replies = nostr_responses_for(&msg);
1155        assert_eq!(replies.len(), 1);
1156        match &replies[0] {
1157            NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id.as_ref(), &sub),
1158            other => panic!("expected EOSE, got {:?}", other),
1159        }
1160    }
1161
1162    #[test]
1163    fn nostr_replies_for_event_ok() {
1164        let keys = Keys::generate();
1165        let event = EventBuilder::new(Kind::TextNote, "hello")
1166            .sign_with_keys(&keys)
1167            .unwrap();
1168        let msg = NostrClientMessage::event(event.clone());
1169        let replies = nostr_responses_for(&msg);
1170        assert_eq!(replies.len(), 1);
1171        match &replies[0] {
1172            NostrRelayMessage::Ok {
1173                event_id, status, ..
1174            } => {
1175                assert_eq!(event_id, &event.id);
1176                assert!(*status);
1177            }
1178            other => panic!("expected OK, got {:?}", other),
1179        }
1180    }
1181
1182    #[test]
1183    fn nostr_replies_for_invalid_event_is_not_ok() {
1184        let keys = Keys::generate();
1185        let mut event = EventBuilder::new(Kind::TextNote, "hello")
1186            .sign_with_keys(&keys)
1187            .unwrap();
1188        event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1189        let msg = NostrClientMessage::event(event);
1190        let replies = nostr_responses_for(&msg);
1191        assert_eq!(replies.len(), 1);
1192        match &replies[0] {
1193            NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1194            other => panic!("expected OK=false, got {:?}", other),
1195        }
1196    }
1197
1198    async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1199        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1200        let addr = listener.local_addr().expect("relay addr");
1201        tokio::spawn(async move {
1202            let (stream, _) = listener.accept().await.expect("accept relay");
1203            let ws = accept_async(stream).await.expect("accept websocket");
1204            let (mut write, mut read) = ws.split();
1205
1206            while let Some(Ok(message)) = read.next().await {
1207                let TungsteniteMessage::Text(text) = message else {
1208                    continue;
1209                };
1210                let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1211                    continue;
1212                };
1213                if let NostrClientMessage::Req {
1214                    subscription_id,
1215                    filters,
1216                } = parsed
1217                {
1218                    let subscription_id = subscription_id.into_owned();
1219                    let filters = filters
1220                        .into_iter()
1221                        .map(|filter| filter.into_owned())
1222                        .collect::<Vec<_>>();
1223                    for event in events.iter().filter(|event| {
1224                        filters
1225                            .iter()
1226                            .any(|filter| filter.match_event(event, Default::default()))
1227                    }) {
1228                        let _ = write
1229                            .send(TungsteniteMessage::Text(
1230                                NostrRelayMessage::event(subscription_id.clone(), event.clone())
1231                                    .as_json()
1232                                    .into(),
1233                            ))
1234                            .await;
1235                    }
1236                    let _ = write
1237                        .send(TungsteniteMessage::Text(
1238                            NostrRelayMessage::eose(subscription_id).as_json().into(),
1239                        ))
1240                        .await;
1241                }
1242            }
1243        });
1244        format!("ws://{}", addr)
1245    }
1246
1247    fn test_app_state(
1248        tmp: &TempDir,
1249        relay: Arc<NostrRelay>,
1250        relay_url: String,
1251    ) -> Result<AppState> {
1252        let store = Arc::new(crate::storage::HashtreeStore::with_options(
1253            tmp.path(),
1254            None,
1255            128 * 1024 * 1024,
1256        )?);
1257        Ok(AppState {
1258            store,
1259            auth: None,
1260            daemon_started_at: 1_700_000_000,
1261            peer_mode: crate::config::ServerMode::Normal,
1262            hash_get_enabled: true,
1263            http_webrtc_fetch: true,
1264            webrtc_peers: None,
1265            fips_transport: None,
1266            fetch_from_fips_peers: true,
1267            ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1268            max_upload_bytes: 5 * 1024 * 1024,
1269            public_writes: true,
1270            public_plaintext_reads: true,
1271            require_random_untrusted_ingest: false,
1272            optimistic_blossom_uploads: false,
1273            optimistic_upload_queue_bytes: 512 * 1024 * 1024,
1274            optimistic_upload_queue: Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)),
1275            allowed_pubkeys: HashSet::new(),
1276            upstream_blossom: Vec::new(),
1277            social_graph: None,
1278            social_graph_store: None,
1279            social_graph_root: None,
1280            socialgraph_snapshot_public: false,
1281            nostr_relay: Some(relay),
1282            nostr_relay_urls: vec![relay_url],
1283            tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1284            inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1285                std::collections::HashMap::new(),
1286            )),
1287            inflight_blob_reads: Arc::new(
1288                tokio::sync::Mutex::new(std::collections::HashMap::new()),
1289            ),
1290            blob_cache: Arc::new(crate::blob_cache::BlobCache::for_tests()),
1291            directory_listing_cache: Arc::new(std::sync::Mutex::new(
1292                super::super::auth::new_lookup_cache(),
1293            )),
1294            resolved_path_cache: Arc::new(std::sync::Mutex::new(
1295                super::super::auth::new_lookup_cache(),
1296            )),
1297            thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1298                super::super::auth::new_lookup_cache(),
1299            )),
1300            cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1301        })
1302    }
1303
1304    #[tokio::test]
1305    async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1306        let tmp = TempDir::new()?;
1307        let graph_store = {
1308            let _guard = crate::socialgraph::test_lock();
1309            crate::socialgraph::open_social_graph_store_with_mapsize(
1310                tmp.path(),
1311                Some(128 * 1024 * 1024),
1312            )?
1313        };
1314        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1315        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1316            Arc::clone(&backend),
1317            0,
1318            HashSet::new(),
1319        ));
1320
1321        let keys = Keys::generate();
1322        let relay = Arc::new(NostrRelay::new(
1323            Arc::clone(&backend),
1324            tmp.path().to_path_buf(),
1325            HashSet::from([keys.public_key().to_hex()]),
1326            Some(access),
1327            NostrRelayConfig {
1328                spambox_db_max_bytes: 0,
1329                ..Default::default()
1330            },
1331        )?);
1332
1333        let event = EventBuilder::new(Kind::from(30078_u16), "")
1334            .tags([
1335                nostr::Tag::parse(["d", "videos/Test"]).expect("d tag"),
1336                nostr::Tag::parse(["l", "hashtree"]).expect("label tag"),
1337            ])
1338            .sign_with_keys(&keys)?;
1339
1340        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1341        let filter = Filter::new()
1342            .authors(vec![event.pubkey])
1343            .kinds(vec![event.kind]);
1344        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1345        let client_id = 7_u64;
1346        let (tx, mut rx) = mpsc::unbounded_channel();
1347        state.ws_relay.clients.lock().await.insert(client_id, tx);
1348        let subscription_id = SubscriptionId::new("sub-1");
1349
1350        start_upstream_nostr_subscription(
1351            &state,
1352            client_id,
1353            subscription_id.clone(),
1354            vec![filter.clone()],
1355        )
1356        .await;
1357
1358        let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1359            .await?
1360            .expect("forwarded upstream event");
1361        let Message::Text(text) = forwarded else {
1362            panic!("expected text event");
1363        };
1364        match NostrRelayMessage::from_json(text.as_str())? {
1365            NostrRelayMessage::Event {
1366                subscription_id: sid,
1367                event: forwarded_event,
1368            } => {
1369                assert_eq!(sid.as_ref(), &subscription_id);
1370                assert_eq!(forwarded_event.id, event.id);
1371            }
1372            other => panic!("expected forwarded EVENT, got {:?}", other),
1373        }
1374
1375        let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1376            .await?
1377            .expect("forwarded upstream eose");
1378        let Message::Text(eose_text) = eose else {
1379            panic!("expected text eose");
1380        };
1381        match NostrRelayMessage::from_json(eose_text.as_str())? {
1382            NostrRelayMessage::EndOfStoredEvents(sid) => {
1383                assert_eq!(sid.as_ref(), &subscription_id);
1384            }
1385            other => panic!("expected forwarded EOSE, got {:?}", other),
1386        }
1387
1388        let events = relay.query_events(&filter, 10).await;
1389        assert_eq!(events.len(), 1);
1390        assert_eq!(events[0].id, event.id);
1391
1392        close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1393        assert!(state
1394            .ws_relay
1395            .upstream_nostr_subscriptions
1396            .lock()
1397            .await
1398            .is_empty());
1399        Ok(())
1400    }
1401
1402    #[tokio::test]
1403    async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1404        let tmp = TempDir::new()?;
1405        let graph_store = {
1406            let _guard = crate::socialgraph::test_lock();
1407            crate::socialgraph::open_social_graph_store_with_mapsize(
1408                tmp.path(),
1409                Some(128 * 1024 * 1024),
1410            )?
1411        };
1412        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1413        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1414            Arc::clone(&backend),
1415            0,
1416            HashSet::new(),
1417        ));
1418
1419        let keys = Keys::generate();
1420        let relay = Arc::new(NostrRelay::new(
1421            Arc::clone(&backend),
1422            tmp.path().to_path_buf(),
1423            HashSet::from([keys.public_key().to_hex()]),
1424            Some(access),
1425            NostrRelayConfig {
1426                spambox_db_max_bytes: 0,
1427                ..Default::default()
1428            },
1429        )?);
1430
1431        let event = EventBuilder::new(Kind::from(30078_u16), "")
1432            .tags([
1433                nostr::Tag::parse(["d", "videos/Test"]).expect("d tag"),
1434                nostr::Tag::parse(["l", "hashtree"]).expect("label tag"),
1435            ])
1436            .sign_with_keys(&keys)?;
1437
1438        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1439        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1440        let client_id = 11_u64;
1441        let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1442        let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1443        state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1444        relay.register_client(client_id, relay_tx, None).await;
1445
1446        let request = NostrClientMessage::req(
1447            SubscriptionId::new("feed"),
1448            vec![Filter::new()
1449                .authors(vec![event.pubkey])
1450                .kinds(vec![event.kind])],
1451        )
1452        .as_json();
1453
1454        handle_message(client_id, Message::Text(request.into()), &state).await;
1455
1456        let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1457            .await?
1458            .expect("first forwarded message");
1459        let Message::Text(first_text) = first else {
1460            panic!("expected text event");
1461        };
1462        match NostrRelayMessage::from_json(first_text.as_str())? {
1463            NostrRelayMessage::Event {
1464                event: forwarded_event,
1465                ..
1466            } => {
1467                assert_eq!(forwarded_event.id, event.id);
1468            }
1469            other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1470        }
1471
1472        let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1473            .await?
1474            .expect("second forwarded message");
1475        let Message::Text(second_text) = second else {
1476            panic!("expected text eose");
1477        };
1478        match NostrRelayMessage::from_json(second_text.as_str())? {
1479            NostrRelayMessage::EndOfStoredEvents(sid) => {
1480                assert_eq!(sid.as_ref(), &SubscriptionId::new("feed"));
1481            }
1482            other => panic!("expected aggregated EOSE, got {:?}", other),
1483        }
1484
1485        Ok(())
1486    }
1487
1488    #[tokio::test]
1489    async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1490        let tmp = TempDir::new()?;
1491        let graph_store = {
1492            let _guard = crate::socialgraph::test_lock();
1493            crate::socialgraph::open_social_graph_store_with_mapsize(
1494                tmp.path(),
1495                Some(128 * 1024 * 1024),
1496            )?
1497        };
1498        let author_keys = Keys::generate();
1499        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1500        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1501            Arc::clone(&backend),
1502            0,
1503            HashSet::from([author_keys.public_key().to_hex()]),
1504        ));
1505        let relay = Arc::new(NostrRelay::new(
1506            Arc::clone(&backend),
1507            tmp.path().to_path_buf(),
1508            HashSet::from([author_keys.public_key().to_hex()]),
1509            Some(access),
1510            NostrRelayConfig {
1511                spambox_db_max_bytes: 0,
1512                ..Default::default()
1513            },
1514        )?);
1515
1516        let state = test_app_state(&tmp, relay.clone(), String::new())?;
1517        let listener = TcpListener::bind("127.0.0.1:0").await?;
1518        let addr = listener.local_addr()?;
1519        let client_pubkey = author_keys.public_key().to_hex();
1520        let app = axum::Router::new().route(
1521            "/ws",
1522            axum::routing::get({
1523                let state = state.clone();
1524                let client_pubkey = client_pubkey.clone();
1525                move |ws: WebSocketUpgrade| {
1526                    let state = state.clone();
1527                    let client_pubkey = client_pubkey.clone();
1528                    async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1529                }
1530            }),
1531        );
1532        tokio::spawn(async move {
1533            let _ = axum::serve(listener, app).await;
1534        });
1535
1536        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1537        let event = EventBuilder::new(Kind::TextNote, "websocket publish ack")
1538            .sign_with_keys(&author_keys)?;
1539        socket
1540            .send(TungsteniteMessage::Text(
1541                NostrClientMessage::event(event.clone()).as_json().into(),
1542            ))
1543            .await?;
1544
1545        let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1546            .await?
1547            .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1548        let TungsteniteMessage::Text(text) = reply else {
1549            anyhow::bail!("expected text publish ack");
1550        };
1551
1552        match NostrRelayMessage::from_json(text.as_str())? {
1553            NostrRelayMessage::Ok {
1554                event_id, status, ..
1555            } => {
1556                assert_eq!(event_id, event.id);
1557                assert!(status);
1558            }
1559            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1560        }
1561
1562        let stored = relay
1563            .query_events(
1564                &Filter::new()
1565                    .authors(vec![event.pubkey])
1566                    .kinds(vec![event.kind]),
1567                10,
1568            )
1569            .await;
1570        assert!(stored.iter().any(|candidate| candidate.id == event.id));
1571        Ok(())
1572    }
1573
1574    #[tokio::test]
1575    async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1576        let tmp = TempDir::new()?;
1577        let graph_store = {
1578            let _guard = crate::socialgraph::test_lock();
1579            crate::socialgraph::open_social_graph_store_with_mapsize(
1580                tmp.path(),
1581                Some(128 * 1024 * 1024),
1582            )?
1583        };
1584        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1585        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1586            Arc::clone(&backend),
1587            0,
1588            HashSet::new(),
1589        ));
1590        let relay = Arc::new(NostrRelay::new(
1591            Arc::clone(&backend),
1592            tmp.path().to_path_buf(),
1593            HashSet::new(),
1594            Some(access),
1595            NostrRelayConfig {
1596                spambox_db_max_bytes: 0,
1597                spambox_max_reqs_per_min: 1,
1598                ..Default::default()
1599            },
1600        )?);
1601
1602        let state = test_app_state(&tmp, relay, String::new())?;
1603        let listener = TcpListener::bind("127.0.0.1:0").await?;
1604        let addr = listener.local_addr()?;
1605        let app = axum::Router::new().route(
1606            "/ws",
1607            axum::routing::get({
1608                let state = state.clone();
1609                move |ws: WebSocketUpgrade| {
1610                    let state = state.clone();
1611                    async move { ws_data_with_client_pubkey(state, ws, None) }
1612                }
1613            }),
1614        );
1615        tokio::spawn(async move {
1616            let _ = axum::serve(listener, app).await;
1617        });
1618
1619        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1620        socket
1621            .send(TungsteniteMessage::Text(
1622                NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1623                    .as_json()
1624                    .into(),
1625            ))
1626            .await?;
1627
1628        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1629            .await?
1630            .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1631        let TungsteniteMessage::Text(first_text) = first else {
1632            anyhow::bail!("expected text EOSE reply");
1633        };
1634        match NostrRelayMessage::from_json(first_text.as_str())? {
1635            NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1636                assert_eq!(subscription_id.as_ref(), &SubscriptionId::new("sub-1"));
1637            }
1638            other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1639        }
1640
1641        socket
1642            .send(TungsteniteMessage::Text(
1643                NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1644                    .as_json()
1645                    .into(),
1646            ))
1647            .await?;
1648
1649        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1650            .await?
1651            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1652        let TungsteniteMessage::Text(second_text) = second else {
1653            anyhow::bail!("expected text CLOSED reply");
1654        };
1655        let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1656        assert_eq!(
1657            second_value,
1658            serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1659        );
1660
1661        Ok(())
1662    }
1663
1664    #[tokio::test]
1665    async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1666        let tmp = TempDir::new()?;
1667        let graph_store = {
1668            let _guard = crate::socialgraph::test_lock();
1669            crate::socialgraph::open_social_graph_store_with_mapsize(
1670                tmp.path(),
1671                Some(128 * 1024 * 1024),
1672            )?
1673        };
1674        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1675        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1676            Arc::clone(&backend),
1677            0,
1678            HashSet::new(),
1679        ));
1680        let relay = Arc::new(NostrRelay::new(
1681            Arc::clone(&backend),
1682            tmp.path().to_path_buf(),
1683            HashSet::new(),
1684            Some(access),
1685            NostrRelayConfig {
1686                spambox_db_max_bytes: 0,
1687                spambox_max_events_per_min: 1,
1688                ..Default::default()
1689            },
1690        )?);
1691
1692        let state = test_app_state(&tmp, relay, String::new())?;
1693        let listener = TcpListener::bind("127.0.0.1:0").await?;
1694        let addr = listener.local_addr()?;
1695        let app = axum::Router::new().route(
1696            "/ws",
1697            axum::routing::get({
1698                let state = state.clone();
1699                move |ws: WebSocketUpgrade| {
1700                    let state = state.clone();
1701                    async move { ws_data_with_client_pubkey(state, ws, None) }
1702                }
1703            }),
1704        );
1705        tokio::spawn(async move {
1706            let _ = axum::serve(listener, app).await;
1707        });
1708
1709        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1710        let author_keys = Keys::generate();
1711        let event_a =
1712            EventBuilder::new(Kind::TextNote, "spambox-a").sign_with_keys(&author_keys)?;
1713        let event_b =
1714            EventBuilder::new(Kind::TextNote, "spambox-b").sign_with_keys(&author_keys)?;
1715
1716        socket
1717            .send(TungsteniteMessage::Text(
1718                NostrClientMessage::event(event_a.clone()).as_json().into(),
1719            ))
1720            .await?;
1721
1722        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1723            .await?
1724            .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1725        let TungsteniteMessage::Text(first_text) = first else {
1726            anyhow::bail!("expected text publish ack");
1727        };
1728        match NostrRelayMessage::from_json(first_text.as_str())? {
1729            NostrRelayMessage::Ok {
1730                event_id,
1731                status,
1732                message,
1733            } => {
1734                assert_eq!(event_id, event_a.id);
1735                assert!(status);
1736                assert_eq!(message, "spambox");
1737            }
1738            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1739        }
1740
1741        socket
1742            .send(TungsteniteMessage::Text(
1743                NostrClientMessage::event(event_b.clone()).as_json().into(),
1744            ))
1745            .await?;
1746
1747        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1748            .await?
1749            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1750        let TungsteniteMessage::Text(second_text) = second else {
1751            anyhow::bail!("expected text publish ack");
1752        };
1753        match NostrRelayMessage::from_json(second_text.as_str())? {
1754            NostrRelayMessage::Ok {
1755                event_id,
1756                status,
1757                message,
1758            } => {
1759                assert_eq!(event_id, event_b.id);
1760                assert!(!status);
1761                assert_eq!(message, "rate limited");
1762            }
1763            other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1764        }
1765
1766        Ok(())
1767    }
1768}