Skip to main content

iris_chat_core/
local_relay.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex, MutexGuard};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{anyhow, Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17    events_by_id: BTreeMap<String, Value>,
18    subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19    clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22fn lock_relay_state(state: &Arc<Mutex<RelayState>>) -> MutexGuard<'_, RelayState> {
23    state.lock().unwrap_or_else(|poison| poison.into_inner())
24}
25
26enum RelayControl {
27    ReplayStored,
28    Snapshot(std_mpsc::Sender<Vec<Value>>),
29    Shutdown,
30}
31
32pub struct TestRelay {
33    control_tx: mpsc::UnboundedSender<RelayControl>,
34    join: Option<thread::JoinHandle<()>>,
35    url: String,
36}
37
38impl TestRelay {
39    pub fn start() -> Self {
40        match Self::start_with_bind("127.0.0.1:0") {
41            Ok(relay) => relay,
42            Err(error) => {
43                eprintln!("failed to start local relay: {error}");
44                let (control_tx, _) = mpsc::unbounded_channel();
45                Self {
46                    control_tx,
47                    join: None,
48                    url: String::new(),
49                }
50            }
51        }
52    }
53
54    pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
55        let (control_tx, mut control_rx) = mpsc::unbounded_channel();
56        let (ready_tx, ready_rx) = std_mpsc::channel();
57        let bind_addr = bind_addr.to_string();
58
59        let join = thread::spawn(move || {
60            let runtime = match tokio::runtime::Builder::new_multi_thread()
61                .enable_all()
62                .build()
63            {
64                Ok(runtime) => runtime,
65                Err(error) => {
66                    let _ = ready_tx.send(Err(anyhow!("relay runtime: {error}")));
67                    return;
68                }
69            };
70
71            runtime.block_on(async move {
72                let listener = match TcpListener::bind(&bind_addr)
73                    .await
74                    .with_context(|| format!("bind relay listener {bind_addr}"))
75                {
76                    Ok(listener) => listener,
77                    Err(error) => {
78                        let _ = ready_tx.send(Err(error));
79                        return;
80                    }
81                };
82                let local_addr = match listener.local_addr() {
83                    Ok(addr) => addr,
84                    Err(error) => {
85                        let _ = ready_tx.send(Err(anyhow!("relay local addr: {error}")));
86                        return;
87                    }
88                };
89                let state = Arc::new(Mutex::new(RelayState::default()));
90                let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
91                let _ = ready_tx.send(Ok(format!("ws://{local_addr}")));
92
93                loop {
94                    tokio::select! {
95                        Some(control) = control_rx.recv() => {
96                            match control {
97                                RelayControl::ReplayStored => replay_stored_events(&state),
98                                RelayControl::Snapshot(reply_tx) => {
99                                    let events = lock_relay_state(&state)
100                                        .events_by_id
101                                        .values()
102                                        .cloned()
103                                        .collect::<Vec<_>>();
104                                    let _ = reply_tx.send(events);
105                                }
106                                RelayControl::Shutdown => break,
107                            }
108                        }
109                        accept_result = listener.accept() => {
110                            let Ok((stream, _)) = accept_result else {
111                                break;
112                            };
113                            let websocket = match accept_async(stream).await {
114                                Ok(websocket) => websocket,
115                                Err(error) => {
116                                    eprintln!("Ignoring failed test relay websocket handshake: {error}");
117                                    continue;
118                                }
119                            };
120                            let state = state.clone();
121                            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
122                            tokio::spawn(async move {
123                                handle_connection(client_id, websocket, state).await;
124                            });
125                        }
126                    }
127                }
128            });
129        });
130
131        let url = ready_rx
132            .recv_timeout(StdDuration::from_secs(5))
133            .context("relay ready")??;
134
135        Ok(Self {
136            control_tx,
137            join: Some(join),
138            url,
139        })
140    }
141
142    pub fn url(&self) -> &str {
143        &self.url
144    }
145
146    pub fn replay_stored(&self) {
147        let _ = self.control_tx.send(RelayControl::ReplayStored);
148    }
149
150    pub fn events(&self) -> Vec<Value> {
151        let (reply_tx, reply_rx) = std_mpsc::channel();
152        let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
153        reply_rx
154            .recv_timeout(StdDuration::from_secs(5))
155            .unwrap_or_default()
156    }
157}
158
159impl Drop for TestRelay {
160    fn drop(&mut self) {
161        let _ = self.control_tx.send(RelayControl::Shutdown);
162        if let Some(join) = self.join.take() {
163            let _ = join.join();
164        }
165    }
166}
167
168pub fn run_forever(bind_addr: &str) -> Result<()> {
169    let runtime = tokio::runtime::Builder::new_multi_thread()
170        .enable_all()
171        .build()
172        .context("relay runtime")?;
173    let bind_addr = bind_addr.to_string();
174
175    runtime.block_on(async move {
176        let listener = TcpListener::bind(&bind_addr)
177            .await
178            .with_context(|| format!("bind relay listener {bind_addr}"))?;
179        let state = Arc::new(Mutex::new(RelayState::default()));
180        let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
181
182        println!("Local Nostr relay listening on ws://{bind_addr}");
183
184        loop {
185            let (stream, _) = listener
186                .accept()
187                .await
188                .with_context(|| format!("accept relay client on {bind_addr}"))?;
189            let websocket = match accept_async(stream).await {
190                Ok(websocket) => websocket,
191                Err(error) => {
192                    eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
193                    continue;
194                }
195            };
196            let state = state.clone();
197            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
198            tokio::spawn(async move {
199                handle_connection(client_id, websocket, state).await;
200            });
201        }
202    })
203}
204
205async fn handle_connection(
206    client_id: usize,
207    websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
208    state: Arc<Mutex<RelayState>>,
209) {
210    let (mut sink, mut stream) = websocket.split();
211    let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
212
213    {
214        let mut relay = lock_relay_state(&state);
215        relay.clients.insert(client_id, client_tx);
216    }
217
218    let writer = tokio::spawn(async move {
219        while let Some(message) = client_rx.recv().await {
220            if sink.send(message).await.is_err() {
221                break;
222            }
223        }
224    });
225
226    while let Some(message) = stream.next().await {
227        let Ok(message) = message else {
228            break;
229        };
230        match message {
231            Message::Text(text) => handle_client_message(client_id, &text, &state),
232            Message::Ping(payload) => {
233                let sender = {
234                    let relay = lock_relay_state(&state);
235                    relay.clients.get(&client_id).cloned()
236                };
237                if let Some(sender) = sender {
238                    let _ = sender.send(Message::Pong(payload));
239                }
240            }
241            Message::Close(_) => break,
242            _ => {}
243        }
244    }
245
246    {
247        let mut relay = lock_relay_state(&state);
248        relay.clients.remove(&client_id);
249        relay.subscriptions.remove(&client_id);
250    }
251
252    writer.abort();
253}
254
255fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
256    let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
257        return;
258    };
259    let Some(parts) = message.as_array() else {
260        return;
261    };
262    let Some(kind) = parts.first().and_then(Value::as_str) else {
263        return;
264    };
265
266    match kind {
267        "REQ" if parts.len() >= 2 => {
268            let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
269                return;
270            };
271            let filters: Vec<Value> = parts
272                .iter()
273                .skip(2)
274                .filter(|value| value.is_object())
275                .cloned()
276                .collect();
277            let (sender, events) = {
278                let mut relay = lock_relay_state(state);
279                relay
280                    .subscriptions
281                    .entry(client_id)
282                    .or_default()
283                    .insert(subscription_id.to_string(), filters.clone());
284                (
285                    relay.clients.get(&client_id).cloned(),
286                    relay.events_by_id.values().cloned().collect::<Vec<_>>(),
287                )
288            };
289
290            if let Some(sender) = sender {
291                for event in events {
292                    if matches_any_filter(&event, &filters) {
293                        let payload =
294                            Message::Text(json!(["EVENT", subscription_id, event]).to_string());
295                        let _ = sender.send(payload);
296                    }
297                }
298                let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
299            }
300        }
301        "CLOSE" if parts.len() >= 2 => {
302            let Some(subscription_id) = parts.get(1).and_then(Value::as_str) else {
303                return;
304            };
305            let mut relay = lock_relay_state(state);
306            if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
307                subscriptions.remove(subscription_id);
308            }
309        }
310        "EVENT" if parts.get(1).is_some_and(Value::is_object) => {
311            let Some(event) = parts.get(1).cloned() else {
312                return;
313            };
314            let Some(event_id) = event.get("id").and_then(Value::as_str) else {
315                return;
316            };
317            let (sender, deliveries) = {
318                let mut relay = lock_relay_state(state);
319                relay
320                    .events_by_id
321                    .insert(event_id.to_string(), event.clone());
322                let sender = relay.clients.get(&client_id).cloned();
323                let deliveries = matching_deliveries(&relay, &event);
324                (sender, deliveries)
325            };
326            if let Some(sender) = sender {
327                let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
328            }
329
330            for (target, payload) in deliveries {
331                let _ = target.send(payload);
332            }
333        }
334        _ => {}
335    }
336}
337
338fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
339    let deliveries = {
340        let relay = lock_relay_state(state);
341        relay
342            .events_by_id
343            .values()
344            .flat_map(|event| matching_deliveries(&relay, event))
345            .collect::<Vec<_>>()
346    };
347
348    for (target, payload) in deliveries {
349        let _ = target.send(payload);
350    }
351}
352
353fn matching_deliveries(
354    relay: &RelayState,
355    event: &Value,
356) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
357    let mut deliveries = Vec::new();
358    for (client_id, subscriptions) in &relay.subscriptions {
359        let Some(target) = relay.clients.get(client_id).cloned() else {
360            continue;
361        };
362        for (subscription_id, filters) in subscriptions {
363            if matches_any_filter(event, filters) {
364                deliveries.push((
365                    target.clone(),
366                    Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
367                ));
368            }
369        }
370    }
371    deliveries
372}
373
374pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
375    if filters.is_empty() {
376        return true;
377    }
378
379    filters.iter().any(|filter| matches_filter(event, filter))
380}
381
382pub fn matches_filter(event: &Value, filter: &Value) -> bool {
383    let Some(filter_object) = filter.as_object() else {
384        return false;
385    };
386
387    if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
388        let Some(event_id) = event.get("id").and_then(Value::as_str) else {
389            return false;
390        };
391        if !ids
392            .iter()
393            .filter_map(Value::as_str)
394            .any(|id| id == event_id)
395        {
396            return false;
397        }
398    }
399
400    if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
401        let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
402            return false;
403        };
404        if !authors
405            .iter()
406            .filter_map(Value::as_str)
407            .any(|author| author == pubkey)
408        {
409            return false;
410        }
411    }
412
413    if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
414        let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
415            return false;
416        };
417        if !kinds
418            .iter()
419            .filter_map(Value::as_u64)
420            .any(|value| value == kind)
421        {
422            return false;
423        }
424    }
425
426    if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
427        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
428            return false;
429        };
430        if created_at < since {
431            return false;
432        }
433    }
434
435    if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
436        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
437            return false;
438        };
439        if created_at > until {
440            return false;
441        }
442    }
443
444    for (key, value) in filter_object {
445        let Some(tag_name) = key.strip_prefix('#') else {
446            continue;
447        };
448
449        let Some(expected_values) = value.as_array() else {
450            return false;
451        };
452        if expected_values.is_empty() {
453            continue;
454        }
455
456        let Some(tags) = event.get("tags").and_then(Value::as_array) else {
457            return false;
458        };
459        let matched = tags.iter().any(|tag| {
460            let Some(tag_values) = tag.as_array() else {
461                return false;
462            };
463            if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
464                return false;
465            }
466            tag_values
467                .iter()
468                .skip(1)
469                .filter_map(Value::as_str)
470                .any(|tag_value| {
471                    expected_values
472                        .iter()
473                        .filter_map(Value::as_str)
474                        .any(|expected| expected == tag_value)
475                })
476        });
477        if !matched {
478            return false;
479        }
480    }
481
482    true
483}