iris_chat_core/
local_relay.rs1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17 events_by_id: BTreeMap<String, Value>,
18 subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19 clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22enum RelayControl {
23 ReplayStored,
24 Snapshot(std_mpsc::Sender<Vec<Value>>),
25 Shutdown,
26}
27
28pub struct TestRelay {
29 control_tx: mpsc::UnboundedSender<RelayControl>,
30 join: Option<thread::JoinHandle<()>>,
31 url: String,
32}
33
34impl TestRelay {
35 pub fn start() -> Self {
36 Self::start_with_bind("127.0.0.1:0").expect("start relay")
37 }
38
39 pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
40 let (control_tx, mut control_rx) = mpsc::unbounded_channel();
41 let (ready_tx, ready_rx) = std_mpsc::channel();
42 let bind_addr = bind_addr.to_string();
43
44 let join = thread::spawn(move || {
45 let runtime = tokio::runtime::Builder::new_multi_thread()
46 .enable_all()
47 .build()
48 .expect("relay runtime");
49
50 runtime.block_on(async move {
51 let listener = TcpListener::bind(&bind_addr)
52 .await
53 .with_context(|| format!("bind relay listener {bind_addr}"))
54 .expect("bind relay listener");
55 let local_addr = listener.local_addr().expect("relay local addr");
56 let state = Arc::new(Mutex::new(RelayState::default()));
57 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
58 ready_tx
59 .send(format!("ws://{local_addr}"))
60 .expect("signal relay ready");
61
62 loop {
63 tokio::select! {
64 Some(control) = control_rx.recv() => {
65 match control {
66 RelayControl::ReplayStored => replay_stored_events(&state),
67 RelayControl::Snapshot(reply_tx) => {
68 let events = state
69 .lock()
70 .expect("relay state lock")
71 .events_by_id
72 .values()
73 .cloned()
74 .collect::<Vec<_>>();
75 let _ = reply_tx.send(events);
76 }
77 RelayControl::Shutdown => break,
78 }
79 }
80 accept_result = listener.accept() => {
81 let (stream, _) = accept_result.expect("accept relay client");
82 let websocket = match accept_async(stream).await {
83 Ok(websocket) => websocket,
84 Err(error) => {
85 eprintln!("Ignoring failed test relay websocket handshake: {error}");
86 continue;
87 }
88 };
89 let state = state.clone();
90 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
91 tokio::spawn(async move {
92 handle_connection(client_id, websocket, state).await;
93 });
94 }
95 }
96 }
97 });
98 });
99
100 let url = ready_rx
101 .recv_timeout(StdDuration::from_secs(5))
102 .context("relay ready")?;
103
104 Ok(Self {
105 control_tx,
106 join: Some(join),
107 url,
108 })
109 }
110
111 pub fn url(&self) -> &str {
112 &self.url
113 }
114
115 pub fn replay_stored(&self) {
116 let _ = self.control_tx.send(RelayControl::ReplayStored);
117 }
118
119 pub fn events(&self) -> Vec<Value> {
120 let (reply_tx, reply_rx) = std_mpsc::channel();
121 let _ = self.control_tx.send(RelayControl::Snapshot(reply_tx));
122 reply_rx
123 .recv_timeout(StdDuration::from_secs(5))
124 .unwrap_or_default()
125 }
126}
127
128impl Drop for TestRelay {
129 fn drop(&mut self) {
130 let _ = self.control_tx.send(RelayControl::Shutdown);
131 if let Some(join) = self.join.take() {
132 let _ = join.join();
133 }
134 }
135}
136
137pub fn run_forever(bind_addr: &str) -> Result<()> {
138 let runtime = tokio::runtime::Builder::new_multi_thread()
139 .enable_all()
140 .build()
141 .context("relay runtime")?;
142 let bind_addr = bind_addr.to_string();
143
144 runtime.block_on(async move {
145 let listener = TcpListener::bind(&bind_addr)
146 .await
147 .with_context(|| format!("bind relay listener {bind_addr}"))?;
148 let state = Arc::new(Mutex::new(RelayState::default()));
149 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
150
151 println!("Local Nostr relay listening on ws://{bind_addr}");
152
153 loop {
154 let (stream, _) = listener
155 .accept()
156 .await
157 .with_context(|| format!("accept relay client on {bind_addr}"))?;
158 let websocket = match accept_async(stream).await {
159 Ok(websocket) => websocket,
160 Err(error) => {
161 eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
162 continue;
163 }
164 };
165 let state = state.clone();
166 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
167 tokio::spawn(async move {
168 handle_connection(client_id, websocket, state).await;
169 });
170 }
171 })
172}
173
174async fn handle_connection(
175 client_id: usize,
176 websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
177 state: Arc<Mutex<RelayState>>,
178) {
179 let (mut sink, mut stream) = websocket.split();
180 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
181
182 {
183 let mut relay = state.lock().expect("relay state lock");
184 relay.clients.insert(client_id, client_tx);
185 }
186
187 let writer = tokio::spawn(async move {
188 while let Some(message) = client_rx.recv().await {
189 if sink.send(message).await.is_err() {
190 break;
191 }
192 }
193 });
194
195 while let Some(message) = stream.next().await {
196 let Ok(message) = message else {
197 break;
198 };
199 match message {
200 Message::Text(text) => handle_client_message(client_id, &text, &state),
201 Message::Ping(payload) => {
202 let sender = {
203 let relay = state.lock().expect("relay state lock");
204 relay.clients.get(&client_id).cloned()
205 };
206 if let Some(sender) = sender {
207 let _ = sender.send(Message::Pong(payload));
208 }
209 }
210 Message::Close(_) => break,
211 _ => {}
212 }
213 }
214
215 {
216 let mut relay = state.lock().expect("relay state lock");
217 relay.clients.remove(&client_id);
218 relay.subscriptions.remove(&client_id);
219 }
220
221 writer.abort();
222}
223
224fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
225 let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
226 return;
227 };
228 let Some(parts) = message.as_array() else {
229 return;
230 };
231 let Some(kind) = parts.first().and_then(Value::as_str) else {
232 return;
233 };
234
235 match kind {
236 "REQ" if parts.len() >= 2 => {
237 let Some(subscription_id) = parts[1].as_str() else {
238 return;
239 };
240 let filters: Vec<Value> = parts
241 .iter()
242 .skip(2)
243 .filter(|value| value.is_object())
244 .cloned()
245 .collect();
246 let (sender, events) = {
247 let mut relay = state.lock().expect("relay state lock");
248 relay
249 .subscriptions
250 .entry(client_id)
251 .or_default()
252 .insert(subscription_id.to_string(), filters.clone());
253 (
254 relay.clients.get(&client_id).cloned(),
255 relay.events_by_id.values().cloned().collect::<Vec<_>>(),
256 )
257 };
258
259 if let Some(sender) = sender {
260 for event in events {
261 if matches_any_filter(&event, &filters) {
262 let payload =
263 Message::Text(json!(["EVENT", subscription_id, event]).to_string());
264 let _ = sender.send(payload);
265 }
266 }
267 let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
268 }
269 }
270 "CLOSE" if parts.len() >= 2 => {
271 let Some(subscription_id) = parts[1].as_str() else {
272 return;
273 };
274 let mut relay = state.lock().expect("relay state lock");
275 if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
276 subscriptions.remove(subscription_id);
277 }
278 }
279 "EVENT" if parts.len() >= 2 && parts[1].is_object() => {
280 let event = parts[1].clone();
281 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
282 return;
283 };
284 let (sender, deliveries) = {
285 let mut relay = state.lock().expect("relay state lock");
286 relay
287 .events_by_id
288 .insert(event_id.to_string(), event.clone());
289 let sender = relay.clients.get(&client_id).cloned();
290 let deliveries = matching_deliveries(&relay, &event);
291 (sender, deliveries)
292 };
293 if let Some(sender) = sender {
294 let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
295 }
296
297 for (target, payload) in deliveries {
298 let _ = target.send(payload);
299 }
300 }
301 _ => {}
302 }
303}
304
305fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
306 let deliveries = {
307 let relay = state.lock().expect("relay state lock");
308 relay
309 .events_by_id
310 .values()
311 .flat_map(|event| matching_deliveries(&relay, event))
312 .collect::<Vec<_>>()
313 };
314
315 for (target, payload) in deliveries {
316 let _ = target.send(payload);
317 }
318}
319
320fn matching_deliveries(
321 relay: &RelayState,
322 event: &Value,
323) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
324 let mut deliveries = Vec::new();
325 for (client_id, subscriptions) in &relay.subscriptions {
326 let Some(target) = relay.clients.get(client_id).cloned() else {
327 continue;
328 };
329 for (subscription_id, filters) in subscriptions {
330 if matches_any_filter(event, filters) {
331 deliveries.push((
332 target.clone(),
333 Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
334 ));
335 }
336 }
337 }
338 deliveries
339}
340
341pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
342 if filters.is_empty() {
343 return true;
344 }
345
346 filters.iter().any(|filter| matches_filter(event, filter))
347}
348
349pub fn matches_filter(event: &Value, filter: &Value) -> bool {
350 let Some(filter_object) = filter.as_object() else {
351 return false;
352 };
353
354 if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
355 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
356 return false;
357 };
358 if !ids
359 .iter()
360 .filter_map(Value::as_str)
361 .any(|id| id == event_id)
362 {
363 return false;
364 }
365 }
366
367 if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
368 let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
369 return false;
370 };
371 if !authors
372 .iter()
373 .filter_map(Value::as_str)
374 .any(|author| author == pubkey)
375 {
376 return false;
377 }
378 }
379
380 if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
381 let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
382 return false;
383 };
384 if !kinds
385 .iter()
386 .filter_map(Value::as_u64)
387 .any(|value| value == kind)
388 {
389 return false;
390 }
391 }
392
393 if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
394 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
395 return false;
396 };
397 if created_at < since {
398 return false;
399 }
400 }
401
402 if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
403 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
404 return false;
405 };
406 if created_at > until {
407 return false;
408 }
409 }
410
411 for (key, value) in filter_object {
412 let Some(tag_name) = key.strip_prefix('#') else {
413 continue;
414 };
415
416 let Some(expected_values) = value.as_array() else {
417 return false;
418 };
419 if expected_values.is_empty() {
420 continue;
421 }
422
423 let Some(tags) = event.get("tags").and_then(Value::as_array) else {
424 return false;
425 };
426 let matched = tags.iter().any(|tag| {
427 let Some(tag_values) = tag.as_array() else {
428 return false;
429 };
430 if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
431 return false;
432 }
433 tag_values
434 .iter()
435 .skip(1)
436 .filter_map(Value::as_str)
437 .any(|tag_value| {
438 expected_values
439 .iter()
440 .filter_map(Value::as_str)
441 .any(|expected| expected == tag_value)
442 })
443 });
444 if !matched {
445 return false;
446 }
447 }
448
449 true
450}