Skip to main content

hashtree_cli/server/
ws_relay.rs

1use axum::{
2    extract::{
3        ws::{Message, WebSocket, WebSocketUpgrade},
4        State,
5    },
6    response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11    ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12    RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21    nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24    encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31    #[serde(rename = "req")]
32    Request { id: u32, hash: String },
33    #[serde(rename = "res")]
34    Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage {
39    Hashtree(WsClientMessage),
40    Nostr(NostrClientMessage),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45    #[serde(rename = "type")]
46    kind: String,
47    id: u32,
48    hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53    #[serde(rename = "type")]
54    kind: &'static str,
55    id: u32,
56    hash: String,
57    found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61    ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65    state: AppState,
66    ws: WebSocketUpgrade,
67    client_pubkey: Option<String>,
68) -> impl IntoResponse {
69    ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73    // Use the Nostr relay's client-id generator when available so `/ws` IDs
74    // can't collide with WebRTC relay clients.
75    let client_id = state
76        .nostr_relay
77        .as_ref()
78        .map(|relay| relay.next_client_id())
79        .unwrap_or_else(|| state.ws_relay.next_id());
80    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82    {
83        let mut clients = state.ws_relay.clients.lock().await;
84        clients.insert(client_id, tx);
85    }
86    {
87        let mut protocols = state.ws_relay.client_protocols.lock().await;
88        protocols.insert(client_id, WsProtocol::HashtreeJson);
89    }
90
91    let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92        let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93        relay
94            .register_client(client_id, nostr_tx, client_pubkey.clone())
95            .await;
96        Some(nostr_rx)
97    } else {
98        None
99    };
100
101    let (mut sender, mut receiver) = socket.split();
102    loop {
103        tokio::select! {
104            maybe_msg = rx.recv() => {
105                let Some(msg) = maybe_msg else {
106                    break;
107                };
108                if sender.send(msg).await.is_err() {
109                    break;
110                }
111            }
112            maybe_text = async {
113                match &mut nostr_rx {
114                    Some(rx) => rx.recv().await,
115                    None => std::future::pending().await,
116                }
117            } => {
118                let Some(text) = maybe_text else {
119                    nostr_rx = None;
120                    continue;
121                };
122                if sender.send(Message::Text(text)).await.is_err() {
123                    break;
124                }
125            }
126            maybe_incoming = receiver.next() => {
127                match maybe_incoming {
128                    Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129                    Some(Err(_)) | None => break,
130                }
131            }
132        }
133    }
134
135    close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137    {
138        let mut clients = state.ws_relay.clients.lock().await;
139        clients.remove(&client_id);
140    }
141    {
142        let mut protocols = state.ws_relay.client_protocols.lock().await;
143        protocols.remove(&client_id);
144    }
145    {
146        let mut pending = state.ws_relay.pending.lock().await;
147        pending.retain(|(peer_id, _), _| *peer_id != client_id);
148    }
149
150    if let Some(relay) = &state.nostr_relay {
151        relay.unregister_client(client_id).await;
152    }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
156    let trimmed = text.trim_start();
157    if trimmed.starts_with('[') {
158        if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159            return Some(WsTextMessage::Nostr(msg));
160        }
161    }
162
163    if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164        return Some(WsTextMessage::Hashtree(msg));
165    }
166
167    None
168}
169async fn close_upstream_nostr_subscription(
170    state: &AppState,
171    client_id: u64,
172    subscription_id: &SubscriptionId,
173) {
174    let key = (client_id, subscription_id.to_string());
175    let subscription = {
176        let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177        subscriptions.remove(&key)
178    };
179    if let Some(subscription) = subscription {
180        let _ = subscription.close_tx.send(true);
181        for task in subscription.tasks {
182            task.abort();
183        }
184    }
185    state
186        .ws_relay
187        .upstream_pending_eose
188        .lock()
189        .await
190        .remove(&key);
191    state
192        .ws_relay
193        .upstream_seen_events
194        .lock()
195        .await
196        .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200    let keys = {
201        let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202        subscriptions
203            .keys()
204            .filter(|(id, _)| *id == client_id)
205            .cloned()
206            .collect::<Vec<_>>()
207    };
208    for (_, sub_id) in keys {
209        close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210    }
211}
212
213async fn forward_upstream_nostr_message(
214    state: &AppState,
215    client_id: u64,
216    subscription_id: &SubscriptionId,
217    text: &str,
218) {
219    let Ok(message) = NostrRelayMessage::from_json(text) else {
220        return;
221    };
222
223    match message {
224        NostrRelayMessage::Event {
225            subscription_id: sid,
226            event,
227        } if sid == *subscription_id => {
228            let event = *event;
229            let key = (client_id, subscription_id.to_string());
230            let event_id = event.id.to_hex();
231            let inserted = {
232                let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233                seen_events.entry(key).or_default().insert(event_id)
234            };
235            if !inserted {
236                return;
237            }
238            if let Some(relay) = &state.nostr_relay {
239                let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240            }
241            send_nostr(
242                state,
243                client_id,
244                NostrRelayMessage::event(subscription_id.clone(), event),
245            )
246            .await;
247        }
248        NostrRelayMessage::Closed {
249            subscription_id: sid,
250            message,
251        } if sid == *subscription_id => {
252            send_nostr(
253                state,
254                client_id,
255                NostrRelayMessage::closed(subscription_id.clone(), message),
256            )
257            .await;
258        }
259        _ => {}
260    }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264    state: &AppState,
265    client_id: u64,
266    subscription_id: &SubscriptionId,
267) {
268    let key = (client_id, subscription_id.to_string());
269    let should_send_eose = {
270        let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271        let Some(remaining) = pending.get_mut(&key) else {
272            return;
273        };
274        if *remaining > 0 {
275            *remaining -= 1;
276        }
277        if *remaining == 0 {
278            pending.remove(&key);
279            true
280        } else {
281            false
282        }
283    };
284
285    if should_send_eose {
286        send_nostr(
287            state,
288            client_id,
289            NostrRelayMessage::eose(subscription_id.clone()),
290        )
291        .await;
292    }
293}
294
295async fn run_upstream_nostr_subscription(
296    state: AppState,
297    client_id: u64,
298    relay_url: String,
299    subscription_id: SubscriptionId,
300    filters: Vec<NostrFilter>,
301    mut close_rx: watch::Receiver<bool>,
302) {
303    let mut relay_complete = false;
304    let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305        tracing::warn!(
306            "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307            client_id,
308            subscription_id,
309            relay_url,
310        );
311        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312        return;
313    };
314    let (mut write, mut read) = socket.split();
315    let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316    state
317        .ws_relay
318        .note_upstream_relay_send(request.as_bytes().len());
319    if write
320        .send(TungsteniteMessage::Text(request.into()))
321        .await
322        .is_err()
323    {
324        tracing::warn!(
325            "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326            client_id,
327            subscription_id,
328            relay_url,
329        );
330        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331        return;
332    }
333
334    loop {
335        tokio::select! {
336            _ = close_rx.changed() => {
337                if *close_rx.borrow() {
338                    let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339                    state
340                        .ws_relay
341                        .note_upstream_relay_send(close.as_bytes().len());
342                    let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343                    let _ = write.close().await;
344                    break;
345                }
346            }
347            message = read.next() => {
348                match message {
349                    Some(Ok(TungsteniteMessage::Text(text))) => {
350                        state
351                            .ws_relay
352                            .note_upstream_relay_receive(text.as_bytes().len());
353                        if matches!(
354                            NostrRelayMessage::from_json(text.as_str()),
355                            Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
356                        ) {
357                            if !relay_complete {
358                                relay_complete = true;
359                                mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360                            }
361                            continue;
362                        }
363                        forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364                    }
365                    Some(Ok(TungsteniteMessage::Ping(payload))) => {
366                        let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367                    }
368                    Some(Ok(TungsteniteMessage::Close(_))) | None => {
369                        if !relay_complete {
370                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371                        }
372                        break;
373                    }
374                    Some(Err(_)) => {
375                        if !relay_complete {
376                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377                        }
378                        break;
379                    }
380                    _ => {}
381                }
382            }
383        }
384    }
385}
386
387async fn start_upstream_nostr_subscription(
388    state: &AppState,
389    client_id: u64,
390    subscription_id: SubscriptionId,
391    filters: Vec<NostrFilter>,
392) -> usize {
393    let memory_before = process_memory_snapshot();
394    close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395    if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396        tracing::info!(
397            "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398            client_id,
399            subscription_id,
400            state.nostr_relay_urls.len(),
401            filters.len(),
402        );
403        return 0;
404    }
405
406    let mut relay_urls = Vec::new();
407    let mut seen = HashSet::new();
408    for relay in &state.nostr_relay_urls {
409        let relay = relay.trim();
410        if relay.is_empty() || !seen.insert(relay.to_string()) {
411            continue;
412        }
413        relay_urls.push(relay.to_string());
414    }
415    if relay_urls.is_empty() {
416        tracing::info!(
417            "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418            client_id,
419            subscription_id,
420        );
421        return 0;
422    }
423
424    tracing::info!(
425        "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426        client_id,
427        subscription_id,
428        relay_urls.len(),
429    );
430
431    let filter_summary = nostr_filters_summary(&filters);
432    let key = (client_id, subscription_id.to_string());
433    state
434        .ws_relay
435        .upstream_seen_events
436        .lock()
437        .await
438        .insert(key.clone(), HashSet::new());
439    state
440        .ws_relay
441        .upstream_pending_eose
442        .lock()
443        .await
444        .insert(key.clone(), relay_urls.len());
445
446    let (close_tx, close_rx) = watch::channel(false);
447    let mut tasks = Vec::new();
448    for relay_url in &relay_urls {
449        tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450            state.clone(),
451            client_id,
452            relay_url.clone(),
453            subscription_id.clone(),
454            filters.clone(),
455            close_rx.clone(),
456        )));
457    }
458
459    tracing::info!(
460        target: "hashtree_cli::server::ws_relay::upstream",
461        client_id,
462        subscription_id = %subscription_id,
463        relays = relay_urls.len(),
464        tasks = tasks.len(),
465        filters = filters.len(),
466        filter = %filter_summary,
467        memory_before = ?memory_before,
468        memory_after = ?process_memory_snapshot(),
469        "upstream nostr relay tasks spawned",
470    );
471
472    state
473        .ws_relay
474        .upstream_nostr_subscriptions
475        .lock()
476        .await
477        .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478    relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482    match msg {
483        Message::Text(text) => {
484            if let Some(msg) = parse_ws_text_message(&text) {
485                match msg {
486                    WsTextMessage::Hashtree(msg) => {
487                        set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488                        match msg {
489                            WsClientMessage::Request { id, hash } => {
490                                handle_request(
491                                    client_id,
492                                    id,
493                                    hash,
494                                    WsProtocol::HashtreeJson,
495                                    state,
496                                )
497                                .await;
498                            }
499                            WsClientMessage::Response { id, hash, found } => {
500                                handle_response(client_id, id, hash, found, state).await;
501                            }
502                        }
503                    }
504                    WsTextMessage::Nostr(msg) => {
505                        if let Some(relay) = &state.nostr_relay {
506                            match msg {
507                                NostrClientMessage::Req {
508                                    subscription_id,
509                                    filters,
510                                } => {
511                                    let local_events = match relay
512                                        .register_subscription_query(
513                                            client_id,
514                                            subscription_id.clone(),
515                                            filters.clone(),
516                                        )
517                                        .await
518                                    {
519                                        Ok(events) => events,
520                                        Err(message) => {
521                                            send_nostr(
522                                                state,
523                                                client_id,
524                                                NostrRelayMessage::closed(subscription_id, message),
525                                            )
526                                            .await;
527                                            return;
528                                        }
529                                    };
530
531                                    let upstream_relays = start_upstream_nostr_subscription(
532                                        state,
533                                        client_id,
534                                        subscription_id.clone(),
535                                        filters,
536                                    )
537                                    .await;
538                                    if upstream_relays > 0 {
539                                        let key = (client_id, subscription_id.to_string());
540                                        let mut seen_events =
541                                            state.ws_relay.upstream_seen_events.lock().await;
542                                        seen_events.entry(key).or_default().extend(
543                                            local_events.iter().map(|event| event.id.to_hex()),
544                                        );
545                                    }
546                                    for event in local_events {
547                                        send_nostr(
548                                            state,
549                                            client_id,
550                                            NostrRelayMessage::event(
551                                                subscription_id.clone(),
552                                                event,
553                                            ),
554                                        )
555                                        .await;
556                                    }
557                                    trim_process_allocations();
558                                    if upstream_relays == 0 {
559                                        send_nostr(
560                                            state,
561                                            client_id,
562                                            NostrRelayMessage::eose(subscription_id),
563                                        )
564                                        .await;
565                                    }
566                                }
567                                NostrClientMessage::Close(subscription_id) => {
568                                    close_upstream_nostr_subscription(
569                                        state,
570                                        client_id,
571                                        &subscription_id,
572                                    )
573                                    .await;
574                                    relay
575                                        .handle_client_message(
576                                            client_id,
577                                            NostrClientMessage::Close(subscription_id.clone()),
578                                        )
579                                        .await;
580                                }
581                                other => {
582                                    relay.handle_client_message(client_id, other).await;
583                                }
584                            }
585                        } else {
586                            handle_nostr_message(client_id, msg, state).await;
587                        }
588                    }
589                }
590            }
591        }
592        Message::Binary(data) => {
593            handle_binary(client_id, data, state).await;
594        }
595        Message::Close(_) => {}
596        _ => {}
597    }
598}
599
600async fn handle_request(
601    client_id: u64,
602    request_id: u32,
603    hash: String,
604    origin_protocol: WsProtocol,
605    state: &AppState,
606) {
607    let hash_hex = hash.to_lowercase();
608    let hash_bytes = match from_hex(&hash_hex) {
609        Ok(bytes) => bytes,
610        Err(_) => {
611            if origin_protocol == WsProtocol::HashtreeJson {
612                send_json(
613                    state,
614                    client_id,
615                    WsResponse {
616                        kind: "res",
617                        id: request_id,
618                        hash,
619                        found: false,
620                    },
621                )
622                .await;
623            }
624            return;
625        }
626    };
627
628    if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
629        match origin_protocol {
630            WsProtocol::HashtreeJson => {
631                send_json(
632                    state,
633                    client_id,
634                    WsResponse {
635                        kind: "res",
636                        id: request_id,
637                        hash: hash.clone(),
638                        found: true,
639                    },
640                )
641                .await;
642                send_binary(state, client_id, request_id, data).await;
643            }
644            WsProtocol::HashtreeMsgpack => {
645                send_msgpack_response(state, client_id, &hash_bytes, &data).await;
646            }
647            WsProtocol::Unknown => {}
648        }
649        return;
650    }
651
652    let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
653        let clients = state.ws_relay.clients.lock().await;
654        let protocols = state.ws_relay.client_protocols.lock().await;
655        clients
656            .iter()
657            .filter(|(id, _)| **id != client_id)
658            .filter_map(|(id, tx)| {
659                let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
660                match protocol {
661                    WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
662                        Some((*id, tx.clone(), protocol))
663                    }
664                    WsProtocol::Unknown => None,
665                }
666            })
667            .collect()
668    };
669
670    if peers.is_empty() {
671        if origin_protocol == WsProtocol::HashtreeJson {
672            send_json(
673                state,
674                client_id,
675                WsResponse {
676                    kind: "res",
677                    id: request_id,
678                    hash,
679                    found: false,
680                },
681            )
682            .await;
683        }
684        return;
685    }
686
687    {
688        let mut pending = state.ws_relay.pending.lock().await;
689        for (peer_id, _, _) in &peers {
690            pending.insert(
691                (*peer_id, request_id),
692                PendingRequest {
693                    origin_id: client_id,
694                    hash: hash.clone(),
695                    found: false,
696                    origin_protocol,
697                },
698            );
699        }
700    }
701
702    let request_text = serde_json::to_string(&WsRequest {
703        kind: "req".to_string(),
704        id: request_id,
705        hash: hash.clone(),
706    })
707    .unwrap_or_else(|_| String::new());
708    for (peer_id, tx, protocol) in peers {
709        match protocol {
710            WsProtocol::HashtreeMsgpack => {
711                let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
712            }
713            WsProtocol::HashtreeJson => {
714                let _ = tx.send(Message::Text(request_text.clone()));
715            }
716            WsProtocol::Unknown => {}
717        }
718    }
719
720    let timeout_state = state.clone();
721    let timeout_hash = hash.clone();
722    tokio::spawn(async move {
723        tokio::time::sleep(Duration::from_millis(1500)).await;
724        let mut pending = timeout_state.ws_relay.pending.lock().await;
725        let still_pending = pending
726            .iter()
727            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
728        let already_found = pending
729            .iter()
730            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
731        if !still_pending || already_found {
732            return;
733        }
734        let origin_protocol = pending
735            .iter()
736            .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
737            .map(|(_, p)| p.origin_protocol)
738            .unwrap_or(WsProtocol::HashtreeJson);
739        pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
740        drop(pending);
741        if origin_protocol == WsProtocol::HashtreeJson {
742            send_json(
743                &timeout_state,
744                client_id,
745                WsResponse {
746                    kind: "res",
747                    id: request_id,
748                    hash: timeout_hash,
749                    found: false,
750                },
751            )
752            .await;
753        }
754    });
755}
756
757async fn handle_response(
758    client_id: u64,
759    request_id: u32,
760    _hash: String,
761    found: bool,
762    state: &AppState,
763) {
764    let pending_entry = {
765        let pending = state.ws_relay.pending.lock().await;
766        pending
767            .get(&(client_id, request_id))
768            .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
769    };
770
771    let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
772        return;
773    };
774
775    if already_found && !found {
776        let mut pending = state.ws_relay.pending.lock().await;
777        pending.remove(&(client_id, request_id));
778        return;
779    }
780
781    if found {
782        let mut pending = state.ws_relay.pending.lock().await;
783        for ((_, id), p) in pending.iter_mut() {
784            if *id == request_id && p.origin_id == origin_id {
785                p.found = true;
786            }
787        }
788        drop(pending);
789        if origin_protocol == WsProtocol::HashtreeJson {
790            send_json(
791                state,
792                origin_id,
793                WsResponse {
794                    kind: "res",
795                    id: request_id,
796                    hash: pending_hash,
797                    found: true,
798                },
799            )
800            .await;
801        }
802        return;
803    }
804
805    let mut pending = state.ws_relay.pending.lock().await;
806    pending.remove(&(client_id, request_id));
807    let has_remaining = pending
808        .iter()
809        .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
810    drop(pending);
811
812    if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
813        send_json(
814            state,
815            origin_id,
816            WsResponse {
817                kind: "res",
818                id: request_id,
819                hash: pending_hash,
820                found: false,
821            },
822        )
823        .await;
824    }
825}
826
827async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
828    if let Some(msg) = parse_msgpack_message(&data) {
829        set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
830        match msg {
831            DataMessage::Request(req) => {
832                let hash_hex = hex_encode(&req.h);
833                let request_id = state.ws_relay.next_request_id();
834                handle_request(
835                    client_id,
836                    request_id,
837                    hash_hex,
838                    WsProtocol::HashtreeMsgpack,
839                    state,
840                )
841                .await;
842            }
843            DataMessage::Response(res) => {
844                handle_msgpack_response(client_id, res, state).await;
845            }
846            DataMessage::QuoteRequest(_)
847            | DataMessage::QuoteResponse(_)
848            | DataMessage::Payment(_)
849            | DataMessage::PaymentAck(_)
850            | DataMessage::Chunk(_)
851            | DataMessage::PeerHints(_) => {}
852        }
853        return;
854    }
855
856    // Legacy binary: [4-byte LE request_id][data]
857    if data.len() < 4 {
858        return;
859    }
860    let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
861    let pending_entry = {
862        let pending = state.ws_relay.pending.lock().await;
863        pending
864            .get(&(client_id, request_id))
865            .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
866    };
867    let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
868        return;
869    };
870
871    match origin_protocol {
872        WsProtocol::HashtreeJson => {
873            send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
874        }
875        WsProtocol::HashtreeMsgpack => {
876            let Ok(hash_bytes) = from_hex(&hash_hex) else {
877                return;
878            };
879            send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
880        }
881        WsProtocol::Unknown => {}
882    }
883
884    let mut pending = state.ws_relay.pending.lock().await;
885    pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
886}
887
888async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
889    let replies = nostr_responses_for(&msg);
890    for reply in replies {
891        send_nostr(state, client_id, reply).await;
892    }
893}
894
895fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
896    match msg {
897        NostrClientMessage::Event(event) => {
898            let ok = event.verify().is_ok();
899            let message = if ok { "" } else { "invalid: signature" };
900            vec![NostrRelayMessage::ok(event.id, ok, message)]
901        }
902        NostrClientMessage::Req {
903            subscription_id, ..
904        } => {
905            vec![NostrRelayMessage::eose(subscription_id.clone())]
906        }
907        NostrClientMessage::Count {
908            subscription_id, ..
909        } => {
910            vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
911        }
912        NostrClientMessage::Close(_) => Vec::new(),
913        NostrClientMessage::Auth(event) => {
914            let ok = event.verify().is_ok();
915            let message = if ok { "" } else { "invalid auth" };
916            vec![NostrRelayMessage::ok(event.id, ok, message)]
917        }
918        NostrClientMessage::NegOpen { .. }
919        | NostrClientMessage::NegMsg { .. }
920        | NostrClientMessage::NegClose { .. } => {
921            vec![NostrRelayMessage::notice("negentropy not supported")]
922        }
923    }
924}
925
926async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
927    let text = response.as_json();
928    send_to_client(state, client_id, Message::Text(text)).await;
929}
930
931fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
932    let msg = parse_message(data).ok()?;
933    match msg {
934        DataMessage::Request(req) => {
935            if req.h.len() == 32 {
936                Some(DataMessage::Request(req))
937            } else {
938                None
939            }
940        }
941        DataMessage::Response(res) => {
942            if res.h.len() == 32 {
943                Some(DataMessage::Response(res))
944            } else {
945                None
946            }
947        }
948        DataMessage::QuoteRequest(req) => {
949            if req.h.len() == 32 {
950                Some(DataMessage::QuoteRequest(req))
951            } else {
952                None
953            }
954        }
955        DataMessage::QuoteResponse(res) => {
956            if res.h.len() == 32 {
957                Some(DataMessage::QuoteResponse(res))
958            } else {
959                None
960            }
961        }
962        DataMessage::Payment(req) => {
963            if req.h.len() == 32 {
964                Some(DataMessage::Payment(req))
965            } else {
966                None
967            }
968        }
969        DataMessage::PaymentAck(res) => {
970            if res.h.len() == 32 {
971                Some(DataMessage::PaymentAck(res))
972            } else {
973                None
974            }
975        }
976        DataMessage::Chunk(chunk) => {
977            if chunk.h.len() == 32 {
978                Some(DataMessage::Chunk(chunk))
979            } else {
980                None
981            }
982        }
983        DataMessage::PeerHints(_) => None,
984    }
985}
986
987async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
988    let hash_hex = hex_encode(&res.h);
989    let data = res.d.clone();
990    let hash_bytes = res.h.clone();
991
992    let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
993    let mut seen = HashSet::new();
994    {
995        let pending = state.ws_relay.pending.lock().await;
996        for ((peer_id, request_id), p) in pending.iter() {
997            if *peer_id != client_id {
998                continue;
999            }
1000            if p.hash != hash_hex {
1001                continue;
1002            }
1003            if seen.insert((p.origin_id, *request_id)) {
1004                responses.push((p.origin_id, *request_id, p.origin_protocol));
1005            }
1006        }
1007    }
1008
1009    if responses.is_empty() {
1010        return;
1011    }
1012
1013    for (origin_id, request_id, protocol) in &responses {
1014        match protocol {
1015            WsProtocol::HashtreeJson => {
1016                send_json(
1017                    state,
1018                    *origin_id,
1019                    WsResponse {
1020                        kind: "res",
1021                        id: *request_id,
1022                        hash: hash_hex.clone(),
1023                        found: true,
1024                    },
1025                )
1026                .await;
1027                send_binary(state, *origin_id, *request_id, data.clone()).await;
1028            }
1029            WsProtocol::HashtreeMsgpack => {
1030                send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1031            }
1032            WsProtocol::Unknown => {}
1033        }
1034    }
1035
1036    let completed: HashSet<(u64, u32)> = responses
1037        .into_iter()
1038        .map(|(origin_id, request_id, _)| (origin_id, request_id))
1039        .collect();
1040    let mut pending = state.ws_relay.pending.lock().await;
1041    pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1042}
1043
1044async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1045    if let Ok(text) = serde_json::to_string(&response) {
1046        send_to_client(state, client_id, Message::Text(text)).await;
1047    }
1048}
1049
1050async fn send_msgpack_request(
1051    state: &AppState,
1052    client_id: u64,
1053    hash: &[u8],
1054) -> Result<(), rmp_serde::encode::Error> {
1055    let req = DataRequest {
1056        h: hash.to_vec(),
1057        htl: MAX_HTL,
1058        q: None,
1059    };
1060    let wire = encode_request(&req)?;
1061    send_to_client(state, client_id, Message::Binary(wire)).await;
1062    Ok(())
1063}
1064
1065async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1066    let res = DataResponse {
1067        h: hash.to_vec(),
1068        d: data.to_vec(),
1069        i: None,
1070        n: None,
1071    };
1072    if let Ok(wire) = encode_response(&res) {
1073        send_to_client(state, client_id, Message::Binary(wire)).await;
1074    }
1075}
1076
1077async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1078    let mut packet = Vec::with_capacity(4 + payload.len());
1079    packet.extend_from_slice(&request_id.to_le_bytes());
1080    packet.extend_from_slice(&payload);
1081    send_to_client(state, client_id, Message::Binary(packet)).await;
1082}
1083
1084async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1085    let sender = {
1086        let clients = state.ws_relay.clients.lock().await;
1087        clients.get(&client_id).cloned()
1088    };
1089    if let Some(tx) = sender {
1090        let _ = tx.send(msg);
1091    }
1092}
1093
1094async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1095    let mut protocols = state.ws_relay.client_protocols.lock().await;
1096    protocols.insert(client_id, protocol);
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102    use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1103    use anyhow::Result;
1104    use futures::{SinkExt, StreamExt};
1105    use nostr::secp256k1::schnorr::Signature;
1106    use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1107    use std::collections::HashSet;
1108    use std::sync::Arc;
1109    use tempfile::TempDir;
1110    use tokio::net::TcpListener;
1111    use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1112
1113    #[test]
1114    fn parse_ws_text_message_detects_nostr_req() {
1115        let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1116        match parse_ws_text_message(msg) {
1117            Some(WsTextMessage::Nostr(_)) => {}
1118            other => panic!("expected Nostr message, got {:?}", other),
1119        }
1120    }
1121
1122    #[test]
1123    fn parse_ws_text_message_detects_hashtree_request() {
1124        let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1125        match parse_ws_text_message(msg) {
1126            Some(WsTextMessage::Hashtree(_)) => {}
1127            other => panic!("expected Hashtree message, got {:?}", other),
1128        }
1129    }
1130
1131    #[test]
1132    fn nostr_replies_for_req_is_eose() {
1133        let sub = SubscriptionId::new("sub-1");
1134        let msg = NostrClientMessage::req(sub.clone(), vec![]);
1135        let replies = nostr_responses_for(&msg);
1136        assert_eq!(replies.len(), 1);
1137        match &replies[0] {
1138            NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1139            other => panic!("expected EOSE, got {:?}", other),
1140        }
1141    }
1142
1143    #[test]
1144    fn nostr_replies_for_event_ok() {
1145        let keys = Keys::generate();
1146        let event = EventBuilder::new(Kind::TextNote, "hello", [])
1147            .to_event(&keys)
1148            .unwrap();
1149        let msg = NostrClientMessage::event(event.clone());
1150        let replies = nostr_responses_for(&msg);
1151        assert_eq!(replies.len(), 1);
1152        match &replies[0] {
1153            NostrRelayMessage::Ok {
1154                event_id, status, ..
1155            } => {
1156                assert_eq!(event_id, &event.id);
1157                assert!(*status);
1158            }
1159            other => panic!("expected OK, got {:?}", other),
1160        }
1161    }
1162
1163    #[test]
1164    fn nostr_replies_for_invalid_event_is_not_ok() {
1165        let keys = Keys::generate();
1166        let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1167            .to_event(&keys)
1168            .unwrap();
1169        event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1170        let msg = NostrClientMessage::event(event);
1171        let replies = nostr_responses_for(&msg);
1172        assert_eq!(replies.len(), 1);
1173        match &replies[0] {
1174            NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1175            other => panic!("expected OK=false, got {:?}", other),
1176        }
1177    }
1178
1179    async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1180        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1181        let addr = listener.local_addr().expect("relay addr");
1182        tokio::spawn(async move {
1183            let (stream, _) = listener.accept().await.expect("accept relay");
1184            let ws = accept_async(stream).await.expect("accept websocket");
1185            let (mut write, mut read) = ws.split();
1186
1187            while let Some(Ok(message)) = read.next().await {
1188                let TungsteniteMessage::Text(text) = message else {
1189                    continue;
1190                };
1191                let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1192                    continue;
1193                };
1194                if let NostrClientMessage::Req {
1195                    subscription_id,
1196                    filters,
1197                } = parsed
1198                {
1199                    for event in events
1200                        .iter()
1201                        .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1202                    {
1203                        let _ = write
1204                            .send(TungsteniteMessage::Text(
1205                                NostrRelayMessage::event(subscription_id.clone(), event.clone())
1206                                    .as_json()
1207                                    .into(),
1208                            ))
1209                            .await;
1210                    }
1211                    let _ = write
1212                        .send(TungsteniteMessage::Text(
1213                            NostrRelayMessage::eose(subscription_id).as_json().into(),
1214                        ))
1215                        .await;
1216                }
1217            }
1218        });
1219        format!("ws://{}", addr)
1220    }
1221
1222    fn test_app_state(
1223        tmp: &TempDir,
1224        relay: Arc<NostrRelay>,
1225        relay_url: String,
1226    ) -> Result<AppState> {
1227        let store = Arc::new(crate::storage::HashtreeStore::with_options(
1228            tmp.path(),
1229            None,
1230            128 * 1024 * 1024,
1231        )?);
1232        Ok(AppState {
1233            store,
1234            auth: None,
1235            peer_mode: crate::config::ServerMode::Normal,
1236            hash_get_enabled: true,
1237            webrtc_peers: None,
1238            ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1239            max_upload_bytes: 5 * 1024 * 1024,
1240            public_writes: true,
1241            allowed_pubkeys: HashSet::new(),
1242            upstream_blossom: Vec::new(),
1243            social_graph: None,
1244            social_graph_store: None,
1245            social_graph_root: None,
1246            socialgraph_snapshot_public: false,
1247            nostr_relay: Some(relay),
1248            nostr_relay_urls: vec![relay_url],
1249            tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1250            inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1251                std::collections::HashMap::new(),
1252            )),
1253            directory_listing_cache: Arc::new(std::sync::Mutex::new(
1254                super::super::auth::new_lookup_cache(),
1255            )),
1256            resolved_path_cache: Arc::new(std::sync::Mutex::new(
1257                super::super::auth::new_lookup_cache(),
1258            )),
1259            thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1260                super::super::auth::new_lookup_cache(),
1261            )),
1262            cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1263        })
1264    }
1265
1266    #[tokio::test]
1267    async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1268        let tmp = TempDir::new()?;
1269        let graph_store = {
1270            let _guard = crate::socialgraph::test_lock();
1271            crate::socialgraph::open_social_graph_store_with_mapsize(
1272                tmp.path(),
1273                Some(128 * 1024 * 1024),
1274            )?
1275        };
1276        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1277        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1278            Arc::clone(&backend),
1279            0,
1280            HashSet::new(),
1281        ));
1282
1283        let keys = Keys::generate();
1284        let relay = Arc::new(NostrRelay::new(
1285            Arc::clone(&backend),
1286            tmp.path().to_path_buf(),
1287            HashSet::from([keys.public_key().to_hex()]),
1288            Some(access),
1289            NostrRelayConfig {
1290                spambox_db_max_bytes: 0,
1291                ..Default::default()
1292            },
1293        )?);
1294
1295        let event = EventBuilder::new(
1296            Kind::from(30078_u16),
1297            "",
1298            [
1299                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1300                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1301            ],
1302        )
1303        .to_event(&keys)?;
1304
1305        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1306        let filter = Filter::new()
1307            .authors(vec![event.pubkey])
1308            .kinds(vec![event.kind]);
1309        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1310        let client_id = 7_u64;
1311        let (tx, mut rx) = mpsc::unbounded_channel();
1312        state.ws_relay.clients.lock().await.insert(client_id, tx);
1313        let subscription_id = SubscriptionId::new("sub-1");
1314
1315        start_upstream_nostr_subscription(
1316            &state,
1317            client_id,
1318            subscription_id.clone(),
1319            vec![filter.clone()],
1320        )
1321        .await;
1322
1323        let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1324            .await?
1325            .expect("forwarded upstream event");
1326        let Message::Text(text) = forwarded else {
1327            panic!("expected text event");
1328        };
1329        match NostrRelayMessage::from_json(text.as_str())? {
1330            NostrRelayMessage::Event {
1331                subscription_id: sid,
1332                event: forwarded_event,
1333            } => {
1334                assert_eq!(sid, subscription_id);
1335                assert_eq!(forwarded_event.id, event.id);
1336            }
1337            other => panic!("expected forwarded EVENT, got {:?}", other),
1338        }
1339
1340        let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1341            .await?
1342            .expect("forwarded upstream eose");
1343        let Message::Text(eose_text) = eose else {
1344            panic!("expected text eose");
1345        };
1346        match NostrRelayMessage::from_json(eose_text.as_str())? {
1347            NostrRelayMessage::EndOfStoredEvents(sid) => {
1348                assert_eq!(sid, subscription_id);
1349            }
1350            other => panic!("expected forwarded EOSE, got {:?}", other),
1351        }
1352
1353        let events = relay.query_events(&filter, 10).await;
1354        assert_eq!(events.len(), 1);
1355        assert_eq!(events[0].id, event.id);
1356
1357        close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1358        assert!(state
1359            .ws_relay
1360            .upstream_nostr_subscriptions
1361            .lock()
1362            .await
1363            .is_empty());
1364        Ok(())
1365    }
1366
1367    #[tokio::test]
1368    async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1369        let tmp = TempDir::new()?;
1370        let graph_store = {
1371            let _guard = crate::socialgraph::test_lock();
1372            crate::socialgraph::open_social_graph_store_with_mapsize(
1373                tmp.path(),
1374                Some(128 * 1024 * 1024),
1375            )?
1376        };
1377        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1378        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1379            Arc::clone(&backend),
1380            0,
1381            HashSet::new(),
1382        ));
1383
1384        let keys = Keys::generate();
1385        let relay = Arc::new(NostrRelay::new(
1386            Arc::clone(&backend),
1387            tmp.path().to_path_buf(),
1388            HashSet::from([keys.public_key().to_hex()]),
1389            Some(access),
1390            NostrRelayConfig {
1391                spambox_db_max_bytes: 0,
1392                ..Default::default()
1393            },
1394        )?);
1395
1396        let event = EventBuilder::new(
1397            Kind::from(30078_u16),
1398            "",
1399            [
1400                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1401                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1402            ],
1403        )
1404        .to_event(&keys)?;
1405
1406        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1407        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1408        let client_id = 11_u64;
1409        let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1410        let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1411        state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1412        relay.register_client(client_id, relay_tx, None).await;
1413
1414        let request = NostrClientMessage::req(
1415            SubscriptionId::new("feed"),
1416            vec![Filter::new()
1417                .authors(vec![event.pubkey])
1418                .kinds(vec![event.kind])],
1419        )
1420        .as_json();
1421
1422        handle_message(client_id, Message::Text(request.into()), &state).await;
1423
1424        let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1425            .await?
1426            .expect("first forwarded message");
1427        let Message::Text(first_text) = first else {
1428            panic!("expected text event");
1429        };
1430        match NostrRelayMessage::from_json(first_text.as_str())? {
1431            NostrRelayMessage::Event {
1432                event: forwarded_event,
1433                ..
1434            } => {
1435                assert_eq!(forwarded_event.id, event.id);
1436            }
1437            other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1438        }
1439
1440        let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1441            .await?
1442            .expect("second forwarded message");
1443        let Message::Text(second_text) = second else {
1444            panic!("expected text eose");
1445        };
1446        match NostrRelayMessage::from_json(second_text.as_str())? {
1447            NostrRelayMessage::EndOfStoredEvents(sid) => {
1448                assert_eq!(sid, SubscriptionId::new("feed"));
1449            }
1450            other => panic!("expected aggregated EOSE, got {:?}", other),
1451        }
1452
1453        Ok(())
1454    }
1455
1456    #[tokio::test]
1457    async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1458        let tmp = TempDir::new()?;
1459        let graph_store = {
1460            let _guard = crate::socialgraph::test_lock();
1461            crate::socialgraph::open_social_graph_store_with_mapsize(
1462                tmp.path(),
1463                Some(128 * 1024 * 1024),
1464            )?
1465        };
1466        let author_keys = Keys::generate();
1467        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1468        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1469            Arc::clone(&backend),
1470            0,
1471            HashSet::from([author_keys.public_key().to_hex()]),
1472        ));
1473        let relay = Arc::new(NostrRelay::new(
1474            Arc::clone(&backend),
1475            tmp.path().to_path_buf(),
1476            HashSet::from([author_keys.public_key().to_hex()]),
1477            Some(access),
1478            NostrRelayConfig {
1479                spambox_db_max_bytes: 0,
1480                ..Default::default()
1481            },
1482        )?);
1483
1484        let state = test_app_state(&tmp, relay.clone(), String::new())?;
1485        let listener = TcpListener::bind("127.0.0.1:0").await?;
1486        let addr = listener.local_addr()?;
1487        let client_pubkey = author_keys.public_key().to_hex();
1488        let app = axum::Router::new().route(
1489            "/ws",
1490            axum::routing::get({
1491                let state = state.clone();
1492                let client_pubkey = client_pubkey.clone();
1493                move |ws: WebSocketUpgrade| {
1494                    let state = state.clone();
1495                    let client_pubkey = client_pubkey.clone();
1496                    async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1497                }
1498            }),
1499        );
1500        tokio::spawn(async move {
1501            let _ = axum::serve(listener, app).await;
1502        });
1503
1504        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1505        let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1506            .to_event(&author_keys)?;
1507        socket
1508            .send(TungsteniteMessage::Text(
1509                NostrClientMessage::event(event.clone()).as_json().into(),
1510            ))
1511            .await?;
1512
1513        let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1514            .await?
1515            .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1516        let TungsteniteMessage::Text(text) = reply else {
1517            anyhow::bail!("expected text publish ack");
1518        };
1519
1520        match NostrRelayMessage::from_json(text.as_str())? {
1521            NostrRelayMessage::Ok {
1522                event_id, status, ..
1523            } => {
1524                assert_eq!(event_id, event.id);
1525                assert!(status);
1526            }
1527            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1528        }
1529
1530        let stored = relay
1531            .query_events(
1532                &Filter::new()
1533                    .authors(vec![event.pubkey])
1534                    .kinds(vec![event.kind]),
1535                10,
1536            )
1537            .await;
1538        assert!(stored.iter().any(|candidate| candidate.id == event.id));
1539        Ok(())
1540    }
1541
1542    #[tokio::test]
1543    async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1544        let tmp = TempDir::new()?;
1545        let graph_store = {
1546            let _guard = crate::socialgraph::test_lock();
1547            crate::socialgraph::open_social_graph_store_with_mapsize(
1548                tmp.path(),
1549                Some(128 * 1024 * 1024),
1550            )?
1551        };
1552        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1553        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1554            Arc::clone(&backend),
1555            0,
1556            HashSet::new(),
1557        ));
1558        let relay = Arc::new(NostrRelay::new(
1559            Arc::clone(&backend),
1560            tmp.path().to_path_buf(),
1561            HashSet::new(),
1562            Some(access),
1563            NostrRelayConfig {
1564                spambox_db_max_bytes: 0,
1565                spambox_max_reqs_per_min: 1,
1566                ..Default::default()
1567            },
1568        )?);
1569
1570        let state = test_app_state(&tmp, relay, String::new())?;
1571        let listener = TcpListener::bind("127.0.0.1:0").await?;
1572        let addr = listener.local_addr()?;
1573        let app = axum::Router::new().route(
1574            "/ws",
1575            axum::routing::get({
1576                let state = state.clone();
1577                move |ws: WebSocketUpgrade| {
1578                    let state = state.clone();
1579                    async move { ws_data_with_client_pubkey(state, ws, None) }
1580                }
1581            }),
1582        );
1583        tokio::spawn(async move {
1584            let _ = axum::serve(listener, app).await;
1585        });
1586
1587        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1588        socket
1589            .send(TungsteniteMessage::Text(
1590                NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1591                    .as_json()
1592                    .into(),
1593            ))
1594            .await?;
1595
1596        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1597            .await?
1598            .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1599        let TungsteniteMessage::Text(first_text) = first else {
1600            anyhow::bail!("expected text EOSE reply");
1601        };
1602        match NostrRelayMessage::from_json(first_text.as_str())? {
1603            NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1604                assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1605            }
1606            other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1607        }
1608
1609        socket
1610            .send(TungsteniteMessage::Text(
1611                NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1612                    .as_json()
1613                    .into(),
1614            ))
1615            .await?;
1616
1617        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1618            .await?
1619            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1620        let TungsteniteMessage::Text(second_text) = second else {
1621            anyhow::bail!("expected text CLOSED reply");
1622        };
1623        let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1624        assert_eq!(
1625            second_value,
1626            serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1627        );
1628
1629        Ok(())
1630    }
1631
1632    #[tokio::test]
1633    async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1634        let tmp = TempDir::new()?;
1635        let graph_store = {
1636            let _guard = crate::socialgraph::test_lock();
1637            crate::socialgraph::open_social_graph_store_with_mapsize(
1638                tmp.path(),
1639                Some(128 * 1024 * 1024),
1640            )?
1641        };
1642        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1643        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1644            Arc::clone(&backend),
1645            0,
1646            HashSet::new(),
1647        ));
1648        let relay = Arc::new(NostrRelay::new(
1649            Arc::clone(&backend),
1650            tmp.path().to_path_buf(),
1651            HashSet::new(),
1652            Some(access),
1653            NostrRelayConfig {
1654                spambox_db_max_bytes: 0,
1655                spambox_max_events_per_min: 1,
1656                ..Default::default()
1657            },
1658        )?);
1659
1660        let state = test_app_state(&tmp, relay, String::new())?;
1661        let listener = TcpListener::bind("127.0.0.1:0").await?;
1662        let addr = listener.local_addr()?;
1663        let app = axum::Router::new().route(
1664            "/ws",
1665            axum::routing::get({
1666                let state = state.clone();
1667                move |ws: WebSocketUpgrade| {
1668                    let state = state.clone();
1669                    async move { ws_data_with_client_pubkey(state, ws, None) }
1670                }
1671            }),
1672        );
1673        tokio::spawn(async move {
1674            let _ = axum::serve(listener, app).await;
1675        });
1676
1677        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1678        let author_keys = Keys::generate();
1679        let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1680        let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1681
1682        socket
1683            .send(TungsteniteMessage::Text(
1684                NostrClientMessage::event(event_a.clone()).as_json().into(),
1685            ))
1686            .await?;
1687
1688        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1689            .await?
1690            .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1691        let TungsteniteMessage::Text(first_text) = first else {
1692            anyhow::bail!("expected text publish ack");
1693        };
1694        match NostrRelayMessage::from_json(first_text.as_str())? {
1695            NostrRelayMessage::Ok {
1696                event_id,
1697                status,
1698                message,
1699            } => {
1700                assert_eq!(event_id, event_a.id);
1701                assert!(status);
1702                assert_eq!(message, "spambox");
1703            }
1704            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1705        }
1706
1707        socket
1708            .send(TungsteniteMessage::Text(
1709                NostrClientMessage::event(event_b.clone()).as_json().into(),
1710            ))
1711            .await?;
1712
1713        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1714            .await?
1715            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1716        let TungsteniteMessage::Text(second_text) = second else {
1717            anyhow::bail!("expected text publish ack");
1718        };
1719        match NostrRelayMessage::from_json(second_text.as_str())? {
1720            NostrRelayMessage::Ok {
1721                event_id,
1722                status,
1723                message,
1724            } => {
1725                assert_eq!(event_id, event_b.id);
1726                assert!(!status);
1727                assert_eq!(message, "rate limited");
1728            }
1729            other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1730        }
1731
1732        Ok(())
1733    }
1734}