Skip to main content

iris_chat_core/
local_relay.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::path::{Path, PathBuf};
3use std::sync::mpsc as std_mpsc;
4use std::sync::{Arc, Mutex, MutexGuard};
5use std::thread;
6use std::time::Duration as StdDuration;
7
8use anyhow::{anyhow, Context, Result};
9use futures_util::{SinkExt, StreamExt};
10use serde_json::{json, Value};
11use tokio::net::TcpListener;
12use tokio::sync::mpsc;
13use tokio_tungstenite::accept_async;
14use tokio_tungstenite::tungstenite::Message;
15
16#[derive(Default)]
17struct RelayState {
18    events_by_id: BTreeMap<String, Value>,
19    subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
20    clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
21    faults: RelayFaults,
22    dropped_event_ids: HashSet<String>,
23}
24
25#[derive(Clone, Default)]
26struct RelayFaults {
27    drop_event_ids_file: Option<PathBuf>,
28    drop_matching_events_once: bool,
29}
30
31impl RelayState {
32    fn from_env() -> Self {
33        Self {
34            faults: RelayFaults::from_env(),
35            ..Self::default()
36        }
37    }
38
39    fn should_drop_event(&mut self, event_id: &str) -> bool {
40        let Some(path) = self.faults.drop_event_ids_file.as_ref() else {
41            return false;
42        };
43        if self.faults.drop_matching_events_once && self.dropped_event_ids.contains(event_id) {
44            return false;
45        }
46        if !drop_event_ids(path).contains(event_id) {
47            return false;
48        }
49        self.dropped_event_ids.insert(event_id.to_string());
50        true
51    }
52}
53
54impl RelayFaults {
55    fn from_env() -> Self {
56        let drop_event_ids_file = std::env::var_os("IRIS_LOCAL_RELAY_DROP_EVENT_IDS_FILE")
57            .filter(|value| !value.is_empty())
58            .map(PathBuf::from);
59        let drop_matching_events_once = !env_flag("IRIS_LOCAL_RELAY_DROP_EVENT_IDS_ALWAYS");
60        Self {
61            drop_event_ids_file,
62            drop_matching_events_once,
63        }
64    }
65}
66
67fn env_flag(name: &str) -> bool {
68    matches!(
69        std::env::var(name)
70            .unwrap_or_default()
71            .trim()
72            .to_ascii_lowercase()
73            .as_str(),
74        "1" | "true" | "yes" | "on"
75    )
76}
77
78fn drop_event_ids(path: &Path) -> HashSet<String> {
79    let Ok(raw) = std::fs::read_to_string(path) else {
80        return HashSet::new();
81    };
82    raw.lines()
83        .filter_map(|line| line.split('#').next())
84        .map(str::trim)
85        .filter(|line| !line.is_empty())
86        .map(str::to_string)
87        .collect()
88}
89
90fn lock_relay_state(state: &Arc<Mutex<RelayState>>) -> MutexGuard<'_, RelayState> {
91    state.lock().unwrap_or_else(|poison| poison.into_inner())
92}
93
94enum RelayControl {
95    ReplayStored,
96    Snapshot(std_mpsc::Sender<Vec<Value>>),
97    Shutdown,
98}
99
100pub struct TestRelay {
101    control_tx: mpsc::UnboundedSender<RelayControl>,
102    join: Option<thread::JoinHandle<()>>,
103    url: String,
104}
105
106impl TestRelay {
107    pub fn start() -> Self {
108        match Self::start_with_bind("127.0.0.1:0") {
109            Ok(relay) => relay,
110            Err(error) => {
111                eprintln!("failed to start local relay: {error}");
112                let (control_tx, _) = mpsc::unbounded_channel();
113                Self {
114                    control_tx,
115                    join: None,
116                    url: String::new(),
117                }
118            }
119        }
120    }
121
122    pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
123        let (control_tx, mut control_rx) = mpsc::unbounded_channel();
124        let (ready_tx, ready_rx) = std_mpsc::channel();
125        let bind_addr = bind_addr.to_string();
126
127        let join = thread::spawn(move || {
128            let runtime = match tokio::runtime::Builder::new_multi_thread()
129                .enable_all()
130                .build()
131            {
132                Ok(runtime) => runtime,
133                Err(error) => {
134                    let _ = ready_tx.send(Err(anyhow!("relay runtime: {error}")));
135                    return;
136                }
137            };
138
139            runtime.block_on(async move {
140                let listener = match TcpListener::bind(&bind_addr)
141                    .await
142                    .with_context(|| format!("bind relay listener {bind_addr}"))
143                {
144                    Ok(listener) => listener,
145                    Err(error) => {
146                        let _ = ready_tx.send(Err(error));
147                        return;
148                    }
149                };
150                let local_addr = match listener.local_addr() {
151                    Ok(addr) => addr,
152                    Err(error) => {
153                        let _ = ready_tx.send(Err(anyhow!("relay local addr: {error}")));
154                        return;
155                    }
156                };
157                let state = Arc::new(Mutex::new(RelayState::default()));
158                let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
159                let _ = ready_tx.send(Ok(format!("ws://{local_addr}")));
160
161                loop {
162                    tokio::select! {
163                        Some(control) = control_rx.recv() => {
164                            match control {
165                                RelayControl::ReplayStored => replay_stored_events(&state),
166                                RelayControl::Snapshot(reply_tx) => {
167                                    let events = lock_relay_state(&state)
168                                        .events_by_id
169                                        .values()
170                                        .cloned()
171                                        .collect::<Vec<_>>();
172                                    let _ = reply_tx.send(events);
173                                }
174                                RelayControl::Shutdown => break,
175                            }
176                        }
177                        accept_result = listener.accept() => {
178                            let Ok((stream, _)) = accept_result else {
179                                break;
180                            };
181                            let websocket = match accept_async(stream).await {
182                                Ok(websocket) => websocket,
183                                Err(error) => {
184                                    eprintln!("Ignoring failed test relay websocket handshake: {error}");
185                                    continue;
186                                }
187                            };
188                            let state = state.clone();
189                            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
190                            tokio::spawn(async move {
191                                handle_connection(client_id, websocket, state).await;
192                            });
193                        }
194                    }
195                }
196            });
197        });
198
199        let url = ready_rx
200            .recv_timeout(StdDuration::from_secs(5))
201            .context("relay ready")??;
202
203        Ok(Self {
204            control_tx,
205            join: Some(join),
206            url,
207        })
208    }
209
210    pub fn url(&self) -> &str {
211        &self.url
212    }
213
214    pub fn replay_stored(&self) {
215        let _ = self.control_tx.send(RelayControl::ReplayStored);
216    }
217
218    pub fn events(&self) -> Vec<Value> {
219        let (reply_tx, reply_rx) = std_mpsc::channel();
220        let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
221        reply_rx
222            .recv_timeout(StdDuration::from_secs(5))
223            .unwrap_or_default()
224    }
225}
226
227impl Drop for TestRelay {
228    fn drop(&mut self) {
229        let _ = self.control_tx.send(RelayControl::Shutdown);
230        if let Some(join) = self.join.take() {
231            let _ = join.join();
232        }
233    }
234}
235
236pub fn run_forever(bind_addr: &str) -> Result<()> {
237    let runtime = tokio::runtime::Builder::new_multi_thread()
238        .enable_all()
239        .build()
240        .context("relay runtime")?;
241    let bind_addr = bind_addr.to_string();
242
243    runtime.block_on(async move {
244        let listener = TcpListener::bind(&bind_addr)
245            .await
246            .with_context(|| format!("bind relay listener {bind_addr}"))?;
247        let state = Arc::new(Mutex::new(RelayState::from_env()));
248        let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
249
250        println!("Local Nostr relay listening on ws://{bind_addr}");
251
252        loop {
253            let (stream, _) = listener
254                .accept()
255                .await
256                .with_context(|| format!("accept relay client on {bind_addr}"))?;
257            let websocket = match accept_async(stream).await {
258                Ok(websocket) => websocket,
259                Err(error) => {
260                    eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
261                    continue;
262                }
263            };
264            let state = state.clone();
265            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
266            tokio::spawn(async move {
267                handle_connection(client_id, websocket, state).await;
268            });
269        }
270    })
271}
272
273async fn handle_connection(
274    client_id: usize,
275    websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
276    state: Arc<Mutex<RelayState>>,
277) {
278    let (mut sink, mut stream) = websocket.split();
279    let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
280
281    {
282        let mut relay = lock_relay_state(&state);
283        relay.clients.insert(client_id, client_tx);
284    }
285
286    let writer = tokio::spawn(async move {
287        while let Some(message) = client_rx.recv().await {
288            if sink.send(message).await.is_err() {
289                break;
290            }
291        }
292    });
293
294    while let Some(message) = stream.next().await {
295        let Ok(message) = message else {
296            break;
297        };
298        match message {
299            Message::Text(text) => handle_client_message(client_id, &text, &state),
300            Message::Ping(payload) => {
301                let sender = {
302                    let relay = lock_relay_state(&state);
303                    relay.clients.get(&client_id).cloned()
304                };
305                if let Some(sender) = sender {
306                    let _ = sender.send(Message::Pong(payload));
307                }
308            }
309            Message::Close(_) => break,
310            _ => {}
311        }
312    }
313
314    {
315        let mut relay = lock_relay_state(&state);
316        relay.clients.remove(&client_id);
317        relay.subscriptions.remove(&client_id);
318    }
319
320    writer.abort();
321}
322
323fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
324    let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
325        return;
326    };
327    let Some(parts) = message.as_array() else {
328        return;
329    };
330    let Some(kind) = parts.first().and_then(Value::as_str) else {
331        return;
332    };
333
334    match kind {
335        "REQ" if parts.len() >= 2 => {
336            let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
337                return;
338            };
339            let filters: Vec<Value> = parts
340                .iter()
341                .skip(2)
342                .filter(|value| value.is_object())
343                .cloned()
344                .collect();
345            let (sender, events) = {
346                let mut relay = lock_relay_state(state);
347                relay
348                    .subscriptions
349                    .entry(client_id)
350                    .or_default()
351                    .insert(subscription_id.to_string(), filters.clone());
352                (
353                    relay.clients.get(&client_id).cloned(),
354                    relay.events_by_id.values().cloned().collect::<Vec<_>>(),
355                )
356            };
357
358            if let Some(sender) = sender {
359                for event in events {
360                    if matches_any_filter(&event, &filters) {
361                        let payload =
362                            Message::Text(json!(["EVENT", subscription_id, event]).to_string());
363                        let _ = sender.send(payload);
364                    }
365                }
366                let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
367            }
368        }
369        "CLOSE" if parts.len() >= 2 => {
370            let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
371                return;
372            };
373            let mut relay = lock_relay_state(state);
374            if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
375                subscriptions.remove(subscription_id);
376            }
377        }
378        "EVENT" if parts.get(1).is_some_and(Value::is_object) => {
379            let Some(event) = parts.get(1).cloned() else {
380                return;
381            };
382            let Some(event_id) = event.get("id").and_then(Value::as_str) else {
383                return;
384            };
385            let event_id = event_id.to_string();
386            let (sender, deliveries, dropped) = {
387                let mut relay = lock_relay_state(state);
388                let sender = relay.clients.get(&client_id).cloned();
389                if relay.should_drop_event(&event_id) {
390                    (sender, Vec::new(), true)
391                } else {
392                    relay.events_by_id.insert(event_id.clone(), event.clone());
393                    let deliveries = matching_deliveries(&relay, &event);
394                    (sender, deliveries, false)
395                }
396            };
397            if dropped {
398                eprintln!("Local relay fault dropped event_id={event_id}");
399            }
400            if let Some(sender) = sender {
401                let message = if dropped {
402                    "fault: dropped by local relay"
403                } else {
404                    ""
405                };
406                let _ = sender.send(Message::Text(
407                    json!(["OK", event_id, true, message]).to_string(),
408                ));
409            }
410            if dropped {
411                return;
412            }
413
414            for (target, payload) in deliveries {
415                let _ = target.send(payload);
416            }
417        }
418        _ => {}
419    }
420}
421
422fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
423    let deliveries = {
424        let relay = lock_relay_state(state);
425        relay
426            .events_by_id
427            .values()
428            .flat_map(|event| matching_deliveries(&relay, event))
429            .collect::<Vec<_>>()
430    };
431
432    for (target, payload) in deliveries {
433        let _ = target.send(payload);
434    }
435}
436
437fn matching_deliveries(
438    relay: &RelayState,
439    event: &Value,
440) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
441    let mut deliveries = Vec::new();
442    for (client_id, subscriptions) in &relay.subscriptions {
443        let Some(target) = relay.clients.get(client_id).cloned() else {
444            continue;
445        };
446        for (subscription_id, filters) in subscriptions {
447            if matches_any_filter(event, filters) {
448                deliveries.push((
449                    target.clone(),
450                    Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
451                ));
452            }
453        }
454    }
455    deliveries
456}
457
458pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
459    if filters.is_empty() {
460        return true;
461    }
462
463    filters.iter().any(|filter| matches_filter(event, filter))
464}
465
466pub fn matches_filter(event: &Value, filter: &Value) -> bool {
467    let Some(filter_object) = filter.as_object() else {
468        return false;
469    };
470
471    if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
472        let Some(event_id) = event.get("id").and_then(Value::as_str) else {
473            return false;
474        };
475        if !ids
476            .iter()
477            .filter_map(Value::as_str)
478            .any(|id| id == event_id)
479        {
480            return false;
481        }
482    }
483
484    if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
485        let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
486            return false;
487        };
488        if !authors
489            .iter()
490            .filter_map(Value::as_str)
491            .any(|author| author == pubkey)
492        {
493            return false;
494        }
495    }
496
497    if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
498        let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
499            return false;
500        };
501        if !kinds
502            .iter()
503            .filter_map(Value::as_u64)
504            .any(|value| value == kind)
505        {
506            return false;
507        }
508    }
509
510    if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
511        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
512            return false;
513        };
514        if created_at < since {
515            return false;
516        }
517    }
518
519    if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
520        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
521            return false;
522        };
523        if created_at > until {
524            return false;
525        }
526    }
527
528    for (key, value) in filter_object {
529        let Some(tag_name) = key.strip_prefix('#') else {
530            continue;
531        };
532
533        let Some(expected_values) = value.as_array() else {
534            return false;
535        };
536        if expected_values.is_empty() {
537            continue;
538        }
539
540        let Some(tags) = event.get("tags").and_then(Value::as_array) else {
541            return false;
542        };
543        let matched = tags.iter().any(|tag| {
544            let Some(tag_values) = tag.as_array() else {
545                return false;
546            };
547            if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
548                return false;
549            }
550            tag_values
551                .iter()
552                .skip(1)
553                .filter_map(Value::as_str)
554                .any(|tag_value| {
555                    expected_values
556                        .iter()
557                        .filter_map(Value::as_str)
558                        .any(|expected| expected == tag_value)
559                })
560        });
561        if !matched {
562            return false;
563        }
564    }
565
566    true
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use std::io::Write;
573
574    #[test]
575    fn drop_event_ids_file_ignores_comments_and_blank_lines() {
576        let mut file = tempfile::NamedTempFile::new().expect("temp drop file");
577        writeln!(file, "\n# comment\nabc\n  def  # inline comment\n").expect("write drop file");
578
579        let ids = drop_event_ids(&file.path().to_path_buf());
580
581        assert!(ids.contains("abc"));
582        assert!(ids.contains("def"));
583        assert!(!ids.contains("# comment"));
584    }
585
586    #[test]
587    fn relay_fault_drops_matching_event_once_by_default() {
588        let mut file = tempfile::NamedTempFile::new().expect("temp drop file");
589        writeln!(file, "event-to-drop").expect("write drop file");
590        let mut state = RelayState {
591            faults: RelayFaults {
592                drop_event_ids_file: Some(file.path().to_path_buf()),
593                drop_matching_events_once: true,
594            },
595            ..RelayState::default()
596        };
597
598        assert!(state.should_drop_event("event-to-drop"));
599        assert!(!state.should_drop_event("event-to-drop"));
600        assert!(!state.should_drop_event("different-event"));
601    }
602
603    #[test]
604    fn relay_fault_can_drop_matching_event_every_time() {
605        let mut file = tempfile::NamedTempFile::new().expect("temp drop file");
606        writeln!(file, "event-to-drop").expect("write drop file");
607        let mut state = RelayState {
608            faults: RelayFaults {
609                drop_event_ids_file: Some(file.path().to_path_buf()),
610                drop_matching_events_once: false,
611            },
612            ..RelayState::default()
613        };
614
615        assert!(state.should_drop_event("event-to-drop"));
616        assert!(state.should_drop_event("event-to-drop"));
617    }
618}