Skip to main content

matrix_bot_sdk/
helpers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use base64::Engine;
6use base64::engine::general_purpose::STANDARD_NO_PAD;
7use regex::Regex;
8use tokio::sync::RwLock;
9
10use crate::models::MatrixProfile;
11
12pub type UserId = String;
13pub type RoomId = String;
14pub type EventId = String;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum MatrixEntity {
18    UserId(String),
19    RoomId(String),
20    RoomAlias(String),
21    EventId(String),
22    Unknown(String),
23}
24
25pub fn parse_entity(value: &str) -> MatrixEntity {
26    if value.starts_with('@') {
27        MatrixEntity::UserId(value.to_owned())
28    } else if value.starts_with('!') {
29        MatrixEntity::RoomId(value.to_owned())
30    } else if value.starts_with('#') {
31        MatrixEntity::RoomAlias(value.to_owned())
32    } else if value.starts_with('$') {
33        MatrixEntity::EventId(value.to_owned())
34    } else {
35        MatrixEntity::Unknown(value.to_owned())
36    }
37}
38
39fn parse_localpart_and_domain(sigil: char, value: &str) -> Option<(String, String)> {
40    if !value.starts_with(sigil) {
41        return None;
42    }
43    let rest = &value[1..];
44    let colon_pos = rest.find(':')?;
45    let localpart = rest[..colon_pos].to_owned();
46    let domain = rest[colon_pos + 1..].to_owned();
47    Some((localpart, domain))
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct UserID {
52    value: String,
53    localpart: String,
54    domain: String,
55}
56
57impl UserID {
58    pub fn new(value: &str) -> Option<Self> {
59        let (localpart, domain) = parse_localpart_and_domain('@', value)?;
60        Some(Self {
61            value: value.to_owned(),
62            localpart,
63            domain,
64        })
65    }
66
67    pub fn localpart(&self) -> &str {
68        &self.localpart
69    }
70
71    pub fn domain(&self) -> &str {
72        &self.domain
73    }
74}
75
76impl std::fmt::Display for UserID {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.write_str(&self.value)
79    }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct RoomAlias {
84    value: String,
85    localpart: String,
86    domain: String,
87}
88
89impl RoomAlias {
90    pub fn new(value: &str) -> Option<Self> {
91        let (localpart, domain) = parse_localpart_and_domain('#', value)?;
92        Some(Self {
93            value: value.to_owned(),
94            localpart,
95            domain,
96        })
97    }
98
99    pub fn localpart(&self) -> &str {
100        &self.localpart
101    }
102
103    pub fn domain(&self) -> &str {
104        &self.domain
105    }
106}
107
108impl std::fmt::Display for RoomAlias {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.write_str(&self.value)
111    }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct RoomID {
116    value: String,
117    localpart: String,
118    domain: String,
119}
120
121impl RoomID {
122    pub fn new(value: &str) -> Option<Self> {
123        let (localpart, domain) = parse_localpart_and_domain('!', value)?;
124        Some(Self {
125            value: value.to_owned(),
126            localpart,
127            domain,
128        })
129    }
130
131    pub fn localpart(&self) -> &str {
132        &self.localpart
133    }
134
135    pub fn domain(&self) -> &str {
136        &self.domain
137    }
138}
139
140impl std::fmt::Display for RoomID {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        f.write_str(&self.value)
143    }
144}
145
146#[derive(Debug, Clone)]
147pub struct MatrixGlob {
148    pattern: String,
149    regex: Regex,
150}
151
152impl MatrixGlob {
153    pub fn new(pattern: impl Into<String>) -> anyhow::Result<Self> {
154        let pattern = pattern.into();
155        let escaped = regex::escape(&pattern)
156            .replace("\\*", ".*")
157            .replace("\\?", ".");
158        let regex = Regex::new(&format!("^{escaped}$"))?;
159        Ok(Self { pattern, regex })
160    }
161
162    pub fn is_match(&self, value: &str) -> bool {
163        self.regex.is_match(value)
164    }
165
166    pub fn pattern(&self) -> &str {
167        &self.pattern
168    }
169}
170
171#[derive(Debug, Clone)]
172pub struct MentionPillResult {
173    pub html: String,
174    pub text: String,
175}
176
177pub struct MentionPill;
178
179impl MentionPill {
180    pub fn html(user_id: &str, display_name: Option<&str>) -> String {
181        let label = display_name.unwrap_or(user_id);
182        format!(r#"<a href="https://matrix.to/#/{user_id}">{label}</a>"#)
183    }
184
185    pub fn for_user(user_id: &str) -> MentionPillResult {
186        MentionPillResult {
187            html: Self::html(user_id, None),
188            text: user_id.to_owned(),
189        }
190    }
191
192    pub fn for_user_with_display_name(user_id: &str, display_name: &str) -> MentionPillResult {
193        MentionPillResult {
194            html: Self::html(user_id, Some(display_name)),
195            text: display_name.to_owned(),
196        }
197    }
198
199    pub fn for_room(room_id_or_alias: &str) -> MentionPillResult {
200        MentionPillResult {
201            html: format!(
202                r#"<a href="https://matrix.to/#/{room_id_or_alias}">{room_id_or_alias}</a>"#
203            ),
204            text: room_id_or_alias.to_owned(),
205        }
206    }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Default)]
210pub struct PermalinkParts {
211    pub room_id_or_alias: Option<String>,
212    pub event_id: Option<String>,
213    pub user_id: Option<String>,
214    pub via_servers: Vec<String>,
215}
216
217pub struct Permalinks;
218
219impl Permalinks {
220    pub fn room(room_id: &str) -> String {
221        format!("https://matrix.to/#/{room_id}")
222    }
223
224    pub fn room_with_via(room_id: &str, via: &[&str]) -> String {
225        if via.is_empty() {
226            Self::room(room_id)
227        } else {
228            let via_str = via
229                .iter()
230                .map(|v| format!("via={v}"))
231                .collect::<Vec<_>>()
232                .join("&");
233            format!("https://matrix.to/#/{room_id}?{via_str}")
234        }
235    }
236
237    pub fn event(room_id: &str, event_id: &str) -> String {
238        format!("https://matrix.to/#/{room_id}/{event_id}")
239    }
240
241    pub fn event_with_via(room_id: &str, event_id: &str, via: &[&str]) -> String {
242        if via.is_empty() {
243            Self::event(room_id, event_id)
244        } else {
245            let via_str = via
246                .iter()
247                .map(|v| format!("via={v}"))
248                .collect::<Vec<_>>()
249                .join("&");
250            format!("https://matrix.to/#/{room_id}/{event_id}?{via_str}")
251        }
252    }
253
254    pub fn user(user_id: &str) -> String {
255        format!("https://matrix.to/#/{user_id}")
256    }
257
258    pub fn parse(url: &str) -> Option<PermalinkParts> {
259        let prefix = "https://matrix.to/#/";
260        if !url.starts_with(prefix) {
261            return None;
262        }
263        let rest = &url[prefix.len()..];
264        let (path, via_servers) = if let Some(qpos) = rest.find('?') {
265            let path = &rest[..qpos];
266            let query = &rest[qpos + 1..];
267            let via: Vec<String> = query
268                .split('&')
269                .filter_map(|pair| {
270                    let mut parts = pair.splitn(2, '=');
271                    if parts.next() == Some("via") {
272                        parts.next().map(ToOwned::to_owned)
273                    } else {
274                        None
275                    }
276                })
277                .collect();
278            (path, via)
279        } else {
280            (rest, Vec::new())
281        };
282
283        let path = path.trim_end_matches('/');
284
285        if path.starts_with('@') {
286            return Some(PermalinkParts {
287                user_id: Some(path.to_owned()),
288                room_id_or_alias: None,
289                event_id: None,
290                via_servers,
291            });
292        }
293
294        let parts: Vec<&str> = path.splitn(2, '/').collect();
295        let room_id_or_alias = parts.first().map(|s| (*s).to_owned());
296        let event_id = parts.get(1).map(|s| (*s).to_owned());
297
298        Some(PermalinkParts {
299            room_id_or_alias,
300            event_id,
301            user_id: None,
302            via_servers,
303        })
304    }
305}
306
307#[derive(Debug, Clone)]
308pub struct RichReply {
309    pub body: String,
310    pub formatted_body: String,
311}
312
313impl RichReply {
314    pub fn new(original_sender: &str, original_body: &str, reply_body: &str) -> Self {
315        let plain_prefix = format!("> <{original_sender}> {original_body}\n\n");
316        let html_prefix = format!(
317            "<mx-reply><blockquote><a href=\"{}\">In reply to</a> <a href=\"{}\">{}</a><br />{}</blockquote></mx-reply>",
318            Permalinks::room("!unknown:example.org"),
319            Permalinks::user(original_sender),
320            original_sender,
321            html_escape(original_body),
322        );
323
324        Self {
325            body: format!("{plain_prefix}{reply_body}"),
326            formatted_body: format!("{html_prefix}{}", html_escape(reply_body)),
327        }
328    }
329}
330
331#[derive(Debug, Clone)]
332pub struct ProfileCache {
333    ttl: Duration,
334    entries: Arc<RwLock<HashMap<String, CachedProfile>>>,
335}
336
337#[derive(Debug, Clone)]
338struct CachedProfile {
339    profile: MatrixProfile,
340    expires_at: Instant,
341}
342
343impl ProfileCache {
344    pub fn new(ttl: Duration) -> Self {
345        Self {
346            ttl,
347            entries: Arc::new(RwLock::new(HashMap::new())),
348        }
349    }
350
351    pub async fn get(&self, user_id: &str) -> Option<MatrixProfile> {
352        let mut guard = self.entries.write().await;
353        match guard.get(user_id) {
354            Some(entry) if entry.expires_at > Instant::now() => Some(entry.profile.clone()),
355            Some(_) => {
356                guard.remove(user_id);
357                None
358            }
359            None => None,
360        }
361    }
362
363    pub async fn put(&self, profile: MatrixProfile) {
364        let key = profile.user_id.clone();
365        self.entries.write().await.insert(
366            key,
367            CachedProfile {
368                profile,
369                expires_at: Instant::now() + self.ttl,
370            },
371        );
372    }
373
374    /// Returns the cached profile for the given user ID, or fetches it using the provided
375    /// async callback if not cached (or expired). The fetched profile is then stored in the cache.
376    pub async fn get_or_fetch<F, Fut>(&self, user_id: &str, fetch: F) -> Option<MatrixProfile>
377    where
378        F: FnOnce(&str) -> Fut,
379        Fut: std::future::Future<Output = Option<MatrixProfile>>,
380    {
381        if let Some(profile) = self.get(user_id).await {
382            return Some(profile);
383        }
384        if let Some(profile) = fetch(user_id).await {
385            self.put(profile.clone()).await;
386            return Some(profile);
387        }
388        None
389    }
390}
391
392pub struct UnpaddedBase64;
393
394impl UnpaddedBase64 {
395    pub fn encode(bytes: &[u8]) -> String {
396        STANDARD_NO_PAD.encode(bytes)
397    }
398
399    pub fn decode(value: &str) -> anyhow::Result<Vec<u8>> {
400        Ok(STANDARD_NO_PAD.decode(value)?)
401    }
402}
403
404fn html_escape(input: &str) -> String {
405    input
406        .replace('&', "&amp;")
407        .replace('<', "&lt;")
408        .replace('>', "&gt;")
409        .replace('"', "&quot;")
410}
411
412/// Validates a Matrix Space order string.
413///
414/// A valid order string must:
415/// - Be non-empty
416/// - Be at most 50 characters long
417/// - Contain only ASCII printable characters (0x20–0x7E)
418pub fn validate_space_order_string(s: &str) -> bool {
419    if s.is_empty() || s.len() > 50 {
420        return false;
421    }
422    s.chars().all(|c| c.is_ascii() && (' '..='~').contains(&c))
423}