gnostr_relay/
session.rs

1use crate::{hash::NoOpHasherDefault, message::*, App, Server};
2use actix::prelude::*;
3use actix_http::ws::Item;
4use actix_web::web;
5use actix_web_actors::ws;
6use bytes::BytesMut;
7use metrics::{counter, gauge};
8use std::{
9    any::{Any, TypeId},
10    collections::HashMap,
11    time::{Duration, Instant},
12};
13use tracing::{debug, info, warn};
14use ws::Message;
15
16pub struct Session {
17    ip: String,
18
19    /// unique session id
20    id: usize,
21
22    /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
23    /// otherwise we drop connection.
24    hb: Instant,
25
26    /// server
27    server: Addr<Server>,
28
29    /// heartbeat timeout
30    /// How long before lack of client response causes a timeout
31    heartbeat_timeout: Duration,
32
33    /// heartbeat interval
34    /// How often heartbeat pings are sent
35    heartbeat_interval: Duration,
36
37    pub app: web::Data<App>,
38
39    /// Simple store for save extension data
40    data: HashMap<TypeId, Box<dyn Any>, NoOpHasherDefault>,
41
42    /// Buffer for constructing continuation messages
43    cont: Option<BytesMut>,
44}
45
46impl Session {
47    /// save extension data
48    pub fn set<T: 'static>(&mut self, val: T) {
49        self.data.insert(TypeId::of::<T>(), Box::new(val));
50    }
51
52    /// get extension data
53    pub fn get<T: 'static>(&self) -> Option<&T> {
54        self.data
55            .get(&TypeId::of::<T>())
56            .and_then(|boxed| boxed.downcast_ref())
57    }
58
59    /// Get session id
60    pub fn id(&self) -> usize {
61        self.id
62    }
63
64    /// Get ip
65    pub fn ip(&self) -> &String {
66        &self.ip
67    }
68
69    pub fn new(ip: String, app: web::Data<App>) -> Session {
70        let setting = app.setting.read();
71        let heartbeat_timeout = setting.network.heartbeat_timeout.into();
72        let heartbeat_interval = setting.network.heartbeat_interval.into();
73        drop(setting);
74        Self {
75            id: 0,
76            ip,
77            hb: Instant::now(),
78            server: app.server.clone(),
79            heartbeat_timeout,
80            heartbeat_interval,
81            app,
82            data: HashMap::default(),
83            cont: None,
84        }
85    }
86
87    /// helper method that sends ping to client.
88    /// also this method checks heartbeats from client
89    fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
90        ctx.run_interval(self.heartbeat_interval, |act, ctx| {
91            // check client heartbeats
92            if Instant::now().duration_since(act.hb) > act.heartbeat_timeout {
93                // heartbeat timed out
94                // stop actor
95                counter!("nostr_relay_session_stop_total", "reason" => "heartbeat timeout")
96                    .increment(1);
97                ctx.stop();
98                // don't try to send a ping
99                return;
100            }
101
102            ctx.ping(b"");
103        });
104    }
105
106    fn send_error(
107        &self,
108        err: crate::Error,
109        msg: &ClientMessage,
110        ctx: &mut ws::WebsocketContext<Self>,
111    ) {
112        if let IncomingMessage::Event(event) = &msg.msg {
113            ctx.text(OutgoingMessage::ok(
114                &event.id_str(),
115                false,
116                &err.to_string(),
117            ));
118        } else if let IncomingMessage::Req(sub) = &msg.msg {
119            ctx.text(OutgoingMessage::closed(&sub.id, &err.to_string()));
120        } else {
121            ctx.text(OutgoingMessage::notice(&err.to_string()));
122        }
123    }
124
125    fn handle_message(&mut self, text: String, ctx: &mut ws::WebsocketContext<Self>) {
126        let msg = serde_json::from_str::<IncomingMessage>(&text);
127        match msg {
128            Ok(msg) => {
129                if let Some(cmd) = msg.known_command() {
130                    // only insert known command metrics
131                    counter!("nostr_relay_message_total", "command" => cmd).increment(1);
132                }
133
134                let mut msg = ClientMessage::new(self.id, text, msg);
135                {
136                    let r = self.app.setting.read();
137                    if let Err(err) = msg.validate(&r.limitation) {
138                        self.send_error(err, &msg, ctx);
139                        return;
140                    }
141                }
142
143                match self
144                    .app
145                    .clone()
146                    .extensions
147                    .read()
148                    .call_message(msg, self, ctx)
149                {
150                    crate::ExtensionMessageResult::Continue(msg) => {
151                        if let Err(err) = msg.validate_nip70() {
152                            self.send_error(err, &msg, ctx);
153                            return;
154                        }
155                        self.server.do_send(msg);
156                    }
157                    crate::ExtensionMessageResult::Stop(out) => {
158                        ctx.text(out);
159                    }
160                    crate::ExtensionMessageResult::Ignore => {
161                        // ignore
162                    }
163                };
164            }
165            Err(err) => {
166                ctx.text(OutgoingMessage::notice(&format!("json error: {}", err)));
167            }
168        };
169    }
170}
171
172/// Handle messages from server, we simply send it to peer websocket
173impl Handler<OutgoingMessage> for Session {
174    type Result = ();
175
176    fn handle(&mut self, msg: OutgoingMessage, ctx: &mut Self::Context) {
177        ctx.text(msg);
178    }
179}
180
181impl Actor for Session {
182    type Context = ws::WebsocketContext<Self>;
183
184    /// Method is called on actor start. We start the heartbeat process here.
185    fn started(&mut self, ctx: &mut Self::Context) {
186        counter!("nostr_relay_session_total").increment(1);
187        gauge!("nostr_relay_session").increment(1.0);
188
189        // we'll start heartbeat process on session start.
190        self.hb(ctx);
191        // register self in server.
192        let addr = ctx.address();
193        self.server
194            .send(Connect {
195                addr: addr.recipient(),
196            })
197            .into_actor(self)
198            .then(|res, act, ctx| {
199                match res {
200                    Ok(res) => {
201                        act.id = res;
202                        act.app.clone().extensions.read().call_connected(act, ctx);
203                        debug!("Session started {} {}", act.id, act.ip);
204                    }
205                    // something is wrong with server
206                    _ => {
207                        counter!("nostr_relay_session_stop_total", "reason" => "server error")
208                            .increment(1);
209                        ctx.stop()
210                    }
211                }
212                fut::ready(())
213            })
214            .wait(ctx);
215    }
216
217    fn stopping(&mut self, _: &mut Self::Context) -> Running {
218        // notify server
219        self.server.do_send(Disconnect { id: self.id });
220        Running::Stop
221    }
222
223    fn stopped(&mut self, ctx: &mut Self::Context) {
224        gauge!("nostr_relay_session").decrement(1.0);
225        self.app
226            .clone()
227            .extensions
228            .read()
229            .call_disconnected(self, ctx);
230        debug!("Session stopped {} {}", self.id, self.ip);
231    }
232}
233
234/// Handler for `ws::Message`
235impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for Session {
236    fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
237        // Text will log after processing
238        if !matches!(msg, Ok(Message::Text(_)) | Ok(Message::Continuation(_))) {
239            debug!("Session message {} {} {:?}", self.id, self.ip, msg);
240        }
241        let msg = match msg {
242            Err(err) => {
243                match err {
244                    ws::ProtocolError::Overflow => {
245                        ctx.text(OutgoingMessage::notice("payload reached size limit."));
246                    }
247                    _ => {
248                        warn!("Session error {} {} {:?}", self.id, self.ip, err);
249                        counter!("nostr_relay_session_stop_total", "reason" => "message error")
250                            .increment(1);
251                        ctx.stop();
252                    }
253                }
254                return;
255            }
256            Ok(msg) => msg,
257        };
258        match msg {
259            ws::Message::Ping(msg) => {
260                self.hb = Instant::now();
261                ctx.pong(&msg);
262            }
263            ws::Message::Pong(_) => {
264                self.hb = Instant::now();
265            }
266            ws::Message::Text(text) => {
267                let text = text.to_string();
268                info!(
269                    "Session text self.id={} self.ip={} text={}",
270                    self.id, self.ip, text
271                );
272                self.handle_message(text, ctx);
273            }
274            ws::Message::Close(reason) => {
275                ctx.close(reason);
276                counter!("nostr_relay_session_stop_total", "reason" => "message close")
277                    .increment(1);
278                ctx.stop();
279            }
280            ws::Message::Binary(_) => {
281                ctx.text(OutgoingMessage::notice("Not support binary message"));
282            }
283            ws::Message::Continuation(cont) => match cont {
284                Item::FirstText(buf) => {
285                    let mut bytes = BytesMut::new();
286                    bytes.extend_from_slice(&buf);
287                    self.cont = Some(bytes);
288                }
289                Item::FirstBinary(_) => {
290                    ctx.text(OutgoingMessage::notice("Not support binary message"));
291                }
292                Item::Continue(buf) => {
293                    if let Some(bytes) = &mut self.cont {
294                        bytes.extend_from_slice(&buf);
295                    }
296                }
297                Item::Last(buf) => {
298                    if let Some(mut bytes) = self.cont.take() {
299                        bytes.extend_from_slice(&buf);
300                        if let Ok(text) = String::from_utf8(bytes.to_vec()) {
301                            debug!("Session text {} {} {}", self.id, self.ip, text);
302                            self.handle_message(text, ctx);
303                        }
304                    }
305                }
306            },
307            ws::Message::Nop => (),
308        }
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::{create_test_app, Extension, ExtensionMessageResult};
316    use actix_rt::time::sleep;
317    use actix_web_actors::ws;
318    use anyhow::Result;
319    use bytes::Bytes;
320    use futures_util::{SinkExt as _, StreamExt as _};
321
322    #[actix_rt::test]
323    async fn pingpong() -> Result<()> {
324        let mut srv = actix_test::start(|| {
325            let data = create_test_app("session").unwrap();
326            data.web_app()
327        });
328
329        // client service
330        let mut framed = srv.ws_at("/").await.unwrap();
331
332        framed.send(ws::Message::Ping("text".into())).await?;
333        let item = framed.next().await.unwrap()?;
334        assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text")));
335
336        framed
337            .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
338            .await?;
339        let item = framed.next().await.unwrap()?;
340        assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
341        Ok(())
342    }
343
344    #[actix_rt::test]
345    async fn heartbeat() -> Result<()> {
346        let mut srv = actix_test::start(|| {
347            let data = create_test_app("session").unwrap();
348            {
349                let mut w = data.setting.write();
350                w.network.heartbeat_interval = Duration::from_secs(1).try_into().unwrap();
351                w.network.heartbeat_timeout = Duration::from_secs(20).try_into().unwrap();
352            }
353            data.web_app()
354        });
355
356        // client service
357        let mut framed = srv.ws_at("/").await.unwrap();
358
359        sleep(Duration::from_secs(3)).await;
360        let item = framed.next().await.unwrap()?;
361        assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
362
363        let item = framed.next().await.unwrap()?;
364        assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
365
366        framed
367            .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
368            .await?;
369        Ok(())
370    }
371
372    #[actix_rt::test]
373    async fn heartbeat_timeout() -> Result<()> {
374        let mut srv = actix_test::start(|| {
375            let data = create_test_app("session").unwrap();
376            {
377                let mut w = data.setting.write();
378                w.network.heartbeat_interval = Duration::from_secs(1).try_into().unwrap();
379                w.network.heartbeat_timeout = Duration::from_secs(2).try_into().unwrap();
380            }
381            data.web_app()
382        });
383        // client service
384        let mut framed = srv.ws_at("/").await.unwrap();
385
386        sleep(Duration::from_secs(3)).await;
387        let item = framed.next().await.unwrap()?;
388        assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
389        let item = framed.next().await;
390        assert!(item.is_none());
391        Ok(())
392    }
393
394    struct Echo;
395    impl Extension for Echo {
396        fn message(
397            &self,
398            msg: ClientMessage,
399            _session: &mut Session,
400            _ctx: &mut <Session as actix::Actor>::Context,
401        ) -> ExtensionMessageResult {
402            ExtensionMessageResult::Stop(OutgoingMessage(msg.text))
403        }
404
405        fn name(&self) -> &'static str {
406            "Echo"
407        }
408    }
409
410    #[actix_rt::test]
411    async fn extension() -> Result<()> {
412        let text = r#"["REQ", "1", {}]"#;
413        let mut srv = actix_test::start(|| {
414            let data = create_test_app("extension").unwrap();
415            data.add_extension(Echo).web_app()
416        });
417        let mut framed = srv.ws_at("/").await.unwrap();
418        framed.send(ws::Message::Text(text.into())).await?;
419        let item = framed.next().await.unwrap()?;
420        assert_eq!(
421            item,
422            ws::Frame::Text(Bytes::copy_from_slice(text.as_bytes()))
423        );
424        Ok(())
425    }
426
427    #[actix_rt::test]
428    async fn max_size() -> Result<()> {
429        let text = r#"["REQ", "1", {}]"#;
430        let max_size = text.len() + 1;
431        let mut srv = actix_test::start(move || {
432            let data = create_test_app("max_size").unwrap();
433            {
434                let mut w = data.setting.write();
435                w.limitation.max_message_length = max_size;
436            }
437            data.add_extension(Echo).web_app()
438        });
439        let mut framed = srv.ws_at("/").await.unwrap();
440        framed.send(ws::Message::Text(text.into())).await?;
441        let item = framed.next().await.unwrap()?;
442        assert_eq!(
443            item,
444            ws::Frame::Text(Bytes::copy_from_slice(text.as_bytes()))
445        );
446
447        framed
448            .send(ws::Message::Text(format!("{}  ", text).into()))
449            .await?;
450        let item = framed.next().await.unwrap()?;
451        assert_eq!(
452            item,
453            ws::Frame::Text(Bytes::copy_from_slice(
454                br#"["NOTICE","payload reached size limit."]"#
455            ))
456        );
457        Ok(())
458    }
459
460    #[actix_rt::test]
461    async fn continuation() -> Result<()> {
462        let text = br#"["REQ", "1", {}]"#;
463
464        let mut srv = actix_test::start(|| {
465            let data = create_test_app("extension").unwrap();
466            data.add_extension(Echo).web_app()
467        });
468        let mut framed = srv.ws_at("/").await.unwrap();
469        framed
470            .send(ws::Message::Continuation(Item::FirstText(
471                Bytes::copy_from_slice(&text[0..2]),
472            )))
473            .await?;
474
475        framed
476            .send(ws::Message::Continuation(Item::Continue(
477                Bytes::copy_from_slice(&text[2..4]),
478            )))
479            .await?;
480        framed
481            .send(ws::Message::Continuation(Item::Last(
482                Bytes::copy_from_slice(&text[4..]),
483            )))
484            .await?;
485
486        let item = framed.next().await.unwrap()?;
487        assert_eq!(item, ws::Frame::Text(Bytes::copy_from_slice(text)));
488
489        Ok(())
490    }
491}