Skip to main content

hashtree_cli/server/
ws_relay.rs

1use axum::{
2    extract::{
3        ws::{Message, WebSocket, WebSocketUpgrade},
4        State,
5    },
6    response::IntoResponse,
7};
8use futures::{SinkExt, StreamExt};
9use hashtree_core::from_hex;
10use nostr::{
11    ClientMessage as NostrClientMessage, Filter as NostrFilter, JsonUtil as NostrJsonUtil,
12    RelayMessage as NostrRelayMessage, SubscriptionId,
13};
14use serde::{Deserialize, Serialize};
15use std::{collections::HashSet, time::Duration};
16use tokio::sync::{mpsc, watch};
17use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
18
19use super::auth::{AppState, PendingRequest, UpstreamNostrSubscription, WsProtocol};
20use crate::webrtc::types::{
21    encode_request, encode_response, parse_message, DataMessage, DataRequest, DataResponse, MAX_HTL,
22};
23use hex::encode as hex_encode;
24
25#[derive(Debug, Deserialize)]
26#[serde(tag = "type")]
27enum WsClientMessage {
28    #[serde(rename = "req")]
29    Request { id: u32, hash: String },
30    #[serde(rename = "res")]
31    Response { id: u32, hash: String, found: bool },
32}
33
34#[derive(Debug)]
35enum WsTextMessage {
36    Hashtree(WsClientMessage),
37    Nostr(NostrClientMessage),
38}
39
40#[derive(Debug, Deserialize, Serialize)]
41struct WsRequest {
42    #[serde(rename = "type")]
43    kind: String,
44    id: u32,
45    hash: String,
46}
47
48#[derive(Debug, Serialize)]
49struct WsResponse {
50    #[serde(rename = "type")]
51    kind: &'static str,
52    id: u32,
53    hash: String,
54    found: bool,
55}
56
57pub async fn ws_data(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
58    ws_data_with_client_pubkey(state, ws, None)
59}
60
61pub fn ws_data_with_client_pubkey(
62    state: AppState,
63    ws: WebSocketUpgrade,
64    client_pubkey: Option<String>,
65) -> impl IntoResponse {
66    ws.on_upgrade(move |socket| handle_socket(socket, state, client_pubkey))
67}
68
69async fn handle_socket(socket: WebSocket, state: AppState, client_pubkey: Option<String>) {
70    // Use the Nostr relay's client-id generator when available so `/ws` IDs
71    // can't collide with WebRTC relay clients.
72    let client_id = state
73        .nostr_relay
74        .as_ref()
75        .map(|relay| relay.next_client_id())
76        .unwrap_or_else(|| state.ws_relay.next_id());
77    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
78
79    {
80        let mut clients = state.ws_relay.clients.lock().await;
81        clients.insert(client_id, tx);
82    }
83    {
84        let mut protocols = state.ws_relay.client_protocols.lock().await;
85        protocols.insert(client_id, WsProtocol::HashtreeJson);
86    }
87
88    let mut nostr_rx = if let Some(relay) = state.nostr_relay.clone() {
89        let (nostr_tx, nostr_rx) = mpsc::unbounded_channel::<String>();
90        relay
91            .register_client(client_id, nostr_tx, client_pubkey.clone())
92            .await;
93        Some(nostr_rx)
94    } else {
95        None
96    };
97
98    let (mut sender, mut receiver) = socket.split();
99    loop {
100        tokio::select! {
101            maybe_msg = rx.recv() => {
102                let Some(msg) = maybe_msg else {
103                    break;
104                };
105                if sender.send(msg).await.is_err() {
106                    break;
107                }
108            }
109            maybe_text = async {
110                match &mut nostr_rx {
111                    Some(rx) => rx.recv().await,
112                    None => std::future::pending().await,
113                }
114            } => {
115                let Some(text) = maybe_text else {
116                    nostr_rx = None;
117                    continue;
118                };
119                if sender.send(Message::Text(text)).await.is_err() {
120                    break;
121                }
122            }
123            maybe_incoming = receiver.next() => {
124                match maybe_incoming {
125                    Some(Ok(msg)) => handle_message(client_id, msg, &state).await,
126                    Some(Err(_)) | None => break,
127                }
128            }
129        }
130    }
131
132    close_all_upstream_nostr_subscriptions(&state, client_id).await;
133
134    {
135        let mut clients = state.ws_relay.clients.lock().await;
136        clients.remove(&client_id);
137    }
138    {
139        let mut protocols = state.ws_relay.client_protocols.lock().await;
140        protocols.remove(&client_id);
141    }
142    {
143        let mut pending = state.ws_relay.pending.lock().await;
144        pending.retain(|(peer_id, _), _| *peer_id != client_id);
145    }
146
147    if let Some(relay) = &state.nostr_relay {
148        relay.unregister_client(client_id).await;
149    }
150}
151
152fn parse_ws_text_message(text: &str) -> Option<WsTextMessage> {
153    let trimmed = text.trim_start();
154    if trimmed.starts_with('[') {
155        if let Ok(msg) = NostrClientMessage::from_json(trimmed) {
156            return Some(WsTextMessage::Nostr(msg));
157        }
158    }
159
160    if let Ok(msg) = serde_json::from_str::<WsClientMessage>(text) {
161        return Some(WsTextMessage::Hashtree(msg));
162    }
163
164    None
165}
166async fn close_upstream_nostr_subscription(
167    state: &AppState,
168    client_id: u64,
169    subscription_id: &SubscriptionId,
170) {
171    let key = (client_id, subscription_id.to_string());
172    let subscription = {
173        let mut subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
174        subscriptions.remove(&key)
175    };
176    if let Some(subscription) = subscription {
177        let _ = subscription.close_tx.send(true);
178        for task in subscription.tasks {
179            task.abort();
180        }
181    }
182    state
183        .ws_relay
184        .upstream_pending_eose
185        .lock()
186        .await
187        .remove(&key);
188    state
189        .ws_relay
190        .upstream_seen_events
191        .lock()
192        .await
193        .remove(&key);
194}
195
196async fn close_all_upstream_nostr_subscriptions(state: &AppState, client_id: u64) {
197    let keys = {
198        let subscriptions = state.ws_relay.upstream_nostr_subscriptions.lock().await;
199        subscriptions
200            .keys()
201            .filter(|(id, _)| *id == client_id)
202            .cloned()
203            .collect::<Vec<_>>()
204    };
205    for (_, sub_id) in keys {
206        close_upstream_nostr_subscription(state, client_id, &SubscriptionId::new(sub_id)).await;
207    }
208}
209
210async fn forward_upstream_nostr_message(
211    state: &AppState,
212    client_id: u64,
213    subscription_id: &SubscriptionId,
214    text: &str,
215) {
216    let Ok(message) = NostrRelayMessage::from_json(text) else {
217        return;
218    };
219
220    match message {
221        NostrRelayMessage::Event {
222            subscription_id: sid,
223            event,
224        } if sid == *subscription_id => {
225            let event = *event;
226            let key = (client_id, subscription_id.to_string());
227            let event_id = event.id.to_hex();
228            let inserted = {
229                let mut seen_events = state.ws_relay.upstream_seen_events.lock().await;
230                seen_events.entry(key).or_default().insert(event_id)
231            };
232            if !inserted {
233                return;
234            }
235            if let Some(relay) = &state.nostr_relay {
236                let _ = relay.ingest_trusted_event_silent(event.clone()).await;
237            }
238            send_nostr(
239                state,
240                client_id,
241                NostrRelayMessage::event(subscription_id.clone(), event),
242            )
243            .await;
244        }
245        NostrRelayMessage::Closed {
246            subscription_id: sid,
247            message,
248        } if sid == *subscription_id => {
249            send_nostr(
250                state,
251                client_id,
252                NostrRelayMessage::closed(subscription_id.clone(), message),
253            )
254            .await;
255        }
256        _ => {}
257    }
258}
259
260async fn mark_upstream_nostr_relay_complete(
261    state: &AppState,
262    client_id: u64,
263    subscription_id: &SubscriptionId,
264) {
265    let key = (client_id, subscription_id.to_string());
266    let should_send_eose = {
267        let mut pending = state.ws_relay.upstream_pending_eose.lock().await;
268        let Some(remaining) = pending.get_mut(&key) else {
269            return;
270        };
271        if *remaining > 0 {
272            *remaining -= 1;
273        }
274        if *remaining == 0 {
275            pending.remove(&key);
276            true
277        } else {
278            false
279        }
280    };
281
282    if should_send_eose {
283        send_nostr(
284            state,
285            client_id,
286            NostrRelayMessage::eose(subscription_id.clone()),
287        )
288        .await;
289    }
290}
291
292async fn run_upstream_nostr_subscription(
293    state: AppState,
294    client_id: u64,
295    relay_url: String,
296    subscription_id: SubscriptionId,
297    filters: Vec<NostrFilter>,
298    mut close_rx: watch::Receiver<bool>,
299) {
300    let mut relay_complete = false;
301    let Ok((socket, _)) = connect_async(relay_url.as_str()).await else {
302        tracing::warn!(
303            "upstream nostr relay connect failed: client_id={} subscription_id={} relay={}",
304            client_id,
305            subscription_id,
306            relay_url,
307        );
308        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
309        return;
310    };
311    let (mut write, mut read) = socket.split();
312    let request = NostrClientMessage::req(subscription_id.clone(), filters).as_json();
313    state
314        .ws_relay
315        .note_upstream_relay_send(request.as_bytes().len());
316    if write
317        .send(TungsteniteMessage::Text(request.into()))
318        .await
319        .is_err()
320    {
321        tracing::warn!(
322            "upstream nostr relay request send failed: client_id={} subscription_id={} relay={}",
323            client_id,
324            subscription_id,
325            relay_url,
326        );
327        mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
328        return;
329    }
330
331    loop {
332        tokio::select! {
333            _ = close_rx.changed() => {
334                if *close_rx.borrow() {
335                    let close = NostrClientMessage::close(subscription_id.clone()).as_json();
336                    state
337                        .ws_relay
338                        .note_upstream_relay_send(close.as_bytes().len());
339                    let _ = write.send(TungsteniteMessage::Text(close.into())).await;
340                    let _ = write.close().await;
341                    break;
342                }
343            }
344            message = read.next() => {
345                match message {
346                    Some(Ok(TungsteniteMessage::Text(text))) => {
347                        state
348                            .ws_relay
349                            .note_upstream_relay_receive(text.as_bytes().len());
350                        if matches!(
351                            NostrRelayMessage::from_json(text.as_str()),
352                            Ok(NostrRelayMessage::EndOfStoredEvents(sid)) if sid == subscription_id
353                        ) {
354                            if !relay_complete {
355                                relay_complete = true;
356                                mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
357                            }
358                            continue;
359                        }
360                        forward_upstream_nostr_message(&state, client_id, &subscription_id, &text).await;
361                    }
362                    Some(Ok(TungsteniteMessage::Ping(payload))) => {
363                        let _ = write.send(TungsteniteMessage::Pong(payload)).await;
364                    }
365                    Some(Ok(TungsteniteMessage::Close(_))) | None => {
366                        if !relay_complete {
367                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
368                        }
369                        break;
370                    }
371                    Some(Err(_)) => {
372                        if !relay_complete {
373                            mark_upstream_nostr_relay_complete(&state, client_id, &subscription_id).await;
374                        }
375                        break;
376                    }
377                    _ => {}
378                }
379            }
380        }
381    }
382}
383
384async fn start_upstream_nostr_subscription(
385    state: &AppState,
386    client_id: u64,
387    subscription_id: SubscriptionId,
388    filters: Vec<NostrFilter>,
389) -> usize {
390    close_upstream_nostr_subscription(state, client_id, &subscription_id).await;
391    if state.nostr_relay_urls.is_empty() || filters.is_empty() {
392        tracing::info!(
393            "upstream nostr relay skipped: client_id={} subscription_id={} relays={} filters={}",
394            client_id,
395            subscription_id,
396            state.nostr_relay_urls.len(),
397            filters.len(),
398        );
399        return 0;
400    }
401
402    let mut relay_urls = Vec::new();
403    let mut seen = HashSet::new();
404    for relay in &state.nostr_relay_urls {
405        let relay = relay.trim();
406        if relay.is_empty() || !seen.insert(relay.to_string()) {
407            continue;
408        }
409        relay_urls.push(relay.to_string());
410    }
411    if relay_urls.is_empty() {
412        tracing::info!(
413            "upstream nostr relay skipped after normalization: client_id={} subscription_id={}",
414            client_id,
415            subscription_id,
416        );
417        return 0;
418    }
419
420    tracing::info!(
421        "upstream nostr relay start: client_id={} subscription_id={} relays={}",
422        client_id,
423        subscription_id,
424        relay_urls.len(),
425    );
426
427    let key = (client_id, subscription_id.to_string());
428    state
429        .ws_relay
430        .upstream_seen_events
431        .lock()
432        .await
433        .insert(key.clone(), HashSet::new());
434    state
435        .ws_relay
436        .upstream_pending_eose
437        .lock()
438        .await
439        .insert(key.clone(), relay_urls.len());
440
441    let (close_tx, close_rx) = watch::channel(false);
442    let mut tasks = Vec::new();
443    for relay_url in &relay_urls {
444        tasks.push(tokio::spawn(run_upstream_nostr_subscription(
445            state.clone(),
446            client_id,
447            relay_url.clone(),
448            subscription_id.clone(),
449            filters.clone(),
450            close_rx.clone(),
451        )));
452    }
453
454    state
455        .ws_relay
456        .upstream_nostr_subscriptions
457        .lock()
458        .await
459        .insert(key, UpstreamNostrSubscription { close_tx, tasks });
460    relay_urls.len()
461}
462
463async fn handle_message(client_id: u64, msg: Message, state: &AppState) {
464    match msg {
465        Message::Text(text) => {
466            if let Some(msg) = parse_ws_text_message(&text) {
467                match msg {
468                    WsTextMessage::Hashtree(msg) => {
469                        set_client_protocol(state, client_id, WsProtocol::HashtreeJson).await;
470                        match msg {
471                            WsClientMessage::Request { id, hash } => {
472                                handle_request(
473                                    client_id,
474                                    id,
475                                    hash,
476                                    WsProtocol::HashtreeJson,
477                                    state,
478                                )
479                                .await;
480                            }
481                            WsClientMessage::Response { id, hash, found } => {
482                                handle_response(client_id, id, hash, found, state).await;
483                            }
484                        }
485                    }
486                    WsTextMessage::Nostr(msg) => {
487                        if let Some(relay) = &state.nostr_relay {
488                            match msg {
489                                NostrClientMessage::Req {
490                                    subscription_id,
491                                    filters,
492                                } => {
493                                    let local_events = match relay
494                                        .register_subscription_query(
495                                            client_id,
496                                            subscription_id.clone(),
497                                            filters.clone(),
498                                        )
499                                        .await
500                                    {
501                                        Ok(events) => events,
502                                        Err(message) => {
503                                            send_nostr(
504                                                state,
505                                                client_id,
506                                                NostrRelayMessage::closed(subscription_id, message),
507                                            )
508                                            .await;
509                                            return;
510                                        }
511                                    };
512
513                                    let upstream_relays = start_upstream_nostr_subscription(
514                                        state,
515                                        client_id,
516                                        subscription_id.clone(),
517                                        filters,
518                                    )
519                                    .await;
520                                    if upstream_relays > 0 {
521                                        let key = (client_id, subscription_id.to_string());
522                                        let mut seen_events =
523                                            state.ws_relay.upstream_seen_events.lock().await;
524                                        seen_events.entry(key).or_default().extend(
525                                            local_events.iter().map(|event| event.id.to_hex()),
526                                        );
527                                    }
528                                    for event in local_events {
529                                        send_nostr(
530                                            state,
531                                            client_id,
532                                            NostrRelayMessage::event(
533                                                subscription_id.clone(),
534                                                event,
535                                            ),
536                                        )
537                                        .await;
538                                    }
539                                    if upstream_relays == 0 {
540                                        send_nostr(
541                                            state,
542                                            client_id,
543                                            NostrRelayMessage::eose(subscription_id),
544                                        )
545                                        .await;
546                                    }
547                                }
548                                NostrClientMessage::Close(subscription_id) => {
549                                    close_upstream_nostr_subscription(
550                                        state,
551                                        client_id,
552                                        &subscription_id,
553                                    )
554                                    .await;
555                                    relay
556                                        .handle_client_message(
557                                            client_id,
558                                            NostrClientMessage::Close(subscription_id.clone()),
559                                        )
560                                        .await;
561                                }
562                                other => {
563                                    relay.handle_client_message(client_id, other).await;
564                                }
565                            }
566                        } else {
567                            handle_nostr_message(client_id, msg, state).await;
568                        }
569                    }
570                }
571            }
572        }
573        Message::Binary(data) => {
574            handle_binary(client_id, data, state).await;
575        }
576        Message::Close(_) => {}
577        _ => {}
578    }
579}
580
581async fn handle_request(
582    client_id: u64,
583    request_id: u32,
584    hash: String,
585    origin_protocol: WsProtocol,
586    state: &AppState,
587) {
588    let hash_hex = hash.to_lowercase();
589    let hash_bytes = match from_hex(&hash_hex) {
590        Ok(bytes) => bytes,
591        Err(_) => {
592            if origin_protocol == WsProtocol::HashtreeJson {
593                send_json(
594                    state,
595                    client_id,
596                    WsResponse {
597                        kind: "res",
598                        id: request_id,
599                        hash,
600                        found: false,
601                    },
602                )
603                .await;
604            }
605            return;
606        }
607    };
608
609    if let Ok(Some(data)) = state.store.get_blob(&hash_bytes) {
610        match origin_protocol {
611            WsProtocol::HashtreeJson => {
612                send_json(
613                    state,
614                    client_id,
615                    WsResponse {
616                        kind: "res",
617                        id: request_id,
618                        hash: hash.clone(),
619                        found: true,
620                    },
621                )
622                .await;
623                send_binary(state, client_id, request_id, data).await;
624            }
625            WsProtocol::HashtreeMsgpack => {
626                send_msgpack_response(state, client_id, &hash_bytes, &data).await;
627            }
628            WsProtocol::Unknown => {}
629        }
630        return;
631    }
632
633    let peers: Vec<(u64, mpsc::UnboundedSender<Message>, WsProtocol)> = {
634        let clients = state.ws_relay.clients.lock().await;
635        let protocols = state.ws_relay.client_protocols.lock().await;
636        clients
637            .iter()
638            .filter(|(id, _)| **id != client_id)
639            .filter_map(|(id, tx)| {
640                let protocol = protocols.get(id).copied().unwrap_or(WsProtocol::Unknown);
641                match protocol {
642                    WsProtocol::HashtreeJson | WsProtocol::HashtreeMsgpack => {
643                        Some((*id, tx.clone(), protocol))
644                    }
645                    WsProtocol::Unknown => None,
646                }
647            })
648            .collect()
649    };
650
651    if peers.is_empty() {
652        if origin_protocol == WsProtocol::HashtreeJson {
653            send_json(
654                state,
655                client_id,
656                WsResponse {
657                    kind: "res",
658                    id: request_id,
659                    hash,
660                    found: false,
661                },
662            )
663            .await;
664        }
665        return;
666    }
667
668    {
669        let mut pending = state.ws_relay.pending.lock().await;
670        for (peer_id, _, _) in &peers {
671            pending.insert(
672                (*peer_id, request_id),
673                PendingRequest {
674                    origin_id: client_id,
675                    hash: hash.clone(),
676                    found: false,
677                    origin_protocol,
678                },
679            );
680        }
681    }
682
683    let request_text = serde_json::to_string(&WsRequest {
684        kind: "req".to_string(),
685        id: request_id,
686        hash: hash.clone(),
687    })
688    .unwrap_or_else(|_| String::new());
689    for (peer_id, tx, protocol) in peers {
690        match protocol {
691            WsProtocol::HashtreeMsgpack => {
692                let _ = send_msgpack_request(state, peer_id, &hash_bytes).await;
693            }
694            WsProtocol::HashtreeJson => {
695                let _ = tx.send(Message::Text(request_text.clone()));
696            }
697            WsProtocol::Unknown => {}
698        }
699    }
700
701    let timeout_state = state.clone();
702    let timeout_hash = hash.clone();
703    tokio::spawn(async move {
704        tokio::time::sleep(Duration::from_millis(1500)).await;
705        let mut pending = timeout_state.ws_relay.pending.lock().await;
706        let still_pending = pending
707            .iter()
708            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id);
709        let already_found = pending
710            .iter()
711            .any(|((_, id), p)| *id == request_id && p.origin_id == client_id && p.found);
712        if !still_pending || already_found {
713            return;
714        }
715        let origin_protocol = pending
716            .iter()
717            .find(|((_, id), p)| *id == request_id && p.origin_id == client_id)
718            .map(|(_, p)| p.origin_protocol)
719            .unwrap_or(WsProtocol::HashtreeJson);
720        pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == client_id));
721        drop(pending);
722        if origin_protocol == WsProtocol::HashtreeJson {
723            send_json(
724                &timeout_state,
725                client_id,
726                WsResponse {
727                    kind: "res",
728                    id: request_id,
729                    hash: timeout_hash,
730                    found: false,
731                },
732            )
733            .await;
734        }
735    });
736}
737
738async fn handle_response(
739    client_id: u64,
740    request_id: u32,
741    _hash: String,
742    found: bool,
743    state: &AppState,
744) {
745    let pending_entry = {
746        let pending = state.ws_relay.pending.lock().await;
747        pending
748            .get(&(client_id, request_id))
749            .map(|p| (p.origin_id, p.hash.clone(), p.found, p.origin_protocol))
750    };
751
752    let Some((origin_id, pending_hash, already_found, origin_protocol)) = pending_entry else {
753        return;
754    };
755
756    if already_found && !found {
757        let mut pending = state.ws_relay.pending.lock().await;
758        pending.remove(&(client_id, request_id));
759        return;
760    }
761
762    if found {
763        let mut pending = state.ws_relay.pending.lock().await;
764        for ((_, id), p) in pending.iter_mut() {
765            if *id == request_id && p.origin_id == origin_id {
766                p.found = true;
767            }
768        }
769        drop(pending);
770        if origin_protocol == WsProtocol::HashtreeJson {
771            send_json(
772                state,
773                origin_id,
774                WsResponse {
775                    kind: "res",
776                    id: request_id,
777                    hash: pending_hash,
778                    found: true,
779                },
780            )
781            .await;
782        }
783        return;
784    }
785
786    let mut pending = state.ws_relay.pending.lock().await;
787    pending.remove(&(client_id, request_id));
788    let has_remaining = pending
789        .iter()
790        .any(|((_, id), p)| *id == request_id && p.origin_id == origin_id);
791    drop(pending);
792
793    if !has_remaining && origin_protocol == WsProtocol::HashtreeJson {
794        send_json(
795            state,
796            origin_id,
797            WsResponse {
798                kind: "res",
799                id: request_id,
800                hash: pending_hash,
801                found: false,
802            },
803        )
804        .await;
805    }
806}
807
808async fn handle_binary(client_id: u64, data: Vec<u8>, state: &AppState) {
809    if let Some(msg) = parse_msgpack_message(&data) {
810        set_client_protocol(state, client_id, WsProtocol::HashtreeMsgpack).await;
811        match msg {
812            DataMessage::Request(req) => {
813                let hash_hex = hex_encode(&req.h);
814                let request_id = state.ws_relay.next_request_id();
815                handle_request(
816                    client_id,
817                    request_id,
818                    hash_hex,
819                    WsProtocol::HashtreeMsgpack,
820                    state,
821                )
822                .await;
823            }
824            DataMessage::Response(res) => {
825                handle_msgpack_response(client_id, res, state).await;
826            }
827            DataMessage::QuoteRequest(_)
828            | DataMessage::QuoteResponse(_)
829            | DataMessage::Payment(_)
830            | DataMessage::PaymentAck(_)
831            | DataMessage::Chunk(_)
832            | DataMessage::PeerHints(_) => {}
833        }
834        return;
835    }
836
837    // Legacy binary: [4-byte LE request_id][data]
838    if data.len() < 4 {
839        return;
840    }
841    let request_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
842    let pending_entry = {
843        let pending = state.ws_relay.pending.lock().await;
844        pending
845            .get(&(client_id, request_id))
846            .map(|p| (p.origin_id, p.hash.clone(), p.origin_protocol))
847    };
848    let Some((origin_id, hash_hex, origin_protocol)) = pending_entry else {
849        return;
850    };
851
852    match origin_protocol {
853        WsProtocol::HashtreeJson => {
854            send_binary(state, origin_id, request_id, data[4..].to_vec()).await;
855        }
856        WsProtocol::HashtreeMsgpack => {
857            let Ok(hash_bytes) = from_hex(&hash_hex) else {
858                return;
859            };
860            send_msgpack_response(state, origin_id, &hash_bytes, &data[4..]).await;
861        }
862        WsProtocol::Unknown => {}
863    }
864
865    let mut pending = state.ws_relay.pending.lock().await;
866    pending.retain(|(_, id), p| !(*id == request_id && p.origin_id == origin_id));
867}
868
869async fn handle_nostr_message(client_id: u64, msg: NostrClientMessage, state: &AppState) {
870    let replies = nostr_responses_for(&msg);
871    for reply in replies {
872        send_nostr(state, client_id, reply).await;
873    }
874}
875
876fn nostr_responses_for(msg: &NostrClientMessage) -> Vec<NostrRelayMessage> {
877    match msg {
878        NostrClientMessage::Event(event) => {
879            let ok = event.verify().is_ok();
880            let message = if ok { "" } else { "invalid: signature" };
881            vec![NostrRelayMessage::ok(event.id, ok, message)]
882        }
883        NostrClientMessage::Req {
884            subscription_id, ..
885        } => {
886            vec![NostrRelayMessage::eose(subscription_id.clone())]
887        }
888        NostrClientMessage::Count {
889            subscription_id, ..
890        } => {
891            vec![NostrRelayMessage::count(subscription_id.clone(), 0)]
892        }
893        NostrClientMessage::Close(_) => Vec::new(),
894        NostrClientMessage::Auth(event) => {
895            let ok = event.verify().is_ok();
896            let message = if ok { "" } else { "invalid auth" };
897            vec![NostrRelayMessage::ok(event.id, ok, message)]
898        }
899        NostrClientMessage::NegOpen { .. }
900        | NostrClientMessage::NegMsg { .. }
901        | NostrClientMessage::NegClose { .. } => {
902            vec![NostrRelayMessage::notice("negentropy not supported")]
903        }
904    }
905}
906
907async fn send_nostr(state: &AppState, client_id: u64, response: NostrRelayMessage) {
908    let text = response.as_json();
909    send_to_client(state, client_id, Message::Text(text)).await;
910}
911
912fn parse_msgpack_message(data: &[u8]) -> Option<DataMessage> {
913    let msg = parse_message(data).ok()?;
914    match msg {
915        DataMessage::Request(req) => {
916            if req.h.len() == 32 {
917                Some(DataMessage::Request(req))
918            } else {
919                None
920            }
921        }
922        DataMessage::Response(res) => {
923            if res.h.len() == 32 {
924                Some(DataMessage::Response(res))
925            } else {
926                None
927            }
928        }
929        DataMessage::QuoteRequest(req) => {
930            if req.h.len() == 32 {
931                Some(DataMessage::QuoteRequest(req))
932            } else {
933                None
934            }
935        }
936        DataMessage::QuoteResponse(res) => {
937            if res.h.len() == 32 {
938                Some(DataMessage::QuoteResponse(res))
939            } else {
940                None
941            }
942        }
943        DataMessage::Payment(req) => {
944            if req.h.len() == 32 {
945                Some(DataMessage::Payment(req))
946            } else {
947                None
948            }
949        }
950        DataMessage::PaymentAck(res) => {
951            if res.h.len() == 32 {
952                Some(DataMessage::PaymentAck(res))
953            } else {
954                None
955            }
956        }
957        DataMessage::Chunk(chunk) => {
958            if chunk.h.len() == 32 {
959                Some(DataMessage::Chunk(chunk))
960            } else {
961                None
962            }
963        }
964        DataMessage::PeerHints(_) => None,
965    }
966}
967
968async fn handle_msgpack_response(client_id: u64, res: DataResponse, state: &AppState) {
969    let hash_hex = hex_encode(&res.h);
970    let data = res.d.clone();
971    let hash_bytes = res.h.clone();
972
973    let mut responses: Vec<(u64, u32, WsProtocol)> = Vec::new();
974    let mut seen = HashSet::new();
975    {
976        let pending = state.ws_relay.pending.lock().await;
977        for ((peer_id, request_id), p) in pending.iter() {
978            if *peer_id != client_id {
979                continue;
980            }
981            if p.hash != hash_hex {
982                continue;
983            }
984            if seen.insert((p.origin_id, *request_id)) {
985                responses.push((p.origin_id, *request_id, p.origin_protocol));
986            }
987        }
988    }
989
990    if responses.is_empty() {
991        return;
992    }
993
994    for (origin_id, request_id, protocol) in &responses {
995        match protocol {
996            WsProtocol::HashtreeJson => {
997                send_json(
998                    state,
999                    *origin_id,
1000                    WsResponse {
1001                        kind: "res",
1002                        id: *request_id,
1003                        hash: hash_hex.clone(),
1004                        found: true,
1005                    },
1006                )
1007                .await;
1008                send_binary(state, *origin_id, *request_id, data.clone()).await;
1009            }
1010            WsProtocol::HashtreeMsgpack => {
1011                send_msgpack_response(state, *origin_id, &hash_bytes, &data).await;
1012            }
1013            WsProtocol::Unknown => {}
1014        }
1015    }
1016
1017    let completed: HashSet<(u64, u32)> = responses
1018        .into_iter()
1019        .map(|(origin_id, request_id, _)| (origin_id, request_id))
1020        .collect();
1021    let mut pending = state.ws_relay.pending.lock().await;
1022    pending.retain(|(_, id), p| !completed.contains(&(p.origin_id, *id)));
1023}
1024
1025async fn send_json(state: &AppState, client_id: u64, response: WsResponse) {
1026    if let Ok(text) = serde_json::to_string(&response) {
1027        send_to_client(state, client_id, Message::Text(text)).await;
1028    }
1029}
1030
1031async fn send_msgpack_request(
1032    state: &AppState,
1033    client_id: u64,
1034    hash: &[u8],
1035) -> Result<(), rmp_serde::encode::Error> {
1036    let req = DataRequest {
1037        h: hash.to_vec(),
1038        htl: MAX_HTL,
1039        q: None,
1040    };
1041    let wire = encode_request(&req)?;
1042    send_to_client(state, client_id, Message::Binary(wire)).await;
1043    Ok(())
1044}
1045
1046async fn send_msgpack_response(state: &AppState, client_id: u64, hash: &[u8], data: &[u8]) {
1047    let res = DataResponse {
1048        h: hash.to_vec(),
1049        d: data.to_vec(),
1050        i: None,
1051        n: None,
1052    };
1053    if let Ok(wire) = encode_response(&res) {
1054        send_to_client(state, client_id, Message::Binary(wire)).await;
1055    }
1056}
1057
1058async fn send_binary(state: &AppState, client_id: u64, request_id: u32, payload: Vec<u8>) {
1059    let mut packet = Vec::with_capacity(4 + payload.len());
1060    packet.extend_from_slice(&request_id.to_le_bytes());
1061    packet.extend_from_slice(&payload);
1062    send_to_client(state, client_id, Message::Binary(packet)).await;
1063}
1064
1065async fn send_to_client(state: &AppState, client_id: u64, msg: Message) {
1066    let sender = {
1067        let clients = state.ws_relay.clients.lock().await;
1068        clients.get(&client_id).cloned()
1069    };
1070    if let Some(tx) = sender {
1071        let _ = tx.send(msg);
1072    }
1073}
1074
1075async fn set_client_protocol(state: &AppState, client_id: u64, protocol: WsProtocol) {
1076    let mut protocols = state.ws_relay.client_protocols.lock().await;
1077    protocols.insert(client_id, protocol);
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082    use super::*;
1083    use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
1084    use anyhow::Result;
1085    use futures::{SinkExt, StreamExt};
1086    use nostr::secp256k1::schnorr::Signature;
1087    use nostr::{EventBuilder, Filter, Keys, Kind, SubscriptionId};
1088    use std::collections::HashSet;
1089    use std::sync::Arc;
1090    use tempfile::TempDir;
1091    use tokio::net::TcpListener;
1092    use tokio_tungstenite::{accept_async, tungstenite::Message as TungsteniteMessage};
1093
1094    #[test]
1095    fn parse_ws_text_message_detects_nostr_req() {
1096        let msg = r#"["REQ","sub-1",{"kinds":[1]}]"#;
1097        match parse_ws_text_message(msg) {
1098            Some(WsTextMessage::Nostr(_)) => {}
1099            other => panic!("expected Nostr message, got {:?}", other),
1100        }
1101    }
1102
1103    #[test]
1104    fn parse_ws_text_message_detects_hashtree_request() {
1105        let msg = r#"{"type":"req","id":1,"hash":"abcd"}"#;
1106        match parse_ws_text_message(msg) {
1107            Some(WsTextMessage::Hashtree(_)) => {}
1108            other => panic!("expected Hashtree message, got {:?}", other),
1109        }
1110    }
1111
1112    #[test]
1113    fn nostr_replies_for_req_is_eose() {
1114        let sub = SubscriptionId::new("sub-1");
1115        let msg = NostrClientMessage::req(sub.clone(), vec![]);
1116        let replies = nostr_responses_for(&msg);
1117        assert_eq!(replies.len(), 1);
1118        match &replies[0] {
1119            NostrRelayMessage::EndOfStoredEvents(id) => assert_eq!(id, &sub),
1120            other => panic!("expected EOSE, got {:?}", other),
1121        }
1122    }
1123
1124    #[test]
1125    fn nostr_replies_for_event_ok() {
1126        let keys = Keys::generate();
1127        let event = EventBuilder::new(Kind::TextNote, "hello", [])
1128            .to_event(&keys)
1129            .unwrap();
1130        let msg = NostrClientMessage::event(event.clone());
1131        let replies = nostr_responses_for(&msg);
1132        assert_eq!(replies.len(), 1);
1133        match &replies[0] {
1134            NostrRelayMessage::Ok {
1135                event_id, status, ..
1136            } => {
1137                assert_eq!(event_id, &event.id);
1138                assert!(*status);
1139            }
1140            other => panic!("expected OK, got {:?}", other),
1141        }
1142    }
1143
1144    #[test]
1145    fn nostr_replies_for_invalid_event_is_not_ok() {
1146        let keys = Keys::generate();
1147        let mut event = EventBuilder::new(Kind::TextNote, "hello", [])
1148            .to_event(&keys)
1149            .unwrap();
1150        event.sig = Signature::from_slice(&[0u8; 64]).unwrap();
1151        let msg = NostrClientMessage::event(event);
1152        let replies = nostr_responses_for(&msg);
1153        assert_eq!(replies.len(), 1);
1154        match &replies[0] {
1155            NostrRelayMessage::Ok { status, .. } => assert!(!*status),
1156            other => panic!("expected OK=false, got {:?}", other),
1157        }
1158    }
1159
1160    async fn spawn_mock_upstream_relay(events: Vec<nostr::Event>) -> String {
1161        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind relay");
1162        let addr = listener.local_addr().expect("relay addr");
1163        tokio::spawn(async move {
1164            let (stream, _) = listener.accept().await.expect("accept relay");
1165            let ws = accept_async(stream).await.expect("accept websocket");
1166            let (mut write, mut read) = ws.split();
1167
1168            while let Some(Ok(message)) = read.next().await {
1169                let TungsteniteMessage::Text(text) = message else {
1170                    continue;
1171                };
1172                let Ok(parsed) = NostrClientMessage::from_json(text.as_bytes()) else {
1173                    continue;
1174                };
1175                if let NostrClientMessage::Req {
1176                    subscription_id,
1177                    filters,
1178                } = parsed
1179                {
1180                    for event in events
1181                        .iter()
1182                        .filter(|event| filters.iter().any(|filter| filter.match_event(event)))
1183                    {
1184                        let _ = write
1185                            .send(TungsteniteMessage::Text(
1186                                NostrRelayMessage::event(subscription_id.clone(), event.clone())
1187                                    .as_json()
1188                                    .into(),
1189                            ))
1190                            .await;
1191                    }
1192                    let _ = write
1193                        .send(TungsteniteMessage::Text(
1194                            NostrRelayMessage::eose(subscription_id).as_json().into(),
1195                        ))
1196                        .await;
1197                }
1198            }
1199        });
1200        format!("ws://{}", addr)
1201    }
1202
1203    fn test_app_state(
1204        tmp: &TempDir,
1205        relay: Arc<NostrRelay>,
1206        relay_url: String,
1207    ) -> Result<AppState> {
1208        let store = Arc::new(crate::storage::HashtreeStore::with_options(
1209            tmp.path(),
1210            None,
1211            128 * 1024 * 1024,
1212        )?);
1213        Ok(AppState {
1214            store,
1215            auth: None,
1216            peer_mode: crate::config::ServerMode::Normal,
1217            hash_get_enabled: true,
1218            webrtc_peers: None,
1219            ws_relay: Arc::new(super::super::auth::WsRelayState::new()),
1220            max_upload_bytes: 5 * 1024 * 1024,
1221            public_writes: true,
1222            allowed_pubkeys: HashSet::new(),
1223            upstream_blossom: Vec::new(),
1224            social_graph: None,
1225            social_graph_store: None,
1226            social_graph_root: None,
1227            socialgraph_snapshot_public: false,
1228            nostr_relay: Some(relay),
1229            nostr_relay_urls: vec![relay_url],
1230            tree_root_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1231            inflight_blob_fetches: Arc::new(tokio::sync::Mutex::new(
1232                std::collections::HashMap::new(),
1233            )),
1234            directory_listing_cache: Arc::new(std::sync::Mutex::new(
1235                super::super::auth::new_lookup_cache(),
1236            )),
1237            resolved_path_cache: Arc::new(std::sync::Mutex::new(
1238                super::super::auth::new_lookup_cache(),
1239            )),
1240            thumbnail_path_cache: Arc::new(std::sync::Mutex::new(
1241                super::super::auth::new_lookup_cache(),
1242            )),
1243            cid_size_cache: Arc::new(std::sync::Mutex::new(super::super::auth::new_lookup_cache())),
1244        })
1245    }
1246
1247    #[tokio::test]
1248    async fn upstream_proxy_forwards_events_and_caches_them() -> Result<()> {
1249        let tmp = TempDir::new()?;
1250        let graph_store = {
1251            let _guard = crate::socialgraph::test_lock();
1252            crate::socialgraph::open_social_graph_store_with_mapsize(
1253                tmp.path(),
1254                Some(128 * 1024 * 1024),
1255            )?
1256        };
1257        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1258        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1259            Arc::clone(&backend),
1260            0,
1261            HashSet::new(),
1262        ));
1263
1264        let keys = Keys::generate();
1265        let relay = Arc::new(NostrRelay::new(
1266            Arc::clone(&backend),
1267            tmp.path().to_path_buf(),
1268            HashSet::from([keys.public_key().to_hex()]),
1269            Some(access),
1270            NostrRelayConfig {
1271                spambox_db_max_bytes: 0,
1272                ..Default::default()
1273            },
1274        )?);
1275
1276        let event = EventBuilder::new(
1277            Kind::from(30078_u16),
1278            "",
1279            [
1280                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1281                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1282            ],
1283        )
1284        .to_event(&keys)?;
1285
1286        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1287        let filter = Filter::new()
1288            .authors(vec![event.pubkey])
1289            .kinds(vec![event.kind]);
1290        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1291        let client_id = 7_u64;
1292        let (tx, mut rx) = mpsc::unbounded_channel();
1293        state.ws_relay.clients.lock().await.insert(client_id, tx);
1294        let subscription_id = SubscriptionId::new("sub-1");
1295
1296        start_upstream_nostr_subscription(
1297            &state,
1298            client_id,
1299            subscription_id.clone(),
1300            vec![filter.clone()],
1301        )
1302        .await;
1303
1304        let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1305            .await?
1306            .expect("forwarded upstream event");
1307        let Message::Text(text) = forwarded else {
1308            panic!("expected text event");
1309        };
1310        match NostrRelayMessage::from_json(text.as_str())? {
1311            NostrRelayMessage::Event {
1312                subscription_id: sid,
1313                event: forwarded_event,
1314            } => {
1315                assert_eq!(sid, subscription_id);
1316                assert_eq!(forwarded_event.id, event.id);
1317            }
1318            other => panic!("expected forwarded EVENT, got {:?}", other),
1319        }
1320
1321        let eose = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
1322            .await?
1323            .expect("forwarded upstream eose");
1324        let Message::Text(eose_text) = eose else {
1325            panic!("expected text eose");
1326        };
1327        match NostrRelayMessage::from_json(eose_text.as_str())? {
1328            NostrRelayMessage::EndOfStoredEvents(sid) => {
1329                assert_eq!(sid, subscription_id);
1330            }
1331            other => panic!("expected forwarded EOSE, got {:?}", other),
1332        }
1333
1334        let events = relay.query_events(&filter, 10).await;
1335        assert_eq!(events.len(), 1);
1336        assert_eq!(events[0].id, event.id);
1337
1338        close_upstream_nostr_subscription(&state, client_id, &subscription_id).await;
1339        assert!(state
1340            .ws_relay
1341            .upstream_nostr_subscriptions
1342            .lock()
1343            .await
1344            .is_empty());
1345        Ok(())
1346    }
1347
1348    #[tokio::test]
1349    async fn req_waits_for_upstream_event_before_eose() -> Result<()> {
1350        let tmp = TempDir::new()?;
1351        let graph_store = {
1352            let _guard = crate::socialgraph::test_lock();
1353            crate::socialgraph::open_social_graph_store_with_mapsize(
1354                tmp.path(),
1355                Some(128 * 1024 * 1024),
1356            )?
1357        };
1358        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1359        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1360            Arc::clone(&backend),
1361            0,
1362            HashSet::new(),
1363        ));
1364
1365        let keys = Keys::generate();
1366        let relay = Arc::new(NostrRelay::new(
1367            Arc::clone(&backend),
1368            tmp.path().to_path_buf(),
1369            HashSet::from([keys.public_key().to_hex()]),
1370            Some(access),
1371            NostrRelayConfig {
1372                spambox_db_max_bytes: 0,
1373                ..Default::default()
1374            },
1375        )?);
1376
1377        let event = EventBuilder::new(
1378            Kind::from(30078_u16),
1379            "",
1380            [
1381                nostr::Tag::parse(&["d", "videos/Test"]).expect("d tag"),
1382                nostr::Tag::parse(&["l", "hashtree"]).expect("label tag"),
1383            ],
1384        )
1385        .to_event(&keys)?;
1386
1387        let relay_url = spawn_mock_upstream_relay(vec![event.clone()]).await;
1388        let state = test_app_state(&tmp, relay.clone(), relay_url)?;
1389        let client_id = 11_u64;
1390        let (ws_tx, mut ws_rx) = mpsc::unbounded_channel();
1391        let (relay_tx, _relay_rx) = mpsc::unbounded_channel();
1392        state.ws_relay.clients.lock().await.insert(client_id, ws_tx);
1393        relay.register_client(client_id, relay_tx, None).await;
1394
1395        let request = NostrClientMessage::req(
1396            SubscriptionId::new("feed"),
1397            vec![Filter::new()
1398                .authors(vec![event.pubkey])
1399                .kinds(vec![event.kind])],
1400        )
1401        .as_json();
1402
1403        handle_message(client_id, Message::Text(request.into()), &state).await;
1404
1405        let first = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1406            .await?
1407            .expect("first forwarded message");
1408        let Message::Text(first_text) = first else {
1409            panic!("expected text event");
1410        };
1411        match NostrRelayMessage::from_json(first_text.as_str())? {
1412            NostrRelayMessage::Event {
1413                event: forwarded_event,
1414                ..
1415            } => {
1416                assert_eq!(forwarded_event.id, event.id);
1417            }
1418            other => panic!("expected upstream EVENT before EOSE, got {:?}", other),
1419        }
1420
1421        let second = tokio::time::timeout(std::time::Duration::from_secs(2), ws_rx.recv())
1422            .await?
1423            .expect("second forwarded message");
1424        let Message::Text(second_text) = second else {
1425            panic!("expected text eose");
1426        };
1427        match NostrRelayMessage::from_json(second_text.as_str())? {
1428            NostrRelayMessage::EndOfStoredEvents(sid) => {
1429                assert_eq!(sid, SubscriptionId::new("feed"));
1430            }
1431            other => panic!("expected aggregated EOSE, got {:?}", other),
1432        }
1433
1434        Ok(())
1435    }
1436
1437    #[tokio::test]
1438    async fn websocket_publish_returns_ok_for_trusted_event() -> Result<()> {
1439        let tmp = TempDir::new()?;
1440        let graph_store = {
1441            let _guard = crate::socialgraph::test_lock();
1442            crate::socialgraph::open_social_graph_store_with_mapsize(
1443                tmp.path(),
1444                Some(128 * 1024 * 1024),
1445            )?
1446        };
1447        let author_keys = Keys::generate();
1448        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1449        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1450            Arc::clone(&backend),
1451            0,
1452            HashSet::from([author_keys.public_key().to_hex()]),
1453        ));
1454        let relay = Arc::new(NostrRelay::new(
1455            Arc::clone(&backend),
1456            tmp.path().to_path_buf(),
1457            HashSet::from([author_keys.public_key().to_hex()]),
1458            Some(access),
1459            NostrRelayConfig {
1460                spambox_db_max_bytes: 0,
1461                ..Default::default()
1462            },
1463        )?);
1464
1465        let state = test_app_state(&tmp, relay.clone(), String::new())?;
1466        let listener = TcpListener::bind("127.0.0.1:0").await?;
1467        let addr = listener.local_addr()?;
1468        let client_pubkey = author_keys.public_key().to_hex();
1469        let app = axum::Router::new().route(
1470            "/ws",
1471            axum::routing::get({
1472                let state = state.clone();
1473                let client_pubkey = client_pubkey.clone();
1474                move |ws: WebSocketUpgrade| {
1475                    let state = state.clone();
1476                    let client_pubkey = client_pubkey.clone();
1477                    async move { ws_data_with_client_pubkey(state, ws, Some(client_pubkey)) }
1478                }
1479            }),
1480        );
1481        tokio::spawn(async move {
1482            let _ = axum::serve(listener, app).await;
1483        });
1484
1485        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1486        let event = EventBuilder::new(Kind::TextNote, "websocket publish ack", [])
1487            .to_event(&author_keys)?;
1488        socket
1489            .send(TungsteniteMessage::Text(
1490                NostrClientMessage::event(event.clone()).as_json().into(),
1491            ))
1492            .await?;
1493
1494        let reply = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1495            .await?
1496            .ok_or_else(|| anyhow::anyhow!("websocket closed before publish ack"))??;
1497        let TungsteniteMessage::Text(text) = reply else {
1498            anyhow::bail!("expected text publish ack");
1499        };
1500
1501        match NostrRelayMessage::from_json(text.as_str())? {
1502            NostrRelayMessage::Ok {
1503                event_id, status, ..
1504            } => {
1505                assert_eq!(event_id, event.id);
1506                assert!(status);
1507            }
1508            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1509        }
1510
1511        let stored = relay
1512            .query_events(
1513                &Filter::new()
1514                    .authors(vec![event.pubkey])
1515                    .kinds(vec![event.kind]),
1516                10,
1517            )
1518            .await;
1519        assert!(stored.iter().any(|candidate| candidate.id == event.id));
1520        Ok(())
1521    }
1522
1523    #[tokio::test]
1524    async fn websocket_req_is_rate_limited_after_configured_quota() -> Result<()> {
1525        let tmp = TempDir::new()?;
1526        let graph_store = {
1527            let _guard = crate::socialgraph::test_lock();
1528            crate::socialgraph::open_social_graph_store_with_mapsize(
1529                tmp.path(),
1530                Some(128 * 1024 * 1024),
1531            )?
1532        };
1533        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1534        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1535            Arc::clone(&backend),
1536            0,
1537            HashSet::new(),
1538        ));
1539        let relay = Arc::new(NostrRelay::new(
1540            Arc::clone(&backend),
1541            tmp.path().to_path_buf(),
1542            HashSet::new(),
1543            Some(access),
1544            NostrRelayConfig {
1545                spambox_db_max_bytes: 0,
1546                spambox_max_reqs_per_min: 1,
1547                ..Default::default()
1548            },
1549        )?);
1550
1551        let state = test_app_state(&tmp, relay, String::new())?;
1552        let listener = TcpListener::bind("127.0.0.1:0").await?;
1553        let addr = listener.local_addr()?;
1554        let app = axum::Router::new().route(
1555            "/ws",
1556            axum::routing::get({
1557                let state = state.clone();
1558                move |ws: WebSocketUpgrade| {
1559                    let state = state.clone();
1560                    async move { ws_data_with_client_pubkey(state, ws, None) }
1561                }
1562            }),
1563        );
1564        tokio::spawn(async move {
1565            let _ = axum::serve(listener, app).await;
1566        });
1567
1568        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1569        socket
1570            .send(TungsteniteMessage::Text(
1571                NostrClientMessage::req(SubscriptionId::new("sub-1"), vec![Filter::new()])
1572                    .as_json()
1573                    .into(),
1574            ))
1575            .await?;
1576
1577        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1578            .await?
1579            .ok_or_else(|| anyhow::anyhow!("websocket closed before first relay reply"))??;
1580        let TungsteniteMessage::Text(first_text) = first else {
1581            anyhow::bail!("expected text EOSE reply");
1582        };
1583        match NostrRelayMessage::from_json(first_text.as_str())? {
1584            NostrRelayMessage::EndOfStoredEvents(subscription_id) => {
1585                assert_eq!(subscription_id, SubscriptionId::new("sub-1"));
1586            }
1587            other => anyhow::bail!("expected EOSE for first request, got {:?}", other),
1588        }
1589
1590        socket
1591            .send(TungsteniteMessage::Text(
1592                NostrClientMessage::req(SubscriptionId::new("sub-2"), vec![Filter::new()])
1593                    .as_json()
1594                    .into(),
1595            ))
1596            .await?;
1597
1598        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1599            .await?
1600            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit reply"))??;
1601        let TungsteniteMessage::Text(second_text) = second else {
1602            anyhow::bail!("expected text CLOSED reply");
1603        };
1604        let second_value: serde_json::Value = serde_json::from_str(second_text.as_str())?;
1605        assert_eq!(
1606            second_value,
1607            serde_json::json!(["CLOSED", "sub-2", "rate limited"])
1608        );
1609
1610        Ok(())
1611    }
1612
1613    #[tokio::test]
1614    async fn websocket_publish_is_rate_limited_for_untrusted_spambox_events() -> Result<()> {
1615        let tmp = TempDir::new()?;
1616        let graph_store = {
1617            let _guard = crate::socialgraph::test_lock();
1618            crate::socialgraph::open_social_graph_store_with_mapsize(
1619                tmp.path(),
1620                Some(128 * 1024 * 1024),
1621            )?
1622        };
1623        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
1624        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
1625            Arc::clone(&backend),
1626            0,
1627            HashSet::new(),
1628        ));
1629        let relay = Arc::new(NostrRelay::new(
1630            Arc::clone(&backend),
1631            tmp.path().to_path_buf(),
1632            HashSet::new(),
1633            Some(access),
1634            NostrRelayConfig {
1635                spambox_db_max_bytes: 0,
1636                spambox_max_events_per_min: 1,
1637                ..Default::default()
1638            },
1639        )?);
1640
1641        let state = test_app_state(&tmp, relay, String::new())?;
1642        let listener = TcpListener::bind("127.0.0.1:0").await?;
1643        let addr = listener.local_addr()?;
1644        let app = axum::Router::new().route(
1645            "/ws",
1646            axum::routing::get({
1647                let state = state.clone();
1648                move |ws: WebSocketUpgrade| {
1649                    let state = state.clone();
1650                    async move { ws_data_with_client_pubkey(state, ws, None) }
1651                }
1652            }),
1653        );
1654        tokio::spawn(async move {
1655            let _ = axum::serve(listener, app).await;
1656        });
1657
1658        let (mut socket, _) = connect_async(format!("ws://{addr}/ws")).await?;
1659        let author_keys = Keys::generate();
1660        let event_a = EventBuilder::new(Kind::TextNote, "spambox-a", []).to_event(&author_keys)?;
1661        let event_b = EventBuilder::new(Kind::TextNote, "spambox-b", []).to_event(&author_keys)?;
1662
1663        socket
1664            .send(TungsteniteMessage::Text(
1665                NostrClientMessage::event(event_a.clone()).as_json().into(),
1666            ))
1667            .await?;
1668
1669        let first = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1670            .await?
1671            .ok_or_else(|| anyhow::anyhow!("websocket closed before first publish ack"))??;
1672        let TungsteniteMessage::Text(first_text) = first else {
1673            anyhow::bail!("expected text publish ack");
1674        };
1675        match NostrRelayMessage::from_json(first_text.as_str())? {
1676            NostrRelayMessage::Ok {
1677                event_id,
1678                status,
1679                message,
1680            } => {
1681                assert_eq!(event_id, event_a.id);
1682                assert!(status);
1683                assert_eq!(message, "spambox");
1684            }
1685            other => anyhow::bail!("expected OK publish ack, got {:?}", other),
1686        }
1687
1688        socket
1689            .send(TungsteniteMessage::Text(
1690                NostrClientMessage::event(event_b.clone()).as_json().into(),
1691            ))
1692            .await?;
1693
1694        let second = tokio::time::timeout(std::time::Duration::from_secs(2), socket.next())
1695            .await?
1696            .ok_or_else(|| anyhow::anyhow!("websocket closed before rate-limit publish ack"))??;
1697        let TungsteniteMessage::Text(second_text) = second else {
1698            anyhow::bail!("expected text publish ack");
1699        };
1700        match NostrRelayMessage::from_json(second_text.as_str())? {
1701            NostrRelayMessage::Ok {
1702                event_id,
1703                status,
1704                message,
1705            } => {
1706                assert_eq!(event_id, event_b.id);
1707                assert!(!status);
1708                assert_eq!(message, "rate limited");
1709            }
1710            other => anyhow::bail!("expected OK=false publish ack, got {:?}", other),
1711        }
1712
1713        Ok(())
1714    }
1715}