gnostr_relay/
session.rs

1use crate::{hash::NoOpHasherDefault, message::*, App, Server};
2use actix::prelude::*;
3use actix_http::ws::Item;
4use actix_web::web;
5use actix_web_actors::ws;
6use bytes::BytesMut;
7use metrics::{counter, gauge};
8use std::{
9    any::{Any, TypeId},
10    collections::HashMap,
11    time::{Duration, Instant},
12};
13use tracing::debug;
14use ws::Message;
15
16pub struct Session {
17    ip: String,
18
19    /// unique session id
20    id: usize,
21
22    /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
23    /// otherwise we drop connection.
24    hb: Instant,
25
26    /// server
27    server: Addr<Server>,
28
29    /// heartbeat timeout
30    /// How long before lack of client response causes a timeout
31    heartbeat_timeout: Duration,
32
33    /// heartbeat interval
34    /// How often heartbeat pings are sent
35    heartbeat_interval: Duration,
36
37    pub app: web::Data<App>,
38
39    /// Simple store for save extension data
40    data: HashMap<TypeId, Box<dyn Any>, NoOpHasherDefault>,
41
42    /// Buffer for constructing continuation messages
43    cont: Option<BytesMut>,
44}
45
46impl Session {
47    /// save extension data
48    pub fn set<T: 'static>(&mut self, val: T) {
49        self.data.insert(TypeId::of::<T>(), Box::new(val));
50    }
51
52    /// get extension data
53    pub fn get<T: 'static>(&self) -> Option<&T> {
54        self.data
55            .get(&TypeId::of::<T>())
56            .and_then(|boxed| boxed.downcast_ref())
57    }
58
59    /// Get session id
60    pub fn id(&self) -> usize {
61        self.id
62    }
63
64    /// Get ip
65    pub fn ip(&self) -> &String {
66        &self.ip
67    }
68
69    pub fn new(ip: String, app: web::Data<App>) -> Session {
70        let setting = app.setting.read();
71        let heartbeat_timeout = setting.network.heartbeat_timeout.into();
72        let heartbeat_interval = setting.network.heartbeat_interval.into();
73        drop(setting);
74        Self {
75            id: 0,
76            ip,
77            hb: Instant::now(),
78            server: app.server.clone(),
79            heartbeat_timeout,
80            heartbeat_interval,
81            app,
82            data: HashMap::default(),
83            cont: None,
84        }
85    }
86
87    /// helper method that sends ping to client.
88    /// also this method checks heartbeats from client
89    fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
90        ctx.run_interval(self.heartbeat_interval, |act, ctx| {
91            // check client heartbeats
92            if Instant::now().duration_since(act.hb) > act.heartbeat_timeout {
93                // heartbeat timed out
94                // stop actor
95                counter!("nostr_relay_session_stop_total", "reason" => "heartbeat timeout")
96                    .increment(1);
97                ctx.stop();
98                // don't try to send a ping
99                return;
100            }
101
102            ctx.ping(b"");
103        });
104    }
105
106    fn send_error(
107        &self,
108        err: crate::Error,
109        msg: &ClientMessage,
110        ctx: &mut ws::WebsocketContext<Self>,
111    ) {
112        if let IncomingMessage::Event(event) = &msg.msg {
113            ctx.text(OutgoingMessage::ok(
114                &event.id_str(),
115                false,
116                &err.to_string(),
117            ));
118        } else if let IncomingMessage::Req(sub) = &msg.msg {
119            ctx.text(OutgoingMessage::closed(&sub.id, &err.to_string()));
120        } else {
121            ctx.text(OutgoingMessage::notice(&err.to_string()));
122        }
123    }
124
125    fn handle_message(&mut self, text: String, ctx: &mut ws::WebsocketContext<Self>) {
126        let msg = serde_json::from_str::<IncomingMessage>(&text);
127        match msg {
128            Ok(msg) => {
129                if let Some(cmd) = msg.known_command() {
130                    // only insert known command metrics
131                    counter!("nostr_relay_message_total", "command" => cmd).increment(1);
132                }
133
134                let mut msg = ClientMessage::new(self.id, text, msg);
135                {
136                    let r = self.app.setting.read();
137                    if let Err(err) = msg.validate(&r.limitation) {
138                        self.send_error(err, &msg, ctx);
139                        return;
140                    }
141                }
142
143                match self
144                    .app
145                    .clone()
146                    .extensions
147                    .read()
148                    .call_message(msg, self, ctx)
149                {
150                    crate::ExtensionMessageResult::Continue(msg) => {
151                        if let Err(err) = msg.validate_nip70() {
152                            self.send_error(err, &msg, ctx);
153                            return;
154                        }
155                        self.server.do_send(msg);
156                    }
157                    crate::ExtensionMessageResult::Stop(out) => {
158                        ctx.text(out);
159                    }
160                    crate::ExtensionMessageResult::Ignore => {
161                        // ignore
162                    }
163                };
164            }
165            Err(err) => {
166                ctx.text(OutgoingMessage::notice(&format!("json error: {}", err)));
167            }
168        };
169    }
170}
171
172/// Handle messages from server, we simply send it to peer websocket
173impl Handler<OutgoingMessage> for Session {
174    type Result = ();
175
176    fn handle(&mut self, msg: OutgoingMessage, ctx: &mut Self::Context) {
177        ctx.text(msg);
178    }
179}
180
181impl Actor for Session {
182    type Context = ws::WebsocketContext<Self>;
183
184    /// Method is called on actor start. We start the heartbeat process here.
185    fn started(&mut self, ctx: &mut Self::Context) {
186        counter!("nostr_relay_session_total").increment(1);
187        gauge!("nostr_relay_session").increment(1.0);
188
189        // we'll start heartbeat process on session start.
190        self.hb(ctx);
191        // register self in server.
192        let addr = ctx.address();
193        self.server
194            .send(Connect {
195                addr: addr.recipient(),
196            })
197            .into_actor(self)
198            .then(|res, act, ctx| {
199                match res {
200                    Ok(res) => {
201                        act.id = res;
202                        act.app.clone().extensions.read().call_connected(act, ctx);
203                        debug!("Session started {} {}", act.id, act.ip);
204                    }
205                    // something is wrong with server
206                    _ => {
207                        counter!("nostr_relay_session_stop_total", "reason" => "server error")
208                            .increment(1);
209                        ctx.stop()
210                    }
211                }
212                fut::ready(())
213            })
214            .wait(ctx);
215    }
216
217    fn stopping(&mut self, _: &mut Self::Context) -> Running {
218        // notify server
219        self.server.do_send(Disconnect { id: self.id });
220        Running::Stop
221    }
222
223    fn stopped(&mut self, ctx: &mut Self::Context) {
224        gauge!("nostr_relay_session").decrement(1.0);
225        self.app
226            .clone()
227            .extensions
228            .read()
229            .call_disconnected(self, ctx);
230        debug!("Session stopped {} {}", self.id, self.ip);
231    }
232}
233
234/// Handler for `ws::Message`
235impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for Session {
236    fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
237        // Text will log after processing
238        if !matches!(msg, Ok(Message::Text(_)) | Ok(Message::Continuation(_))) {
239            debug!("Session message {} {} {:?}", self.id, self.ip, msg);
240        }
241        let msg = match msg {
242            Err(err) => {
243                match err {
244                    ws::ProtocolError::Overflow => {
245                        ctx.text(OutgoingMessage::notice("payload reached size limit."));
246                    }
247                    _ => {
248                        debug!("Session error {} {} {:?}", self.id, self.ip, err);
249                        counter!("nostr_relay_session_stop_total", "reason" => "message error")
250                            .increment(1);
251                        ctx.stop();
252                    }
253                }
254                return;
255            }
256            Ok(msg) => msg,
257        };
258        match msg {
259            ws::Message::Ping(msg) => {
260                self.hb = Instant::now();
261                ctx.pong(&msg);
262            }
263            ws::Message::Pong(_) => {
264                self.hb = Instant::now();
265            }
266            ws::Message::Text(text) => {
267                let text = text.to_string();
268                debug!("Session text {} {} {}", self.id, self.ip, text);
269                self.handle_message(text, ctx);
270            }
271            ws::Message::Close(reason) => {
272                ctx.close(reason);
273                counter!("nostr_relay_session_stop_total", "reason" => "message close")
274                    .increment(1);
275                ctx.stop();
276            }
277            ws::Message::Binary(_) => {
278                ctx.text(OutgoingMessage::notice("Not support binary message"));
279            }
280            ws::Message::Continuation(cont) => match cont {
281                Item::FirstText(buf) => {
282                    let mut bytes = BytesMut::new();
283                    bytes.extend_from_slice(&buf);
284                    self.cont = Some(bytes);
285                }
286                Item::FirstBinary(_) => {
287                    ctx.text(OutgoingMessage::notice("Not support binary message"));
288                }
289                Item::Continue(buf) => {
290                    if let Some(bytes) = &mut self.cont {
291                        bytes.extend_from_slice(&buf);
292                    }
293                }
294                Item::Last(buf) => {
295                    if let Some(mut bytes) = self.cont.take() {
296                        bytes.extend_from_slice(&buf);
297                        if let Ok(text) = String::from_utf8(bytes.to_vec()) {
298                            debug!("Session text {} {} {}", self.id, self.ip, text);
299                            self.handle_message(text, ctx);
300                        }
301                    }
302                }
303            },
304            ws::Message::Nop => (),
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::{create_test_app, Extension, ExtensionMessageResult};
313    use actix_rt::time::sleep;
314    use actix_web_actors::ws;
315    use anyhow::Result;
316    use bytes::Bytes;
317    use futures_util::{SinkExt as _, StreamExt as _};
318
319    #[actix_rt::test]
320    async fn pingpong() -> Result<()> {
321        let mut srv = actix_test::start(|| {
322            let data = create_test_app("session").unwrap();
323            data.web_app()
324        });
325
326        // client service
327        let mut framed = srv.ws_at("/").await.unwrap();
328
329        framed.send(ws::Message::Ping("text".into())).await?;
330        let item = framed.next().await.unwrap()?;
331        assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text")));
332
333        framed
334            .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
335            .await?;
336        let item = framed.next().await.unwrap()?;
337        assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
338        Ok(())
339    }
340
341    #[actix_rt::test]
342    async fn heartbeat() -> Result<()> {
343        let mut srv = actix_test::start(|| {
344            let data = create_test_app("session").unwrap();
345            {
346                let mut w = data.setting.write();
347                w.network.heartbeat_interval = Duration::from_secs(1).try_into().unwrap();
348                w.network.heartbeat_timeout = Duration::from_secs(20).try_into().unwrap();
349            }
350            data.web_app()
351        });
352
353        // client service
354        let mut framed = srv.ws_at("/").await.unwrap();
355
356        sleep(Duration::from_secs(3)).await;
357        let item = framed.next().await.unwrap()?;
358        assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
359
360        let item = framed.next().await.unwrap()?;
361        assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
362
363        framed
364            .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
365            .await?;
366        Ok(())
367    }
368
369    #[actix_rt::test]
370    async fn heartbeat_timeout() -> Result<()> {
371        let mut srv = actix_test::start(|| {
372            let data = create_test_app("session").unwrap();
373            {
374                let mut w = data.setting.write();
375                w.network.heartbeat_interval = Duration::from_secs(1).try_into().unwrap();
376                w.network.heartbeat_timeout = Duration::from_secs(2).try_into().unwrap();
377            }
378            data.web_app()
379        });
380        // client service
381        let mut framed = srv.ws_at("/").await.unwrap();
382
383        sleep(Duration::from_secs(3)).await;
384        let item = framed.next().await.unwrap()?;
385        assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
386        let item = framed.next().await;
387        assert!(item.is_none());
388        Ok(())
389    }
390
391    struct Echo;
392    impl Extension for Echo {
393        fn message(
394            &self,
395            msg: ClientMessage,
396            _session: &mut Session,
397            _ctx: &mut <Session as actix::Actor>::Context,
398        ) -> ExtensionMessageResult {
399            ExtensionMessageResult::Stop(OutgoingMessage(msg.text))
400        }
401
402        fn name(&self) -> &'static str {
403            "Echo"
404        }
405    }
406
407    #[actix_rt::test]
408    async fn extension() -> Result<()> {
409        let text = r#"["REQ", "1", {}]"#;
410        let mut srv = actix_test::start(|| {
411            let data = create_test_app("extension").unwrap();
412            data.add_extension(Echo).web_app()
413        });
414        let mut framed = srv.ws_at("/").await.unwrap();
415        framed.send(ws::Message::Text(text.into())).await?;
416        let item = framed.next().await.unwrap()?;
417        assert_eq!(
418            item,
419            ws::Frame::Text(Bytes::copy_from_slice(text.as_bytes()))
420        );
421        Ok(())
422    }
423
424    #[actix_rt::test]
425    async fn max_size() -> Result<()> {
426        let text = r#"["REQ", "1", {}]"#;
427        let max_size = text.len() + 1;
428        let mut srv = actix_test::start(move || {
429            let data = create_test_app("max_size").unwrap();
430            {
431                let mut w = data.setting.write();
432                w.limitation.max_message_length = max_size;
433            }
434            data.add_extension(Echo).web_app()
435        });
436        let mut framed = srv.ws_at("/").await.unwrap();
437        framed.send(ws::Message::Text(text.into())).await?;
438        let item = framed.next().await.unwrap()?;
439        assert_eq!(
440            item,
441            ws::Frame::Text(Bytes::copy_from_slice(text.as_bytes()))
442        );
443
444        framed
445            .send(ws::Message::Text(format!("{}  ", text).into()))
446            .await?;
447        let item = framed.next().await.unwrap()?;
448        assert_eq!(
449            item,
450            ws::Frame::Text(Bytes::copy_from_slice(
451                br#"["NOTICE","payload reached size limit."]"#
452            ))
453        );
454        Ok(())
455    }
456
457    #[actix_rt::test]
458    async fn continuation() -> Result<()> {
459        let text = br#"["REQ", "1", {}]"#;
460
461        let mut srv = actix_test::start(|| {
462            let data = create_test_app("extension").unwrap();
463            data.add_extension(Echo).web_app()
464        });
465        let mut framed = srv.ws_at("/").await.unwrap();
466        framed
467            .send(ws::Message::Continuation(Item::FirstText(
468                Bytes::copy_from_slice(&text[0..2]),
469            )))
470            .await?;
471
472        framed
473            .send(ws::Message::Continuation(Item::Continue(
474                Bytes::copy_from_slice(&text[2..4]),
475            )))
476            .await?;
477        framed
478            .send(ws::Message::Continuation(Item::Last(
479                Bytes::copy_from_slice(&text[4..]),
480            )))
481            .await?;
482
483        let item = framed.next().await.unwrap()?;
484        assert_eq!(item, ws::Frame::Text(Bytes::copy_from_slice(text)));
485
486        Ok(())
487    }
488}