1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::path::{Path, PathBuf};
3use std::sync::mpsc as std_mpsc;
4use std::sync::{Arc, Mutex, MutexGuard};
5use std::thread;
6use std::time::Duration as StdDuration;
7
8use anyhow::{anyhow, Context, Result};
9use futures_util::{SinkExt, StreamExt};
10use serde_json::{json, Value};
11use tokio::net::TcpListener;
12use tokio::sync::mpsc;
13use tokio_tungstenite::accept_async;
14use tokio_tungstenite::tungstenite::Message;
15
16#[derive(Default)]
17struct RelayState {
18 events_by_id: BTreeMap<String, Value>,
19 subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
20 clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
21 faults: RelayFaults,
22 dropped_event_ids: HashSet<String>,
23}
24
25#[derive(Clone, Default)]
26struct RelayFaults {
27 drop_event_ids_file: Option<PathBuf>,
28 drop_matching_events_once: bool,
29}
30
31impl RelayState {
32 fn from_env() -> Self {
33 Self {
34 faults: RelayFaults::from_env(),
35 ..Self::default()
36 }
37 }
38
39 fn should_drop_event(&mut self, event_id: &str) -> bool {
40 let Some(path) = self.faults.drop_event_ids_file.as_ref() else {
41 return false;
42 };
43 if self.faults.drop_matching_events_once && self.dropped_event_ids.contains(event_id) {
44 return false;
45 }
46 if !drop_event_ids(path).contains(event_id) {
47 return false;
48 }
49 self.dropped_event_ids.insert(event_id.to_string());
50 true
51 }
52}
53
54impl RelayFaults {
55 fn from_env() -> Self {
56 let drop_event_ids_file = std::env::var_os("IRIS_LOCAL_RELAY_DROP_EVENT_IDS_FILE")
57 .filter(|value| !value.is_empty())
58 .map(PathBuf::from);
59 let drop_matching_events_once = !env_flag("IRIS_LOCAL_RELAY_DROP_EVENT_IDS_ALWAYS");
60 Self {
61 drop_event_ids_file,
62 drop_matching_events_once,
63 }
64 }
65}
66
67fn env_flag(name: &str) -> bool {
68 matches!(
69 std::env::var(name)
70 .unwrap_or_default()
71 .trim()
72 .to_ascii_lowercase()
73 .as_str(),
74 "1" | "true" | "yes" | "on"
75 )
76}
77
78fn drop_event_ids(path: &Path) -> HashSet<String> {
79 let Ok(raw) = std::fs::read_to_string(path) else {
80 return HashSet::new();
81 };
82 raw.lines()
83 .filter_map(|line| line.split('#').next())
84 .map(str::trim)
85 .filter(|line| !line.is_empty())
86 .map(str::to_string)
87 .collect()
88}
89
90fn lock_relay_state(state: &Arc<Mutex<RelayState>>) -> MutexGuard<'_, RelayState> {
91 state.lock().unwrap_or_else(|poison| poison.into_inner())
92}
93
94enum RelayControl {
95 ReplayStored,
96 Snapshot(std_mpsc::Sender<Vec<Value>>),
97 Shutdown,
98}
99
100pub struct TestRelay {
101 control_tx: mpsc::UnboundedSender<RelayControl>,
102 join: Option<thread::JoinHandle<()>>,
103 url: String,
104}
105
106impl TestRelay {
107 pub fn start() -> Self {
108 match Self::start_with_bind("127.0.0.1:0") {
109 Ok(relay) => relay,
110 Err(error) => {
111 eprintln!("failed to start local relay: {error}");
112 let (control_tx, _) = mpsc::unbounded_channel();
113 Self {
114 control_tx,
115 join: None,
116 url: String::new(),
117 }
118 }
119 }
120 }
121
122 pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
123 let (control_tx, mut control_rx) = mpsc::unbounded_channel();
124 let (ready_tx, ready_rx) = std_mpsc::channel();
125 let bind_addr = bind_addr.to_string();
126
127 let join = thread::spawn(move || {
128 let runtime = match tokio::runtime::Builder::new_multi_thread()
129 .enable_all()
130 .build()
131 {
132 Ok(runtime) => runtime,
133 Err(error) => {
134 let _ = ready_tx.send(Err(anyhow!("relay runtime: {error}")));
135 return;
136 }
137 };
138
139 runtime.block_on(async move {
140 let listener = match TcpListener::bind(&bind_addr)
141 .await
142 .with_context(|| format!("bind relay listener {bind_addr}"))
143 {
144 Ok(listener) => listener,
145 Err(error) => {
146 let _ = ready_tx.send(Err(error));
147 return;
148 }
149 };
150 let local_addr = match listener.local_addr() {
151 Ok(addr) => addr,
152 Err(error) => {
153 let _ = ready_tx.send(Err(anyhow!("relay local addr: {error}")));
154 return;
155 }
156 };
157 let state = Arc::new(Mutex::new(RelayState::default()));
158 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
159 let _ = ready_tx.send(Ok(format!("ws://{local_addr}")));
160
161 loop {
162 tokio::select! {
163 Some(control) = control_rx.recv() => {
164 match control {
165 RelayControl::ReplayStored => replay_stored_events(&state),
166 RelayControl::Snapshot(reply_tx) => {
167 let events = lock_relay_state(&state)
168 .events_by_id
169 .values()
170 .cloned()
171 .collect::<Vec<_>>();
172 let _ = reply_tx.send(events);
173 }
174 RelayControl::Shutdown => break,
175 }
176 }
177 accept_result = listener.accept() => {
178 let Ok((stream, _)) = accept_result else {
179 break;
180 };
181 let websocket = match accept_async(stream).await {
182 Ok(websocket) => websocket,
183 Err(error) => {
184 eprintln!("Ignoring failed test relay websocket handshake: {error}");
185 continue;
186 }
187 };
188 let state = state.clone();
189 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
190 tokio::spawn(async move {
191 handle_connection(client_id, websocket, state).await;
192 });
193 }
194 }
195 }
196 });
197 });
198
199 let url = ready_rx
200 .recv_timeout(StdDuration::from_secs(5))
201 .context("relay ready")??;
202
203 Ok(Self {
204 control_tx,
205 join: Some(join),
206 url,
207 })
208 }
209
210 pub fn url(&self) -> &str {
211 &self.url
212 }
213
214 pub fn replay_stored(&self) {
215 let _ = self.control_tx.send(RelayControl::ReplayStored);
216 }
217
218 pub fn events(&self) -> Vec<Value> {
219 let (reply_tx, reply_rx) = std_mpsc::channel();
220 let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
221 reply_rx
222 .recv_timeout(StdDuration::from_secs(5))
223 .unwrap_or_default()
224 }
225}
226
227impl Drop for TestRelay {
228 fn drop(&mut self) {
229 let _ = self.control_tx.send(RelayControl::Shutdown);
230 if let Some(join) = self.join.take() {
231 let _ = join.join();
232 }
233 }
234}
235
236pub fn run_forever(bind_addr: &str) -> Result<()> {
237 let runtime = tokio::runtime::Builder::new_multi_thread()
238 .enable_all()
239 .build()
240 .context("relay runtime")?;
241 let bind_addr = bind_addr.to_string();
242
243 runtime.block_on(async move {
244 let listener = TcpListener::bind(&bind_addr)
245 .await
246 .with_context(|| format!("bind relay listener {bind_addr}"))?;
247 let state = Arc::new(Mutex::new(RelayState::from_env()));
248 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
249
250 println!("Local Nostr relay listening on ws://{bind_addr}");
251
252 loop {
253 let (stream, _) = listener
254 .accept()
255 .await
256 .with_context(|| format!("accept relay client on {bind_addr}"))?;
257 let websocket = match accept_async(stream).await {
258 Ok(websocket) => websocket,
259 Err(error) => {
260 eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
261 continue;
262 }
263 };
264 let state = state.clone();
265 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
266 tokio::spawn(async move {
267 handle_connection(client_id, websocket, state).await;
268 });
269 }
270 })
271}
272
273async fn handle_connection(
274 client_id: usize,
275 websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
276 state: Arc<Mutex<RelayState>>,
277) {
278 let (mut sink, mut stream) = websocket.split();
279 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
280
281 {
282 let mut relay = lock_relay_state(&state);
283 relay.clients.insert(client_id, client_tx);
284 }
285
286 let writer = tokio::spawn(async move {
287 while let Some(message) = client_rx.recv().await {
288 if sink.send(message).await.is_err() {
289 break;
290 }
291 }
292 });
293
294 while let Some(message) = stream.next().await {
295 let Ok(message) = message else {
296 break;
297 };
298 match message {
299 Message::Text(text) => handle_client_message(client_id, &text, &state),
300 Message::Ping(payload) => {
301 let sender = {
302 let relay = lock_relay_state(&state);
303 relay.clients.get(&client_id).cloned()
304 };
305 if let Some(sender) = sender {
306 let _ = sender.send(Message::Pong(payload));
307 }
308 }
309 Message::Close(_) => break,
310 _ => {}
311 }
312 }
313
314 {
315 let mut relay = lock_relay_state(&state);
316 relay.clients.remove(&client_id);
317 relay.subscriptions.remove(&client_id);
318 }
319
320 writer.abort();
321}
322
323fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
324 let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
325 return;
326 };
327 let Some(parts) = message.as_array() else {
328 return;
329 };
330 let Some(kind) = parts.first().and_then(Value::as_str) else {
331 return;
332 };
333
334 match kind {
335 "REQ" if parts.len() >= 2 => {
336 let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
337 return;
338 };
339 let filters: Vec<Value> = parts
340 .iter()
341 .skip(2)
342 .filter(|value| value.is_object())
343 .cloned()
344 .collect();
345 let (sender, events) = {
346 let mut relay = lock_relay_state(state);
347 relay
348 .subscriptions
349 .entry(client_id)
350 .or_default()
351 .insert(subscription_id.to_string(), filters.clone());
352 (
353 relay.clients.get(&client_id).cloned(),
354 relay.events_by_id.values().cloned().collect::<Vec<_>>(),
355 )
356 };
357
358 if let Some(sender) = sender {
359 for event in events {
360 if matches_any_filter(&event, &filters) {
361 let payload =
362 Message::Text(json!(["EVENT", subscription_id, event]).to_string());
363 let _ = sender.send(payload);
364 }
365 }
366 let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
367 }
368 }
369 "CLOSE" if parts.len() >= 2 => {
370 let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
371 return;
372 };
373 let mut relay = lock_relay_state(state);
374 if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
375 subscriptions.remove(subscription_id);
376 }
377 }
378 "EVENT" if parts.get(1).is_some_and(Value::is_object) => {
379 let Some(event) = parts.get(1).cloned() else {
380 return;
381 };
382 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
383 return;
384 };
385 let event_id = event_id.to_string();
386 let (sender, deliveries, dropped) = {
387 let mut relay = lock_relay_state(state);
388 let sender = relay.clients.get(&client_id).cloned();
389 if relay.should_drop_event(&event_id) {
390 (sender, Vec::new(), true)
391 } else {
392 relay.events_by_id.insert(event_id.clone(), event.clone());
393 let deliveries = matching_deliveries(&relay, &event);
394 (sender, deliveries, false)
395 }
396 };
397 if dropped {
398 eprintln!("Local relay fault dropped event_id={event_id}");
399 }
400 if let Some(sender) = sender {
401 let message = if dropped {
402 "fault: dropped by local relay"
403 } else {
404 ""
405 };
406 let _ = sender.send(Message::Text(
407 json!(["OK", event_id, true, message]).to_string(),
408 ));
409 }
410 if dropped {
411 return;
412 }
413
414 for (target, payload) in deliveries {
415 let _ = target.send(payload);
416 }
417 }
418 _ => {}
419 }
420}
421
422fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
423 let deliveries = {
424 let relay = lock_relay_state(state);
425 relay
426 .events_by_id
427 .values()
428 .flat_map(|event| matching_deliveries(&relay, event))
429 .collect::<Vec<_>>()
430 };
431
432 for (target, payload) in deliveries {
433 let _ = target.send(payload);
434 }
435}
436
437fn matching_deliveries(
438 relay: &RelayState,
439 event: &Value,
440) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
441 let mut deliveries = Vec::new();
442 for (client_id, subscriptions) in &relay.subscriptions {
443 let Some(target) = relay.clients.get(client_id).cloned() else {
444 continue;
445 };
446 for (subscription_id, filters) in subscriptions {
447 if matches_any_filter(event, filters) {
448 deliveries.push((
449 target.clone(),
450 Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
451 ));
452 }
453 }
454 }
455 deliveries
456}
457
458pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
459 if filters.is_empty() {
460 return true;
461 }
462
463 filters.iter().any(|filter| matches_filter(event, filter))
464}
465
466pub fn matches_filter(event: &Value, filter: &Value) -> bool {
467 let Some(filter_object) = filter.as_object() else {
468 return false;
469 };
470
471 if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
472 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
473 return false;
474 };
475 if !ids
476 .iter()
477 .filter_map(Value::as_str)
478 .any(|id| id == event_id)
479 {
480 return false;
481 }
482 }
483
484 if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
485 let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
486 return false;
487 };
488 if !authors
489 .iter()
490 .filter_map(Value::as_str)
491 .any(|author| author == pubkey)
492 {
493 return false;
494 }
495 }
496
497 if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
498 let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
499 return false;
500 };
501 if !kinds
502 .iter()
503 .filter_map(Value::as_u64)
504 .any(|value| value == kind)
505 {
506 return false;
507 }
508 }
509
510 if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
511 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
512 return false;
513 };
514 if created_at < since {
515 return false;
516 }
517 }
518
519 if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
520 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
521 return false;
522 };
523 if created_at > until {
524 return false;
525 }
526 }
527
528 for (key, value) in filter_object {
529 let Some(tag_name) = key.strip_prefix('#') else {
530 continue;
531 };
532
533 let Some(expected_values) = value.as_array() else {
534 return false;
535 };
536 if expected_values.is_empty() {
537 continue;
538 }
539
540 let Some(tags) = event.get("tags").and_then(Value::as_array) else {
541 return false;
542 };
543 let matched = tags.iter().any(|tag| {
544 let Some(tag_values) = tag.as_array() else {
545 return false;
546 };
547 if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
548 return false;
549 }
550 tag_values
551 .iter()
552 .skip(1)
553 .filter_map(Value::as_str)
554 .any(|tag_value| {
555 expected_values
556 .iter()
557 .filter_map(Value::as_str)
558 .any(|expected| expected == tag_value)
559 })
560 });
561 if !matched {
562 return false;
563 }
564 }
565
566 true
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use std::io::Write;
573
574 #[test]
575 fn drop_event_ids_file_ignores_comments_and_blank_lines() {
576 let mut file = tempfile::NamedTempFile::new().expect("temp drop file");
577 writeln!(file, "\n# comment\nabc\n def # inline comment\n").expect("write drop file");
578
579 let ids = drop_event_ids(&file.path().to_path_buf());
580
581 assert!(ids.contains("abc"));
582 assert!(ids.contains("def"));
583 assert!(!ids.contains("# comment"));
584 }
585
586 #[test]
587 fn relay_fault_drops_matching_event_once_by_default() {
588 let mut file = tempfile::NamedTempFile::new().expect("temp drop file");
589 writeln!(file, "event-to-drop").expect("write drop file");
590 let mut state = RelayState {
591 faults: RelayFaults {
592 drop_event_ids_file: Some(file.path().to_path_buf()),
593 drop_matching_events_once: true,
594 },
595 ..RelayState::default()
596 };
597
598 assert!(state.should_drop_event("event-to-drop"));
599 assert!(!state.should_drop_event("event-to-drop"));
600 assert!(!state.should_drop_event("different-event"));
601 }
602
603 #[test]
604 fn relay_fault_can_drop_matching_event_every_time() {
605 let mut file = tempfile::NamedTempFile::new().expect("temp drop file");
606 writeln!(file, "event-to-drop").expect("write drop file");
607 let mut state = RelayState {
608 faults: RelayFaults {
609 drop_event_ids_file: Some(file.path().to_path_buf()),
610 drop_matching_events_once: false,
611 },
612 ..RelayState::default()
613 };
614
615 assert!(state.should_drop_event("event-to-drop"));
616 assert!(state.should_drop_event("event-to-drop"));
617 }
618}