1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex, MutexGuard};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{anyhow, Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17 events_by_id: BTreeMap<String, Value>,
18 subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19 clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22fn lock_relay_state(state: &Arc<Mutex<RelayState>>) -> MutexGuard<'_, RelayState> {
23 state.lock().unwrap_or_else(|poison| poison.into_inner())
24}
25
26enum RelayControl {
27 ReplayStored,
28 Snapshot(std_mpsc::Sender<Vec<Value>>),
29 Shutdown,
30}
31
32pub struct TestRelay {
33 control_tx: mpsc::UnboundedSender<RelayControl>,
34 join: Option<thread::JoinHandle<()>>,
35 url: String,
36}
37
38impl TestRelay {
39 pub fn start() -> Self {
40 match Self::start_with_bind("127.0.0.1:0") {
41 Ok(relay) => relay,
42 Err(error) => {
43 eprintln!("failed to start local relay: {error}");
44 let (control_tx, _) = mpsc::unbounded_channel();
45 Self {
46 control_tx,
47 join: None,
48 url: String::new(),
49 }
50 }
51 }
52 }
53
54 pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
55 let (control_tx, mut control_rx) = mpsc::unbounded_channel();
56 let (ready_tx, ready_rx) = std_mpsc::channel();
57 let bind_addr = bind_addr.to_string();
58
59 let join = thread::spawn(move || {
60 let runtime = match tokio::runtime::Builder::new_multi_thread()
61 .enable_all()
62 .build()
63 {
64 Ok(runtime) => runtime,
65 Err(error) => {
66 let _ = ready_tx.send(Err(anyhow!("relay runtime: {error}")));
67 return;
68 }
69 };
70
71 runtime.block_on(async move {
72 let listener = match TcpListener::bind(&bind_addr)
73 .await
74 .with_context(|| format!("bind relay listener {bind_addr}"))
75 {
76 Ok(listener) => listener,
77 Err(error) => {
78 let _ = ready_tx.send(Err(error));
79 return;
80 }
81 };
82 let local_addr = match listener.local_addr() {
83 Ok(addr) => addr,
84 Err(error) => {
85 let _ = ready_tx.send(Err(anyhow!("relay local addr: {error}")));
86 return;
87 }
88 };
89 let state = Arc::new(Mutex::new(RelayState::default()));
90 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
91 let _ = ready_tx.send(Ok(format!("ws://{local_addr}")));
92
93 loop {
94 tokio::select! {
95 Some(control) = control_rx.recv() => {
96 match control {
97 RelayControl::ReplayStored => replay_stored_events(&state),
98 RelayControl::Snapshot(reply_tx) => {
99 let events = lock_relay_state(&state)
100 .events_by_id
101 .values()
102 .cloned()
103 .collect::<Vec<_>>();
104 let _ = reply_tx.send(events);
105 }
106 RelayControl::Shutdown => break,
107 }
108 }
109 accept_result = listener.accept() => {
110 let Ok((stream, _)) = accept_result else {
111 break;
112 };
113 let websocket = match accept_async(stream).await {
114 Ok(websocket) => websocket,
115 Err(error) => {
116 eprintln!("Ignoring failed test relay websocket handshake: {error}");
117 continue;
118 }
119 };
120 let state = state.clone();
121 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
122 tokio::spawn(async move {
123 handle_connection(client_id, websocket, state).await;
124 });
125 }
126 }
127 }
128 });
129 });
130
131 let url = ready_rx
132 .recv_timeout(StdDuration::from_secs(5))
133 .context("relay ready")??;
134
135 Ok(Self {
136 control_tx,
137 join: Some(join),
138 url,
139 })
140 }
141
142 pub fn url(&self) -> &str {
143 &self.url
144 }
145
146 pub fn replay_stored(&self) {
147 let _ = self.control_tx.send(RelayControl::ReplayStored);
148 }
149
150 pub fn events(&self) -> Vec<Value> {
151 let (reply_tx, reply_rx) = std_mpsc::channel();
152 let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
153 reply_rx
154 .recv_timeout(StdDuration::from_secs(5))
155 .unwrap_or_default()
156 }
157}
158
159impl Drop for TestRelay {
160 fn drop(&mut self) {
161 let _ = self.control_tx.send(RelayControl::Shutdown);
162 if let Some(join) = self.join.take() {
163 let _ = join.join();
164 }
165 }
166}
167
168pub fn run_forever(bind_addr: &str) -> Result<()> {
169 let runtime = tokio::runtime::Builder::new_multi_thread()
170 .enable_all()
171 .build()
172 .context("relay runtime")?;
173 let bind_addr = bind_addr.to_string();
174
175 runtime.block_on(async move {
176 let listener = TcpListener::bind(&bind_addr)
177 .await
178 .with_context(|| format!("bind relay listener {bind_addr}"))?;
179 let state = Arc::new(Mutex::new(RelayState::default()));
180 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
181
182 println!("Local Nostr relay listening on ws://{bind_addr}");
183
184 loop {
185 let (stream, _) = listener
186 .accept()
187 .await
188 .with_context(|| format!("accept relay client on {bind_addr}"))?;
189 let websocket = match accept_async(stream).await {
190 Ok(websocket) => websocket,
191 Err(error) => {
192 eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
193 continue;
194 }
195 };
196 let state = state.clone();
197 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
198 tokio::spawn(async move {
199 handle_connection(client_id, websocket, state).await;
200 });
201 }
202 })
203}
204
205async fn handle_connection(
206 client_id: usize,
207 websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
208 state: Arc<Mutex<RelayState>>,
209) {
210 let (mut sink, mut stream) = websocket.split();
211 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
212
213 {
214 let mut relay = lock_relay_state(&state);
215 relay.clients.insert(client_id, client_tx);
216 }
217
218 let writer = tokio::spawn(async move {
219 while let Some(message) = client_rx.recv().await {
220 if sink.send(message).await.is_err() {
221 break;
222 }
223 }
224 });
225
226 while let Some(message) = stream.next().await {
227 let Ok(message) = message else {
228 break;
229 };
230 match message {
231 Message::Text(text) => handle_client_message(client_id, &text, &state),
232 Message::Ping(payload) => {
233 let sender = {
234 let relay = lock_relay_state(&state);
235 relay.clients.get(&client_id).cloned()
236 };
237 if let Some(sender) = sender {
238 let _ = sender.send(Message::Pong(payload));
239 }
240 }
241 Message::Close(_) => break,
242 _ => {}
243 }
244 }
245
246 {
247 let mut relay = lock_relay_state(&state);
248 relay.clients.remove(&client_id);
249 relay.subscriptions.remove(&client_id);
250 }
251
252 writer.abort();
253}
254
255fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
256 let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
257 return;
258 };
259 let Some(parts) = message.as_array() else {
260 return;
261 };
262 let Some(kind) = parts.first().and_then(Value::as_str) else {
263 return;
264 };
265
266 match kind {
267 "REQ" if parts.len() >= 2 => {
268 let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
269 return;
270 };
271 let filters: Vec<Value> = parts
272 .iter()
273 .skip(2)
274 .filter(|value| value.is_object())
275 .cloned()
276 .collect();
277 let (sender, events) = {
278 let mut relay = lock_relay_state(state);
279 relay
280 .subscriptions
281 .entry(client_id)
282 .or_default()
283 .insert(subscription_id.to_string(), filters.clone());
284 (
285 relay.clients.get(&client_id).cloned(),
286 relay.events_by_id.values().cloned().collect::<Vec<_>>(),
287 )
288 };
289
290 if let Some(sender) = sender {
291 for event in events {
292 if matches_any_filter(&event, &filters) {
293 let payload =
294 Message::Text(json!(["EVENT", subscription_id, event]).to_string());
295 let _ = sender.send(payload);
296 }
297 }
298 let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
299 }
300 }
301 "CLOSE" if parts.len() >= 2 => {
302 let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
303 return;
304 };
305 let mut relay = lock_relay_state(state);
306 if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
307 subscriptions.remove(subscription_id);
308 }
309 }
310 "EVENT" if parts.get(1).is_some_and(Value::is_object) => {
311 let Some(event) = parts.get(1).cloned() else {
312 return;
313 };
314 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
315 return;
316 };
317 let (sender, deliveries) = {
318 let mut relay = lock_relay_state(state);
319 relay
320 .events_by_id
321 .insert(event_id.to_string(), event.clone());
322 let sender = relay.clients.get(&client_id).cloned();
323 let deliveries = matching_deliveries(&relay, &event);
324 (sender, deliveries)
325 };
326 if let Some(sender) = sender {
327 let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
328 }
329
330 for (target, payload) in deliveries {
331 let _ = target.send(payload);
332 }
333 }
334 _ => {}
335 }
336}
337
338fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
339 let deliveries = {
340 let relay = lock_relay_state(state);
341 relay
342 .events_by_id
343 .values()
344 .flat_map(|event| matching_deliveries(&relay, event))
345 .collect::<Vec<_>>()
346 };
347
348 for (target, payload) in deliveries {
349 let _ = target.send(payload);
350 }
351}
352
353fn matching_deliveries(
354 relay: &RelayState,
355 event: &Value,
356) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
357 let mut deliveries = Vec::new();
358 for (client_id, subscriptions) in &relay.subscriptions {
359 let Some(target) = relay.clients.get(client_id).cloned() else {
360 continue;
361 };
362 for (subscription_id, filters) in subscriptions {
363 if matches_any_filter(event, filters) {
364 deliveries.push((
365 target.clone(),
366 Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
367 ));
368 }
369 }
370 }
371 deliveries
372}
373
374pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
375 if filters.is_empty() {
376 return true;
377 }
378
379 filters.iter().any(|filter| matches_filter(event, filter))
380}
381
382pub fn matches_filter(event: &Value, filter: &Value) -> bool {
383 let Some(filter_object) = filter.as_object() else {
384 return false;
385 };
386
387 if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
388 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
389 return false;
390 };
391 if !ids
392 .iter()
393 .filter_map(Value::as_str)
394 .any(|id| id == event_id)
395 {
396 return false;
397 }
398 }
399
400 if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
401 let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
402 return false;
403 };
404 if !authors
405 .iter()
406 .filter_map(Value::as_str)
407 .any(|author| author == pubkey)
408 {
409 return false;
410 }
411 }
412
413 if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
414 let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
415 return false;
416 };
417 if !kinds
418 .iter()
419 .filter_map(Value::as_u64)
420 .any(|value| value == kind)
421 {
422 return false;
423 }
424 }
425
426 if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
427 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
428 return false;
429 };
430 if created_at < since {
431 return false;
432 }
433 }
434
435 if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
436 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
437 return false;
438 };
439 if created_at > until {
440 return false;
441 }
442 }
443
444 for (key, value) in filter_object {
445 let Some(tag_name) = key.strip_prefix('#') else {
446 continue;
447 };
448
449 let Some(expected_values) = value.as_array() else {
450 return false;
451 };
452 if expected_values.is_empty() {
453 continue;
454 }
455
456 let Some(tags) = event.get("tags").and_then(Value::as_array) else {
457 return false;
458 };
459 let matched = tags.iter().any(|tag| {
460 let Some(tag_values) = tag.as_array() else {
461 return false;
462 };
463 if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
464 return false;
465 }
466 tag_values
467 .iter()
468 .skip(1)
469 .filter_map(Value::as_str)
470 .any(|tag_value| {
471 expected_values
472 .iter()
473 .filter_map(Value::as_str)
474 .any(|expected| expected == tag_value)
475 })
476 });
477 if !matched {
478 return false;
479 }
480 }
481
482 true
483}