1use crate::{hash::NoOpHasherDefault, message::*, App, Server};
2use actix::prelude::*;
3use actix_http::ws::Item;
4use actix_web::web;
5use actix_web_actors::ws;
6use bytes::BytesMut;
7use metrics::{counter, gauge};
8use std::{
9 any::{Any, TypeId},
10 collections::HashMap,
11 time::{Duration, Instant},
12};
13use tracing::debug;
14use ws::Message;
15
16pub struct Session {
17 ip: String,
18
19 id: usize,
21
22 hb: Instant,
25
26 server: Addr<Server>,
28
29 heartbeat_timeout: Duration,
32
33 heartbeat_interval: Duration,
36
37 pub app: web::Data<App>,
38
39 data: HashMap<TypeId, Box<dyn Any>, NoOpHasherDefault>,
41
42 cont: Option<BytesMut>,
44}
45
46impl Session {
47 pub fn set<T: 'static>(&mut self, val: T) {
49 self.data.insert(TypeId::of::<T>(), Box::new(val));
50 }
51
52 pub fn get<T: 'static>(&self) -> Option<&T> {
54 self.data
55 .get(&TypeId::of::<T>())
56 .and_then(|boxed| boxed.downcast_ref())
57 }
58
59 pub fn id(&self) -> usize {
61 self.id
62 }
63
64 pub fn ip(&self) -> &String {
66 &self.ip
67 }
68
69 pub fn new(ip: String, app: web::Data<App>) -> Session {
70 let setting = app.setting.read();
71 let heartbeat_timeout = setting.network.heartbeat_timeout.into();
72 let heartbeat_interval = setting.network.heartbeat_interval.into();
73 drop(setting);
74 Self {
75 id: 0,
76 ip,
77 hb: Instant::now(),
78 server: app.server.clone(),
79 heartbeat_timeout,
80 heartbeat_interval,
81 app,
82 data: HashMap::default(),
83 cont: None,
84 }
85 }
86
87 fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
90 ctx.run_interval(self.heartbeat_interval, |act, ctx| {
91 if Instant::now().duration_since(act.hb) > act.heartbeat_timeout {
93 counter!("nostr_relay_session_stop_total", "reason" => "heartbeat timeout")
96 .increment(1);
97 ctx.stop();
98 return;
100 }
101
102 ctx.ping(b"");
103 });
104 }
105
106 fn send_error(
107 &self,
108 err: crate::Error,
109 msg: &ClientMessage,
110 ctx: &mut ws::WebsocketContext<Self>,
111 ) {
112 if let IncomingMessage::Event(event) = &msg.msg {
113 ctx.text(OutgoingMessage::ok(
114 &event.id_str(),
115 false,
116 &err.to_string(),
117 ));
118 } else if let IncomingMessage::Req(sub) = &msg.msg {
119 ctx.text(OutgoingMessage::closed(&sub.id, &err.to_string()));
120 } else {
121 ctx.text(OutgoingMessage::notice(&err.to_string()));
122 }
123 }
124
125 fn handle_message(&mut self, text: String, ctx: &mut ws::WebsocketContext<Self>) {
126 let msg = serde_json::from_str::<IncomingMessage>(&text);
127 match msg {
128 Ok(msg) => {
129 if let Some(cmd) = msg.known_command() {
130 counter!("nostr_relay_message_total", "command" => cmd).increment(1);
132 }
133
134 let mut msg = ClientMessage::new(self.id, text, msg);
135 {
136 let r = self.app.setting.read();
137 if let Err(err) = msg.validate(&r.limitation) {
138 self.send_error(err, &msg, ctx);
139 return;
140 }
141 }
142
143 match self
144 .app
145 .clone()
146 .extensions
147 .read()
148 .call_message(msg, self, ctx)
149 {
150 crate::ExtensionMessageResult::Continue(msg) => {
151 if let Err(err) = msg.validate_nip70() {
152 self.send_error(err, &msg, ctx);
153 return;
154 }
155 self.server.do_send(msg);
156 }
157 crate::ExtensionMessageResult::Stop(out) => {
158 ctx.text(out);
159 }
160 crate::ExtensionMessageResult::Ignore => {
161 }
163 };
164 }
165 Err(err) => {
166 ctx.text(OutgoingMessage::notice(&format!("json error: {}", err)));
167 }
168 };
169 }
170}
171
172impl Handler<OutgoingMessage> for Session {
174 type Result = ();
175
176 fn handle(&mut self, msg: OutgoingMessage, ctx: &mut Self::Context) {
177 ctx.text(msg);
178 }
179}
180
181impl Actor for Session {
182 type Context = ws::WebsocketContext<Self>;
183
184 fn started(&mut self, ctx: &mut Self::Context) {
186 counter!("nostr_relay_session_total").increment(1);
187 gauge!("nostr_relay_session").increment(1.0);
188
189 self.hb(ctx);
191 let addr = ctx.address();
193 self.server
194 .send(Connect {
195 addr: addr.recipient(),
196 })
197 .into_actor(self)
198 .then(|res, act, ctx| {
199 match res {
200 Ok(res) => {
201 act.id = res;
202 act.app.clone().extensions.read().call_connected(act, ctx);
203 debug!("Session started {} {}", act.id, act.ip);
204 }
205 _ => {
207 counter!("nostr_relay_session_stop_total", "reason" => "server error")
208 .increment(1);
209 ctx.stop()
210 }
211 }
212 fut::ready(())
213 })
214 .wait(ctx);
215 }
216
217 fn stopping(&mut self, _: &mut Self::Context) -> Running {
218 self.server.do_send(Disconnect { id: self.id });
220 Running::Stop
221 }
222
223 fn stopped(&mut self, ctx: &mut Self::Context) {
224 gauge!("nostr_relay_session").decrement(1.0);
225 self.app
226 .clone()
227 .extensions
228 .read()
229 .call_disconnected(self, ctx);
230 debug!("Session stopped {} {}", self.id, self.ip);
231 }
232}
233
234impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for Session {
236 fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
237 if !matches!(msg, Ok(Message::Text(_)) | Ok(Message::Continuation(_))) {
239 debug!("Session message {} {} {:?}", self.id, self.ip, msg);
240 }
241 let msg = match msg {
242 Err(err) => {
243 match err {
244 ws::ProtocolError::Overflow => {
245 ctx.text(OutgoingMessage::notice("payload reached size limit."));
246 }
247 _ => {
248 debug!("Session error {} {} {:?}", self.id, self.ip, err);
249 counter!("nostr_relay_session_stop_total", "reason" => "message error")
250 .increment(1);
251 ctx.stop();
252 }
253 }
254 return;
255 }
256 Ok(msg) => msg,
257 };
258 match msg {
259 ws::Message::Ping(msg) => {
260 self.hb = Instant::now();
261 ctx.pong(&msg);
262 }
263 ws::Message::Pong(_) => {
264 self.hb = Instant::now();
265 }
266 ws::Message::Text(text) => {
267 let text = text.to_string();
268 debug!("Session text {} {} {}", self.id, self.ip, text);
269 self.handle_message(text, ctx);
270 }
271 ws::Message::Close(reason) => {
272 ctx.close(reason);
273 counter!("nostr_relay_session_stop_total", "reason" => "message close")
274 .increment(1);
275 ctx.stop();
276 }
277 ws::Message::Binary(_) => {
278 ctx.text(OutgoingMessage::notice("Not support binary message"));
279 }
280 ws::Message::Continuation(cont) => match cont {
281 Item::FirstText(buf) => {
282 let mut bytes = BytesMut::new();
283 bytes.extend_from_slice(&buf);
284 self.cont = Some(bytes);
285 }
286 Item::FirstBinary(_) => {
287 ctx.text(OutgoingMessage::notice("Not support binary message"));
288 }
289 Item::Continue(buf) => {
290 if let Some(bytes) = &mut self.cont {
291 bytes.extend_from_slice(&buf);
292 }
293 }
294 Item::Last(buf) => {
295 if let Some(mut bytes) = self.cont.take() {
296 bytes.extend_from_slice(&buf);
297 if let Ok(text) = String::from_utf8(bytes.to_vec()) {
298 debug!("Session text {} {} {}", self.id, self.ip, text);
299 self.handle_message(text, ctx);
300 }
301 }
302 }
303 },
304 ws::Message::Nop => (),
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::{create_test_app, Extension, ExtensionMessageResult};
313 use actix_rt::time::sleep;
314 use actix_web_actors::ws;
315 use anyhow::Result;
316 use bytes::Bytes;
317 use futures_util::{SinkExt as _, StreamExt as _};
318
319 #[actix_rt::test]
320 async fn pingpong() -> Result<()> {
321 let mut srv = actix_test::start(|| {
322 let data = create_test_app("session").unwrap();
323 data.web_app()
324 });
325
326 let mut framed = srv.ws_at("/").await.unwrap();
328
329 framed.send(ws::Message::Ping("text".into())).await?;
330 let item = framed.next().await.unwrap()?;
331 assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text")));
332
333 framed
334 .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
335 .await?;
336 let item = framed.next().await.unwrap()?;
337 assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
338 Ok(())
339 }
340
341 #[actix_rt::test]
342 async fn heartbeat() -> Result<()> {
343 let mut srv = actix_test::start(|| {
344 let data = create_test_app("session").unwrap();
345 {
346 let mut w = data.setting.write();
347 w.network.heartbeat_interval = Duration::from_secs(1).try_into().unwrap();
348 w.network.heartbeat_timeout = Duration::from_secs(20).try_into().unwrap();
349 }
350 data.web_app()
351 });
352
353 let mut framed = srv.ws_at("/").await.unwrap();
355
356 sleep(Duration::from_secs(3)).await;
357 let item = framed.next().await.unwrap()?;
358 assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
359
360 let item = framed.next().await.unwrap()?;
361 assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
362
363 framed
364 .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
365 .await?;
366 Ok(())
367 }
368
369 #[actix_rt::test]
370 async fn heartbeat_timeout() -> Result<()> {
371 let mut srv = actix_test::start(|| {
372 let data = create_test_app("session").unwrap();
373 {
374 let mut w = data.setting.write();
375 w.network.heartbeat_interval = Duration::from_secs(1).try_into().unwrap();
376 w.network.heartbeat_timeout = Duration::from_secs(2).try_into().unwrap();
377 }
378 data.web_app()
379 });
380 let mut framed = srv.ws_at("/").await.unwrap();
382
383 sleep(Duration::from_secs(3)).await;
384 let item = framed.next().await.unwrap()?;
385 assert_eq!(item, ws::Frame::Ping(Bytes::copy_from_slice(b"")));
386 let item = framed.next().await;
387 assert!(item.is_none());
388 Ok(())
389 }
390
391 struct Echo;
392 impl Extension for Echo {
393 fn message(
394 &self,
395 msg: ClientMessage,
396 _session: &mut Session,
397 _ctx: &mut <Session as actix::Actor>::Context,
398 ) -> ExtensionMessageResult {
399 ExtensionMessageResult::Stop(OutgoingMessage(msg.text))
400 }
401
402 fn name(&self) -> &'static str {
403 "Echo"
404 }
405 }
406
407 #[actix_rt::test]
408 async fn extension() -> Result<()> {
409 let text = r#"["REQ", "1", {}]"#;
410 let mut srv = actix_test::start(|| {
411 let data = create_test_app("extension").unwrap();
412 data.add_extension(Echo).web_app()
413 });
414 let mut framed = srv.ws_at("/").await.unwrap();
415 framed.send(ws::Message::Text(text.into())).await?;
416 let item = framed.next().await.unwrap()?;
417 assert_eq!(
418 item,
419 ws::Frame::Text(Bytes::copy_from_slice(text.as_bytes()))
420 );
421 Ok(())
422 }
423
424 #[actix_rt::test]
425 async fn max_size() -> Result<()> {
426 let text = r#"["REQ", "1", {}]"#;
427 let max_size = text.len() + 1;
428 let mut srv = actix_test::start(move || {
429 let data = create_test_app("max_size").unwrap();
430 {
431 let mut w = data.setting.write();
432 w.limitation.max_message_length = max_size;
433 }
434 data.add_extension(Echo).web_app()
435 });
436 let mut framed = srv.ws_at("/").await.unwrap();
437 framed.send(ws::Message::Text(text.into())).await?;
438 let item = framed.next().await.unwrap()?;
439 assert_eq!(
440 item,
441 ws::Frame::Text(Bytes::copy_from_slice(text.as_bytes()))
442 );
443
444 framed
445 .send(ws::Message::Text(format!("{} ", text).into()))
446 .await?;
447 let item = framed.next().await.unwrap()?;
448 assert_eq!(
449 item,
450 ws::Frame::Text(Bytes::copy_from_slice(
451 br#"["NOTICE","payload reached size limit."]"#
452 ))
453 );
454 Ok(())
455 }
456
457 #[actix_rt::test]
458 async fn continuation() -> Result<()> {
459 let text = br#"["REQ", "1", {}]"#;
460
461 let mut srv = actix_test::start(|| {
462 let data = create_test_app("extension").unwrap();
463 data.add_extension(Echo).web_app()
464 });
465 let mut framed = srv.ws_at("/").await.unwrap();
466 framed
467 .send(ws::Message::Continuation(Item::FirstText(
468 Bytes::copy_from_slice(&text[0..2]),
469 )))
470 .await?;
471
472 framed
473 .send(ws::Message::Continuation(Item::Continue(
474 Bytes::copy_from_slice(&text[2..4]),
475 )))
476 .await?;
477 framed
478 .send(ws::Message::Continuation(Item::Last(
479 Bytes::copy_from_slice(&text[4..]),
480 )))
481 .await?;
482
483 let item = framed.next().await.unwrap()?;
484 assert_eq!(item, ws::Frame::Text(Bytes::copy_from_slice(text)));
485
486 Ok(())
487 }
488}