1use actix::{Message, MessageResponse, Recipient};
2use bytestring::ByteString;
3use nostr_db::{now, CheckEventResult, Event, Filter};
4use serde::{
5 de::{self, SeqAccess, Visitor},
6 Deserialize, Deserializer,
7};
8use serde_json::{json, Value};
9use std::fmt::Display;
10use std::{fmt, marker::PhantomData};
11
12use crate::{setting::Limitation, Error};
13
14#[derive(Message, Clone, Debug)]
16#[rtype(usize)]
17pub struct Connect {
18 pub addr: Recipient<OutgoingMessage>,
19}
20
21#[derive(Message, Clone, Debug)]
23#[rtype(result = "()")]
24pub struct Disconnect {
25 pub id: usize,
26}
27
28#[derive(Message, Clone, Debug)]
30#[rtype(result = "()")]
31pub struct ClientMessage {
32 pub id: usize,
34 pub text: String,
36 pub msg: IncomingMessage,
38 pub nip70_checked: bool,
40}
41
42impl ClientMessage {
43 pub fn new(id: usize, text: String, msg: IncomingMessage) -> Self {
44 Self {
45 id,
46 text,
47 msg,
48 nip70_checked: false,
49 }
50 }
51}
52
53macro_rules! check_max {
54 ($source:expr, $limit:expr) => {
55 if $source > $limit {
56 return Err(Error::Invalid(format!("{} {}", stringify!($limit), $limit)));
57 }
58 };
59}
60
61macro_rules! check_min {
62 ($source:expr, $limit:expr) => {
63 if $source < $limit {
64 return Err(Error::Invalid(format!("{} {}", stringify!($limit), $limit)));
65 }
66 };
67}
68
69impl ClientMessage {
70 pub fn validate_nip70(&self) -> Result<(), Error> {
71 if !self.nip70_checked {
72 if let IncomingMessage::Event(event) = &self.msg {
73 for tag in event.tags() {
74 if tag.len() == 1 && tag[0] == "-" {
75 return Err(Error::Message(
76 "blocked: event marked as protected".to_owned(),
77 ));
78 }
79 }
80 }
81 }
82 Ok(())
83 }
84
85 pub fn validate(&mut self, limitation: &Limitation) -> Result<(), Error> {
86 check_max!(self.text.as_bytes().len(), limitation.max_message_length);
87
88 match &mut self.msg {
89 IncomingMessage::Event(event) => {
90 check_max!(event.tags().len(), limitation.max_event_tags);
91 event.validate(
92 now(),
93 limitation.max_event_time_older_than_now,
94 limitation.max_event_time_newer_than_now,
95 )?;
96 }
97
98 IncomingMessage::Req(sub) => {
99 check_max!(sub.filters.len(), limitation.max_filters);
100 check_max!(sub.id.len(), limitation.max_subid_length);
101
102 for f in &mut sub.filters {
103 if let Some(limit) = f.limit {
105 if limit > limitation.max_limit {
106 f.limit = Some(limitation.max_limit);
107 }
108 } else {
109 f.limit = Some(limitation.max_limit);
110 }
111 for id in f.ids.iter() {
112 check_min!(id.len(), limitation.min_prefix);
113 }
114 }
115 }
116 _ => {}
117 }
118 Ok(())
119 }
120}
121
122#[derive(Clone, Debug)]
138pub enum IncomingMessage {
139 Event(Event),
140 Close(String),
141 Req(Subscription),
142 Auth(Event),
144 Count(Subscription),
146 Unknown(String, Vec<Value>),
147}
148
149impl IncomingMessage {
150 pub fn command(&self) -> &str {
151 match self {
152 IncomingMessage::Event(_) => "EVENT",
153 IncomingMessage::Close(_) => "CLOSE",
154 IncomingMessage::Req(_) => "REQ",
155 IncomingMessage::Auth(_) => "AUTH",
156 IncomingMessage::Count(_) => "COUNT",
157 IncomingMessage::Unknown(cmd, _) => cmd,
158 }
159 }
160
161 pub fn known_command(&self) -> Option<&'static str> {
162 match self {
163 IncomingMessage::Event(_) => Some("EVENT"),
164 IncomingMessage::Close(_) => Some("CLOSE"),
165 IncomingMessage::Req(_) => Some("REQ"),
166 IncomingMessage::Auth(_) => Some("AUTH"),
167 IncomingMessage::Count(_) => Some("COUNT"),
168 IncomingMessage::Unknown(_, _) => None,
169 }
170 }
171}
172
173struct MessageVisitor(PhantomData<()>);
176
177impl<'de> Visitor<'de> for MessageVisitor {
178 type Value = IncomingMessage;
179
180 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
181 formatter.write_str("sequence")
182 }
183
184 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
185 where
186 A: SeqAccess<'de>,
187 {
188 let t: &str = seq
189 .next_element()?
190 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
191 match t {
192 "EVENT" => Ok(IncomingMessage::Event(
193 seq.next_element()?
194 .ok_or_else(|| de::Error::invalid_length(0, &self))?,
195 )),
196 "CLOSE" => Ok(IncomingMessage::Close(
197 seq.next_element()?
198 .ok_or_else(|| de::Error::invalid_length(0, &self))?,
199 )),
200 "REQ" => {
201 let t = seq
202 .next_element()?
203 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
204 let r = Vec::<Filter>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
205 Ok(IncomingMessage::Req(Subscription { id: t, filters: r }))
206 }
207 "AUTH" => Ok(IncomingMessage::Auth(
208 seq.next_element()?
209 .ok_or_else(|| de::Error::invalid_length(0, &self))?,
210 )),
211 "COUNT" => {
212 let t = seq
213 .next_element()?
214 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
215 let r = Vec::<Filter>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
216 Ok(IncomingMessage::Count(Subscription { id: t, filters: r }))
217 }
218 _ => Ok(IncomingMessage::Unknown(
219 t.to_string(),
220 Vec::<Value>::deserialize(de::value::SeqAccessDeserializer::new(seq))?,
221 )),
222 }
223 }
224}
225
226impl<'de> Deserialize<'de> for IncomingMessage {
227 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
228 where
229 D: Deserializer<'de>,
230 {
231 deserializer.deserialize_seq(MessageVisitor(PhantomData))
232 }
233}
234
235#[derive(Clone, Debug)]
247pub struct Subscription {
248 pub id: String,
249 pub filters: Vec<Filter>,
250}
251
252#[derive(Message, Clone, Debug)]
286#[rtype(result = "()")]
287pub struct OutgoingMessage(pub String);
288
289impl OutgoingMessage {
290 pub fn notice(message: &str) -> Self {
291 Self(json!(["NOTICE", message]).to_string())
292 }
293
294 pub fn closed(sub_id: &str, message: &str) -> Self {
295 Self(json!(["CLOSED", sub_id, message]).to_string())
296 }
297
298 pub fn eose(sub_id: &str) -> Self {
299 Self(format!(r#"["EOSE","{}"]"#, sub_id))
300 }
301
302 pub fn event(sub_id: &str, event: &str) -> Self {
303 Self(format!(r#"["EVENT","{}",{}]"#, sub_id, event))
304 }
305
306 pub fn ok(event_id: &str, saved: bool, message: &str) -> Self {
307 Self(json!(["OK", event_id, saved, message]).to_string())
308 }
309}
310
311impl Display for OutgoingMessage {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.write_str(&self.0)?;
314 Ok(())
315 }
316}
317
318impl From<OutgoingMessage> for ByteString {
325 fn from(val: OutgoingMessage) -> Self {
326 ByteString::from(val.0)
327 }
328}
329
330#[derive(Message, Clone, Debug)]
331#[rtype(result = "()")]
332pub struct WriteEvent {
333 pub id: usize,
334 pub event: Event,
335}
336
337#[derive(Message, Clone, Debug)]
338#[rtype(result = "()")]
339pub enum WriteEventResult {
340 Write {
341 id: usize,
342 event: Event,
343 result: CheckEventResult,
344 },
345 Message {
346 id: usize,
347 event: Event,
348 msg: OutgoingMessage,
349 },
350}
351#[derive(Message, Clone, Debug)]
358#[rtype(result = "()")]
359pub struct ReadEvent {
360 pub id: usize,
361 pub subscription: Subscription,
362}
363
364#[derive(Message, Clone, Debug)]
365#[rtype(result = "()")]
366pub struct ReadEventResult {
367 pub id: usize,
368 pub sub_id: String,
369 pub msg: OutgoingMessage,
370}
371
372#[derive(MessageResponse, Clone, Debug, PartialEq, Eq)]
373pub enum Subscribed {
374 Ok,
375 Overlimit,
376 InvalidIdLength,
377}
378
379#[derive(Message, Clone, Debug)]
380#[rtype(result = "Subscribed")]
381pub struct Subscribe {
382 pub id: usize,
383 pub subscription: Subscription,
384}
385
386#[derive(Message, Clone, Debug)]
387#[rtype(result = "()")]
388pub struct Unsubscribe {
389 pub id: usize,
390 pub sub_id: Option<String>,
391}
392
393#[derive(Message, Clone, Debug)]
394#[rtype(result = "()")]
395pub struct Dispatch {
396 pub id: usize,
397 pub event: Event,
398}
399
400#[derive(Message, Clone, Debug)]
401#[rtype(result = "()")]
402pub struct SubscribeResult {
403 pub id: usize,
404 pub sub_id: String,
405 pub msg: OutgoingMessage,
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use anyhow::Result;
412
413 #[test]
414 fn de_incoming_message() -> Result<()> {
415 let msg: IncomingMessage = serde_json::from_str(r#"["CLOSE", "sub_id1"]"#)?;
417 assert!(matches!(msg, IncomingMessage::Close(ref id) if id == "sub_id1"));
418
419 let msg = serde_json::from_str::<IncomingMessage>(r#"["CLOSE", "sub_id1", "other"]"#);
420 assert!(msg.is_err());
421
422 let msg: IncomingMessage = serde_json::from_str(
424 r#"["EVENT", {
425 "content": "Good morning everyone 😃",
426 "created_at": 1680690006,
427 "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
428 "kind": 1,
429 "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
430 "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
431 "tags": [["t", "nostr"], ["t", ""], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
432 }]"#,
433 )?;
434 assert!(matches!(msg, IncomingMessage::Event( ref event ) if event.kind() == 1));
435
436 let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {}]"#)?;
442 assert!(matches!(msg, IncomingMessage::Req(sub) if sub.id == "sub_id1"));
443 let msg = serde_json::from_str::<IncomingMessage>(r#"["REQ", "sub_id1", ""]"#);
444 assert!(msg.is_err());
445 let msg = serde_json::from_str::<IncomingMessage>(r#"["REQ", "sub_id1"]"#);
446 assert!(msg.is_ok());
447
448 let msg: IncomingMessage = serde_json::from_str(r#"["REQ1", "sub_id1", {}]"#)?;
450 assert!(matches!(msg, IncomingMessage::Unknown(ref cmd, ref _val) if cmd == "REQ1"));
451
452 let msg: IncomingMessage = serde_json::from_str(
454 r#"["AUTH", {
455 "content": "Good morning everyone 😃",
456 "created_at": 1680690006,
457 "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
458 "kind": 1,
459 "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
460 "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
461 "tags": [["t", "nostr"], ["t", ""], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
462 }]"#,
463 )?;
464 assert!(matches!(msg, IncomingMessage::Auth( ref event ) if event.kind() == 1));
465
466 let msg: IncomingMessage = serde_json::from_str(r#"["COUNT", "sub_id1", {}]"#)?;
468 assert!(matches!(msg, IncomingMessage::Count(sub) if sub.id == "sub_id1"));
469
470 Ok(())
471 }
472
473 #[test]
474 fn se_outgoing_message() -> Result<()> {
475 let msg = OutgoingMessage::notice("hello");
476 let json = msg.to_string();
477 assert_eq!(json, r#"["NOTICE","hello"]"#);
478 let msg = OutgoingMessage::event("id", r#"{"id":"1"}"#);
479 let json = msg.to_string();
480 assert_eq!(json, r#"["EVENT","id",{"id":"1"}]"#);
481 let msg = OutgoingMessage::eose("hello");
482 let json = msg.to_string();
483 assert_eq!(json, r#"["EOSE","hello"]"#);
484
485 let msg = OutgoingMessage::closed("1", "hello");
486 let json = msg.to_string();
487 assert_eq!(json, r#"["CLOSED","1","hello"]"#);
488
489 Ok(())
494 }
495
496 #[test]
497 fn validate() -> Result<()> {
498 let msg: IncomingMessage = serde_json::from_str(
499 r#"["EVENT", {
500 "content": "Good morning everyone 😃",
501 "created_at": 1680690006,
502 "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
503 "kind": 1,
504 "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
505 "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
506 "tags": [["t", "nostr"], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
507 }]"#,
508 )?;
509 let msg = ClientMessage::new(1, "text".to_string(), msg);
510 assert!(msg.validate_nip70().is_ok());
511
512 let msg: IncomingMessage = serde_json::from_str(
513 r#"["EVENT", {
514 "content": "Good morning everyone 😃",
515 "created_at": 1680690006,
516 "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
517 "kind": 1,
518 "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
519 "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
520 "tags": [["t", "nostr"], ["-"], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
521 }]"#,
522 )?;
523 let msg = ClientMessage::new(1, "text".to_string(), msg);
524 assert!(msg.validate_nip70().is_err());
525
526 let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {}]"#)?;
527 let mut msg = ClientMessage::new(1, "text".to_string(), msg);
528 let mut limitation = Limitation::default();
529 limitation.max_limit = 300;
530
531 msg.validate(&limitation).unwrap();
532 assert!(
533 matches!(msg.msg, IncomingMessage::Req(sub) if sub.filters.get(0).unwrap().limit.unwrap() == 300)
534 );
535
536 let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {"limit": 400}]"#)?;
537 let mut msg = ClientMessage::new(1, "text".to_string(), msg);
538 msg.validate(&limitation).unwrap();
539 assert!(
540 matches!(msg.msg, IncomingMessage::Req(sub) if sub.filters.get(0).unwrap().limit.unwrap() == 300)
541 );
542
543 let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {"limit": 200}]"#)?;
544 let mut msg = ClientMessage::new(1, "text".to_string(), msg);
545 msg.validate(&limitation).unwrap();
546 assert!(
547 matches!(msg.msg, IncomingMessage::Req(sub) if sub.filters.get(0).unwrap().limit.unwrap() == 200)
548 );
549
550 Ok(())
551 }
552}