iris_chat_core/
local_relay.rs1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17 events_by_id: BTreeMap<String, Value>,
18 subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19 clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22enum RelayControl {
23 ReplayStored,
24 Snapshot(std_mpsc::Sender<Vec<Value>>),
25 Shutdown,
26}
27
28pub struct TestRelay {
29 control_tx: mpsc::UnboundedSender<RelayControl>,
30 join: Option<thread::JoinHandle<()>>,
31 url: String,
32}
33
34impl TestRelay {
35 pub fn start() -> Self {
36 Self::start_with_bind("127.0.0.1:0").expect("start relay")
37 }
38
39 pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
40 let (control_tx, mut control_rx) = mpsc::unbounded_channel();
41 let (ready_tx, ready_rx) = std_mpsc::channel();
42 let bind_addr = bind_addr.to_string();
43
44 let join = thread::spawn(move || {
45 let runtime = tokio::runtime::Builder::new_multi_thread()
46 .enable_all()
47 .build()
48 .expect("relay runtime");
49
50 runtime.block_on(async move {
51 let listener = TcpListener::bind(&bind_addr)
52 .await
53 .with_context(|| format!("bind relay listener {bind_addr}"))
54 .expect("bind relay listener");
55 let local_addr = listener.local_addr().expect("relay local addr");
56 let state = Arc::new(Mutex::new(RelayState::default()));
57 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
58 ready_tx
59 .send(format!("ws://{local_addr}"))
60 .expect("signal relay ready");
61
62 loop {
63 tokio::select! {
64 Some(control) = control_rx.recv() => {
65 match control {
66 RelayControl::ReplayStored => replay_stored_events(&state),
67 RelayControl::Snapshot(reply_tx) => {
68 let events = state
69 .lock()
70 .expect("relay state lock")
71 .events_by_id
72 .values()
73 .cloned()
74 .collect::<Vec<_>>();
75 let _ = reply_tx.send(events);
76 }
77 RelayControl::Shutdown => break,
78 }
79 }
80 accept_result = listener.accept() => {
81 let (stream, _) = accept_result.expect("accept relay client");
82 let websocket = accept_async(stream).await.expect("accept websocket");
83 let state = state.clone();
84 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
85 tokio::spawn(async move {
86 handle_connection(client_id, websocket, state).await;
87 });
88 }
89 }
90 }
91 });
92 });
93
94 let url = ready_rx
95 .recv_timeout(StdDuration::from_secs(5))
96 .context("relay ready")?;
97
98 Ok(Self {
99 control_tx,
100 join: Some(join),
101 url,
102 })
103 }
104
105 pub fn url(&self) -> &str {
106 &self.url
107 }
108
109 pub fn replay_stored(&self) {
110 let _ = self.control_tx.send(RelayControl::ReplayStored);
111 }
112
113 pub fn events(&self) -> Vec<Value> {
114 let (reply_tx, reply_rx) = std_mpsc::channel();
115 let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
116 reply_rx
117 .recv_timeout(StdDuration::from_secs(5))
118 .unwrap_or_default()
119 }
120}
121
122impl Drop for TestRelay {
123 fn drop(&mut self) {
124 let _ = self.control_tx.send(RelayControl::Shutdown);
125 if let Some(join) = self.join.take() {
126 let _ = join.join();
127 }
128 }
129}
130
131pub fn run_forever(bind_addr: &str) -> Result<()> {
132 let runtime = tokio::runtime::Builder::new_multi_thread()
133 .enable_all()
134 .build()
135 .context("relay runtime")?;
136 let bind_addr = bind_addr.to_string();
137
138 runtime.block_on(async move {
139 let listener = TcpListener::bind(&bind_addr)
140 .await
141 .with_context(|| format!("bind relay listener {bind_addr}"))?;
142 let state = Arc::new(Mutex::new(RelayState::default()));
143 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
144
145 println!("Local Nostr relay listening on ws://{bind_addr}");
146
147 loop {
148 let (stream, _) = listener
149 .accept()
150 .await
151 .with_context(|| format!("accept relay client on {bind_addr}"))?;
152 let websocket = match accept_async(stream).await {
153 Ok(websocket) => websocket,
154 Err(error) => {
155 eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
156 continue;
157 }
158 };
159 let state = state.clone();
160 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
161 tokio::spawn(async move {
162 handle_connection(client_id, websocket, state).await;
163 });
164 }
165 })
166}
167
168async fn handle_connection(
169 client_id: usize,
170 websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
171 state: Arc<Mutex<RelayState>>,
172) {
173 let (mut sink, mut stream) = websocket.split();
174 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
175
176 {
177 let mut relay = state.lock().expect("relay state lock");
178 relay.clients.insert(client_id, client_tx);
179 }
180
181 let writer = tokio::spawn(async move {
182 while let Some(message) = client_rx.recv().await {
183 if sink.send(message).await.is_err() {
184 break;
185 }
186 }
187 });
188
189 while let Some(message) = stream.next().await {
190 let Ok(message) = message else {
191 break;
192 };
193 match message {
194 Message::Text(text) => handle_client_message(client_id, &text, &state),
195 Message::Ping(payload) => {
196 let sender = {
197 let relay = state.lock().expect("relay state lock");
198 relay.clients.get(&client_id).cloned()
199 };
200 if let Some(sender) = sender {
201 let _ = sender.send(Message::Pong(payload));
202 }
203 }
204 Message::Close(_) => break,
205 _ => {}
206 }
207 }
208
209 {
210 let mut relay = state.lock().expect("relay state lock");
211 relay.clients.remove(&client_id);
212 relay.subscriptions.remove(&client_id);
213 }
214
215 writer.abort();
216}
217
218fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
219 let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
220 return;
221 };
222 let Some(parts) = message.as_array() else {
223 return;
224 };
225 let Some(kind) = parts.first().and_then(Value::as_str) else {
226 return;
227 };
228
229 match kind {
230 "REQ" if parts.len() >= 2 => {
231 let Some(subscription_id) = parts[1].as_str() else {
232 return;
233 };
234 let filters: Vec<Value> = parts
235 .iter()
236 .skip(2)
237 .filter(|value| value.is_object())
238 .cloned()
239 .collect();
240 let (sender, events) = {
241 let mut relay = state.lock().expect("relay state lock");
242 relay
243 .subscriptions
244 .entry(client_id)
245 .or_default()
246 .insert(subscription_id.to_string(), filters.clone());
247 (
248 relay.clients.get(&client_id).cloned(),
249 relay.events_by_id.values().cloned().collect::<Vec<_>>(),
250 )
251 };
252
253 if let Some(sender) = sender {
254 for event in events {
255 if matches_any_filter(&event, &filters) {
256 let payload =
257 Message::Text(json!(["EVENT", subscription_id, event]).to_string());
258 let _ = sender.send(payload);
259 }
260 }
261 let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
262 }
263 }
264 "CLOSE" if parts.len() >= 2 => {
265 let Some(subscription_id) = parts[1].as_str() else {
266 return;
267 };
268 let mut relay = state.lock().expect("relay state lock");
269 if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
270 subscriptions.remove(subscription_id);
271 }
272 }
273 "EVENT" if parts.len() >= 2 && parts[1].is_object() => {
274 let event = parts[1].clone();
275 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
276 return;
277 };
278 let (sender, deliveries) = {
279 let mut relay = state.lock().expect("relay state lock");
280 relay
281 .events_by_id
282 .insert(event_id.to_string(), event.clone());
283 let sender = relay.clients.get(&client_id).cloned();
284 let deliveries = matching_deliveries(&relay, &event);
285 (sender, deliveries)
286 };
287 if let Some(sender) = sender {
288 let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
289 }
290
291 for (target, payload) in deliveries {
292 let _ = target.send(payload);
293 }
294 }
295 _ => {}
296 }
297}
298
299fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
300 let deliveries = {
301 let relay = state.lock().expect("relay state lock");
302 relay
303 .events_by_id
304 .values()
305 .flat_map(|event| matching_deliveries(&relay, event))
306 .collect::<Vec<_>>()
307 };
308
309 for (target, payload) in deliveries {
310 let _ = target.send(payload);
311 }
312}
313
314fn matching_deliveries(
315 relay: &RelayState,
316 event: &Value,
317) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
318 let mut deliveries = Vec::new();
319 for (client_id, subscriptions) in &relay.subscriptions {
320 let Some(target) = relay.clients.get(client_id).cloned() else {
321 continue;
322 };
323 for (subscription_id, filters) in subscriptions {
324 if matches_any_filter(event, filters) {
325 deliveries.push((
326 target.clone(),
327 Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
328 ));
329 }
330 }
331 }
332 deliveries
333}
334
335pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
336 if filters.is_empty() {
337 return true;
338 }
339
340 filters.iter().any(|filter| matches_filter(event, filter))
341}
342
343pub fn matches_filter(event: &Value, filter: &Value) -> bool {
344 let Some(filter_object) = filter.as_object() else {
345 return false;
346 };
347
348 if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
349 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
350 return false;
351 };
352 if !ids
353 .iter()
354 .filter_map(Value::as_str)
355 .any(|id| id == event_id)
356 {
357 return false;
358 }
359 }
360
361 if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
362 let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
363 return false;
364 };
365 if !authors
366 .iter()
367 .filter_map(Value::as_str)
368 .any(|author| author == pubkey)
369 {
370 return false;
371 }
372 }
373
374 if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
375 let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
376 return false;
377 };
378 if !kinds
379 .iter()
380 .filter_map(Value::as_u64)
381 .any(|value| value == kind)
382 {
383 return false;
384 }
385 }
386
387 if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
388 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
389 return false;
390 };
391 if created_at < since {
392 return false;
393 }
394 }
395
396 if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
397 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
398 return false;
399 };
400 if created_at > until {
401 return false;
402 }
403 }
404
405 for (key, value) in filter_object {
406 let Some(tag_name) = key.strip_prefix('#') else {
407 continue;
408 };
409
410 let Some(expected_values) = value.as_array() else {
411 return false;
412 };
413 if expected_values.is_empty() {
414 continue;
415 }
416
417 let Some(tags) = event.get("tags").and_then(Value::as_array) else {
418 return false;
419 };
420 let matched = tags.iter().any(|tag| {
421 let Some(tag_values) = tag.as_array() else {
422 return false;
423 };
424 if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
425 return false;
426 }
427 tag_values
428 .iter()
429 .skip(1)
430 .filter_map(Value::as_str)
431 .any(|tag_value| {
432 expected_values
433 .iter()
434 .filter_map(Value::as_str)
435 .any(|expected| expected == tag_value)
436 })
437 });
438 if !matched {
439 return false;
440 }
441 }
442
443 true
444}