Skip to main content

hashtree_cli/server/
ws_relay.rs

1use axum::{
2    extract::{
3        ws::{Message, WebSocket, WebSocketUpgrade},
4        State,
5    },
6    response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11    ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12    RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21    nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24    encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31    #[serde(rename = "req")]
32    Request { id: u32, hash: String },
33    #[serde(rename = "res")]
34    Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage {
39    Hashtree(WsClientMessage),
40    Nostr(NostrClientMessage),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45    #[serde(rename = "type")]
46    kind: String,
47    id: u32,
48    hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53    #[serde(rename = "type")]
54    kind: &'static str,
55    id: u32,
56    hash: String,
57    found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61    ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65    state: AppState,
66    ws: WebSocketUpgrade,
67    client_pubkey: Option<String>,
68) -> impl IntoResponse {
69    ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73    // Use the Nostr relay's client-id generator when available so `/ws` IDs
74    // can't collide with WebRTC relay clients.
75    let client_id = state
76        .nostr_relay
77        .as_ref()
78        .map(|relay| relay.next_client_id())
79        .unwrap_or_else(|| state.ws_relay.next_id());
80    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82    {
83        let mut clients = state.ws_relay.clients.lock().await;
84        clients.insert(client_id, tx);
85    }
86    {
87        let mut protocols = state.ws_relay.client_protocols.lock().await;
88        protocols.insert(client_id, WsProtocol::HashtreeJson);
89    }
90
91    let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92        let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93        relay
94            .register_client(client_id, nostr_tx, client_pubkey.clone())
95            .await;
96        Some(nostr_rx)
97    } else {
98        None
99    };
100
101    let (mut sender, mut receiver) = socket.split();
102    loop {
103        tokio::select! {
104            maybe_msg = rx.recv() => {
105                let Some(msg) = maybe_msg else {
106                    break;
107                };
108                if sender.send(msg).await.is_err() {
109                    break;
110                }
111            }
112            maybe_text = async {
113                match &mut nostr_rx {
114                    Some(rx) => rx.recv().await,
115                    None => std::future::pending().await,
116                }
117            } => {
118                let Some(text) = maybe_text else {
119                    nostr_rx = None;
120                    continue;
121                };
122                if sender.send(Message::Text(text)).await.is_err() {
123                    break;
124                }
125            }
126            maybe_incoming = receiver.next() => {
127                match maybe_incoming {
128                    Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129                    Some(Err(_)) | None => break,
130                }
131            }
132        }
133    }
134
135    close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137    {
138        let mut clients = state.ws_relay.clients.lock().await;
139        clients.remove(&client_id);
140    }
141    {
142        let mut protocols = state.ws_relay.client_protocols.lock().await;
143        protocols.remove(&client_id);
144    }
145    {
146        let mut pending = state.ws_relay.pending.lock().await;
147        pending.retain(|(peer_id, _), _| *peer_id != client_id);
148    }
149
150    if let Some(relay) = &state.nostr_relay {
151        relay.unregister_client(client_id).await;
152    }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
156    let trimmed = text.trim_start();
157    if trimmed.starts_with('[') {
158        if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159            return Some(WsTextMessage::Nostr(msg));
160        }
161    }
162
163    if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164        return Some(WsTextMessage::Hashtree(msg));
165    }
166
167    None
168}
169async fn close_upstream_nostr_subscription(
170    state: &AppState,
171    client_id: u64,
172    subscription_id: &SubscriptionId,
173) {
174    let key = (client_id, subscription_id.to_string());
175    let subscription = {
176        let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177        subscriptions.remove(&key)
178    };
179    if let Some(subscription) = subscription {
180        let _ = subscription.close_tx.send(true);
181        for task in subscription.tasks {
182            task.abort();
183        }
184    }
185    state
186        .ws_relay
187        .upstream_pending_eose
188        .lock()
189        .await
190        .remove(&key);
191    state
192        .ws_relay
193        .upstream_seen_events
194        .lock()
195        .await
196        .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200    let keys = {
201        let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202        subscriptions
203            .keys()
204            .filter(|(id, _)| *id == client_id)
205            .cloned()
206            .collect::<Vec<_>>()
207    };
208    for (_, sub_id) in keys {
209        close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210    }
211}
212
213async fn forward_upstream_nostr_message(
214    state: &AppState,
215    client_id: u64,
216    subscription_id: &SubscriptionId,
217    text: &str,
218) {
219    let Ok(message) = NostrRelayMessage::from_json(text) else {
220        return;
221    };
222
223    match message {
224        NostrRelayMessage::Event {
225            subscription_id: sid,
226            event,
227        } if sid == *subscription_id => {
228            let event = *event;
229            let key = (client_id, subscription_id.to_string());
230            let event_id = event.id.to_hex();
231            let inserted = {
232                let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233                seen_events.entry(key).or_default().insert(event_id)
234            };
235            if !inserted {
236                return;
237            }
238            if let Some(relay) = &state.nostr_relay {
239                let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240            }
241            send_nostr(
242                state,
243                client_id,
244                NostrRelayMessage::event(subscription_id.clone(), event),
245            )
246            .await;
247        }
248        NostrRelayMessage::Closed {
249            subscription_id: sid,
250            message,
251        } if sid == *subscription_id => {
252            send_nostr(
253                state,
254                client_id,
255                NostrRelayMessage::closed(subscription_id.clone(), message),
256            )
257            .await;
258        }
259        _ => {}
260    }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264    state: &AppState,
265    client_id: u64,
266    subscription_id: &SubscriptionId,
267) {
268    let key = (client_id, subscription_id.to_string());
269    let should_send_eose = {
270        let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271        let Some(remaining) = pending.get_mut(&key) else {
272            return;
273        };
274        if *remaining > 0 {
275            *remaining -= 1;
276        }
277        if *remaining == 0 {
278            pending.remove(&key);
279            true
280        } else {
281            false
282        }
283    };
284
285    if should_send_eose {
286        send_nostr(
287            state,
288            client_id,
289            NostrRelayMessage::eose(subscription_id.clone()),
290        )
291        .await;
292    }
293}
294
295async fn run_upstream_nostr_subscription(
296    state: AppState,
297    client_id: u64,
298    relay_url: String,
299    subscription_id: SubscriptionId,
300    filters: Vec<NostrFilter>,
301    mut close_rx: watch::Receiver<bool>,
302) {
303    let mut relay_complete = false;
304    let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305        tracing::warn!(
306            "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307            client_id,
308            subscription_id,
309            relay_url,
310        );
311        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312        return;
313    };
314    let (mut write, mut read) = socket.split();
315    let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316    state
317        .ws_relay
318        .note_upstream_relay_send(request.as_bytes().len());
319    if write
320        .send(TungsteniteMessage::Text(request.into()))
321        .await
322        .is_err()
323    {
324        tracing::warn!(
325            "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326            client_id,
327            subscription_id,
328            relay_url,
329        );
330        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331        return;
332    }
333
334    loop {
335        tokio::select! {
336            _ = close_rx.changed() => {
337                if *close_rx.borrow() {
338                    let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339                    state
340                        .ws_relay
341                        .note_upstream_relay_send(close.as_bytes().len());
342                    let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343                    let _ = write.close().await;
344                    break;
345                }
346            }
347            message = read.next() => {
348                match message {
349                    Some(Ok(TungsteniteMessage::Text(text))) => {
350                        state
351                            .ws_relay
352                            .note_upstream_relay_receive(text.as_bytes().len());
353                        if matches!(
354                            NostrRelayMessage::from_json(text.as_str()),
355                            Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
356                        ) {
357                            if !relay_complete {
358                                relay_complete = true;
359                                mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360                            }
361                            continue;
362                        }
363                        forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364                    }
365                    Some(Ok(TungsteniteMessage::Ping(payload))) => {
366                        let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367                    }
368                    Some(Ok(TungsteniteMessage::Close(_))) | None => {
369                        if !relay_complete {
370                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371                        }
372                        break;
373                    }
374                    Some(Err(_)) => {
375                        if !relay_complete {
376                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377                        }
378                        break;
379                    }
380                    _ => {}
381                }
382            }
383        }
384    }
385}
386
387async fn start_upstream_nostr_subscription(
388    state: &AppState,
389    client_id: u64,
390    subscription_id: SubscriptionId,
391    filters: Vec<NostrFilter>,
392) -> usize {
393    let memory_before = process_memory_snapshot();
394    close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395    if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396        tracing::info!(
397            "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398            client_id,
399            subscription_id,
400            state.nostr_relay_urls.len(),
401            filters.len(),
402        );
403        return 0;
404    }
405
406    let mut relay_urls = Vec::new();
407    let mut seen = HashSet::new();
408    for relay in &state.nostr_relay_urls {
409        let relay = relay.trim();
410        if relay.is_empty() || !seen.insert(relay.to_string()) {
411            continue;
412        }
413        relay_urls.push(relay.to_string());
414    }
415    if relay_urls.is_empty() {
416        tracing::info!(
417            "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418            client_id,
419            subscription_id,
420        );
421        return 0;
422    }
423
424    tracing::info!(
425        "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426        client_id,
427        subscription_id,
428        relay_urls.len(),
429    );
430
431    let filter_summary = nostr_filters_summary(&filters);
432    let key = (client_id, subscription_id.to_string());
433    state
434        .ws_relay
435        .upstream_seen_events
436        .lock()
437        .await
438        .insert(key.clone(), HashSet::new());
439    state
440        .ws_relay
441        .upstream_pending_eose
442        .lock()
443        .await
444        .insert(key.clone(), relay_urls.len());
445
446    let (close_tx, close_rx) = watch::channel(false);
447    let mut tasks = Vec::new();
448    for relay_url in &relay_urls {
449        tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450            state.clone(),
451            client_id,
452            relay_url.clone(),
453            subscription_id.clone(),
454            filters.clone(),
455            close_rx.clone(),
456        )));
457    }
458
459    tracing::info!(
460        target: "hashtree_cli::server::ws_relay::upstream",
461        client_id,
462        subscription_id = %subscription_id,
463        relays = relay_urls.len(),
464        tasks = tasks.len(),
465        filters = filters.len(),
466        filter = %filter_summary,
467        memory_before = ?memory_before,
468        memory_after = ?process_memory_snapshot(),
469        "upstream nostr relay tasks spawned",
470    );
471
472    state
473        .ws_relay
474        .upstream_nostr_subscriptions
475        .lock()
476        .await
477        .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478    relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482    match msg {
483        Message::Text(text) => {
484            if let Some(msg) = parse_ws_text_message(&text) {
485                match msg {
486                    WsTextMessage::Hashtree(msg) => {
487                        set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488                        match msg {
489                            WsClientMessage::Request { id, hash } => {
490                                handle_request(
491                                    client_id,
492                                    id,
493                                    hash,
494                                    WsProtocol::HashtreeJson,
495                                    state,
496                                )
497                                .await;
498                            }
499                            WsClientMessage::Response { id, hash, found } => {
500                                handle_response(client_id, id, hash, found, state).await;
501                            }
502                        }
503                    }
504                    WsTextMessage::Nostr(msg) => {
505                        if let Some(relay) = &state.nostr_relay {
506                            match msg {
507                                NostrClientMessage::Req {
508                                    subscription_id,
509                                    filters,
510                                } => {
511                                    let local_events = match relay
512                                        .register_subscription_query(
513                                            client_id,
514                                            subscription_id.clone(),
515                                            filters.clone(),
516                                        )
517                                        .await
518                                    {
519                                        Ok(events) => events,
520                                        Err(message) => {
521                                            send_nostr(
522                                                state,
523                                                client_id,
524                                                NostrRelayMessage::closed(subscription_id, message),
525                                            )
526                                            .await;
527                                            return;
528                                        }
529                                    };
530
531                                    let upstream_relays = start_upstream_nostr_subscription(
532                                        state,
533                                        client_id,
534                                        subscription_id.clone(),
535                                        filters,
536                                    )
537                                    .await;
538                                    if upstream_relays > 0 {
539                                        let key = (client_id, subscription_id.to_string());
540                                        let mut seen_events =
541                                            state.ws_relay.upstream_seen_events.lock().await;
542                                        seen_events.entry(key).or_default().extend(
543                                            local_events.iter().map(|event| event.id.to_hex()),
544                                        );
545                                    }
546                                    for event in local_events {
547                                        send_nostr(
548                                            state,
549                                            client_id,
550                                            NostrRelayMessage::event(
551                                                subscription_id.clone(),
552                                                event,
553                                            ),
554                                        )
555                                        .await;
556                                    }
557                                    trim_process_allocations();
558                                    if upstream_relays == 0 {
559                                        send_nostr(
560                                            state,
561                                            client_id,
562                                            NostrRelayMessage::eose(subscription_id),
563                                        )
564                                        .await;
565                                    }
566                                }
567                                NostrClientMessage::Close(subscription_id) => {
568                                    close_upstream_nostr_subscription(
569                                        state,
570                                        client_id,
571                                        &subscription_id,
572                                    )
573                                    .await;
574                                    relay
575                                        .handle_client_message(
576                                            client_id,
577                                            NostrClientMessage::Close(subscription_id.clone()),
578                                        )
579                                        .await;
580                                }
581                                other => {
582                                    relay.handle_client_message(client_id, other).await;
583                                }
584                            }
585                        } else {
586                            handle_nostr_message(client_id, msg, state).await;
587                        }
588                    }
589                }
590            }
591        }
592        Message::Binary(data) => {
593            handle_binary(client_id, data, state).await;
594        }
595        Message::Close(_) => {}
596        _ => {}
597    }
598}
599
600async fn handle_request(
601    client_id: u64,
602    request_id: u32,
603    hash: String,
604    origin_protocol: WsProtocol,
605    state: &AppState,
606) {
607    let hash_hex = hash.to_lowercase();
608    let hash_bytes = match from_hex(&hash_hex) {
609        Ok(bytes) => bytes,
610        Err(_) => {
611            if origin_protocol == WsProtocol::HashtreeJson {
612                send_json(
613                    state,
614                    client_id,
615                    WsResponse {
616                        kind: "res",
617                        id: request_id,
618                        hash,
619                        found: false,
620                    },
621                )
622                .await;
623            }
624            return;
625        }
626    };
627
628    if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
629        match origin_protocol {
630            WsProtocol::HashtreeJson => {
631                send_json(
632                    state,
633                    client_id,
634                    WsResponse {
635                        kind: "res",
636                        id: request_id,
637                        hash: hash.clone(),
638                        found: true,
639                    },
640                )
641                .await;
642                send_binary(state, client_id, request_id, data).await;
643            }
644            WsProtocol::HashtreeMsgpack => {
645                send_msgpack_response(state, client_id, &hash_bytes, &data).await;
646            }
647            WsProtocol::Unknown => {}
648        }
649        return;
650    }
651
652    let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
653        let clients = state.ws_relay.clients.lock().await;
654        let protocols = state.ws_relay.client_protocols.lock().await;
655        clients
656            .iter()
657            .filter(|(id, _)| **id != client_id)
658            .filter_map(|(id, tx)| {
659                let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
660                match protocol {
661                    WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
662                        Some((*id, tx.clone(), protocol))
663                    }
664                    WsProtocol::Unknown => None,
665                }
666            })
667            .collect()
668    };
669
670    if peers.is_empty() {
671        if origin_protocol == WsProtocol::HashtreeJson {
672            send_json(
673                state,
674                client_id,
675                WsResponse {
676                    kind: "res",
677                    id: request_id,
678                    hash,
679                    found: false,
680                },
681            )
682            .await;
683        }
684        return;
685    }
686
687    {
688        let mut pending = state.ws_relay.pending.lock().await;
689        for (peer_id, _, _) in &peers {
690            pending.insert(
691                (*peer_id, request_id),
692                PendingRequest {
693                    origin_id: client_id,
694                    hash: hash.clone(),
695                    found: false,
696                    origin_protocol,
697                },
698            );
699        }
700    }
701
702    let request_text = serde_json::to_string(&WsRequest {
703        kind: "req".to_string(),
704        id: request_id,
705        hash: hash.clone(),
706    })
707    .unwrap_or_else(|_| String::new());
708    for (peer_id, tx, protocol) in peers {
709        match protocol {
710            WsProtocol::HashtreeMsgpack => {
711                let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
712            }
713            WsProtocol::HashtreeJson => {
714                let _ = tx.send(Message::Text(request_text.clone()));
715            }
716            WsProtocol::Unknown => {}
717        }
718    }
719
720    let timeout_state = state.clone();
721    let timeout_hash = hash.clone();
722    tokio::spawn(async move {
723        tokio::time::sleep(Duration::from_millis(1500)).await;
724        let mut pending = timeout_state.ws_relay.pending.lock().await;
725        let still_pending = pending
726            .iter()
727            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
728        let already_found = pending
729            .iter()
730            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
731        if !still_pending || already_found {
732            return;
733        }
734        let origin_protocol = pending
735            .iter()
736            .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
737            .map(|(_, p)| p.origin_protocol)
738            .unwrap_or(WsProtocol::HashtreeJson);
739        pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
740        drop(pending);
741        if origin_protocol == WsProtocol::HashtreeJson {
742            send_json(
743                &timeout_state,
744                client_id,
745                WsResponse {
746                    kind: "res",
747                    id: request_id,
748                    hash: timeout_hash,
749                    found: false,
750                },
751            )
752            .await;
753        }
754    });
755}
756
757async fn handle_response(
758    client_id: u64,
759    request_id: u32,
760    _hash: String,
761    found: bool,
762    state: &AppState,
763) {
764    let pending_entry = {
765        let pending = state.ws_relay.pending.lock().await;
766        pending
767            .get(&(client_id, request_id))
768            .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
769    };
770
771    let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
772        return;
773    };
774
775    if already_found && !found {
776        let mut pending = state.ws_relay.pending.lock().await;
777        pending.remove(&(client_id, request_id));
778        return;
779    }
780
781    if found {
782        let mut pending = state.ws_relay.pending.lock().await;
783        for ((_, id), p) in pending.iter_mut() {
784            if *id == request_id && p.origin_id == origin_id {
785                p.found = true;
786            }
787        }
788        drop(pending);
789        if origin_protocol == WsProtocol::HashtreeJson {
790            send_json(
791                state,
792                origin_id,
793                WsResponse {
794                    kind: "res",
795                    id: request_id,
796                    hash: pending_hash,
797                    found: true,
798                },
799            )
800            .await;
801        }
802        return;
803    }
804
805    let mut pending = state.ws_relay.pending.lock().await;
806    pending.remove(&(client_id, request_id));
807    let has_remaining = pending
808        .iter()
809        .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
810    drop(pending);
811
812    if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
813        send_json(
814            state,
815            origin_id,
816            WsResponse {
817                kind: "res",
818                id: request_id,
819                hash: pending_hash,
820                found: false,
821            },
822        )
823        .await;
824    }
825}
826
827async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
828    if let Some(msg) = parse_msgpack_message(&data) {
829        set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
830        match msg {
831            DataMessage::Request(req) => {
832                let hash_hex = hex_encode(&req.h);
833                let request_id = state.ws_relay.next_request_id();
834                handle_request(
835                    client_id,
836                    request_id,
837                    hash_hex,
838                    WsProtocol::HashtreeMsgpack,
839                    state,
840                )
841                .await;
842            }
843            DataMessage::Response(res) => {
844                handle_msgpack_response(client_id, res, state).await;
845            }
846            DataMessage::QuoteRequest(_)
847            | DataMessage::QuoteResponse(_)
848            | DataMessage::Payment(_)
849            | DataMessage::PaymentAck(_)
850            | DataMessage::Chunk(_)
851            | DataMessage::PeerHints(_)
852            | DataMessage::PubsubInterest(_)
853            | DataMessage::PubsubFrame(_)
854            | DataMessage::PubsubInventory(_)
855            | DataMessage::PubsubWant(_) => {}
856        }
857        return;
858    }
859
860    // Legacy binary: [4-byte LE request_id][data]
861    if data.len() < 4 {
862        return;
863    }
864    let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
865    let pending_entry = {
866        let pending = state.ws_relay.pending.lock().await;
867        pending
868            .get(&(client_id, request_id))
869            .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
870    };
871    let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
872        return;
873    };
874
875    match origin_protocol {
876        WsProtocol::HashtreeJson => {
877            send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
878        }
879        WsProtocol::HashtreeMsgpack => {
880            let Ok(hash_bytes) = from_hex(&hash_hex) else {
881                return;
882            };
883            send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
884        }
885        WsProtocol::Unknown => {}
886    }
887
888    let mut pending = state.ws_relay.pending.lock().await;
889    pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
890}
891
892async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
893    let replies = nostr_responses_for(&msg);
894    for reply in replies {
895        send_nostr(state, client_id, reply).await;
896    }
897}
898
899fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
900    match msg {
901        NostrClientMessage::Event(event) => {
902            let ok = event.verify().is_ok();
903            let message = if ok { "" } else { "invalid: signature" };
904            vec![NostrRelayMessage::ok(event.id, ok, message)]
905        }
906        NostrClientMessage::Req {
907            subscription_id, ..
908        } => {
909            vec![NostrRelayMessage::eose(subscription_id.clone())]
910        }
911        NostrClientMessage::Count {
912            subscription_id, ..
913        } => {
914            vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
915        }
916        NostrClientMessage::Close(_) => Vec::new(),
917        NostrClientMessage::Auth(event) => {
918            let ok = event.verify().is_ok();
919            let message = if ok { "" } else { "invalid auth" };
920            vec![NostrRelayMessage::ok(event.id, ok, message)]
921        }
922        NostrClientMessage::NegOpen { .. }
923        | NostrClientMessage::NegMsg { .. }
924        | NostrClientMessage::NegClose { .. } => {
925            vec![NostrRelayMessage::notice("negentropy not supported")]
926        }
927    }
928}
929
930async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
931    let text = response.as_json();
932    send_to_client(state, client_id, Message::Text(text)).await;
933}
934
935fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
936    let msg = parse_message(data).ok()?;
937    match msg {
938        DataMessage::Request(req) => {
939            if req.h.len() == 32 {
940                Some(DataMessage::Request(req))
941            } else {
942                None
943            }
944        }
945        DataMessage::Response(res) => {
946            if res.h.len() == 32 {
947                Some(DataMessage::Response(res))
948            } else {
949                None
950            }
951        }
952        DataMessage::QuoteRequest(req) => {
953            if req.h.len() == 32 {
954                Some(DataMessage::QuoteRequest(req))
955            } else {
956                None
957            }
958        }
959        DataMessage::QuoteResponse(res) => {
960            if res.h.len() == 32 {
961                Some(DataMessage::QuoteResponse(res))
962            } else {
963                None
964            }
965        }
966        DataMessage::Payment(req) => {
967            if req.h.len() == 32 {
968                Some(DataMessage::Payment(req))
969            } else {
970                None
971            }
972        }
973        DataMessage::PaymentAck(res) => {
974            if res.h.len() == 32 {
975                Some(DataMessage::PaymentAck(res))
976            } else {
977                None
978            }
979        }
980        DataMessage::Chunk(chunk) => {
981            if chunk.h.len() == 32 {
982                Some(DataMessage::Chunk(chunk))
983            } else {
984                None
985            }
986        }
987        DataMessage::PeerHints(_)
988        | DataMessage::PubsubInterest(_)
989        | DataMessage::PubsubFrame(_)
990        | DataMessage::PubsubInventory(_)
991        | DataMessage::PubsubWant(_) => Some(msg),
992    }
993}
994
995async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
996    let hash_hex = hex_encode(&res.h);
997    let data = res.d.clone();
998    let hash_bytes = res.h.clone();
999
1000    let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
1001    let mut seen = HashSet::new();
1002    {
1003        let pending = state.ws_relay.pending.lock().await;
1004        for ((peer_id, request_id), p) in pending.iter() {
1005            if *peer_id != client_id {
1006                continue;
1007            }
1008            if p.hash != hash_hex {
1009                continue;
1010            }
1011            if seen.insert((p.origin_id, *request_id)) {
1012                responses.push((p.origin_id, *request_id, p.origin_protocol));
1013            }
1014        }
1015    }
1016
1017    if responses.is_empty() {
1018        return;
1019    }
1020
1021    for (origin_id, request_id, protocol) in &responses {
1022        match protocol {
1023            WsProtocol::HashtreeJson => {
1024                send_json(
1025                    state,
1026                    *origin_id,
1027                    WsResponse {
1028                        kind: "res",
1029                        id: *request_id,
1030                        hash: hash_hex.clone(),
1031                        found: true,
1032                    },
1033                )
1034                .await;
1035                send_binary(state, *origin_id, *request_id, data.clone()).await;
1036            }
1037            WsProtocol::HashtreeMsgpack => {
1038                send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1039            }
1040            WsProtocol::Unknown => {}
1041        }
1042    }
1043
1044    let completed: HashSet<(u64, u32)> = responses
1045        .into_iter()
1046        .map(|(origin_id, request_id, _)| (origin_id, request_id))
1047        .collect();
1048    let mut pending = state.ws_relay.pending.lock().await;
1049    pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1050}
1051
1052async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1053    if let Ok(text) = serde_json::to_string(&response) {
1054        send_to_client(state, client_id, Message::Text(text)).await;
1055    }
1056}
1057
1058async fn send_msgpack_request(
1059    state: &AppState,
1060    client_id: u64,
1061    hash: &[u8],
1062) -> Result<(), rmp_serde::encode::Error> {
1063    let req = DataRequest {
1064        h: hash.to_vec(),
1065        htl: MAX_HTL,
1066        q: None,
1067    };
1068    let wire = encode_request(&req)?;
1069    send_to_client(state, client_id, Message::Binary(wire)).await;
1070    Ok(())
1071}
1072
1073async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1074    let res = DataResponse {
1075        h: hash.to_vec(),
1076        d: data.to_vec(),
1077        i: None,
1078        n: None,
1079    };
1080    if let Ok(wire) = encode_response(&res) {
1081        send_to_client(state, client_id, Message::Binary(wire)).await;
1082    }
1083}
1084
1085async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1086    let mut packet = Vec::with_capacity(4 + payload.len());
1087    packet.extend_from_slice(&request_id.to_le_bytes());
1088    packet.extend_from_slice(&payload);
1089    send_to_client(state, client_id, Message::Binary(packet)).await;
1090}
1091
1092async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1093    let sender = {
1094        let clients = state.ws_relay.clients.lock().await;
1095        clients.get(&client_id).cloned()
1096    };
1097    if let Some(tx) = sender {
1098        let _ = tx.send(msg);
1099    }
1100}
1101
1102async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1103    let mut protocols = state.ws_relay.client_protocols.lock().await;
1104    protocols.insert(client_id, protocol);
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109    use super::*;
1110    use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1111    use anyhow::Result;
1112    use futures::{SinkExt, StreamExt};
1113    use nostr::secp256k1::schnorr::Signature;
1114    use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1115    use std::collections::HashSet;
1116    use std::sync::Arc;
1117    use tempfile::TempDir;
1118    use tokio::net::TcpListener;
1119    use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1120
1121    #[test]
1122    fn parse_ws_text_message_detects_nostr_req() {
1123        let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1124        match parse_ws_text_message(msg) {
1125            Some(WsTextMessage::Nostr(_)) => {}
1126            other => panic!("expected Nostr message, got {:?}", other),
1127        }
1128    }
1129
1130    #[test]
1131    fn parse_ws_text_message_detects_hashtree_request() {
1132        let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1133        match parse_ws_text_message(msg) {
1134            Some(WsTextMessage::Hashtree(_)) => {}
1135            other => panic!("expected Hashtree message, got {:?}", other),
1136        }
1137    }
1138
1139    #[test]
1140    fn nostr_replies_for_req_is_eose() {
1141        let sub = SubscriptionId::new("sub-1");
1142        let msg = NostrClientMessage::req(sub.clone(), vec![]);
1143        let replies = nostr_responses_for(&msg);
1144        assert_eq!(replies.len(), 1);
1145        match &replies[0] {
1146            NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1147            other => panic!("expected EOSE, got {:?}", other),
1148        }
1149    }
1150
1151    #[test]
1152    fn nostr_replies_for_event_ok() {
1153        let keys = Keys::generate();
1154        let event = EventBuilder::new(Kind::TextNote, "hello", [])
1155            .to_event(&keys)
1156            .unwrap();
1157        let msg = NostrClientMessage::event(event.clone());
1158        let replies = nostr_responses_for(&msg);
1159        assert_eq!(replies.len(), 1);
1160        match &replies[0] {
1161            NostrRelayMessage::Ok {
1162                event_id, status, ..
1163            } => {
1164                assert_eq!(event_id, &event.id);
1165                assert!(*status);
1166            }
1167            other => panic!("expected OK, got {:?}", other),
1168        }
1169    }
1170
1171    #[test]
1172    fn nostr_replies_for_invalid_event_is_not_ok() {
1173        let keys = Keys::generate();
1174        let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1175            .to_event(&keys)
1176            .unwrap();
1177        event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1178        let msg = NostrClientMessage::event(event);
1179        let replies = nostr_responses_for(&msg);
1180        assert_eq!(replies.len(), 1);
1181        match &replies[0] {
1182            NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1183            other => panic!("expected OK=false, got {:?}", other),
1184        }
1185    }
1186
1187    async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1188        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1189        let addr = listener.local_addr().expect("relay addr");
1190        tokio::spawn(async move {
1191            let (stream, _) = listener.accept().await.expect("accept relay");
1192            let ws = accept_async(stream).await.expect("accept websocket");
1193            let (mut write, mut read) = ws.split();
1194
1195            while let Some(Ok(message)) = read.next().await {
1196                let TungsteniteMessage::Text(text) = message else {
1197                    continue;
1198                };
1199                let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1200                    continue;
1201                };
1202                if let NostrClientMessage::Req {
1203                    subscription_id,
1204                    filters,
1205                } = parsed
1206                {
1207                    for event in events
1208                        .iter()
1209                        .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1210                    {
1211                        let _ = write
1212                            .send(TungsteniteMessage::Text(
1213                                NostrRelayMessage::event(subscription_id.clone(), event.clone())
1214                                    .as_json()
1215                                    .into(),
1216                            ))
1217                            .await;
1218                    }
1219                    let _ = write
1220                        .send(TungsteniteMessage::Text(
1221                            NostrRelayMessage::eose(subscription_id).as_json().into(),
1222                        ))
1223                        .await;
1224                }
1225            }
1226        });
1227        format!("ws://{}", addr)
1228    }
1229
1230    fn test_app_state(
1231        tmp: &TempDir,
1232        relay: Arc<NostrRelay>,
1233        relay_url: String,
1234    ) -> Result<AppState> {
1235        let store = Arc::new(crate::storage::HashtreeStore::with_options(
1236            tmp.path(),
1237            None,
1238            128 * 1024 * 1024,
1239        )?);
1240        Ok(AppState {
1241            store,
1242            auth: None,
1243            peer_mode: crate::config::ServerMode::Normal,
1244            hash_get_enabled: true,
1245            webrtc_peers: None,
1246            ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1247            max_upload_bytes: 5 * 1024 * 1024,
1248            public_writes: true,
1249            allowed_pubkeys: HashSet::new(),
1250            upstream_blossom: Vec::new(),
1251            social_graph: None,
1252            social_graph_store: None,
1253            social_graph_root: None,
1254            socialgraph_snapshot_public: false,
1255            nostr_relay: Some(relay),
1256            nostr_relay_urls: vec![relay_url],
1257            tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1258            inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1259                std::collections::HashMap::new(),
1260            )),
1261            directory_listing_cache: Arc::new(std::sync::Mutex::new(
1262                super::super::auth::new_lookup_cache(),
1263            )),
1264            resolved_path_cache: Arc::new(std::sync::Mutex::new(
1265                super::super::auth::new_lookup_cache(),
1266            )),
1267            thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1268                super::super::auth::new_lookup_cache(),
1269            )),
1270            cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1271        })
1272    }
1273
1274    #[tokio::test]
1275    async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1276        let tmp = TempDir::new()?;
1277        let graph_store = {
1278            let _guard = crate::socialgraph::test_lock();
1279            crate::socialgraph::open_social_graph_store_with_mapsize(
1280                tmp.path(),
1281                Some(128 * 1024 * 1024),
1282            )?
1283        };
1284        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1285        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1286            Arc::clone(&backend),
1287            0,
1288            HashSet::new(),
1289        ));
1290
1291        let keys = Keys::generate();
1292        let relay = Arc::new(NostrRelay::new(
1293            Arc::clone(&backend),
1294            tmp.path().to_path_buf(),
1295            HashSet::from([keys.public_key().to_hex()]),
1296            Some(access),
1297            NostrRelayConfig {
1298                spambox_db_max_bytes: 0,
1299                ..Default::default()
1300            },
1301        )?);
1302
1303        let event = EventBuilder::new(
1304            Kind::from(30078_u16),
1305            "",
1306            [
1307                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1308                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1309            ],
1310        )
1311        .to_event(&keys)?;
1312
1313        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1314        let filter = Filter::new()
1315            .authors(vec![event.pubkey])
1316            .kinds(vec![event.kind]);
1317        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1318        let client_id = 7_u64;
1319        let (tx, mut rx) = mpsc::unbounded_channel();
1320        state.ws_relay.clients.lock().await.insert(client_id, tx);
1321        let subscription_id = SubscriptionId::new("sub-1");
1322
1323        start_upstream_nostr_subscription(
1324            &state,
1325            client_id,
1326            subscription_id.clone(),
1327            vec![filter.clone()],
1328        )
1329        .await;
1330
1331        let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1332            .await?
1333            .expect("forwarded upstream event");
1334        let Message::Text(text) = forwarded else {
1335            panic!("expected text event");
1336        };
1337        match NostrRelayMessage::from_json(text.as_str())? {
1338            NostrRelayMessage::Event {
1339                subscription_id: sid,
1340                event: forwarded_event,
1341            } => {
1342                assert_eq!(sid, subscription_id);
1343                assert_eq!(forwarded_event.id, event.id);
1344            }
1345            other => panic!("expected forwarded EVENT, got {:?}", other),
1346        }
1347
1348        let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1349            .await?
1350            .expect("forwarded upstream eose");
1351        let Message::Text(eose_text) = eose else {
1352            panic!("expected text eose");
1353        };
1354        match NostrRelayMessage::from_json(eose_text.as_str())? {
1355            NostrRelayMessage::EndOfStoredEvents(sid) => {
1356                assert_eq!(sid, subscription_id);
1357            }
1358            other => panic!("expected forwarded EOSE, got {:?}", other),
1359        }
1360
1361        let events = relay.query_events(&filter, 10).await;
1362        assert_eq!(events.len(), 1);
1363        assert_eq!(events[0].id, event.id);
1364
1365        close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1366        assert!(state
1367            .ws_relay
1368            .upstream_nostr_subscriptions
1369            .lock()
1370            .await
1371            .is_empty());
1372        Ok(())
1373    }
1374
1375    #[tokio::test]
1376    async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1377        let tmp = TempDir::new()?;
1378        let graph_store = {
1379            let _guard = crate::socialgraph::test_lock();
1380            crate::socialgraph::open_social_graph_store_with_mapsize(
1381                tmp.path(),
1382                Some(128 * 1024 * 1024),
1383            )?
1384        };
1385        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1386        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1387            Arc::clone(&backend),
1388            0,
1389            HashSet::new(),
1390        ));
1391
1392        let keys = Keys::generate();
1393        let relay = Arc::new(NostrRelay::new(
1394            Arc::clone(&backend),
1395            tmp.path().to_path_buf(),
1396            HashSet::from([keys.public_key().to_hex()]),
1397            Some(access),
1398            NostrRelayConfig {
1399                spambox_db_max_bytes: 0,
1400                ..Default::default()
1401            },
1402        )?);
1403
1404        let event = EventBuilder::new(
1405            Kind::from(30078_u16),
1406            "",
1407            [
1408                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1409                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1410            ],
1411        )
1412        .to_event(&keys)?;
1413
1414        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1415        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1416        let client_id = 11_u64;
1417        let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1418        let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1419        state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1420        relay.register_client(client_id, relay_tx, None).await;
1421
1422        let request = NostrClientMessage::req(
1423            SubscriptionId::new("feed"),
1424            vec![Filter::new()
1425                .authors(vec![event.pubkey])
1426                .kinds(vec![event.kind])],
1427        )
1428        .as_json();
1429
1430        handle_message(client_id, Message::Text(request.into()), &state).await;
1431
1432        let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1433            .await?
1434            .expect("first forwarded message");
1435        let Message::Text(first_text) = first else {
1436            panic!("expected text event");
1437        };
1438        match NostrRelayMessage::from_json(first_text.as_str())? {
1439            NostrRelayMessage::Event {
1440                event: forwarded_event,
1441                ..
1442            } => {
1443                assert_eq!(forwarded_event.id, event.id);
1444            }
1445            other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1446        }
1447
1448        let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1449            .await?
1450            .expect("second forwarded message");
1451        let Message::Text(second_text) = second else {
1452            panic!("expected text eose");
1453        };
1454        match NostrRelayMessage::from_json(second_text.as_str())? {
1455            NostrRelayMessage::EndOfStoredEvents(sid) => {
1456                assert_eq!(sid, SubscriptionId::new("feed"));
1457            }
1458            other => panic!("expected aggregated EOSE, got {:?}", other),
1459        }
1460
1461        Ok(())
1462    }
1463
1464    #[tokio::test]
1465    async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1466        let tmp = TempDir::new()?;
1467        let graph_store = {
1468            let _guard = crate::socialgraph::test_lock();
1469            crate::socialgraph::open_social_graph_store_with_mapsize(
1470                tmp.path(),
1471                Some(128 * 1024 * 1024),
1472            )?
1473        };
1474        let author_keys = Keys::generate();
1475        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1476        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1477            Arc::clone(&backend),
1478            0,
1479            HashSet::from([author_keys.public_key().to_hex()]),
1480        ));
1481        let relay = Arc::new(NostrRelay::new(
1482            Arc::clone(&backend),
1483            tmp.path().to_path_buf(),
1484            HashSet::from([author_keys.public_key().to_hex()]),
1485            Some(access),
1486            NostrRelayConfig {
1487                spambox_db_max_bytes: 0,
1488                ..Default::default()
1489            },
1490        )?);
1491
1492        let state = test_app_state(&tmp, relay.clone(), String::new())?;
1493        let listener = TcpListener::bind("127.0.0.1:0").await?;
1494        let addr = listener.local_addr()?;
1495        let client_pubkey = author_keys.public_key().to_hex();
1496        let app = axum::Router::new().route(
1497            "/ws",
1498            axum::routing::get({
1499                let state = state.clone();
1500                let client_pubkey = client_pubkey.clone();
1501                move |ws: WebSocketUpgrade| {
1502                    let state = state.clone();
1503                    let client_pubkey = client_pubkey.clone();
1504                    async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1505                }
1506            }),
1507        );
1508        tokio::spawn(async move {
1509            let _ = axum::serve(listener, app).await;
1510        });
1511
1512        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1513        let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1514            .to_event(&author_keys)?;
1515        socket
1516            .send(TungsteniteMessage::Text(
1517                NostrClientMessage::event(event.clone()).as_json().into(),
1518            ))
1519            .await?;
1520
1521        let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1522            .await?
1523            .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1524        let TungsteniteMessage::Text(text) = reply else {
1525            anyhow::bail!("expected text publish ack");
1526        };
1527
1528        match NostrRelayMessage::from_json(text.as_str())? {
1529            NostrRelayMessage::Ok {
1530                event_id, status, ..
1531            } => {
1532                assert_eq!(event_id, event.id);
1533                assert!(status);
1534            }
1535            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1536        }
1537
1538        let stored = relay
1539            .query_events(
1540                &Filter::new()
1541                    .authors(vec![event.pubkey])
1542                    .kinds(vec![event.kind]),
1543                10,
1544            )
1545            .await;
1546        assert!(stored.iter().any(|candidate| candidate.id == event.id));
1547        Ok(())
1548    }
1549
1550    #[tokio::test]
1551    async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1552        let tmp = TempDir::new()?;
1553        let graph_store = {
1554            let _guard = crate::socialgraph::test_lock();
1555            crate::socialgraph::open_social_graph_store_with_mapsize(
1556                tmp.path(),
1557                Some(128 * 1024 * 1024),
1558            )?
1559        };
1560        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1561        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1562            Arc::clone(&backend),
1563            0,
1564            HashSet::new(),
1565        ));
1566        let relay = Arc::new(NostrRelay::new(
1567            Arc::clone(&backend),
1568            tmp.path().to_path_buf(),
1569            HashSet::new(),
1570            Some(access),
1571            NostrRelayConfig {
1572                spambox_db_max_bytes: 0,
1573                spambox_max_reqs_per_min: 1,
1574                ..Default::default()
1575            },
1576        )?);
1577
1578        let state = test_app_state(&tmp, relay, String::new())?;
1579        let listener = TcpListener::bind("127.0.0.1:0").await?;
1580        let addr = listener.local_addr()?;
1581        let app = axum::Router::new().route(
1582            "/ws",
1583            axum::routing::get({
1584                let state = state.clone();
1585                move |ws: WebSocketUpgrade| {
1586                    let state = state.clone();
1587                    async move { ws_data_with_client_pubkey(state, ws, None) }
1588                }
1589            }),
1590        );
1591        tokio::spawn(async move {
1592            let _ = axum::serve(listener, app).await;
1593        });
1594
1595        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1596        socket
1597            .send(TungsteniteMessage::Text(
1598                NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1599                    .as_json()
1600                    .into(),
1601            ))
1602            .await?;
1603
1604        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1605            .await?
1606            .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1607        let TungsteniteMessage::Text(first_text) = first else {
1608            anyhow::bail!("expected text EOSE reply");
1609        };
1610        match NostrRelayMessage::from_json(first_text.as_str())? {
1611            NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1612                assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1613            }
1614            other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1615        }
1616
1617        socket
1618            .send(TungsteniteMessage::Text(
1619                NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1620                    .as_json()
1621                    .into(),
1622            ))
1623            .await?;
1624
1625        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1626            .await?
1627            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1628        let TungsteniteMessage::Text(second_text) = second else {
1629            anyhow::bail!("expected text CLOSED reply");
1630        };
1631        let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1632        assert_eq!(
1633            second_value,
1634            serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1635        );
1636
1637        Ok(())
1638    }
1639
1640    #[tokio::test]
1641    async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1642        let tmp = TempDir::new()?;
1643        let graph_store = {
1644            let _guard = crate::socialgraph::test_lock();
1645            crate::socialgraph::open_social_graph_store_with_mapsize(
1646                tmp.path(),
1647                Some(128 * 1024 * 1024),
1648            )?
1649        };
1650        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1651        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1652            Arc::clone(&backend),
1653            0,
1654            HashSet::new(),
1655        ));
1656        let relay = Arc::new(NostrRelay::new(
1657            Arc::clone(&backend),
1658            tmp.path().to_path_buf(),
1659            HashSet::new(),
1660            Some(access),
1661            NostrRelayConfig {
1662                spambox_db_max_bytes: 0,
1663                spambox_max_events_per_min: 1,
1664                ..Default::default()
1665            },
1666        )?);
1667
1668        let state = test_app_state(&tmp, relay, String::new())?;
1669        let listener = TcpListener::bind("127.0.0.1:0").await?;
1670        let addr = listener.local_addr()?;
1671        let app = axum::Router::new().route(
1672            "/ws",
1673            axum::routing::get({
1674                let state = state.clone();
1675                move |ws: WebSocketUpgrade| {
1676                    let state = state.clone();
1677                    async move { ws_data_with_client_pubkey(state, ws, None) }
1678                }
1679            }),
1680        );
1681        tokio::spawn(async move {
1682            let _ = axum::serve(listener, app).await;
1683        });
1684
1685        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1686        let author_keys = Keys::generate();
1687        let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1688        let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1689
1690        socket
1691            .send(TungsteniteMessage::Text(
1692                NostrClientMessage::event(event_a.clone()).as_json().into(),
1693            ))
1694            .await?;
1695
1696        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1697            .await?
1698            .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1699        let TungsteniteMessage::Text(first_text) = first else {
1700            anyhow::bail!("expected text publish ack");
1701        };
1702        match NostrRelayMessage::from_json(first_text.as_str())? {
1703            NostrRelayMessage::Ok {
1704                event_id,
1705                status,
1706                message,
1707            } => {
1708                assert_eq!(event_id, event_a.id);
1709                assert!(status);
1710                assert_eq!(message, "spambox");
1711            }
1712            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1713        }
1714
1715        socket
1716            .send(TungsteniteMessage::Text(
1717                NostrClientMessage::event(event_b.clone()).as_json().into(),
1718            ))
1719            .await?;
1720
1721        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1722            .await?
1723            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1724        let TungsteniteMessage::Text(second_text) = second else {
1725            anyhow::bail!("expected text publish ack");
1726        };
1727        match NostrRelayMessage::from_json(second_text.as_str())? {
1728            NostrRelayMessage::Ok {
1729                event_id,
1730                status,
1731                message,
1732            } => {
1733                assert_eq!(event_id, event_b.id);
1734                assert!(!status);
1735                assert_eq!(message, "rate limited");
1736            }
1737            other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1738        }
1739
1740        Ok(())
1741    }
1742}