gnostr_relay/
message.rs

1use actix::{Message, MessageResponse, Recipient};
2use bytestring::ByteString;
3use nostr_db::{now, CheckEventResult, Event, Filter};
4use serde::{
5    de::{self, SeqAccess, Visitor},
6    Deserialize, Deserializer,
7};
8use serde_json::{json, Value};
9use std::fmt::Display;
10use std::{fmt, marker::PhantomData};
11
12use crate::{setting::Limitation, Error};
13
14/// New session is created
15#[derive(Message, Clone, Debug)]
16#[rtype(usize)]
17pub struct Connect {
18    pub addr: Recipient<OutgoingMessage>,
19}
20
21/// Session is disconnected
22#[derive(Message, Clone, Debug)]
23#[rtype(result = "()")]
24pub struct Disconnect {
25    pub id: usize,
26}
27
28/// Message from client
29#[derive(Message, Clone, Debug)]
30#[rtype(result = "()")]
31pub struct ClientMessage {
32    /// Id of the client session
33    pub id: usize,
34    /// text message
35    pub text: String,
36    /// parsed message
37    pub msg: IncomingMessage,
38    /// is nip70 checked
39    pub nip70_checked: bool,
40}
41
42impl ClientMessage {
43    pub fn new(id: usize, text: String, msg: IncomingMessage) -> Self {
44        Self {
45            id,
46            text,
47            msg,
48            nip70_checked: false,
49        }
50    }
51}
52
53macro_rules! check_max {
54    ($source:expr, $limit:expr) => {
55        if $source > $limit {
56            return Err(Error::Invalid(format!("{} {}", stringify!($limit), $limit)));
57        }
58    };
59}
60
61macro_rules! check_min {
62    ($source:expr, $limit:expr) => {
63        if $source < $limit {
64            return Err(Error::Invalid(format!("{} {}", stringify!($limit), $limit)));
65        }
66    };
67}
68
69impl ClientMessage {
70    pub fn validate_nip70(&self) -> Result<(), Error> {
71        if !self.nip70_checked {
72            if let IncomingMessage::Event(event) = &self.msg {
73                for tag in event.tags() {
74                    if tag.len() == 1 && tag[0] == "-" {
75                        return Err(Error::Message(
76                            "blocked: event marked as protected".to_owned(),
77                        ));
78                    }
79                }
80            }
81        }
82        Ok(())
83    }
84
85    pub fn validate(&mut self, limitation: &Limitation) -> Result<(), Error> {
86        check_max!(self.text.as_bytes().len(), limitation.max_message_length);
87
88        match &mut self.msg {
89            IncomingMessage::Event(event) => {
90                check_max!(event.tags().len(), limitation.max_event_tags);
91                event.validate(
92                    now(),
93                    limitation.max_event_time_older_than_now,
94                    limitation.max_event_time_newer_than_now,
95                )?;
96            }
97
98            IncomingMessage::Req(sub) => {
99                check_max!(sub.filters.len(), limitation.max_filters);
100                check_max!(sub.id.len(), limitation.max_subid_length);
101
102                for f in &mut sub.filters {
103                    // Fill default limit, Override the incoming limit if it is too large
104                    if let Some(limit) = f.limit {
105                        if limit > limitation.max_limit {
106                            f.limit = Some(limitation.max_limit);
107                        }
108                    } else {
109                        f.limit = Some(limitation.max_limit);
110                    }
111                    for id in f.ids.iter() {
112                        check_min!(id.len(), limitation.min_prefix);
113                    }
114                }
115            }
116            _ => {}
117        }
118        Ok(())
119    }
120}
121
122// #[derive(Deserialize, Clone, Debug)]
123// #[serde(rename_all = "UPPERCASE", tag = "0")]
124// pub enum IncomingMessage {
125//     Event {
126//         event: Event,
127//     },
128//     Close {
129//         id: String,
130//     },
131//     Req(Subscription),
132//     #[serde(other, deserialize_with = "ignore_contents")]
133//     Unknown,
134// }
135
136/// Parsed incoming messages from a client
137#[derive(Clone, Debug)]
138pub enum IncomingMessage {
139    Event(Event),
140    Close(String),
141    Req(Subscription),
142    /// nip-42
143    Auth(Event),
144    /// nip-45
145    Count(Subscription),
146    Unknown(String, Vec<Value>),
147}
148
149impl IncomingMessage {
150    pub fn command(&self) -> &str {
151        match self {
152            IncomingMessage::Event(_) => "EVENT",
153            IncomingMessage::Close(_) => "CLOSE",
154            IncomingMessage::Req(_) => "REQ",
155            IncomingMessage::Auth(_) => "AUTH",
156            IncomingMessage::Count(_) => "COUNT",
157            IncomingMessage::Unknown(cmd, _) => cmd,
158        }
159    }
160
161    pub fn known_command(&self) -> Option<&'static str> {
162        match self {
163            IncomingMessage::Event(_) => Some("EVENT"),
164            IncomingMessage::Close(_) => Some("CLOSE"),
165            IncomingMessage::Req(_) => Some("REQ"),
166            IncomingMessage::Auth(_) => Some("AUTH"),
167            IncomingMessage::Count(_) => Some("COUNT"),
168            IncomingMessage::Unknown(_, _) => None,
169        }
170    }
171}
172
173// https://github.com/serde-rs/serde/issues/1337
174
175struct MessageVisitor(PhantomData<()>);
176
177impl<'de> Visitor<'de> for MessageVisitor {
178    type Value = IncomingMessage;
179
180    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
181        formatter.write_str("sequence")
182    }
183
184    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
185    where
186        A: SeqAccess<'de>,
187    {
188        let t: &str = seq
189            .next_element()?
190            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
191        match t {
192            "EVENT" => Ok(IncomingMessage::Event(
193                seq.next_element()?
194                    .ok_or_else(|| de::Error::invalid_length(0, &self))?,
195            )),
196            "CLOSE" => Ok(IncomingMessage::Close(
197                seq.next_element()?
198                    .ok_or_else(|| de::Error::invalid_length(0, &self))?,
199            )),
200            "REQ" => {
201                let t = seq
202                    .next_element()?
203                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
204                let r = Vec::<Filter>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
205                Ok(IncomingMessage::Req(Subscription { id: t, filters: r }))
206            }
207            "AUTH" => Ok(IncomingMessage::Auth(
208                seq.next_element()?
209                    .ok_or_else(|| de::Error::invalid_length(0, &self))?,
210            )),
211            "COUNT" => {
212                let t = seq
213                    .next_element()?
214                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
215                let r = Vec::<Filter>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
216                Ok(IncomingMessage::Count(Subscription { id: t, filters: r }))
217            }
218            _ => Ok(IncomingMessage::Unknown(
219                t.to_string(),
220                Vec::<Value>::deserialize(de::value::SeqAccessDeserializer::new(seq))?,
221            )),
222        }
223    }
224}
225
226impl<'de> Deserialize<'de> for IncomingMessage {
227    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
228    where
229        D: Deserializer<'de>,
230    {
231        deserializer.deserialize_seq(MessageVisitor(PhantomData))
232    }
233}
234
235// fn ignore_contents<'de, D>(deserializer: D) -> Result<(), D::Error>
236// where
237//     D: Deserializer<'de>,
238// {
239//     // Ignore any content at this part of the json structure
240//     let _ = deserializer.deserialize_ignored_any(serde::de::IgnoredAny);
241//     // Return unit as our 'Unknown' variant has no args
242//     Ok(())
243// }
244
245/// Subscription
246#[derive(Clone, Debug)]
247pub struct Subscription {
248    pub id: String,
249    pub filters: Vec<Filter>,
250}
251
252// https://github.com/serde-rs/serde/issues/1337
253// prefix
254// impl<'de> Deserialize<'de> for Subscription {
255//     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
256//     where
257//         D: Deserializer<'de>,
258//     {
259//         struct PrefixVisitor(PhantomData<()>);
260
261//         impl<'de> Visitor<'de> for PrefixVisitor {
262//             type Value = Subscription;
263
264//             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
265//                 formatter.write_str("sequence")
266//             }
267
268//             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
269//             where
270//                 A: SeqAccess<'de>,
271//             {
272//                 let t = seq
273//                     .next_element()?
274//                     .ok_or_else(|| de::Error::invalid_length(0, &self))?;
275//                 let r = Vec::<Filter>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
276//                 Ok(Subscription { id: t, filters: r })
277//             }
278//         }
279
280//         deserializer.deserialize_seq(PrefixVisitor(PhantomData))
281//     }
282// }
283
284/// The message sent to the client
285#[derive(Message, Clone, Debug)]
286#[rtype(result = "()")]
287pub struct OutgoingMessage(pub String);
288
289impl OutgoingMessage {
290    pub fn notice(message: &str) -> Self {
291        Self(json!(["NOTICE", message]).to_string())
292    }
293
294    pub fn closed(sub_id: &str, message: &str) -> Self {
295        Self(json!(["CLOSED", sub_id, message]).to_string())
296    }
297
298    pub fn eose(sub_id: &str) -> Self {
299        Self(format!(r#"["EOSE","{}"]"#, sub_id))
300    }
301
302    pub fn event(sub_id: &str, event: &str) -> Self {
303        Self(format!(r#"["EVENT","{}",{}]"#, sub_id, event))
304    }
305
306    pub fn ok(event_id: &str, saved: bool, message: &str) -> Self {
307        Self(json!(["OK", event_id, saved, message]).to_string())
308    }
309}
310
311impl Display for OutgoingMessage {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.write_str(&self.0)?;
314        Ok(())
315    }
316}
317
318// impl Into<ByteString> for OutgoingMessage {
319//     fn into(self) -> ByteString {
320//         ByteString::from(self.0)
321//     }
322// }
323
324impl From<OutgoingMessage> for ByteString {
325    fn from(val: OutgoingMessage) -> Self {
326        ByteString::from(val.0)
327    }
328}
329
330#[derive(Message, Clone, Debug)]
331#[rtype(result = "()")]
332pub struct WriteEvent {
333    pub id: usize,
334    pub event: Event,
335}
336
337#[derive(Message, Clone, Debug)]
338#[rtype(result = "()")]
339pub enum WriteEventResult {
340    Write {
341        id: usize,
342        event: Event,
343        result: CheckEventResult,
344    },
345    Message {
346        id: usize,
347        event: Event,
348        msg: OutgoingMessage,
349    },
350}
351// pub struct WriteEventResult {
352//     pub id: usize,
353//     pub event: Event,
354//     pub result: CheckEventResult,
355// }
356
357#[derive(Message, Clone, Debug)]
358#[rtype(result = "()")]
359pub struct ReadEvent {
360    pub id: usize,
361    pub subscription: Subscription,
362}
363
364#[derive(Message, Clone, Debug)]
365#[rtype(result = "()")]
366pub struct ReadEventResult {
367    pub id: usize,
368    pub sub_id: String,
369    pub msg: OutgoingMessage,
370}
371
372#[derive(MessageResponse, Clone, Debug, PartialEq, Eq)]
373pub enum Subscribed {
374    Ok,
375    Overlimit,
376    InvalidIdLength,
377}
378
379#[derive(Message, Clone, Debug)]
380#[rtype(result = "Subscribed")]
381pub struct Subscribe {
382    pub id: usize,
383    pub subscription: Subscription,
384}
385
386#[derive(Message, Clone, Debug)]
387#[rtype(result = "()")]
388pub struct Unsubscribe {
389    pub id: usize,
390    pub sub_id: Option<String>,
391}
392
393#[derive(Message, Clone, Debug)]
394#[rtype(result = "()")]
395pub struct Dispatch {
396    pub id: usize,
397    pub event: Event,
398}
399
400#[derive(Message, Clone, Debug)]
401#[rtype(result = "()")]
402pub struct SubscribeResult {
403    pub id: usize,
404    pub sub_id: String,
405    pub msg: OutgoingMessage,
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use anyhow::Result;
412
413    #[test]
414    fn de_incoming_message() -> Result<()> {
415        // close
416        let msg: IncomingMessage = serde_json::from_str(r#"["CLOSE", "sub_id1"]"#)?;
417        assert!(matches!(msg, IncomingMessage::Close(ref id) if id == "sub_id1"));
418
419        let msg = serde_json::from_str::<IncomingMessage>(r#"["CLOSE", "sub_id1", "other"]"#);
420        assert!(msg.is_err());
421
422        // event
423        let msg: IncomingMessage = serde_json::from_str(
424            r#"["EVENT", {
425            "content": "Good morning everyone 😃",
426            "created_at": 1680690006,
427            "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
428            "kind": 1,
429            "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
430            "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
431            "tags": [["t", "nostr"], ["t", ""], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
432          }]"#,
433        )?;
434        assert!(matches!(msg, IncomingMessage::Event( ref event ) if event.kind() == 1));
435
436        // let sub: Subscription = serde_json::from_str(r#"["sub_id1", {}, {}]"#)?;
437        // assert_eq!(sub.id, "sub_id1");
438        // assert_eq!(sub.filters.len(), 2);
439
440        // req
441        let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {}]"#)?;
442        assert!(matches!(msg, IncomingMessage::Req(sub) if sub.id == "sub_id1"));
443        let msg = serde_json::from_str::<IncomingMessage>(r#"["REQ", "sub_id1", ""]"#);
444        assert!(msg.is_err());
445        let msg = serde_json::from_str::<IncomingMessage>(r#"["REQ", "sub_id1"]"#);
446        assert!(msg.is_ok());
447
448        // unknown
449        let msg: IncomingMessage = serde_json::from_str(r#"["REQ1", "sub_id1", {}]"#)?;
450        assert!(matches!(msg, IncomingMessage::Unknown(ref cmd, ref _val) if cmd == "REQ1"));
451
452        // auth
453        let msg: IncomingMessage = serde_json::from_str(
454            r#"["AUTH", {
455    "content": "Good morning everyone 😃",
456    "created_at": 1680690006,
457    "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
458    "kind": 1,
459    "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
460    "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
461    "tags": [["t", "nostr"], ["t", ""], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
462  }]"#,
463        )?;
464        assert!(matches!(msg, IncomingMessage::Auth( ref event ) if event.kind() == 1));
465
466        // count
467        let msg: IncomingMessage = serde_json::from_str(r#"["COUNT", "sub_id1", {}]"#)?;
468        assert!(matches!(msg, IncomingMessage::Count(sub) if sub.id == "sub_id1"));
469
470        Ok(())
471    }
472
473    #[test]
474    fn se_outgoing_message() -> Result<()> {
475        let msg = OutgoingMessage::notice("hello");
476        let json = msg.to_string();
477        assert_eq!(json, r#"["NOTICE","hello"]"#);
478        let msg = OutgoingMessage::event("id", r#"{"id":"1"}"#);
479        let json = msg.to_string();
480        assert_eq!(json, r#"["EVENT","id",{"id":"1"}]"#);
481        let msg = OutgoingMessage::eose("hello");
482        let json = msg.to_string();
483        assert_eq!(json, r#"["EOSE","hello"]"#);
484
485        let msg = OutgoingMessage::closed("1", "hello");
486        let json = msg.to_string();
487        assert_eq!(json, r#"["CLOSED","1","hello"]"#);
488
489        // let event = Event::default();
490        // let msg = OutgoingMessage("id".to_owned(), Some(event));
491        // let json = msg.to_string();
492        // assert!(json.starts_with(r#"["EVENT","id",{"#));
493        Ok(())
494    }
495
496    #[test]
497    fn validate() -> Result<()> {
498        let msg: IncomingMessage = serde_json::from_str(
499            r#"["EVENT", {
500            "content": "Good morning everyone 😃",
501            "created_at": 1680690006,
502            "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
503            "kind": 1,
504            "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
505            "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
506            "tags": [["t", "nostr"], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
507          }]"#,
508        )?;
509        let msg = ClientMessage::new(1, "text".to_string(), msg);
510        assert!(msg.validate_nip70().is_ok());
511
512        let msg: IncomingMessage = serde_json::from_str(
513            r#"["EVENT", {
514            "content": "Good morning everyone 😃",
515            "created_at": 1680690006,
516            "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
517            "kind": 1,
518            "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
519            "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
520            "tags": [["t", "nostr"], ["-"], ["expiration", "1"], ["delegation", "8e0d3d3eb2881ec137a11debe736a9086715a8c8beeeda615780064d68bc25dd"]]
521          }]"#,
522        )?;
523        let msg = ClientMessage::new(1, "text".to_string(), msg);
524        assert!(msg.validate_nip70().is_err());
525
526        let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {}]"#)?;
527        let mut msg = ClientMessage::new(1, "text".to_string(), msg);
528        let mut limitation = Limitation::default();
529        limitation.max_limit = 300;
530
531        msg.validate(&limitation).unwrap();
532        assert!(
533            matches!(msg.msg, IncomingMessage::Req(sub) if sub.filters.get(0).unwrap().limit.unwrap() == 300)
534        );
535
536        let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {"limit": 400}]"#)?;
537        let mut msg = ClientMessage::new(1, "text".to_string(), msg);
538        msg.validate(&limitation).unwrap();
539        assert!(
540            matches!(msg.msg, IncomingMessage::Req(sub) if sub.filters.get(0).unwrap().limit.unwrap() == 300)
541        );
542
543        let msg: IncomingMessage = serde_json::from_str(r#"["REQ", "sub_id1", {"limit": 200}]"#)?;
544        let mut msg = ClientMessage::new(1, "text".to_string(), msg);
545        msg.validate(&limitation).unwrap();
546        assert!(
547            matches!(msg.msg, IncomingMessage::Req(sub) if sub.filters.get(0).unwrap().limit.unwrap() == 200)
548        );
549
550        Ok(())
551    }
552}