Skip to main content

hashtree_cli/server/
ws_relay.rs

1use axum::{
2    extract::{
3        ws::{Message, WebSocket, WebSocketUpgrade},
4        State,
5    },
6    response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11    ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12    RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::diagnostics::{
21    nostr_filters_summary, process_memory_snapshot, trim_process_allocations,
22};
23use crate::webrtc::types::{
24    encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
25};
26use hex::encode as hex_encode;
27
28#[derive(Debug, Deserialize)]
29#[serde(tag = "type")]
30enum WsClientMessage {
31    #[serde(rename = "req")]
32    Request { id: u32, hash: String },
33    #[serde(rename = "res")]
34    Response { id: u32, hash: String, found: bool },
35}
36
37#[derive(Debug)]
38enum WsTextMessage {
39    Hashtree(WsClientMessage),
40    Nostr(NostrClientMessage),
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct WsRequest {
45    #[serde(rename = "type")]
46    kind: String,
47    id: u32,
48    hash: String,
49}
50
51#[derive(Debug, Serialize)]
52struct WsResponse {
53    #[serde(rename = "type")]
54    kind: &'static str,
55    id: u32,
56    hash: String,
57    found: bool,
58}
59
60pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
61    ws_data_with_client_pubkey(state, ws, None)
62}
63
64pub fn ws_data_with_client_pubkey(
65    state: AppState,
66    ws: WebSocketUpgrade,
67    client_pubkey: Option<String>,
68) -> impl IntoResponse {
69    ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
70}
71
72async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
73    // Use the Nostr relay's client-id generator when available so `/ws` IDs
74    // can't collide with WebRTC relay clients.
75    let client_id = state
76        .nostr_relay
77        .as_ref()
78        .map(|relay| relay.next_client_id())
79        .unwrap_or_else(|| state.ws_relay.next_id());
80    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
81
82    {
83        let mut clients = state.ws_relay.clients.lock().await;
84        clients.insert(client_id, tx);
85    }
86    {
87        let mut protocols = state.ws_relay.client_protocols.lock().await;
88        protocols.insert(client_id, WsProtocol::HashtreeJson);
89    }
90
91    let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
92        let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
93        relay
94            .register_client(client_id, nostr_tx, client_pubkey.clone())
95            .await;
96        Some(nostr_rx)
97    } else {
98        None
99    };
100
101    let (mut sender, mut receiver) = socket.split();
102    loop {
103        tokio::select! {
104            maybe_msg = rx.recv() => {
105                let Some(msg) = maybe_msg else {
106                    break;
107                };
108                if sender.send(msg).await.is_err() {
109                    break;
110                }
111            }
112            maybe_text = async {
113                match &mut nostr_rx {
114                    Some(rx) => rx.recv().await,
115                    None => std::future::pending().await,
116                }
117            } => {
118                let Some(text) = maybe_text else {
119                    nostr_rx = None;
120                    continue;
121                };
122                if sender.send(Message::Text(text)).await.is_err() {
123                    break;
124                }
125            }
126            maybe_incoming = receiver.next() => {
127                match maybe_incoming {
128                    Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
129                    Some(Err(_)) | None => break,
130                }
131            }
132        }
133    }
134
135    close_all_upstream_nostr_subscriptions(&state, client_id).await;
136
137    {
138        let mut clients = state.ws_relay.clients.lock().await;
139        clients.remove(&client_id);
140    }
141    {
142        let mut protocols = state.ws_relay.client_protocols.lock().await;
143        protocols.remove(&client_id);
144    }
145    {
146        let mut pending = state.ws_relay.pending.lock().await;
147        pending.retain(|(peer_id, _), _| *peer_id != client_id);
148    }
149
150    if let Some(relay) = &state.nostr_relay {
151        relay.unregister_client(client_id).await;
152    }
153}
154
155fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
156    let trimmed = text.trim_start();
157    if trimmed.starts_with('[') {
158        if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
159            return Some(WsTextMessage::Nostr(msg));
160        }
161    }
162
163    if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
164        return Some(WsTextMessage::Hashtree(msg));
165    }
166
167    None
168}
169async fn close_upstream_nostr_subscription(
170    state: &AppState,
171    client_id: u64,
172    subscription_id: &SubscriptionId,
173) {
174    let key = (client_id, subscription_id.to_string());
175    let subscription = {
176        let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
177        subscriptions.remove(&key)
178    };
179    if let Some(subscription) = subscription {
180        let _ = subscription.close_tx.send(true);
181        for task in subscription.tasks {
182            task.abort();
183        }
184    }
185    state
186        .ws_relay
187        .upstream_pending_eose
188        .lock()
189        .await
190        .remove(&key);
191    state
192        .ws_relay
193        .upstream_seen_events
194        .lock()
195        .await
196        .remove(&key);
197}
198
199async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
200    let keys = {
201        let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
202        subscriptions
203            .keys()
204            .filter(|(id, _)| *id == client_id)
205            .cloned()
206            .collect::<Vec<_>>()
207    };
208    for (_, sub_id) in keys {
209        close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
210    }
211}
212
213async fn forward_upstream_nostr_message(
214    state: &AppState,
215    client_id: u64,
216    subscription_id: &SubscriptionId,
217    text: &str,
218) {
219    let Ok(message) = NostrRelayMessage::from_json(text) else {
220        return;
221    };
222
223    match message {
224        NostrRelayMessage::Event {
225            subscription_id: sid,
226            event,
227        } if sid == *subscription_id => {
228            let event = *event;
229            let key = (client_id, subscription_id.to_string());
230            let event_id = event.id.to_hex();
231            let inserted = {
232                let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
233                seen_events.entry(key).or_default().insert(event_id)
234            };
235            if !inserted {
236                return;
237            }
238            if let Some(relay) = &state.nostr_relay {
239                let _ = relay.ingest_trusted_event_silent(event.clone()).await;
240            }
241            send_nostr(
242                state,
243                client_id,
244                NostrRelayMessage::event(subscription_id.clone(), event),
245            )
246            .await;
247        }
248        NostrRelayMessage::Closed {
249            subscription_id: sid,
250            message,
251        } if sid == *subscription_id => {
252            send_nostr(
253                state,
254                client_id,
255                NostrRelayMessage::closed(subscription_id.clone(), message),
256            )
257            .await;
258        }
259        _ => {}
260    }
261}
262
263async fn mark_upstream_nostr_relay_complete(
264    state: &AppState,
265    client_id: u64,
266    subscription_id: &SubscriptionId,
267) {
268    let key = (client_id, subscription_id.to_string());
269    let should_send_eose = {
270        let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
271        let Some(remaining) = pending.get_mut(&key) else {
272            return;
273        };
274        if *remaining > 0 {
275            *remaining -= 1;
276        }
277        if *remaining == 0 {
278            pending.remove(&key);
279            true
280        } else {
281            false
282        }
283    };
284
285    if should_send_eose {
286        send_nostr(
287            state,
288            client_id,
289            NostrRelayMessage::eose(subscription_id.clone()),
290        )
291        .await;
292    }
293}
294
295async fn run_upstream_nostr_subscription(
296    state: AppState,
297    client_id: u64,
298    relay_url: String,
299    subscription_id: SubscriptionId,
300    filters: Vec<NostrFilter>,
301    mut close_rx: watch::Receiver<bool>,
302) {
303    let mut relay_complete = false;
304    let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
305        tracing::warn!(
306            "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
307            client_id,
308            subscription_id,
309            relay_url,
310        );
311        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
312        return;
313    };
314    let (mut write, mut read) = socket.split();
315    let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
316    state
317        .ws_relay
318        .note_upstream_relay_send(request.as_bytes().len());
319    if write
320        .send(TungsteniteMessage::Text(request.into()))
321        .await
322        .is_err()
323    {
324        tracing::warn!(
325            "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
326            client_id,
327            subscription_id,
328            relay_url,
329        );
330        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
331        return;
332    }
333
334    loop {
335        tokio::select! {
336            _ = close_rx.changed() => {
337                if *close_rx.borrow() {
338                    let close = NostrClientMessage::close(subscription_id.clone()).as_json();
339                    state
340                        .ws_relay
341                        .note_upstream_relay_send(close.as_bytes().len());
342                    let _ = write.send(TungsteniteMessage::Text(close.into())).await;
343                    let _ = write.close().await;
344                    break;
345                }
346            }
347            message = read.next() => {
348                match message {
349                    Some(Ok(TungsteniteMessage::Text(text))) => {
350                        state
351                            .ws_relay
352                            .note_upstream_relay_receive(text.as_bytes().len());
353                        if matches!(
354                            NostrRelayMessage::from_json(text.as_str()),
355                            Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
356                        ) {
357                            if !relay_complete {
358                                relay_complete = true;
359                                mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
360                            }
361                            continue;
362                        }
363                        forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
364                    }
365                    Some(Ok(TungsteniteMessage::Ping(payload))) => {
366                        let _ = write.send(TungsteniteMessage::Pong(payload)).await;
367                    }
368                    Some(Ok(TungsteniteMessage::Close(_))) | None => {
369                        if !relay_complete {
370                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
371                        }
372                        break;
373                    }
374                    Some(Err(_)) => {
375                        if !relay_complete {
376                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
377                        }
378                        break;
379                    }
380                    _ => {}
381                }
382            }
383        }
384    }
385}
386
387async fn start_upstream_nostr_subscription(
388    state: &AppState,
389    client_id: u64,
390    subscription_id: SubscriptionId,
391    filters: Vec<NostrFilter>,
392) -> usize {
393    let memory_before = process_memory_snapshot();
394    close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
395    if state.nostr_relay_urls.is_empty() || filters.is_empty() {
396        tracing::info!(
397            "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
398            client_id,
399            subscription_id,
400            state.nostr_relay_urls.len(),
401            filters.len(),
402        );
403        return 0;
404    }
405
406    let mut relay_urls = Vec::new();
407    let mut seen = HashSet::new();
408    for relay in &state.nostr_relay_urls {
409        let relay = relay.trim();
410        if relay.is_empty() || !seen.insert(relay.to_string()) {
411            continue;
412        }
413        relay_urls.push(relay.to_string());
414    }
415    if relay_urls.is_empty() {
416        tracing::info!(
417            "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
418            client_id,
419            subscription_id,
420        );
421        return 0;
422    }
423
424    tracing::info!(
425        "upstream nostr relay start: client_id={} subscription_id={} relays={}",
426        client_id,
427        subscription_id,
428        relay_urls.len(),
429    );
430
431    let filter_summary = nostr_filters_summary(&filters);
432    let key = (client_id, subscription_id.to_string());
433    state
434        .ws_relay
435        .upstream_seen_events
436        .lock()
437        .await
438        .insert(key.clone(), HashSet::new());
439    state
440        .ws_relay
441        .upstream_pending_eose
442        .lock()
443        .await
444        .insert(key.clone(), relay_urls.len());
445
446    let (close_tx, close_rx) = watch::channel(false);
447    let mut tasks = Vec::new();
448    for relay_url in &relay_urls {
449        tasks.push(tokio::spawn(run_upstream_nostr_subscription(
450            state.clone(),
451            client_id,
452            relay_url.clone(),
453            subscription_id.clone(),
454            filters.clone(),
455            close_rx.clone(),
456        )));
457    }
458
459    tracing::info!(
460        target: "hashtree_cli::server::ws_relay::upstream",
461        client_id,
462        subscription_id = %subscription_id,
463        relays = relay_urls.len(),
464        tasks = tasks.len(),
465        filters = filters.len(),
466        filter = %filter_summary,
467        memory_before = ?memory_before,
468        memory_after = ?process_memory_snapshot(),
469        "upstream nostr relay tasks spawned",
470    );
471
472    state
473        .ws_relay
474        .upstream_nostr_subscriptions
475        .lock()
476        .await
477        .insert(key, UpstreamNostrSubscription { close_tx, tasks });
478    relay_urls.len()
479}
480
481async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
482    match msg {
483        Message::Text(text) => {
484            if let Some(msg) = parse_ws_text_message(&text) {
485                match msg {
486                    WsTextMessage::Hashtree(msg) => {
487                        set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
488                        match msg {
489                            WsClientMessage::Request { id, hash } => {
490                                handle_request(
491                                    client_id,
492                                    id,
493                                    hash,
494                                    WsProtocol::HashtreeJson,
495                                    state,
496                                )
497                                .await;
498                            }
499                            WsClientMessage::Response { id, hash, found } => {
500                                handle_response(client_id, id, hash, found, state).await;
501                            }
502                        }
503                    }
504                    WsTextMessage::Nostr(msg) => {
505                        if let Some(relay) = &state.nostr_relay {
506                            match msg {
507                                NostrClientMessage::Req {
508                                    subscription_id,
509                                    filters,
510                                } => {
511                                    let local_events = match relay
512                                        .register_subscription_query(
513                                            client_id,
514                                            subscription_id.clone(),
515                                            filters.clone(),
516                                        )
517                                        .await
518                                    {
519                                        Ok(events) => events,
520                                        Err(message) => {
521                                            send_nostr(
522                                                state,
523                                                client_id,
524                                                NostrRelayMessage::closed(subscription_id, message),
525                                            )
526                                            .await;
527                                            return;
528                                        }
529                                    };
530
531                                    let upstream_relays = start_upstream_nostr_subscription(
532                                        state,
533                                        client_id,
534                                        subscription_id.clone(),
535                                        filters,
536                                    )
537                                    .await;
538                                    if upstream_relays > 0 {
539                                        let key = (client_id, subscription_id.to_string());
540                                        let mut seen_events =
541                                            state.ws_relay.upstream_seen_events.lock().await;
542                                        seen_events.entry(key).or_default().extend(
543                                            local_events.iter().map(|event| event.id.to_hex()),
544                                        );
545                                    }
546                                    for event in local_events {
547                                        send_nostr(
548                                            state,
549                                            client_id,
550                                            NostrRelayMessage::event(
551                                                subscription_id.clone(),
552                                                event,
553                                            ),
554                                        )
555                                        .await;
556                                    }
557                                    trim_process_allocations();
558                                    if upstream_relays == 0 {
559                                        send_nostr(
560                                            state,
561                                            client_id,
562                                            NostrRelayMessage::eose(subscription_id),
563                                        )
564                                        .await;
565                                    }
566                                }
567                                NostrClientMessage::Close(subscription_id) => {
568                                    close_upstream_nostr_subscription(
569                                        state,
570                                        client_id,
571                                        &subscription_id,
572                                    )
573                                    .await;
574                                    relay
575                                        .handle_client_message(
576                                            client_id,
577                                            NostrClientMessage::Close(subscription_id.clone()),
578                                        )
579                                        .await;
580                                }
581                                other => {
582                                    relay.handle_client_message(client_id, other).await;
583                                }
584                            }
585                        } else {
586                            handle_nostr_message(client_id, msg, state).await;
587                        }
588                    }
589                }
590            }
591        }
592        Message::Binary(data) => {
593            handle_binary(client_id, data, state).await;
594        }
595        Message::Close(_) => {}
596        _ => {}
597    }
598}
599
600async fn handle_request(
601    client_id: u64,
602    request_id: u32,
603    hash: String,
604    origin_protocol: WsProtocol,
605    state: &AppState,
606) {
607    let hash_hex = hash.to_lowercase();
608    let hash_bytes = match from_hex(&hash_hex) {
609        Ok(bytes) => bytes,
610        Err(_) => {
611            if origin_protocol == WsProtocol::HashtreeJson {
612                send_json(
613                    state,
614                    client_id,
615                    WsResponse {
616                        kind: "res",
617                        id: request_id,
618                        hash,
619                        found: false,
620                    },
621                )
622                .await;
623            }
624            return;
625        }
626    };
627
628    if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
629        match origin_protocol {
630            WsProtocol::HashtreeJson => {
631                send_json(
632                    state,
633                    client_id,
634                    WsResponse {
635                        kind: "res",
636                        id: request_id,
637                        hash: hash.clone(),
638                        found: true,
639                    },
640                )
641                .await;
642                send_binary(state, client_id, request_id, data).await;
643            }
644            WsProtocol::HashtreeMsgpack => {
645                send_msgpack_response(state, client_id, &hash_bytes, &data).await;
646            }
647            WsProtocol::Unknown => {}
648        }
649        return;
650    }
651
652    let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
653        let clients = state.ws_relay.clients.lock().await;
654        let protocols = state.ws_relay.client_protocols.lock().await;
655        clients
656            .iter()
657            .filter(|(id, _)| **id != client_id)
658            .filter_map(|(id, tx)| {
659                let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
660                match protocol {
661                    WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
662                        Some((*id, tx.clone(), protocol))
663                    }
664                    WsProtocol::Unknown => None,
665                }
666            })
667            .collect()
668    };
669
670    if peers.is_empty() {
671        if origin_protocol == WsProtocol::HashtreeJson {
672            send_json(
673                state,
674                client_id,
675                WsResponse {
676                    kind: "res",
677                    id: request_id,
678                    hash,
679                    found: false,
680                },
681            )
682            .await;
683        }
684        return;
685    }
686
687    {
688        let mut pending = state.ws_relay.pending.lock().await;
689        for (peer_id, _, _) in &peers {
690            pending.insert(
691                (*peer_id, request_id),
692                PendingRequest {
693                    origin_id: client_id,
694                    hash: hash.clone(),
695                    found: false,
696                    origin_protocol,
697                },
698            );
699        }
700    }
701
702    let request_text = serde_json::to_string(&WsRequest {
703        kind: "req".to_string(),
704        id: request_id,
705        hash: hash.clone(),
706    })
707    .unwrap_or_else(|_| String::new());
708    for (peer_id, tx, protocol) in peers {
709        match protocol {
710            WsProtocol::HashtreeMsgpack => {
711                let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
712            }
713            WsProtocol::HashtreeJson => {
714                let _ = tx.send(Message::Text(request_text.clone()));
715            }
716            WsProtocol::Unknown => {}
717        }
718    }
719
720    let timeout_state = state.clone();
721    let timeout_hash = hash.clone();
722    tokio::spawn(async move {
723        tokio::time::sleep(Duration::from_millis(1500)).await;
724        let mut pending = timeout_state.ws_relay.pending.lock().await;
725        let still_pending = pending
726            .iter()
727            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
728        let already_found = pending
729            .iter()
730            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
731        if !still_pending || already_found {
732            return;
733        }
734        let origin_protocol = pending
735            .iter()
736            .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
737            .map(|(_, p)| p.origin_protocol)
738            .unwrap_or(WsProtocol::HashtreeJson);
739        pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
740        drop(pending);
741        if origin_protocol == WsProtocol::HashtreeJson {
742            send_json(
743                &timeout_state,
744                client_id,
745                WsResponse {
746                    kind: "res",
747                    id: request_id,
748                    hash: timeout_hash,
749                    found: false,
750                },
751            )
752            .await;
753        }
754    });
755}
756
757async fn handle_response(
758    client_id: u64,
759    request_id: u32,
760    _hash: String,
761    found: bool,
762    state: &AppState,
763) {
764    let pending_entry = {
765        let pending = state.ws_relay.pending.lock().await;
766        pending
767            .get(&(client_id, request_id))
768            .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
769    };
770
771    let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
772        return;
773    };
774
775    if already_found && !found {
776        let mut pending = state.ws_relay.pending.lock().await;
777        pending.remove(&(client_id, request_id));
778        return;
779    }
780
781    if found {
782        let mut pending = state.ws_relay.pending.lock().await;
783        for ((_, id), p) in pending.iter_mut() {
784            if *id == request_id && p.origin_id == origin_id {
785                p.found = true;
786            }
787        }
788        drop(pending);
789        if origin_protocol == WsProtocol::HashtreeJson {
790            send_json(
791                state,
792                origin_id,
793                WsResponse {
794                    kind: "res",
795                    id: request_id,
796                    hash: pending_hash,
797                    found: true,
798                },
799            )
800            .await;
801        }
802        return;
803    }
804
805    let mut pending = state.ws_relay.pending.lock().await;
806    pending.remove(&(client_id, request_id));
807    let has_remaining = pending
808        .iter()
809        .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
810    drop(pending);
811
812    if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
813        send_json(
814            state,
815            origin_id,
816            WsResponse {
817                kind: "res",
818                id: request_id,
819                hash: pending_hash,
820                found: false,
821            },
822        )
823        .await;
824    }
825}
826
827async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
828    if let Some(msg) = parse_msgpack_message(&data) {
829        set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
830        match msg {
831            DataMessage::Request(req) => {
832                let hash_hex = hex_encode(&req.h);
833                let request_id = state.ws_relay.next_request_id();
834                handle_request(
835                    client_id,
836                    request_id,
837                    hash_hex,
838                    WsProtocol::HashtreeMsgpack,
839                    state,
840                )
841                .await;
842            }
843            DataMessage::Response(res) => {
844                handle_msgpack_response(client_id, res, state).await;
845            }
846            DataMessage::QuoteRequest(_)
847            | DataMessage::QuoteResponse(_)
848            | DataMessage::Payment(_)
849            | DataMessage::PaymentAck(_)
850            | DataMessage::Chunk(_)
851            | DataMessage::PeerHints(_)
852            | DataMessage::PubsubInterest(_)
853            | DataMessage::PubsubFrame(_)
854            | DataMessage::PubsubInventory(_)
855            | DataMessage::PubsubWant(_) => {}
856        }
857        return;
858    }
859
860    // Legacy binary: [4-byte LE request_id][data]
861    if data.len() < 4 {
862        return;
863    }
864    let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
865    let pending_entry = {
866        let pending = state.ws_relay.pending.lock().await;
867        pending
868            .get(&(client_id, request_id))
869            .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
870    };
871    let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
872        return;
873    };
874
875    match origin_protocol {
876        WsProtocol::HashtreeJson => {
877            send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
878        }
879        WsProtocol::HashtreeMsgpack => {
880            let Ok(hash_bytes) = from_hex(&hash_hex) else {
881                return;
882            };
883            send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
884        }
885        WsProtocol::Unknown => {}
886    }
887
888    let mut pending = state.ws_relay.pending.lock().await;
889    pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
890}
891
892async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
893    let replies = nostr_responses_for(&msg);
894    for reply in replies {
895        send_nostr(state, client_id, reply).await;
896    }
897}
898
899fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
900    match msg {
901        NostrClientMessage::Event(event) => {
902            let ok = event.verify().is_ok();
903            let message = if ok { "" } else { "invalid: signature" };
904            vec![NostrRelayMessage::ok(event.id, ok, message)]
905        }
906        NostrClientMessage::Req {
907            subscription_id, ..
908        } => {
909            vec![NostrRelayMessage::eose(subscription_id.clone())]
910        }
911        NostrClientMessage::Count {
912            subscription_id, ..
913        } => {
914            vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
915        }
916        NostrClientMessage::Close(_) => Vec::new(),
917        NostrClientMessage::Auth(event) => {
918            let ok = event.verify().is_ok();
919            let message = if ok { "" } else { "invalid auth" };
920            vec![NostrRelayMessage::ok(event.id, ok, message)]
921        }
922        NostrClientMessage::NegOpen { .. }
923        | NostrClientMessage::NegMsg { .. }
924        | NostrClientMessage::NegClose { .. } => {
925            vec![NostrRelayMessage::notice("negentropy not supported")]
926        }
927    }
928}
929
930async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
931    let text = response.as_json();
932    send_to_client(state, client_id, Message::Text(text)).await;
933}
934
935fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
936    let msg = parse_message(data).ok()?;
937    match msg {
938        DataMessage::Request(req) => {
939            if req.h.len() == 32 {
940                Some(DataMessage::Request(req))
941            } else {
942                None
943            }
944        }
945        DataMessage::Response(res) => {
946            if res.h.len() == 32 {
947                Some(DataMessage::Response(res))
948            } else {
949                None
950            }
951        }
952        DataMessage::QuoteRequest(req) => {
953            if req.h.len() == 32 {
954                Some(DataMessage::QuoteRequest(req))
955            } else {
956                None
957            }
958        }
959        DataMessage::QuoteResponse(res) => {
960            if res.h.len() == 32 {
961                Some(DataMessage::QuoteResponse(res))
962            } else {
963                None
964            }
965        }
966        DataMessage::Payment(req) => {
967            if req.h.len() == 32 {
968                Some(DataMessage::Payment(req))
969            } else {
970                None
971            }
972        }
973        DataMessage::PaymentAck(res) => {
974            if res.h.len() == 32 {
975                Some(DataMessage::PaymentAck(res))
976            } else {
977                None
978            }
979        }
980        DataMessage::Chunk(chunk) => {
981            if chunk.h.len() == 32 {
982                Some(DataMessage::Chunk(chunk))
983            } else {
984                None
985            }
986        }
987        DataMessage::PeerHints(_)
988        | DataMessage::PubsubInterest(_)
989        | DataMessage::PubsubFrame(_)
990        | DataMessage::PubsubInventory(_)
991        | DataMessage::PubsubWant(_) => Some(msg),
992    }
993}
994
995async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
996    let hash_hex = hex_encode(&res.h);
997    let data = res.d.clone();
998    let hash_bytes = res.h.clone();
999
1000    let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
1001    let mut seen = HashSet::new();
1002    {
1003        let pending = state.ws_relay.pending.lock().await;
1004        for ((peer_id, request_id), p) in pending.iter() {
1005            if *peer_id != client_id {
1006                continue;
1007            }
1008            if p.hash != hash_hex {
1009                continue;
1010            }
1011            if seen.insert((p.origin_id, *request_id)) {
1012                responses.push((p.origin_id, *request_id, p.origin_protocol));
1013            }
1014        }
1015    }
1016
1017    if responses.is_empty() {
1018        return;
1019    }
1020
1021    for (origin_id, request_id, protocol) in &responses {
1022        match protocol {
1023            WsProtocol::HashtreeJson => {
1024                send_json(
1025                    state,
1026                    *origin_id,
1027                    WsResponse {
1028                        kind: "res",
1029                        id: *request_id,
1030                        hash: hash_hex.clone(),
1031                        found: true,
1032                    },
1033                )
1034                .await;
1035                send_binary(state, *origin_id, *request_id, data.clone()).await;
1036            }
1037            WsProtocol::HashtreeMsgpack => {
1038                send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1039            }
1040            WsProtocol::Unknown => {}
1041        }
1042    }
1043
1044    let completed: HashSet<(u64, u32)> = responses
1045        .into_iter()
1046        .map(|(origin_id, request_id, _)| (origin_id, request_id))
1047        .collect();
1048    let mut pending = state.ws_relay.pending.lock().await;
1049    pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1050}
1051
1052async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1053    if let Ok(text) = serde_json::to_string(&response) {
1054        send_to_client(state, client_id, Message::Text(text)).await;
1055    }
1056}
1057
1058async fn send_msgpack_request(
1059    state: &AppState,
1060    client_id: u64,
1061    hash: &[u8],
1062) -> Result<(), rmp_serde::encode::Error> {
1063    let req = DataRequest {
1064        h: hash.to_vec(),
1065        htl: MAX_HTL,
1066        q: None,
1067    };
1068    let wire = encode_request(&req)?;
1069    send_to_client(state, client_id, Message::Binary(wire)).await;
1070    Ok(())
1071}
1072
1073async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1074    let res = DataResponse {
1075        h: hash.to_vec(),
1076        d: data.to_vec(),
1077        i: None,
1078        n: None,
1079    };
1080    if let Ok(wire) = encode_response(&res) {
1081        send_to_client(state, client_id, Message::Binary(wire)).await;
1082    }
1083}
1084
1085async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1086    let mut packet = Vec::with_capacity(4 + payload.len());
1087    packet.extend_from_slice(&request_id.to_le_bytes());
1088    packet.extend_from_slice(&payload);
1089    send_to_client(state, client_id, Message::Binary(packet)).await;
1090}
1091
1092async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1093    let sender = {
1094        let clients = state.ws_relay.clients.lock().await;
1095        clients.get(&client_id).cloned()
1096    };
1097    if let Some(tx) = sender {
1098        let _ = tx.send(msg);
1099    }
1100}
1101
1102async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1103    let mut protocols = state.ws_relay.client_protocols.lock().await;
1104    protocols.insert(client_id, protocol);
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109    use super::*;
1110    use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1111    use anyhow::Result;
1112    use futures::{SinkExt, StreamExt};
1113    use nostr::secp256k1::schnorr::Signature;
1114    use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1115    use std::collections::HashSet;
1116    use std::sync::Arc;
1117    use tempfile::TempDir;
1118    use tokio::net::TcpListener;
1119    use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1120
1121    #[test]
1122    fn parse_ws_text_message_detects_nostr_req() {
1123        let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1124        match parse_ws_text_message(msg) {
1125            Some(WsTextMessage::Nostr(_)) => {}
1126            other => panic!("expected Nostr message, got {:?}", other),
1127        }
1128    }
1129
1130    #[test]
1131    fn parse_ws_text_message_detects_hashtree_request() {
1132        let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1133        match parse_ws_text_message(msg) {
1134            Some(WsTextMessage::Hashtree(_)) => {}
1135            other => panic!("expected Hashtree message, got {:?}", other),
1136        }
1137    }
1138
1139    #[test]
1140    fn nostr_replies_for_req_is_eose() {
1141        let sub = SubscriptionId::new("sub-1");
1142        let msg = NostrClientMessage::req(sub.clone(), vec![]);
1143        let replies = nostr_responses_for(&msg);
1144        assert_eq!(replies.len(), 1);
1145        match &replies[0] {
1146            NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1147            other => panic!("expected EOSE, got {:?}", other),
1148        }
1149    }
1150
1151    #[test]
1152    fn nostr_replies_for_event_ok() {
1153        let keys = Keys::generate();
1154        let event = EventBuilder::new(Kind::TextNote, "hello", [])
1155            .to_event(&keys)
1156            .unwrap();
1157        let msg = NostrClientMessage::event(event.clone());
1158        let replies = nostr_responses_for(&msg);
1159        assert_eq!(replies.len(), 1);
1160        match &replies[0] {
1161            NostrRelayMessage::Ok {
1162                event_id, status, ..
1163            } => {
1164                assert_eq!(event_id, &event.id);
1165                assert!(*status);
1166            }
1167            other => panic!("expected OK, got {:?}", other),
1168        }
1169    }
1170
1171    #[test]
1172    fn nostr_replies_for_invalid_event_is_not_ok() {
1173        let keys = Keys::generate();
1174        let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1175            .to_event(&keys)
1176            .unwrap();
1177        event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1178        let msg = NostrClientMessage::event(event);
1179        let replies = nostr_responses_for(&msg);
1180        assert_eq!(replies.len(), 1);
1181        match &replies[0] {
1182            NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1183            other => panic!("expected OK=false, got {:?}", other),
1184        }
1185    }
1186
1187    async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1188        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1189        let addr = listener.local_addr().expect("relay addr");
1190        tokio::spawn(async move {
1191            let (stream, _) = listener.accept().await.expect("accept relay");
1192            let ws = accept_async(stream).await.expect("accept websocket");
1193            let (mut write, mut read) = ws.split();
1194
1195            while let Some(Ok(message)) = read.next().await {
1196                let TungsteniteMessage::Text(text) = message else {
1197                    continue;
1198                };
1199                let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1200                    continue;
1201                };
1202                if let NostrClientMessage::Req {
1203                    subscription_id,
1204                    filters,
1205                } = parsed
1206                {
1207                    for event in events
1208                        .iter()
1209                        .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1210                    {
1211                        let _ = write
1212                            .send(TungsteniteMessage::Text(
1213                                NostrRelayMessage::event(subscription_id.clone(), event.clone())
1214                                    .as_json()
1215                                    .into(),
1216                            ))
1217                            .await;
1218                    }
1219                    let _ = write
1220                        .send(TungsteniteMessage::Text(
1221                            NostrRelayMessage::eose(subscription_id).as_json().into(),
1222                        ))
1223                        .await;
1224                }
1225            }
1226        });
1227        format!("ws://{}", addr)
1228    }
1229
1230    fn test_app_state(
1231        tmp: &TempDir,
1232        relay: Arc<NostrRelay>,
1233        relay_url: String,
1234    ) -> Result<AppState> {
1235        let store = Arc::new(crate::storage::HashtreeStore::with_options(
1236            tmp.path(),
1237            None,
1238            128 * 1024 * 1024,
1239        )?);
1240        Ok(AppState {
1241            store,
1242            auth: None,
1243            daemon_started_at: 1_700_000_000,
1244            peer_mode: crate::config::ServerMode::Normal,
1245            hash_get_enabled: true,
1246            http_webrtc_fetch: true,
1247            webrtc_peers: None,
1248            fips_transport: None,
1249            fetch_from_fips_peers: true,
1250            ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1251            max_upload_bytes: 5 * 1024 * 1024,
1252            public_writes: true,
1253            require_random_untrusted_ingest: false,
1254            optimistic_blossom_uploads: false,
1255            optimistic_upload_queue_bytes: 512 * 1024 * 1024,
1256            optimistic_upload_queue: Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)),
1257            allowed_pubkeys: HashSet::new(),
1258            upstream_blossom: Vec::new(),
1259            social_graph: None,
1260            social_graph_store: None,
1261            social_graph_root: None,
1262            socialgraph_snapshot_public: false,
1263            nostr_relay: Some(relay),
1264            nostr_relay_urls: vec![relay_url],
1265            tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1266            inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1267                std::collections::HashMap::new(),
1268            )),
1269            inflight_blob_reads: Arc::new(
1270                tokio::sync::Mutex::new(std::collections::HashMap::new()),
1271            ),
1272            blob_cache: Arc::new(crate::blob_cache::BlobCache::for_tests()),
1273            directory_listing_cache: Arc::new(std::sync::Mutex::new(
1274                super::super::auth::new_lookup_cache(),
1275            )),
1276            resolved_path_cache: Arc::new(std::sync::Mutex::new(
1277                super::super::auth::new_lookup_cache(),
1278            )),
1279            thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1280                super::super::auth::new_lookup_cache(),
1281            )),
1282            cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1283        })
1284    }
1285
1286    #[tokio::test]
1287    async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1288        let tmp = TempDir::new()?;
1289        let graph_store = {
1290            let _guard = crate::socialgraph::test_lock();
1291            crate::socialgraph::open_social_graph_store_with_mapsize(
1292                tmp.path(),
1293                Some(128 * 1024 * 1024),
1294            )?
1295        };
1296        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1297        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1298            Arc::clone(&backend),
1299            0,
1300            HashSet::new(),
1301        ));
1302
1303        let keys = Keys::generate();
1304        let relay = Arc::new(NostrRelay::new(
1305            Arc::clone(&backend),
1306            tmp.path().to_path_buf(),
1307            HashSet::from([keys.public_key().to_hex()]),
1308            Some(access),
1309            NostrRelayConfig {
1310                spambox_db_max_bytes: 0,
1311                ..Default::default()
1312            },
1313        )?);
1314
1315        let event = EventBuilder::new(
1316            Kind::from(30078_u16),
1317            "",
1318            [
1319                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1320                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1321            ],
1322        )
1323        .to_event(&keys)?;
1324
1325        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1326        let filter = Filter::new()
1327            .authors(vec![event.pubkey])
1328            .kinds(vec![event.kind]);
1329        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1330        let client_id = 7_u64;
1331        let (tx, mut rx) = mpsc::unbounded_channel();
1332        state.ws_relay.clients.lock().await.insert(client_id, tx);
1333        let subscription_id = SubscriptionId::new("sub-1");
1334
1335        start_upstream_nostr_subscription(
1336            &state,
1337            client_id,
1338            subscription_id.clone(),
1339            vec![filter.clone()],
1340        )
1341        .await;
1342
1343        let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1344            .await?
1345            .expect("forwarded upstream event");
1346        let Message::Text(text) = forwarded else {
1347            panic!("expected text event");
1348        };
1349        match NostrRelayMessage::from_json(text.as_str())? {
1350            NostrRelayMessage::Event {
1351                subscription_id: sid,
1352                event: forwarded_event,
1353            } => {
1354                assert_eq!(sid, subscription_id);
1355                assert_eq!(forwarded_event.id, event.id);
1356            }
1357            other => panic!("expected forwarded EVENT, got {:?}", other),
1358        }
1359
1360        let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1361            .await?
1362            .expect("forwarded upstream eose");
1363        let Message::Text(eose_text) = eose else {
1364            panic!("expected text eose");
1365        };
1366        match NostrRelayMessage::from_json(eose_text.as_str())? {
1367            NostrRelayMessage::EndOfStoredEvents(sid) => {
1368                assert_eq!(sid, subscription_id);
1369            }
1370            other => panic!("expected forwarded EOSE, got {:?}", other),
1371        }
1372
1373        let events = relay.query_events(&filter, 10).await;
1374        assert_eq!(events.len(), 1);
1375        assert_eq!(events[0].id, event.id);
1376
1377        close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1378        assert!(state
1379            .ws_relay
1380            .upstream_nostr_subscriptions
1381            .lock()
1382            .await
1383            .is_empty());
1384        Ok(())
1385    }
1386
1387    #[tokio::test]
1388    async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1389        let tmp = TempDir::new()?;
1390        let graph_store = {
1391            let _guard = crate::socialgraph::test_lock();
1392            crate::socialgraph::open_social_graph_store_with_mapsize(
1393                tmp.path(),
1394                Some(128 * 1024 * 1024),
1395            )?
1396        };
1397        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1398        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1399            Arc::clone(&backend),
1400            0,
1401            HashSet::new(),
1402        ));
1403
1404        let keys = Keys::generate();
1405        let relay = Arc::new(NostrRelay::new(
1406            Arc::clone(&backend),
1407            tmp.path().to_path_buf(),
1408            HashSet::from([keys.public_key().to_hex()]),
1409            Some(access),
1410            NostrRelayConfig {
1411                spambox_db_max_bytes: 0,
1412                ..Default::default()
1413            },
1414        )?);
1415
1416        let event = EventBuilder::new(
1417            Kind::from(30078_u16),
1418            "",
1419            [
1420                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1421                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1422            ],
1423        )
1424        .to_event(&keys)?;
1425
1426        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1427        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1428        let client_id = 11_u64;
1429        let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1430        let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1431        state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1432        relay.register_client(client_id, relay_tx, None).await;
1433
1434        let request = NostrClientMessage::req(
1435            SubscriptionId::new("feed"),
1436            vec![Filter::new()
1437                .authors(vec![event.pubkey])
1438                .kinds(vec![event.kind])],
1439        )
1440        .as_json();
1441
1442        handle_message(client_id, Message::Text(request.into()), &state).await;
1443
1444        let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1445            .await?
1446            .expect("first forwarded message");
1447        let Message::Text(first_text) = first else {
1448            panic!("expected text event");
1449        };
1450        match NostrRelayMessage::from_json(first_text.as_str())? {
1451            NostrRelayMessage::Event {
1452                event: forwarded_event,
1453                ..
1454            } => {
1455                assert_eq!(forwarded_event.id, event.id);
1456            }
1457            other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1458        }
1459
1460        let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1461            .await?
1462            .expect("second forwarded message");
1463        let Message::Text(second_text) = second else {
1464            panic!("expected text eose");
1465        };
1466        match NostrRelayMessage::from_json(second_text.as_str())? {
1467            NostrRelayMessage::EndOfStoredEvents(sid) => {
1468                assert_eq!(sid, SubscriptionId::new("feed"));
1469            }
1470            other => panic!("expected aggregated EOSE, got {:?}", other),
1471        }
1472
1473        Ok(())
1474    }
1475
1476    #[tokio::test]
1477    async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1478        let tmp = TempDir::new()?;
1479        let graph_store = {
1480            let _guard = crate::socialgraph::test_lock();
1481            crate::socialgraph::open_social_graph_store_with_mapsize(
1482                tmp.path(),
1483                Some(128 * 1024 * 1024),
1484            )?
1485        };
1486        let author_keys = Keys::generate();
1487        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1488        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1489            Arc::clone(&backend),
1490            0,
1491            HashSet::from([author_keys.public_key().to_hex()]),
1492        ));
1493        let relay = Arc::new(NostrRelay::new(
1494            Arc::clone(&backend),
1495            tmp.path().to_path_buf(),
1496            HashSet::from([author_keys.public_key().to_hex()]),
1497            Some(access),
1498            NostrRelayConfig {
1499                spambox_db_max_bytes: 0,
1500                ..Default::default()
1501            },
1502        )?);
1503
1504        let state = test_app_state(&tmp, relay.clone(), String::new())?;
1505        let listener = TcpListener::bind("127.0.0.1:0").await?;
1506        let addr = listener.local_addr()?;
1507        let client_pubkey = author_keys.public_key().to_hex();
1508        let app = axum::Router::new().route(
1509            "/ws",
1510            axum::routing::get({
1511                let state = state.clone();
1512                let client_pubkey = client_pubkey.clone();
1513                move |ws: WebSocketUpgrade| {
1514                    let state = state.clone();
1515                    let client_pubkey = client_pubkey.clone();
1516                    async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1517                }
1518            }),
1519        );
1520        tokio::spawn(async move {
1521            let _ = axum::serve(listener, app).await;
1522        });
1523
1524        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1525        let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1526            .to_event(&author_keys)?;
1527        socket
1528            .send(TungsteniteMessage::Text(
1529                NostrClientMessage::event(event.clone()).as_json().into(),
1530            ))
1531            .await?;
1532
1533        let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1534            .await?
1535            .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1536        let TungsteniteMessage::Text(text) = reply else {
1537            anyhow::bail!("expected text publish ack");
1538        };
1539
1540        match NostrRelayMessage::from_json(text.as_str())? {
1541            NostrRelayMessage::Ok {
1542                event_id, status, ..
1543            } => {
1544                assert_eq!(event_id, event.id);
1545                assert!(status);
1546            }
1547            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1548        }
1549
1550        let stored = relay
1551            .query_events(
1552                &Filter::new()
1553                    .authors(vec![event.pubkey])
1554                    .kinds(vec![event.kind]),
1555                10,
1556            )
1557            .await;
1558        assert!(stored.iter().any(|candidate| candidate.id == event.id));
1559        Ok(())
1560    }
1561
1562    #[tokio::test]
1563    async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1564        let tmp = TempDir::new()?;
1565        let graph_store = {
1566            let _guard = crate::socialgraph::test_lock();
1567            crate::socialgraph::open_social_graph_store_with_mapsize(
1568                tmp.path(),
1569                Some(128 * 1024 * 1024),
1570            )?
1571        };
1572        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1573        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1574            Arc::clone(&backend),
1575            0,
1576            HashSet::new(),
1577        ));
1578        let relay = Arc::new(NostrRelay::new(
1579            Arc::clone(&backend),
1580            tmp.path().to_path_buf(),
1581            HashSet::new(),
1582            Some(access),
1583            NostrRelayConfig {
1584                spambox_db_max_bytes: 0,
1585                spambox_max_reqs_per_min: 1,
1586                ..Default::default()
1587            },
1588        )?);
1589
1590        let state = test_app_state(&tmp, relay, String::new())?;
1591        let listener = TcpListener::bind("127.0.0.1:0").await?;
1592        let addr = listener.local_addr()?;
1593        let app = axum::Router::new().route(
1594            "/ws",
1595            axum::routing::get({
1596                let state = state.clone();
1597                move |ws: WebSocketUpgrade| {
1598                    let state = state.clone();
1599                    async move { ws_data_with_client_pubkey(state, ws, None) }
1600                }
1601            }),
1602        );
1603        tokio::spawn(async move {
1604            let _ = axum::serve(listener, app).await;
1605        });
1606
1607        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1608        socket
1609            .send(TungsteniteMessage::Text(
1610                NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1611                    .as_json()
1612                    .into(),
1613            ))
1614            .await?;
1615
1616        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1617            .await?
1618            .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1619        let TungsteniteMessage::Text(first_text) = first else {
1620            anyhow::bail!("expected text EOSE reply");
1621        };
1622        match NostrRelayMessage::from_json(first_text.as_str())? {
1623            NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1624                assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1625            }
1626            other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1627        }
1628
1629        socket
1630            .send(TungsteniteMessage::Text(
1631                NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1632                    .as_json()
1633                    .into(),
1634            ))
1635            .await?;
1636
1637        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1638            .await?
1639            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1640        let TungsteniteMessage::Text(second_text) = second else {
1641            anyhow::bail!("expected text CLOSED reply");
1642        };
1643        let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1644        assert_eq!(
1645            second_value,
1646            serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1647        );
1648
1649        Ok(())
1650    }
1651
1652    #[tokio::test]
1653    async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1654        let tmp = TempDir::new()?;
1655        let graph_store = {
1656            let _guard = crate::socialgraph::test_lock();
1657            crate::socialgraph::open_social_graph_store_with_mapsize(
1658                tmp.path(),
1659                Some(128 * 1024 * 1024),
1660            )?
1661        };
1662        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1663        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1664            Arc::clone(&backend),
1665            0,
1666            HashSet::new(),
1667        ));
1668        let relay = Arc::new(NostrRelay::new(
1669            Arc::clone(&backend),
1670            tmp.path().to_path_buf(),
1671            HashSet::new(),
1672            Some(access),
1673            NostrRelayConfig {
1674                spambox_db_max_bytes: 0,
1675                spambox_max_events_per_min: 1,
1676                ..Default::default()
1677            },
1678        )?);
1679
1680        let state = test_app_state(&tmp, relay, String::new())?;
1681        let listener = TcpListener::bind("127.0.0.1:0").await?;
1682        let addr = listener.local_addr()?;
1683        let app = axum::Router::new().route(
1684            "/ws",
1685            axum::routing::get({
1686                let state = state.clone();
1687                move |ws: WebSocketUpgrade| {
1688                    let state = state.clone();
1689                    async move { ws_data_with_client_pubkey(state, ws, None) }
1690                }
1691            }),
1692        );
1693        tokio::spawn(async move {
1694            let _ = axum::serve(listener, app).await;
1695        });
1696
1697        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1698        let author_keys = Keys::generate();
1699        let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1700        let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1701
1702        socket
1703            .send(TungsteniteMessage::Text(
1704                NostrClientMessage::event(event_a.clone()).as_json().into(),
1705            ))
1706            .await?;
1707
1708        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1709            .await?
1710            .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1711        let TungsteniteMessage::Text(first_text) = first else {
1712            anyhow::bail!("expected text publish ack");
1713        };
1714        match NostrRelayMessage::from_json(first_text.as_str())? {
1715            NostrRelayMessage::Ok {
1716                event_id,
1717                status,
1718                message,
1719            } => {
1720                assert_eq!(event_id, event_a.id);
1721                assert!(status);
1722                assert_eq!(message, "spambox");
1723            }
1724            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1725        }
1726
1727        socket
1728            .send(TungsteniteMessage::Text(
1729                NostrClientMessage::event(event_b.clone()).as_json().into(),
1730            ))
1731            .await?;
1732
1733        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1734            .await?
1735            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1736        let TungsteniteMessage::Text(second_text) = second else {
1737            anyhow::bail!("expected text publish ack");
1738        };
1739        match NostrRelayMessage::from_json(second_text.as_str())? {
1740            NostrRelayMessage::Ok {
1741                event_id,
1742                status,
1743                message,
1744            } => {
1745                assert_eq!(event_id, event_b.id);
1746                assert!(!status);
1747                assert_eq!(message, "rate limited");
1748            }
1749            other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1750        }
1751
1752        Ok(())
1753    }
1754}