Skip to main content

chat_system/messengers/
matrix.rs

1//! Matrix messenger backed by the Matrix Client-Server HTTP API.
2
3use crate::{Message, Messenger};
4use anyhow::{anyhow, ensure, Context, Result};
5use async_trait::async_trait;
6use reqwest::{Client, Url};
7use serde::Deserialize;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU64, Ordering};
11use tokio::sync::Mutex;
12
13pub struct MatrixMessenger {
14    name: String,
15    homeserver: String,
16    username: String,
17    password: String,
18    client: Client,
19    access_token: Option<String>,
20    user_id: Option<String>,
21    sync_token: Mutex<Option<String>>,
22    txn_counter: AtomicU64,
23    connected: bool,
24}
25
26impl MatrixMessenger {
27    pub fn new(
28        name: impl Into<String>,
29        homeserver: impl Into<String>,
30        username: impl Into<String>,
31        password: impl Into<String>,
32    ) -> Self {
33        Self {
34            name: name.into(),
35            homeserver: homeserver.into(),
36            username: username.into(),
37            password: password.into(),
38            client: Client::new(),
39            access_token: None,
40            user_id: None,
41            sync_token: Mutex::new(None),
42            txn_counter: AtomicU64::new(1),
43            connected: false,
44        }
45    }
46
47    fn validate_config(&self) -> Result<()> {
48        ensure!(
49            !self.homeserver.trim().is_empty(),
50            "Matrix homeserver must not be empty"
51        );
52        ensure!(
53            !self.username.trim().is_empty(),
54            "Matrix username must not be empty"
55        );
56        ensure!(
57            !self.password.trim().is_empty(),
58            "Matrix password must not be empty"
59        );
60        Ok(())
61    }
62
63    fn access_token(&self) -> Result<&str> {
64        self.access_token
65            .as_deref()
66            .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
67    }
68
69    fn user_id(&self) -> Result<&str> {
70        self.user_id
71            .as_deref()
72            .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
73    }
74
75    fn url_for_segments(&self, segments: &[&str]) -> Result<Url> {
76        let mut url = Url::parse(self.homeserver.trim_end_matches('/'))
77            .with_context(|| format!("Invalid Matrix homeserver URL: {}", self.homeserver))?;
78        {
79            let mut path_segments = url
80                .path_segments_mut()
81                .map_err(|_| anyhow!("Matrix homeserver URL cannot be a base URL"))?;
82            path_segments.extend(segments.iter().copied());
83        }
84        Ok(url)
85    }
86
87    fn client_api_url(&self, path: &[&str]) -> Result<Url> {
88        let mut segments = vec!["_matrix", "client", "v3"];
89        segments.extend_from_slice(path);
90        self.url_for_segments(&segments)
91    }
92
93    async fn sync_once(&self) -> Result<Vec<Message>> {
94        #[derive(Debug, Deserialize)]
95        struct SyncResponse {
96            next_batch: String,
97            #[serde(default)]
98            rooms: SyncRooms,
99        }
100
101        #[derive(Debug, Default, Deserialize)]
102        struct SyncRooms {
103            #[serde(default)]
104            join: HashMap<String, JoinedRoom>,
105        }
106
107        #[derive(Debug, Default, Deserialize)]
108        struct JoinedRoom {
109            #[serde(default)]
110            timeline: Timeline,
111        }
112
113        #[derive(Debug, Default, Deserialize)]
114        struct Timeline {
115            #[serde(default)]
116            events: Vec<TimelineEvent>,
117        }
118
119        #[derive(Debug, Deserialize)]
120        struct TimelineEvent {
121            #[serde(rename = "type")]
122            event_type: String,
123            event_id: String,
124            sender: String,
125            origin_server_ts: i64,
126            #[serde(default)]
127            content: TimelineContent,
128        }
129
130        #[derive(Debug, Default, Deserialize)]
131        struct TimelineContent {
132            #[serde(default)]
133            body: String,
134            #[serde(default, rename = "m.relates_to")]
135            relates_to: Option<RelatesTo>,
136        }
137
138        #[derive(Debug, Deserialize)]
139        struct RelatesTo {
140            #[serde(default, rename = "m.in_reply_to")]
141            in_reply_to: Option<ReplyTo>,
142        }
143
144        #[derive(Debug, Deserialize)]
145        struct ReplyTo {
146            event_id: String,
147        }
148
149        let since = self.sync_token.lock().await.clone();
150        let mut url = self.client_api_url(&["sync"])?;
151        {
152            let mut query = url.query_pairs_mut();
153            query.append_pair("timeout", "1");
154            if let Some(since) = since {
155                query.append_pair("since", &since);
156            }
157        }
158
159        let response = self
160            .client
161            .get(url)
162            .bearer_auth(self.access_token()?)
163            .send()
164            .await
165            .context("Matrix sync request failed")?;
166
167        let status = response.status();
168        if !status.is_success() {
169            let body = response.text().await.unwrap_or_default();
170            anyhow::bail!("Matrix sync failed {}: {}", status, body);
171        }
172
173        let sync: SyncResponse = response.json().await.context("Invalid Matrix sync response")?;
174        *self.sync_token.lock().await = Some(sync.next_batch);
175
176        let mut messages = Vec::new();
177        for (room_id, joined_room) in sync.rooms.join {
178            for event in joined_room.timeline.events {
179                if event.event_type != "m.room.message" || event.content.body.is_empty() {
180                    continue;
181                }
182
183                messages.push(Message {
184                    id: event.event_id,
185                    sender: event.sender,
186                    content: event.content.body,
187                    timestamp: event.origin_server_ts / 1000,
188                    channel: Some(room_id.clone()),
189                    reply_to: event
190                        .content
191                        .relates_to
192                        .and_then(|r| r.in_reply_to)
193                        .map(|r| r.event_id),
194                    media: None,
195                    is_direct: false,
196                    reactions: None,
197                });
198            }
199        }
200
201        Ok(messages)
202    }
203
204    async fn join_room_if_needed(&self, recipient: &str) -> Result<String> {
205        if recipient.starts_with('!') {
206            return Ok(recipient.to_string());
207        }
208
209        let response = self
210            .client
211            .post(self.client_api_url(&["join", recipient])?)
212            .bearer_auth(self.access_token()?)
213            .send()
214            .await
215            .context("Matrix join request failed")?;
216
217        let status = response.status();
218        if !status.is_success() {
219            let body = response.text().await.unwrap_or_default();
220            anyhow::bail!("Matrix join failed {}: {}", status, body);
221        }
222
223        #[derive(Deserialize)]
224        struct JoinResponse {
225            room_id: String,
226        }
227
228        let join: JoinResponse = response.json().await.context("Invalid Matrix join response")?;
229        Ok(join.room_id)
230    }
231}
232
233#[async_trait]
234impl Messenger for MatrixMessenger {
235    fn name(&self) -> &str {
236        &self.name
237    }
238
239    fn messenger_type(&self) -> &str {
240        "matrix"
241    }
242
243    async fn initialize(&mut self) -> Result<()> {
244        #[derive(Deserialize)]
245        struct LoginResponse {
246            access_token: String,
247            user_id: String,
248        }
249
250        self.validate_config()?;
251
252        let response = self
253            .client
254            .post(self.client_api_url(&["login"])? )
255            .json(&json!({
256                "type": "m.login.password",
257                "identifier": {
258                    "type": "m.id.user",
259                    "user": self.username,
260                },
261                "password": self.password,
262                "initial_device_display_name": self.name,
263            }))
264            .send()
265            .await
266            .context("Matrix login request failed")?;
267
268        let status = response.status();
269        if !status.is_success() {
270            let body = response.text().await.unwrap_or_default();
271            anyhow::bail!("Matrix login failed {}: {}", status, body);
272        }
273
274        let login: LoginResponse = response.json().await.context("Invalid Matrix login response")?;
275        self.access_token = Some(login.access_token);
276        self.user_id = Some(login.user_id);
277
278        *self.sync_token.lock().await = None;
279        let _ = self.sync_once().await?;
280
281        self.connected = true;
282        Ok(())
283    }
284
285    async fn send_message(&self, recipient: &str, content: &str) -> Result<String> {
286        let room_id = self.join_room_if_needed(recipient).await?;
287        let txn_id = self.txn_counter.fetch_add(1, Ordering::Relaxed).to_string();
288
289        let response = self
290            .client
291            .put(self.client_api_url(&[
292                "rooms",
293                &room_id,
294                "send",
295                "m.room.message",
296                &txn_id,
297            ])?)
298            .bearer_auth(self.access_token()?)
299            .json(&json!({
300                "msgtype": "m.text",
301                "body": content,
302            }))
303            .send()
304            .await
305            .context("Matrix send request failed")?;
306
307        let status = response.status();
308        if !status.is_success() {
309            let body = response.text().await.unwrap_or_default();
310            anyhow::bail!("Matrix send failed {}: {}", status, body);
311        }
312
313        #[derive(Deserialize)]
314        struct SendResponse {
315            event_id: String,
316        }
317
318        let send: SendResponse = response.json().await.context("Invalid Matrix send response")?;
319        Ok(send.event_id)
320    }
321
322    async fn receive_messages(&self) -> Result<Vec<Message>> {
323        self.sync_once().await
324    }
325
326    fn is_connected(&self) -> bool {
327        self.connected
328    }
329
330    async fn disconnect(&mut self) -> Result<()> {
331        if let Some(token) = self.access_token.as_deref() {
332            let response = self
333                .client
334                .post(self.client_api_url(&["logout"])? )
335                .bearer_auth(token)
336                .send()
337                .await;
338
339            if let Err(error) = response {
340                tracing::warn!(messenger = %self.name, "Matrix logout failed: {error}");
341            }
342        }
343
344        self.access_token = None;
345        self.user_id = None;
346        *self.sync_token.lock().await = None;
347        self.connected = false;
348        Ok(())
349    }
350
351    async fn set_typing(&self, channel: &str, typing: bool) -> Result<()> {
352        let room_id = self.join_room_if_needed(channel).await?;
353        let mut payload = json!({ "typing": typing });
354        if typing {
355            payload["timeout"] = json!(30_000);
356        }
357
358        let response = self
359            .client
360            .put(self.client_api_url(&["rooms", &room_id, "typing", self.user_id()?])?)
361            .bearer_auth(self.access_token()?)
362            .json(&payload)
363            .send()
364            .await
365            .context("Matrix typing request failed")?;
366
367        let status = response.status();
368        if !status.is_success() {
369            let body = response.text().await.unwrap_or_default();
370            anyhow::bail!("Matrix typing failed {}: {}", status, body);
371        }
372
373        Ok(())
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn validate_config_rejects_empty_homeserver() {
383        let messenger = MatrixMessenger::new("matrix", "", "bot", "secret");
384        assert!(messenger.validate_config().is_err());
385    }
386
387    #[test]
388    fn validate_config_rejects_empty_username() {
389        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "", "secret");
390        assert!(messenger.validate_config().is_err());
391    }
392
393    #[test]
394    fn validate_config_rejects_empty_password() {
395        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "");
396        assert!(messenger.validate_config().is_err());
397    }
398
399    #[test]
400    fn validate_config_accepts_non_empty_values() {
401        let messenger = MatrixMessenger::new(
402            "matrix",
403            "https://matrix.example",
404            "bot",
405            "secret",
406        );
407        assert!(messenger.validate_config().is_ok());
408    }
409}