Skip to main content

iris_chat_core/
local_relay.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17    events_by_id: BTreeMap<String, Value>,
18    subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19    clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22enum RelayControl {
23    ReplayStored,
24    Snapshot(std_mpsc::Sender<Vec<Value>>),
25    Shutdown,
26}
27
28pub struct TestRelay {
29    control_tx: mpsc::UnboundedSender<RelayControl>,
30    join: Option<thread::JoinHandle<()>>,
31    url: String,
32}
33
34impl TestRelay {
35    pub fn start() -> Self {
36        Self::start_with_bind("127.0.0.1:0").expect("start relay")
37    }
38
39    pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
40        let (control_tx, mut control_rx) = mpsc::unbounded_channel();
41        let (ready_tx, ready_rx) = std_mpsc::channel();
42        let bind_addr = bind_addr.to_string();
43
44        let join = thread::spawn(move || {
45            let runtime = tokio::runtime::Builder::new_multi_thread()
46                .enable_all()
47                .build()
48                .expect("relay runtime");
49
50            runtime.block_on(async move {
51                let listener = TcpListener::bind(&bind_addr)
52                    .await
53                    .with_context(|| format!("bind relay listener {bind_addr}"))
54                    .expect("bind relay listener");
55                let local_addr = listener.local_addr().expect("relay local addr");
56                let state = Arc::new(Mutex::new(RelayState::default()));
57                let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
58                ready_tx
59                    .send(format!("ws://{local_addr}"))
60                    .expect("signal relay ready");
61
62                loop {
63                    tokio::select! {
64                        Some(control) = control_rx.recv() => {
65                            match control {
66                                RelayControl::ReplayStored => replay_stored_events(&state),
67                                RelayControl::Snapshot(reply_tx) => {
68                                    let events = state
69                                        .lock()
70                                        .expect("relay state lock")
71                                        .events_by_id
72                                        .values()
73                                        .cloned()
74                                        .collect::<Vec<_>>();
75                                    let _ = reply_tx.send(events);
76                                }
77                                RelayControl::Shutdown => break,
78                            }
79                        }
80                        accept_result = listener.accept() => {
81                            let (stream, _) = accept_result.expect("accept relay client");
82                            let websocket = accept_async(stream).await.expect("accept websocket");
83                            let state = state.clone();
84                            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
85                            tokio::spawn(async move {
86                                handle_connection(client_id, websocket, state).await;
87                            });
88                        }
89                    }
90                }
91            });
92        });
93
94        let url = ready_rx
95            .recv_timeout(StdDuration::from_secs(5))
96            .context("relay ready")?;
97
98        Ok(Self {
99            control_tx,
100            join: Some(join),
101            url,
102        })
103    }
104
105    pub fn url(&self) -> &str {
106        &self.url
107    }
108
109    pub fn replay_stored(&self) {
110        let _ = self.control_tx.send(RelayControl::ReplayStored);
111    }
112
113    pub fn events(&self) -> Vec<Value> {
114        let (reply_tx, reply_rx) = std_mpsc::channel();
115        let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
116        reply_rx
117            .recv_timeout(StdDuration::from_secs(5))
118            .unwrap_or_default()
119    }
120}
121
122impl Drop for TestRelay {
123    fn drop(&mut self) {
124        let _ = self.control_tx.send(RelayControl::Shutdown);
125        if let Some(join) = self.join.take() {
126            let _ = join.join();
127        }
128    }
129}
130
131pub fn run_forever(bind_addr: &str) -> Result<()> {
132    let runtime = tokio::runtime::Builder::new_multi_thread()
133        .enable_all()
134        .build()
135        .context("relay runtime")?;
136    let bind_addr = bind_addr.to_string();
137
138    runtime.block_on(async move {
139        let listener = TcpListener::bind(&bind_addr)
140            .await
141            .with_context(|| format!("bind relay listener {bind_addr}"))?;
142        let state = Arc::new(Mutex::new(RelayState::default()));
143        let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
144
145        println!("Local Nostr relay listening on ws://{bind_addr}");
146
147        loop {
148            let (stream, _) = listener
149                .accept()
150                .await
151                .with_context(|| format!("accept relay client on {bind_addr}"))?;
152            let websocket = match accept_async(stream).await {
153                Ok(websocket) => websocket,
154                Err(error) => {
155                    eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
156                    continue;
157                }
158            };
159            let state = state.clone();
160            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
161            tokio::spawn(async move {
162                handle_connection(client_id, websocket, state).await;
163            });
164        }
165    })
166}
167
168async fn handle_connection(
169    client_id: usize,
170    websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
171    state: Arc<Mutex<RelayState>>,
172) {
173    let (mut sink, mut stream) = websocket.split();
174    let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
175
176    {
177        let mut relay = state.lock().expect("relay state lock");
178        relay.clients.insert(client_id, client_tx);
179    }
180
181    let writer = tokio::spawn(async move {
182        while let Some(message) = client_rx.recv().await {
183            if sink.send(message).await.is_err() {
184                break;
185            }
186        }
187    });
188
189    while let Some(message) = stream.next().await {
190        let Ok(message) = message else {
191            break;
192        };
193        match message {
194            Message::Text(text) => handle_client_message(client_id, &text, &state),
195            Message::Ping(payload) => {
196                let sender = {
197                    let relay = state.lock().expect("relay state lock");
198                    relay.clients.get(&client_id).cloned()
199                };
200                if let Some(sender) = sender {
201                    let _ = sender.send(Message::Pong(payload));
202                }
203            }
204            Message::Close(_) => break,
205            _ => {}
206        }
207    }
208
209    {
210        let mut relay = state.lock().expect("relay state lock");
211        relay.clients.remove(&client_id);
212        relay.subscriptions.remove(&client_id);
213    }
214
215    writer.abort();
216}
217
218fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
219    let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
220        return;
221    };
222    let Some(parts) = message.as_array() else {
223        return;
224    };
225    let Some(kind) = parts.first().and_then(Value::as_str) else {
226        return;
227    };
228
229    match kind {
230        "REQ" if parts.len() >= 2 => {
231            let Some(subscription_id) = parts[1].as_str() else {
232                return;
233            };
234            let filters: Vec<Value> = parts
235                .iter()
236                .skip(2)
237                .filter(|value| value.is_object())
238                .cloned()
239                .collect();
240            let (sender, events) = {
241                let mut relay = state.lock().expect("relay state lock");
242                relay
243                    .subscriptions
244                    .entry(client_id)
245                    .or_default()
246                    .insert(subscription_id.to_string(), filters.clone());
247                (
248                    relay.clients.get(&client_id).cloned(),
249                    relay.events_by_id.values().cloned().collect::<Vec<_>>(),
250                )
251            };
252
253            if let Some(sender) = sender {
254                for event in events {
255                    if matches_any_filter(&event, &filters) {
256                        let payload =
257                            Message::Text(json!(["EVENT", subscription_id, event]).to_string());
258                        let _ = sender.send(payload);
259                    }
260                }
261                let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
262            }
263        }
264        "CLOSE" if parts.len() >= 2 => {
265            let Some(subscription_id) = parts[1].as_str() else {
266                return;
267            };
268            let mut relay = state.lock().expect("relay state lock");
269            if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
270                subscriptions.remove(subscription_id);
271            }
272        }
273        "EVENT" if parts.len() >= 2 && parts[1].is_object() => {
274            let event = parts[1].clone();
275            let Some(event_id) = event.get("id").and_then(Value::as_str) else {
276                return;
277            };
278            let (sender, deliveries) = {
279                let mut relay = state.lock().expect("relay state lock");
280                relay
281                    .events_by_id
282                    .insert(event_id.to_string(), event.clone());
283                let sender = relay.clients.get(&client_id).cloned();
284                let deliveries = matching_deliveries(&relay, &event);
285                (sender, deliveries)
286            };
287            if let Some(sender) = sender {
288                let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
289            }
290
291            for (target, payload) in deliveries {
292                let _ = target.send(payload);
293            }
294        }
295        _ => {}
296    }
297}
298
299fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
300    let deliveries = {
301        let relay = state.lock().expect("relay state lock");
302        relay
303            .events_by_id
304            .values()
305            .flat_map(|event| matching_deliveries(&relay, event))
306            .collect::<Vec<_>>()
307    };
308
309    for (target, payload) in deliveries {
310        let _ = target.send(payload);
311    }
312}
313
314fn matching_deliveries(
315    relay: &RelayState,
316    event: &Value,
317) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
318    let mut deliveries = Vec::new();
319    for (client_id, subscriptions) in &relay.subscriptions {
320        let Some(target) = relay.clients.get(client_id).cloned() else {
321            continue;
322        };
323        for (subscription_id, filters) in subscriptions {
324            if matches_any_filter(event, filters) {
325                deliveries.push((
326                    target.clone(),
327                    Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
328                ));
329            }
330        }
331    }
332    deliveries
333}
334
335pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
336    if filters.is_empty() {
337        return true;
338    }
339
340    filters.iter().any(|filter| matches_filter(event, filter))
341}
342
343pub fn matches_filter(event: &Value, filter: &Value) -> bool {
344    let Some(filter_object) = filter.as_object() else {
345        return false;
346    };
347
348    if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
349        let Some(event_id) = event.get("id").and_then(Value::as_str) else {
350            return false;
351        };
352        if !ids
353            .iter()
354            .filter_map(Value::as_str)
355            .any(|id| id == event_id)
356        {
357            return false;
358        }
359    }
360
361    if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
362        let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
363            return false;
364        };
365        if !authors
366            .iter()
367            .filter_map(Value::as_str)
368            .any(|author| author == pubkey)
369        {
370            return false;
371        }
372    }
373
374    if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
375        let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
376            return false;
377        };
378        if !kinds
379            .iter()
380            .filter_map(Value::as_u64)
381            .any(|value| value == kind)
382        {
383            return false;
384        }
385    }
386
387    if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
388        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
389            return false;
390        };
391        if created_at < since {
392            return false;
393        }
394    }
395
396    if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
397        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
398            return false;
399        };
400        if created_at > until {
401            return false;
402        }
403    }
404
405    for (key, value) in filter_object {
406        let Some(tag_name) = key.strip_prefix('#') else {
407            continue;
408        };
409
410        let Some(expected_values) = value.as_array() else {
411            return false;
412        };
413        if expected_values.is_empty() {
414            continue;
415        }
416
417        let Some(tags) = event.get("tags").and_then(Value::as_array) else {
418            return false;
419        };
420        let matched = tags.iter().any(|tag| {
421            let Some(tag_values) = tag.as_array() else {
422                return false;
423            };
424            if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
425                return false;
426            }
427            tag_values
428                .iter()
429                .skip(1)
430                .filter_map(Value::as_str)
431                .any(|tag_value| {
432                    expected_values
433                        .iter()
434                        .filter_map(Value::as_str)
435                        .any(|expected| expected == tag_value)
436                })
437        });
438        if !matched {
439            return false;
440        }
441    }
442
443    true
444}