iris_chat_core/
local_relay.rs1use std::collections::{BTreeMap, HashMap};
2use std::sync::mpsc as std_mpsc;
3use std::sync::{Arc, Mutex};
4use std::thread;
5use std::time::Duration as StdDuration;
6
7use anyhow::{Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::net::TcpListener;
11use tokio::sync::mpsc;
12use tokio_tungstenite::accept_async;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Default)]
16struct RelayState {
17 events_by_id: BTreeMap<String, Value>,
18 subscriptions: HashMap<usize, HashMap<String, Vec<Value>>>,
19 clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
20}
21
22enum RelayControl {
23 ReplayStored,
24 Shutdown,
25}
26
27pub struct TestRelay {
28 control_tx: mpsc::UnboundedSender<RelayControl>,
29 join: Option<thread::JoinHandle<()>>,
30}
31
32impl TestRelay {
33 pub fn start() -> Self {
34 Self::start_with_bind("127.0.0.1:4848").expect("start relay")
35 }
36
37 pub fn start_with_bind(bind_addr: &str) -> Result<Self> {
38 let (control_tx, mut control_rx) = mpsc::unbounded_channel();
39 let (ready_tx, ready_rx) = std_mpsc::channel();
40 let bind_addr = bind_addr.to_string();
41
42 let join = thread::spawn(move || {
43 let runtime = tokio::runtime::Builder::new_multi_thread()
44 .enable_all()
45 .build()
46 .expect("relay runtime");
47
48 runtime.block_on(async move {
49 let listener = TcpListener::bind(&bind_addr)
50 .await
51 .with_context(|| format!("bind relay listener {bind_addr}"))
52 .expect("bind relay listener");
53 let state = Arc::new(Mutex::new(RelayState::default()));
54 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
55 ready_tx.send(()).expect("signal relay ready");
56
57 loop {
58 tokio::select! {
59 Some(control) = control_rx.recv() => {
60 match control {
61 RelayControl::ReplayStored => replay_stored_events(&state),
62 RelayControl::Shutdown => break,
63 }
64 }
65 accept_result = listener.accept() => {
66 let (stream, _) = accept_result.expect("accept relay client");
67 let websocket = accept_async(stream).await.expect("accept websocket");
68 let state = state.clone();
69 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
70 tokio::spawn(async move {
71 handle_connection(client_id, websocket, state).await;
72 });
73 }
74 }
75 }
76 });
77 });
78
79 ready_rx
80 .recv_timeout(StdDuration::from_secs(5))
81 .context("relay ready")?;
82
83 Ok(Self {
84 control_tx,
85 join: Some(join),
86 })
87 }
88
89 pub fn replay_stored(&self) {
90 let _ = self.control_tx.send(RelayControl::ReplayStored);
91 }
92}
93
94impl Drop for TestRelay {
95 fn drop(&mut self) {
96 let _ = self.control_tx.send(RelayControl::Shutdown);
97 if let Some(join) = self.join.take() {
98 let _ = join.join();
99 }
100 }
101}
102
103pub fn run_forever(bind_addr: &str) -> Result<()> {
104 let runtime = tokio::runtime::Builder::new_multi_thread()
105 .enable_all()
106 .build()
107 .context("relay runtime")?;
108 let bind_addr = bind_addr.to_string();
109
110 runtime.block_on(async move {
111 let listener = TcpListener::bind(&bind_addr)
112 .await
113 .with_context(|| format!("bind relay listener {bind_addr}"))?;
114 let state = Arc::new(Mutex::new(RelayState::default()));
115 let next_client_id = Arc::new(std::sync::atomic::AtomicUsize::new(1));
116
117 println!("Local Nostr relay listening on ws://{bind_addr}");
118
119 loop {
120 let (stream, _) = listener
121 .accept()
122 .await
123 .with_context(|| format!("accept relay client on {bind_addr}"))?;
124 let websocket = match accept_async(stream).await {
125 Ok(websocket) => websocket,
126 Err(error) => {
127 eprintln!("Ignoring failed websocket handshake on {bind_addr}: {error}");
128 continue;
129 }
130 };
131 let state = state.clone();
132 let client_id = next_client_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
133 tokio::spawn(async move {
134 handle_connection(client_id, websocket, state).await;
135 });
136 }
137 })
138}
139
140async fn handle_connection(
141 client_id: usize,
142 websocket: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
143 state: Arc<Mutex<RelayState>>,
144) {
145 let (mut sink, mut stream) = websocket.split();
146 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Message>();
147
148 {
149 let mut relay = state.lock().expect("relay state lock");
150 relay.clients.insert(client_id, client_tx);
151 }
152
153 let writer = tokio::spawn(async move {
154 while let Some(message) = client_rx.recv().await {
155 if sink.send(message).await.is_err() {
156 break;
157 }
158 }
159 });
160
161 while let Some(message) = stream.next().await {
162 let Ok(message) = message else {
163 break;
164 };
165 match message {
166 Message::Text(text) => handle_client_message(client_id, &text, &state),
167 Message::Ping(payload) => {
168 let sender = {
169 let relay = state.lock().expect("relay state lock");
170 relay.clients.get(&client_id).cloned()
171 };
172 if let Some(sender) = sender {
173 let _ = sender.send(Message::Pong(payload));
174 }
175 }
176 Message::Close(_) => break,
177 _ => {}
178 }
179 }
180
181 {
182 let mut relay = state.lock().expect("relay state lock");
183 relay.clients.remove(&client_id);
184 relay.subscriptions.remove(&client_id);
185 }
186
187 writer.abort();
188}
189
190fn handle_client_message(client_id: usize, raw_message: &str, state: &Arc<Mutex<RelayState>>) {
191 let Ok(message) = serde_json::from_str::<Value>(raw_message) else {
192 return;
193 };
194 let Some(parts) = message.as_array() else {
195 return;
196 };
197 let Some(kind) = parts.first().and_then(Value::as_str) else {
198 return;
199 };
200
201 match kind {
202 "REQ" if parts.len() >= 2 => {
203 let Some(subscription_id) = parts[1].as_str() else {
204 return;
205 };
206 let filters: Vec<Value> = parts
207 .iter()
208 .skip(2)
209 .filter(|value| value.is_object())
210 .cloned()
211 .collect();
212 let (sender, events) = {
213 let mut relay = state.lock().expect("relay state lock");
214 relay
215 .subscriptions
216 .entry(client_id)
217 .or_default()
218 .insert(subscription_id.to_string(), filters.clone());
219 (
220 relay.clients.get(&client_id).cloned(),
221 relay.events_by_id.values().cloned().collect::<Vec<_>>(),
222 )
223 };
224
225 if let Some(sender) = sender {
226 for event in events {
227 if matches_any_filter(&event, &filters) {
228 let payload =
229 Message::Text(json!(["EVENT", subscription_id, event]).to_string());
230 let _ = sender.send(payload);
231 }
232 }
233 let _ = sender.send(Message::Text(json!(["EOSE", subscription_id]).to_string()));
234 }
235 }
236 "CLOSE" if parts.len() >= 2 => {
237 let Some(subscription_id) = parts[1].as_str() else {
238 return;
239 };
240 let mut relay = state.lock().expect("relay state lock");
241 if let Some(subscriptions) = relay.subscriptions.get_mut(&client_id) {
242 subscriptions.remove(subscription_id);
243 }
244 }
245 "EVENT" if parts.len() >= 2 && parts[1].is_object() => {
246 let event = parts[1].clone();
247 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
248 return;
249 };
250 let (sender, deliveries) = {
251 let mut relay = state.lock().expect("relay state lock");
252 relay
253 .events_by_id
254 .insert(event_id.to_string(), event.clone());
255 let sender = relay.clients.get(&client_id).cloned();
256 let deliveries = matching_deliveries(&relay, &event);
257 (sender, deliveries)
258 };
259 if let Some(sender) = sender {
260 let _ = sender.send(Message::Text(json!(["OK", event_id, true, ""]).to_string()));
261 }
262
263 for (target, payload) in deliveries {
264 let _ = target.send(payload);
265 }
266 }
267 _ => {}
268 }
269}
270
271fn replay_stored_events(state: &Arc<Mutex<RelayState>>) {
272 let deliveries = {
273 let relay = state.lock().expect("relay state lock");
274 relay
275 .events_by_id
276 .values()
277 .flat_map(|event| matching_deliveries(&relay, event))
278 .collect::<Vec<_>>()
279 };
280
281 for (target, payload) in deliveries {
282 let _ = target.send(payload);
283 }
284}
285
286fn matching_deliveries(
287 relay: &RelayState,
288 event: &Value,
289) -> Vec<(mpsc::UnboundedSender<Message>, Message)> {
290 let mut deliveries = Vec::new();
291 for (client_id, subscriptions) in &relay.subscriptions {
292 let Some(target) = relay.clients.get(client_id).cloned() else {
293 continue;
294 };
295 for (subscription_id, filters) in subscriptions {
296 if matches_any_filter(event, filters) {
297 deliveries.push((
298 target.clone(),
299 Message::Text(json!(["EVENT", subscription_id, event]).to_string()),
300 ));
301 }
302 }
303 }
304 deliveries
305}
306
307pub fn matches_any_filter(event: &Value, filters: &[Value]) -> bool {
308 if filters.is_empty() {
309 return true;
310 }
311
312 filters.iter().any(|filter| matches_filter(event, filter))
313}
314
315pub fn matches_filter(event: &Value, filter: &Value) -> bool {
316 let Some(filter_object) = filter.as_object() else {
317 return false;
318 };
319
320 if let Some(ids) = filter_object.get("ids").and_then(Value::as_array) {
321 let Some(event_id) = event.get("id").and_then(Value::as_str) else {
322 return false;
323 };
324 if !ids
325 .iter()
326 .filter_map(Value::as_str)
327 .any(|id| id == event_id)
328 {
329 return false;
330 }
331 }
332
333 if let Some(authors) = filter_object.get("authors").and_then(Value::as_array) {
334 let Some(pubkey) = event.get("pubkey").and_then(Value::as_str) else {
335 return false;
336 };
337 if !authors
338 .iter()
339 .filter_map(Value::as_str)
340 .any(|author| author == pubkey)
341 {
342 return false;
343 }
344 }
345
346 if let Some(kinds) = filter_object.get("kinds").and_then(Value::as_array) {
347 let Some(kind) = event.get("kind").and_then(Value::as_u64) else {
348 return false;
349 };
350 if !kinds
351 .iter()
352 .filter_map(Value::as_u64)
353 .any(|value| value == kind)
354 {
355 return false;
356 }
357 }
358
359 if let Some(since) = filter_object.get("since").and_then(Value::as_u64) {
360 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
361 return false;
362 };
363 if created_at < since {
364 return false;
365 }
366 }
367
368 if let Some(until) = filter_object.get("until").and_then(Value::as_u64) {
369 let Some(created_at) = event.get("created_at").and_then(Value::as_u64) else {
370 return false;
371 };
372 if created_at > until {
373 return false;
374 }
375 }
376
377 for (key, value) in filter_object {
378 let Some(tag_name) = key.strip_prefix('#') else {
379 continue;
380 };
381
382 let Some(expected_values) = value.as_array() else {
383 return false;
384 };
385 if expected_values.is_empty() {
386 continue;
387 }
388
389 let Some(tags) = event.get("tags").and_then(Value::as_array) else {
390 return false;
391 };
392 let matched = tags.iter().any(|tag| {
393 let Some(tag_values) = tag.as_array() else {
394 return false;
395 };
396 if tag_values.first().and_then(Value::as_str) != Some(tag_name) {
397 return false;
398 }
399 tag_values
400 .iter()
401 .skip(1)
402 .filter_map(Value::as_str)
403 .any(|tag_value| {
404 expected_values
405 .iter()
406 .filter_map(Value::as_str)
407 .any(|expected| expected == tag_value)
408 })
409 });
410 if !matched {
411 return false;
412 }
413 }
414
415 true
416}