Skip to main content

iris_chat_core/
local_relay.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17    events_by_id: BTreeMap<String, Value>,
18    subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19    clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22enum RelayControl {
23    ReplayStored,
24    Shutdown,
25}
26
27pub struct TestRelay {
28    control_tx: mpsc::UnboundedSender<RelayControl>,
29    join: Option<thread::JoinHandle<()>>,
30}
31
32impl TestRelay {
33    pub fn start() -> Self {
34        Self::start_with_bind("127.0.0.1:4848").expect("start relay")
35    }
36
37    pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
38        let (control_tx, mut control_rx) = mpsc::unbounded_channel();
39        let (ready_tx, ready_rx) = std_mpsc::channel();
40        let bind_addr = bind_addr.to_string();
41
42        let join = thread::spawn(move || {
43            let runtime = tokio::runtime::Builder::new_multi_thread()
44                .enable_all()
45                .build()
46                .expect("relay runtime");
47
48            runtime.block_on(async move {
49                let listener = TcpListener::bind(&bind_addr)
50                    .await
51                    .with_context(|| format!("bind relay listener {bind_addr}"))
52                    .expect("bind relay listener");
53                let state = Arc::new(Mutex::new(RelayState::default()));
54                let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
55                ready_tx.send(()).expect("signal relay ready");
56
57                loop {
58                    tokio::select! {
59                        Some(control) = control_rx.recv() => {
60                            match control {
61                                RelayControl::ReplayStored => replay_stored_events(&state),
62                                RelayControl::Shutdown => break,
63                            }
64                        }
65                        accept_result = listener.accept() => {
66                            let (stream, _) = accept_result.expect("accept relay client");
67                            let websocket = accept_async(stream).await.expect("accept websocket");
68                            let state = state.clone();
69                            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
70                            tokio::spawn(async move {
71                                handle_connection(client_id, websocket, state).await;
72                            });
73                        }
74                    }
75                }
76            });
77        });
78
79        ready_rx
80            .recv_timeout(StdDuration::from_secs(5))
81            .context("relay ready")?;
82
83        Ok(Self {
84            control_tx,
85            join: Some(join),
86        })
87    }
88
89    pub fn replay_stored(&self) {
90        let _ = self.control_tx.send(RelayControl::ReplayStored);
91    }
92}
93
94impl Drop for TestRelay {
95    fn drop(&mut self) {
96        let _ = self.control_tx.send(RelayControl::Shutdown);
97        if let Some(join) = self.join.take() {
98            let _ = join.join();
99        }
100    }
101}
102
103pub fn run_forever(bind_addr: &str) -> Result<()> {
104    let runtime = tokio::runtime::Builder::new_multi_thread()
105        .enable_all()
106        .build()
107        .context("relay runtime")?;
108    let bind_addr = bind_addr.to_string();
109
110    runtime.block_on(async move {
111        let listener = TcpListener::bind(&bind_addr)
112            .await
113            .with_context(|| format!("bind relay listener {bind_addr}"))?;
114        let state = Arc::new(Mutex::new(RelayState::default()));
115        let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
116
117        println!("Local Nostr relay listening on ws://{bind_addr}");
118
119        loop {
120            let (stream, _) = listener
121                .accept()
122                .await
123                .with_context(|| format!("accept relay client on {bind_addr}"))?;
124            let websocket = match accept_async(stream).await {
125                Ok(websocket) => websocket,
126                Err(error) => {
127                    eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
128                    continue;
129                }
130            };
131            let state = state.clone();
132            let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
133            tokio::spawn(async move {
134                handle_connection(client_id, websocket, state).await;
135            });
136        }
137    })
138}
139
140async fn handle_connection(
141    client_id: usize,
142    websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
143    state: Arc<Mutex<RelayState>>,
144) {
145    let (mut sink, mut stream) = websocket.split();
146    let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
147
148    {
149        let mut relay = state.lock().expect("relay state lock");
150        relay.clients.insert(client_id, client_tx);
151    }
152
153    let writer = tokio::spawn(async move {
154        while let Some(message) = client_rx.recv().await {
155            if sink.send(message).await.is_err() {
156                break;
157            }
158        }
159    });
160
161    while let Some(message) = stream.next().await {
162        let Ok(message) = message else {
163            break;
164        };
165        match message {
166            Message::Text(text) => handle_client_message(client_id, &text, &state),
167            Message::Ping(payload) => {
168                let sender = {
169                    let relay = state.lock().expect("relay state lock");
170                    relay.clients.get(&client_id).cloned()
171                };
172                if let Some(sender) = sender {
173                    let _ = sender.send(Message::Pong(payload));
174                }
175            }
176            Message::Close(_) => break,
177            _ => {}
178        }
179    }
180
181    {
182        let mut relay = state.lock().expect("relay state lock");
183        relay.clients.remove(&client_id);
184        relay.subscriptions.remove(&client_id);
185    }
186
187    writer.abort();
188}
189
190fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
191    let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
192        return;
193    };
194    let Some(parts) = message.as_array() else {
195        return;
196    };
197    let Some(kind) = parts.first().and_then(Value::as_str) else {
198        return;
199    };
200
201    match kind {
202        "REQ" if parts.len() >= 2 => {
203            let Some(subscription_id) = parts[1].as_str() else {
204                return;
205            };
206            let filters: Vec<Value> = parts
207                .iter()
208                .skip(2)
209                .filter(|value| value.is_object())
210                .cloned()
211                .collect();
212            let (sender, events) = {
213                let mut relay = state.lock().expect("relay state lock");
214                relay
215                    .subscriptions
216                    .entry(client_id)
217                    .or_default()
218                    .insert(subscription_id.to_string(), filters.clone());
219                (
220                    relay.clients.get(&client_id).cloned(),
221                    relay.events_by_id.values().cloned().collect::<Vec<_>>(),
222                )
223            };
224
225            if let Some(sender) = sender {
226                for event in events {
227                    if matches_any_filter(&event, &filters) {
228                        let payload =
229                            Message::Text(json!(["EVENT", subscription_id, event]).to_string());
230                        let _ = sender.send(payload);
231                    }
232                }
233                let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
234            }
235        }
236        "CLOSE" if parts.len() >= 2 => {
237            let Some(subscription_id) = parts[1].as_str() else {
238                return;
239            };
240            let mut relay = state.lock().expect("relay state lock");
241            if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
242                subscriptions.remove(subscription_id);
243            }
244        }
245        "EVENT" if parts.len() >= 2 && parts[1].is_object() => {
246            let event = parts[1].clone();
247            let Some(event_id) = event.get("id").and_then(Value::as_str) else {
248                return;
249            };
250            let (sender, deliveries) = {
251                let mut relay = state.lock().expect("relay state lock");
252                relay
253                    .events_by_id
254                    .insert(event_id.to_string(), event.clone());
255                let sender = relay.clients.get(&client_id).cloned();
256                let deliveries = matching_deliveries(&relay, &event);
257                (sender, deliveries)
258            };
259            if let Some(sender) = sender {
260                let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
261            }
262
263            for (target, payload) in deliveries {
264                let _ = target.send(payload);
265            }
266        }
267        _ => {}
268    }
269}
270
271fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
272    let deliveries = {
273        let relay = state.lock().expect("relay state lock");
274        relay
275            .events_by_id
276            .values()
277            .flat_map(|event| matching_deliveries(&relay, event))
278            .collect::<Vec<_>>()
279    };
280
281    for (target, payload) in deliveries {
282        let _ = target.send(payload);
283    }
284}
285
286fn matching_deliveries(
287    relay: &RelayState,
288    event: &Value,
289) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
290    let mut deliveries = Vec::new();
291    for (client_id, subscriptions) in &relay.subscriptions {
292        let Some(target) = relay.clients.get(client_id).cloned() else {
293            continue;
294        };
295        for (subscription_id, filters) in subscriptions {
296            if matches_any_filter(event, filters) {
297                deliveries.push((
298                    target.clone(),
299                    Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
300                ));
301            }
302        }
303    }
304    deliveries
305}
306
307pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
308    if filters.is_empty() {
309        return true;
310    }
311
312    filters.iter().any(|filter| matches_filter(event, filter))
313}
314
315pub fn matches_filter(event: &Value, filter: &Value) -> bool {
316    let Some(filter_object) = filter.as_object() else {
317        return false;
318    };
319
320    if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
321        let Some(event_id) = event.get("id").and_then(Value::as_str) else {
322            return false;
323        };
324        if !ids
325            .iter()
326            .filter_map(Value::as_str)
327            .any(|id| id == event_id)
328        {
329            return false;
330        }
331    }
332
333    if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
334        let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
335            return false;
336        };
337        if !authors
338            .iter()
339            .filter_map(Value::as_str)
340            .any(|author| author == pubkey)
341        {
342            return false;
343        }
344    }
345
346    if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
347        let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
348            return false;
349        };
350        if !kinds
351            .iter()
352            .filter_map(Value::as_u64)
353            .any(|value| value == kind)
354        {
355            return false;
356        }
357    }
358
359    if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
360        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
361            return false;
362        };
363        if created_at < since {
364            return false;
365        }
366    }
367
368    if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
369        let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
370            return false;
371        };
372        if created_at > until {
373            return false;
374        }
375    }
376
377    for (key, value) in filter_object {
378        let Some(tag_name) = key.strip_prefix('#') else {
379            continue;
380        };
381
382        let Some(expected_values) = value.as_array() else {
383            return false;
384        };
385        if expected_values.is_empty() {
386            continue;
387        }
388
389        let Some(tags) = event.get("tags").and_then(Value::as_array) else {
390            return false;
391        };
392        let matched = tags.iter().any(|tag| {
393            let Some(tag_values) = tag.as_array() else {
394                return false;
395            };
396            if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
397                return false;
398            }
399            tag_values
400                .iter()
401                .skip(1)
402                .filter_map(Value::as_str)
403                .any(|tag_value| {
404                    expected_values
405                        .iter()
406                        .filter_map(Value::as_str)
407                        .any(|expected| expected == tag_value)
408                })
409        });
410        if !matched {
411            return false;
412        }
413    }
414
415    true
416}