Skip to main content

matrix_bot_sdk/
preprocessors.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5#[async_trait]
6pub trait IPreprocessor: Send + Sync {
7    async fn process(&self, event: &mut Value) -> anyhow::Result<()>;
8}
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct RichReplyInfo {
12    pub was_lenient: bool,
13    pub parent_event_id: String,
14    pub fallback_plain_body: String,
15    pub fallback_html_body: String,
16    pub fallback_sender: String,
17    pub real_event: Option<Value>,
18}
19
20#[derive(Default)]
21pub struct RichRepliesPreprocessor {
22    _fetch_real_event: bool,
23}
24
25impl RichRepliesPreprocessor {
26    pub fn new() -> Self {
27        Self {
28            _fetch_real_event: false,
29        }
30    }
31
32    pub fn with_fetch_real_event(fetch: bool) -> Self {
33        Self {
34            _fetch_real_event: fetch,
35        }
36    }
37
38    pub fn get_supported_event_types(&self) -> Vec<&'static str> {
39        vec!["m.room.message"]
40    }
41}
42
43#[async_trait]
44impl IPreprocessor for RichRepliesPreprocessor {
45    async fn process(&self, event: &mut Value) -> anyhow::Result<()> {
46        let Some(content) = event.get_mut("content") else {
47            return Ok(());
48        };
49
50        let relates_to = content
51            .get("m.relates_to")
52            .and_then(|r| r.get("m.in_reply_to"))
53            .and_then(|r| r.get("event_id"))
54            .and_then(Value::as_str);
55
56        let Some(parent_event_id) = relates_to else {
57            return Ok(());
58        };
59        let parent_event_id = parent_event_id.to_owned();
60
61        let body = content
62            .get("body")
63            .and_then(Value::as_str)
64            .map(ToOwned::to_owned)
65            .unwrap_or_default();
66
67        let formatted_body = content
68            .get("formatted_body")
69            .and_then(Value::as_str)
70            .map(ToOwned::to_owned);
71
72        let (fallback_plain_body, fallback_sender, reply_plain_body) =
73            Self::parse_plain_body(&body);
74
75        let (fallback_html_body, reply_html_body) =
76            Self::parse_html_body(&formatted_body.unwrap_or_default());
77
78        let was_lenient = fallback_plain_body.is_empty() && fallback_sender.is_empty();
79
80        if !was_lenient && let Some(obj) = content.as_object_mut() {
81            obj.insert("body".to_owned(), Value::String(reply_plain_body));
82            if let Some(html) = reply_html_body {
83                obj.insert(
84                    "format".to_owned(),
85                    Value::String("org.matrix.custom.html".to_owned()),
86                );
87                obj.insert("formatted_body".to_owned(), Value::String(html));
88            }
89        }
90
91        let info = RichReplyInfo {
92            was_lenient,
93            parent_event_id,
94            fallback_plain_body,
95            fallback_html_body,
96            fallback_sender,
97            real_event: None,
98        };
99
100        if let Some(obj) = event.as_object_mut() {
101            obj.insert("mx_richreply".to_owned(), serde_json::to_value(info)?);
102        }
103
104        Ok(())
105    }
106}
107
108impl RichRepliesPreprocessor {
109    fn parse_plain_body(body: &str) -> (String, String, String) {
110        if !body.starts_with("> <") {
111            return (String::new(), String::new(), body.to_owned());
112        }
113
114        let mut lines = body.lines();
115        let first_line = lines.next().unwrap_or("");
116
117        let sender_start = 3;
118        let sender_end = first_line[sender_start..]
119            .find('>')
120            .map(|pos| sender_start + pos)
121            .unwrap_or(first_line.len());
122
123        let sender = first_line[sender_start..sender_end].to_owned();
124
125        let mut fallback_lines = vec![first_line[sender_end + 2..].to_owned()];
126        let mut reply_lines = Vec::new();
127        let mut in_fallback = true;
128
129        for line in lines {
130            if in_fallback {
131                if let Some(stripped) = line.strip_prefix("> ") {
132                    fallback_lines.push(stripped.to_owned());
133                } else if line.is_empty() {
134                    in_fallback = false;
135                } else {
136                    reply_lines.push(line.to_owned());
137                }
138            } else {
139                reply_lines.push(line.to_owned());
140            }
141        }
142
143        let fallback_plain_body = format!("<{}> {}", sender, fallback_lines.join("\n"));
144        let reply_plain_body = reply_lines.join("\n");
145
146        (fallback_plain_body, sender, reply_plain_body)
147    }
148
149    fn parse_html_body(html: &str) -> (String, Option<String>) {
150        let mx_reply_end = html.find("</mx-reply>");
151        if let Some(end_pos) = mx_reply_end {
152            let fallback = html[..end_pos]
153                .replace("<mx-reply>", "")
154                .replace("</mx-reply>", "");
155
156            let reply_start = end_pos + "</mx-reply>".len();
157            let reply = html[reply_start..].to_owned();
158
159            let fallback_html = Self::extract_html_content(&fallback);
160            (fallback_html, Some(reply))
161        } else {
162            (String::new(), None)
163        }
164    }
165
166    fn extract_html_content(fallback: &str) -> String {
167        if let Some(start) = fallback.find("<br />") {
168            fallback[start + 6..].to_owned()
169        } else if let Some(start) = fallback.find("<br/>") {
170            fallback[start + 5..].to_owned()
171        } else {
172            fallback.to_owned()
173        }
174    }
175}