Skip to main content

iris_chat_core/
local_relay.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17    events_by_id: BTreeMap<String, Value>,
18    subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19    clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22enum RelayControl {
23    ReplayStored,
24    Snapshot(std_mpsc::Sender<Vec<Value>>),
25    Shutdown,
26}
27
28pub struct TestRelay {
29    control_tx: mpsc::UnboundedSender<RelayControl>,
30    join: Option<thread::JoinHandle<()>>,
31    url: String,
32}
33
34impl TestRelay {
35    pub fn start() -> Self {
36        Self::start_with_bind("127.0.0.1:0").expect("start relay")
37    }
38
39    pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
40        let (control_tx, mut control_rx) = mpsc::unbounded_channel();
41        let (ready_tx, ready_rx) = std_mpsc::channel();
42        let bind_addr = bind_addr.to_string();
43
44        let join = thread::spawn(move || {
45            let runtime = tokio::runtime::Builder::new_multi_thread()
46                .enable_all()
47                .build()
48                .expect("relay runtime");
49
50            runtime.block_on(async move {
51                let listener = TcpListener::bind(&bind_addr)
52                    .await
53                    .with_context(|| format!("bind relay listener {bind_addr}"))
54                    .expect("bind relay listener");
55                let local_addr = listener.local_addr().expect("relay local addr");
56                let state = Arc::new(Mutex::new(RelayState::default()));
57                let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
58                ready_tx
59                    .send(format!("ws://{local_addr}"))
60                    .expect("signal relay ready");
61
62                loop {
63                    tokio::select! {
64                        Some(control) = control_rx.recv() => {
65                            match control {
66                                RelayControl::ReplayStored => replay_stored_events(&state),
67                                RelayControl::Snapshot(reply_tx) => {
68                                    let events = state
69                                        .lock()
70                                        .expect("relay state lock")
71                                        .events_by_id
72                                        .values()
73                                        .cloned()
74                                        .collect::<Vec<_>>();
75                                    let _ = reply_tx.send(events);
76                                }
77                                RelayControl::Shutdown => break,
78                            }
79                        }
80                        accept_result = listener.accept() => {
81                            let (stream, _) = accept_result.expect("accept relay client");
82                            let websocket = match accept_async(stream).await {
83                                Ok(websocket) => websocket,
84                                Err(error) => {
85                                    eprintln!("Ignoring failed test relay websocket handshake: {error}");
86                                    continue;
87                                }
88                            };
89                            let state = state.clone();
90                            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
91                            tokio::spawn(async move {
92                                handle_connection(client_id, websocket, state).await;
93                            });
94                        }
95                    }
96                }
97            });
98        });
99
100        let url = ready_rx
101            .recv_timeout(StdDuration::from_secs(5))
102            .context("relay ready")?;
103
104        Ok(Self {
105            control_tx,
106            join: Some(join),
107            url,
108        })
109    }
110
111    pub fn url(&self) -> &str {
112        &self.url
113    }
114
115    pub fn replay_stored(&self) {
116        let _ = self.control_tx.send(RelayControl::ReplayStored);
117    }
118
119    pub fn events(&self) -> Vec<Value> {
120        let (reply_tx, reply_rx) = std_mpsc::channel();
121        let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
122        reply_rx
123            .recv_timeout(StdDuration::from_secs(5))
124            .unwrap_or_default()
125    }
126}
127
128impl Drop for TestRelay {
129    fn drop(&mut self) {
130        let _ = self.control_tx.send(RelayControl::Shutdown);
131        if let Some(join) = self.join.take() {
132            let _ = join.join();
133        }
134    }
135}
136
137pub fn run_forever(bind_addr: &str) -> Result<()> {
138    let runtime = tokio::runtime::Builder::new_multi_thread()
139        .enable_all()
140        .build()
141        .context("relay runtime")?;
142    let bind_addr = bind_addr.to_string();
143
144    runtime.block_on(async move {
145        let listener = TcpListener::bind(&bind_addr)
146            .await
147            .with_context(|| format!("bind relay listener {bind_addr}"))?;
148        let state = Arc::new(Mutex::new(RelayState::default()));
149        let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
150
151        println!("Local Nostr relay listening on ws://{bind_addr}");
152
153        loop {
154            let (stream, _) = listener
155                .accept()
156                .await
157                .with_context(|| format!("accept relay client on {bind_addr}"))?;
158            let websocket = match accept_async(stream).await {
159                Ok(websocket) => websocket,
160                Err(error) => {
161                    eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
162                    continue;
163                }
164            };
165            let state = state.clone();
166            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
167            tokio::spawn(async move {
168                handle_connection(client_id, websocket, state).await;
169            });
170        }
171    })
172}
173
174async fn handle_connection(
175    client_id: usize,
176    websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
177    state: Arc<Mutex<RelayState>>,
178) {
179    let (mut sink, mut stream) = websocket.split();
180    let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
181
182    {
183        let mut relay = state.lock().expect("relay state lock");
184        relay.clients.insert(client_id, client_tx);
185    }
186
187    let writer = tokio::spawn(async move {
188        while let Some(message) = client_rx.recv().await {
189            if sink.send(message).await.is_err() {
190                break;
191            }
192        }
193    });
194
195    while let Some(message) = stream.next().await {
196        let Ok(message) = message else {
197            break;
198        };
199        match message {
200            Message::Text(text) => handle_client_message(client_id, &text, &state),
201            Message::Ping(payload) => {
202                let sender = {
203                    let relay = state.lock().expect("relay state lock");
204                    relay.clients.get(&client_id).cloned()
205                };
206                if let Some(sender) = sender {
207                    let _ = sender.send(Message::Pong(payload));
208                }
209            }
210            Message::Close(_) => break,
211            _ => {}
212        }
213    }
214
215    {
216        let mut relay = state.lock().expect("relay state lock");
217        relay.clients.remove(&client_id);
218        relay.subscriptions.remove(&client_id);
219    }
220
221    writer.abort();
222}
223
224fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
225    let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
226        return;
227    };
228    let Some(parts) = message.as_array() else {
229        return;
230    };
231    let Some(kind) = parts.first().and_then(Value::as_str) else {
232        return;
233    };
234
235    match kind {
236        "REQ" if parts.len() >= 2 => {
237            let Some(subscription_id) = parts[1].as_str() else {
238                return;
239            };
240            let filters: Vec<Value> = parts
241                .iter()
242                .skip(2)
243                .filter(|value| value.is_object())
244                .cloned()
245                .collect();
246            let (sender, events) = {
247                let mut relay = state.lock().expect("relay state lock");
248                relay
249                    .subscriptions
250                    .entry(client_id)
251                    .or_default()
252                    .insert(subscription_id.to_string(), filters.clone());
253                (
254                    relay.clients.get(&client_id).cloned(),
255                    relay.events_by_id.values().cloned().collect::<Vec<_>>(),
256                )
257            };
258
259            if let Some(sender) = sender {
260                for event in events {
261                    if matches_any_filter(&event, &filters) {
262                        let payload =
263                            Message::Text(json!(["EVENT", subscription_id, event]).to_string());
264                        let _ = sender.send(payload);
265                    }
266                }
267                let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
268            }
269        }
270        "CLOSE" if parts.len() >= 2 => {
271            let Some(subscription_id) = parts[1].as_str() else {
272                return;
273            };
274            let mut relay = state.lock().expect("relay state lock");
275            if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
276                subscriptions.remove(subscription_id);
277            }
278        }
279        "EVENT" if parts.len() >= 2 && parts[1].is_object() => {
280            let event = parts[1].clone();
281            let Some(event_id) = event.get("id").and_then(Value::as_str) else {
282                return;
283            };
284            let (sender, deliveries) = {
285                let mut relay = state.lock().expect("relay state lock");
286                relay
287                    .events_by_id
288                    .insert(event_id.to_string(), event.clone());
289                let sender = relay.clients.get(&client_id).cloned();
290                let deliveries = matching_deliveries(&relay, &event);
291                (sender, deliveries)
292            };
293            if let Some(sender) = sender {
294                let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
295            }
296
297            for (target, payload) in deliveries {
298                let _ = target.send(payload);
299            }
300        }
301        _ => {}
302    }
303}
304
305fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
306    let deliveries = {
307        let relay = state.lock().expect("relay state lock");
308        relay
309            .events_by_id
310            .values()
311            .flat_map(|event| matching_deliveries(&relay, event))
312            .collect::<Vec<_>>()
313    };
314
315    for (target, payload) in deliveries {
316        let _ = target.send(payload);
317    }
318}
319
320fn matching_deliveries(
321    relay: &RelayState,
322    event: &Value,
323) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
324    let mut deliveries = Vec::new();
325    for (client_id, subscriptions) in &relay.subscriptions {
326        let Some(target) = relay.clients.get(client_id).cloned() else {
327            continue;
328        };
329        for (subscription_id, filters) in subscriptions {
330            if matches_any_filter(event, filters) {
331                deliveries.push((
332                    target.clone(),
333                    Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
334                ));
335            }
336        }
337    }
338    deliveries
339}
340
341pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
342    if filters.is_empty() {
343        return true;
344    }
345
346    filters.iter().any(|filter| matches_filter(event, filter))
347}
348
349pub fn matches_filter(event: &Value, filter: &Value) -> bool {
350    let Some(filter_object) = filter.as_object() else {
351        return false;
352    };
353
354    if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
355        let Some(event_id) = event.get("id").and_then(Value::as_str) else {
356            return false;
357        };
358        if !ids
359            .iter()
360            .filter_map(Value::as_str)
361            .any(|id| id == event_id)
362        {
363            return false;
364        }
365    }
366
367    if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
368        let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
369            return false;
370        };
371        if !authors
372            .iter()
373            .filter_map(Value::as_str)
374            .any(|author| author == pubkey)
375        {
376            return false;
377        }
378    }
379
380    if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
381        let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
382            return false;
383        };
384        if !kinds
385            .iter()
386            .filter_map(Value::as_u64)
387            .any(|value| value == kind)
388        {
389            return false;
390        }
391    }
392
393    if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
394        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
395            return false;
396        };
397        if created_at < since {
398            return false;
399        }
400    }
401
402    if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
403        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
404            return false;
405        };
406        if created_at > until {
407            return false;
408        }
409    }
410
411    for (key, value) in filter_object {
412        let Some(tag_name) = key.strip_prefix('#') else {
413            continue;
414        };
415
416        let Some(expected_values) = value.as_array() else {
417            return false;
418        };
419        if expected_values.is_empty() {
420            continue;
421        }
422
423        let Some(tags) = event.get("tags").and_then(Value::as_array) else {
424            return false;
425        };
426        let matched = tags.iter().any(|tag| {
427            let Some(tag_values) = tag.as_array() else {
428                return false;
429            };
430            if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
431                return false;
432            }
433            tag_values
434                .iter()
435                .skip(1)
436                .filter_map(Value::as_str)
437                .any(|tag_value| {
438                    expected_values
439                        .iter()
440                        .filter_map(Value::as_str)
441                        .any(|expected| expected == tag_value)
442                })
443        });
444        if !matched {
445            return false;
446        }
447    }
448
449    true
450}