Skip to main content

chat_system/messengers/
matrix.rs

1//! Matrix messenger backed by the Matrix Client-Server HTTP API.
2
3use crate::message::MessageType;
4use crate::{Message, Messenger};
5use anyhow::{Context, Result, anyhow, ensure};
6use async_trait::async_trait;
7use reqwest::{Client, Url};
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use tokio::sync::Mutex;
13
14pub struct MatrixMessenger {
15    name: String,
16    homeserver: String,
17    username: String,
18    password: String,
19    client: Client,
20    access_token: Option<String>,
21    user_id: Option<String>,
22    sync_token: Mutex<Option<String>>,
23    txn_counter: AtomicU64,
24    connected: bool,
25}
26
27impl MatrixMessenger {
28    pub fn new(
29        name: impl Into<String>,
30        homeserver: impl Into<String>,
31        username: impl Into<String>,
32        password: impl Into<String>,
33    ) -> Self {
34        Self {
35            name: name.into(),
36            homeserver: homeserver.into(),
37            username: username.into(),
38            password: password.into(),
39            client: Client::new(),
40            access_token: None,
41            user_id: None,
42            sync_token: Mutex::new(None),
43            txn_counter: AtomicU64::new(1),
44            connected: false,
45        }
46    }
47
48    fn validate_config(&self) -> Result<()> {
49        ensure!(
50            !self.homeserver.trim().is_empty(),
51            "Matrix homeserver must not be empty"
52        );
53        ensure!(
54            !self.username.trim().is_empty(),
55            "Matrix username must not be empty"
56        );
57        ensure!(
58            !self.password.trim().is_empty(),
59            "Matrix password must not be empty"
60        );
61        Ok(())
62    }
63
64    fn access_token(&self) -> Result<&str> {
65        self.access_token
66            .as_deref()
67            .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
68    }
69
70    fn user_id(&self) -> Result<&str> {
71        self.user_id
72            .as_deref()
73            .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
74    }
75
76    fn url_for_segments(&self, segments: &[&str]) -> Result<Url> {
77        let mut url = Url::parse(self.homeserver.trim_end_matches('/'))
78            .with_context(|| format!("Invalid Matrix homeserver URL: {}", self.homeserver))?;
79        {
80            let mut path_segments = url
81                .path_segments_mut()
82                .map_err(|_| anyhow!("Matrix homeserver URL cannot be a base URL"))?;
83            path_segments.extend(segments.iter().copied());
84        }
85        Ok(url)
86    }
87
88    fn client_api_url(&self, path: &[&str]) -> Result<Url> {
89        let mut segments = vec!["_matrix", "client", "v3"];
90        segments.extend_from_slice(path);
91        self.url_for_segments(&segments)
92    }
93
94    async fn sync_once(&self) -> Result<Vec<Message>> {
95        #[derive(Debug, Deserialize)]
96        struct SyncResponse {
97            next_batch: String,
98            #[serde(default)]
99            rooms: SyncRooms,
100        }
101
102        #[derive(Debug, Default, Deserialize)]
103        struct SyncRooms {
104            #[serde(default)]
105            join: HashMap<String, JoinedRoom>,
106        }
107
108        #[derive(Debug, Default, Deserialize)]
109        struct JoinedRoom {
110            #[serde(default)]
111            timeline: Timeline,
112        }
113
114        #[derive(Debug, Default, Deserialize)]
115        struct Timeline {
116            #[serde(default)]
117            events: Vec<TimelineEvent>,
118        }
119
120        #[derive(Debug, Deserialize)]
121        struct TimelineEvent {
122            #[serde(rename = "type")]
123            event_type: String,
124            event_id: String,
125            sender: String,
126            origin_server_ts: i64,
127            #[serde(default)]
128            content: TimelineContent,
129        }
130
131        #[derive(Debug, Default, Deserialize)]
132        struct TimelineContent {
133            #[serde(default)]
134            body: String,
135            #[serde(default, rename = "m.relates_to")]
136            relates_to: Option<RelatesTo>,
137        }
138
139        #[derive(Debug, Deserialize)]
140        struct RelatesTo {
141            #[serde(default, rename = "m.in_reply_to")]
142            in_reply_to: Option<ReplyTo>,
143        }
144
145        #[derive(Debug, Deserialize)]
146        struct ReplyTo {
147            event_id: String,
148        }
149
150        let since = self.sync_token.lock().await.clone();
151        let mut url = self.client_api_url(&["sync"])?;
152        {
153            let mut query = url.query_pairs_mut();
154            query.append_pair("timeout", "1");
155            if let Some(since) = since {
156                query.append_pair("since", &since);
157            }
158        }
159
160        let response = self
161            .client
162            .get(url)
163            .bearer_auth(self.access_token()?)
164            .send()
165            .await
166            .context("Matrix sync request failed")?;
167
168        let status = response.status();
169        if !status.is_success() {
170            let body = response.text().await.unwrap_or_default();
171            anyhow::bail!("Matrix sync failed {}: {}", status, body);
172        }
173
174        let sync: SyncResponse = response
175            .json()
176            .await
177            .context("Invalid Matrix sync response")?;
178        *self.sync_token.lock().await = Some(sync.next_batch);
179
180        let mut messages = Vec::new();
181        for (room_id, joined_room) in sync.rooms.join {
182            for event in joined_room.timeline.events {
183                if event.event_type != "m.room.message" || event.content.body.is_empty() {
184                    continue;
185                }
186
187                messages.push(Message {
188                    id: event.event_id,
189                    sender: event.sender,
190                    content: event.content.body,
191                    timestamp: event.origin_server_ts / 1000,
192                    channel: Some(room_id.clone()),
193                    reply_to: event
194                        .content
195                        .relates_to
196                        .and_then(|r| r.in_reply_to)
197                        .map(|r| r.event_id),
198                    thread_id: None,
199                    media: None,
200                    is_direct: false,
201                    message_type: MessageType::Text,
202                    edited_timestamp: None,
203                    reactions: None,
204                });
205            }
206        }
207
208        Ok(messages)
209    }
210
211    async fn join_room_if_needed(&self, recipient: &str) -> Result<String> {
212        if recipient.starts_with('!') {
213            return Ok(recipient.to_string());
214        }
215
216        let response = self
217            .client
218            .post(self.client_api_url(&["join", recipient])?)
219            .bearer_auth(self.access_token()?)
220            .send()
221            .await
222            .context("Matrix join request failed")?;
223
224        let status = response.status();
225        if !status.is_success() {
226            let body = response.text().await.unwrap_or_default();
227            anyhow::bail!("Matrix join failed {}: {}", status, body);
228        }
229
230        #[derive(Deserialize)]
231        struct JoinResponse {
232            room_id: String,
233        }
234
235        let join: JoinResponse = response
236            .json()
237            .await
238            .context("Invalid Matrix join response")?;
239        Ok(join.room_id)
240    }
241}
242
243#[async_trait]
244impl Messenger for MatrixMessenger {
245    fn name(&self) -> &str {
246        &self.name
247    }
248
249    fn messenger_type(&self) -> &str {
250        "matrix"
251    }
252
253    async fn initialize(&mut self) -> Result<()> {
254        #[derive(Deserialize)]
255        struct LoginResponse {
256            access_token: String,
257            user_id: String,
258        }
259
260        self.validate_config()?;
261
262        let response = self
263            .client
264            .post(self.client_api_url(&["login"])?)
265            .json(&json!({
266                "type": "m.login.password",
267                "identifier": {
268                    "type": "m.id.user",
269                    "user": self.username,
270                },
271                "password": self.password,
272                "initial_device_display_name": self.name,
273            }))
274            .send()
275            .await
276            .context("Matrix login request failed")?;
277
278        let status = response.status();
279        if !status.is_success() {
280            let body = response.text().await.unwrap_or_default();
281            anyhow::bail!("Matrix login failed {}: {}", status, body);
282        }
283
284        let login: LoginResponse = response
285            .json()
286            .await
287            .context("Invalid Matrix login response")?;
288        self.access_token = Some(login.access_token);
289        self.user_id = Some(login.user_id);
290
291        *self.sync_token.lock().await = None;
292        let _ = self.sync_once().await?;
293
294        self.connected = true;
295        Ok(())
296    }
297
298    async fn send_message(&self, recipient: &str, content: &str) -> Result<String> {
299        let room_id = self.join_room_if_needed(recipient).await?;
300        let txn_id = self.txn_counter.fetch_add(1, Ordering::Relaxed).to_string();
301
302        let response = self
303            .client
304            .put(self.client_api_url(&["rooms", &room_id, "send", "m.room.message", &txn_id])?)
305            .bearer_auth(self.access_token()?)
306            .json(&json!({
307                "msgtype": "m.text",
308                "body": content,
309            }))
310            .send()
311            .await
312            .context("Matrix send request failed")?;
313
314        let status = response.status();
315        if !status.is_success() {
316            let body = response.text().await.unwrap_or_default();
317            anyhow::bail!("Matrix send failed {}: {}", status, body);
318        }
319
320        #[derive(Deserialize)]
321        struct SendResponse {
322            event_id: String,
323        }
324
325        let send: SendResponse = response
326            .json()
327            .await
328            .context("Invalid Matrix send response")?;
329        Ok(send.event_id)
330    }
331
332    async fn receive_messages(&self) -> Result<Vec<Message>> {
333        self.sync_once().await
334    }
335
336    fn is_connected(&self) -> bool {
337        self.connected
338    }
339
340    async fn disconnect(&mut self) -> Result<()> {
341        if let Some(token) = self.access_token.as_deref() {
342            let response = self
343                .client
344                .post(self.client_api_url(&["logout"])?)
345                .bearer_auth(token)
346                .send()
347                .await;
348
349            if let Err(error) = response {
350                tracing::warn!(messenger = %self.name, "Matrix logout failed: {error}");
351            }
352        }
353
354        self.access_token = None;
355        self.user_id = None;
356        *self.sync_token.lock().await = None;
357        self.connected = false;
358        Ok(())
359    }
360
361    async fn set_typing(&self, channel: &str, typing: bool) -> Result<()> {
362        let room_id = self.join_room_if_needed(channel).await?;
363        let mut payload = json!({ "typing": typing });
364        if typing {
365            payload["timeout"] = json!(30_000);
366        }
367
368        let response = self
369            .client
370            .put(self.client_api_url(&["rooms", &room_id, "typing", self.user_id()?])?)
371            .bearer_auth(self.access_token()?)
372            .json(&payload)
373            .send()
374            .await
375            .context("Matrix typing request failed")?;
376
377        let status = response.status();
378        if !status.is_success() {
379            let body = response.text().await.unwrap_or_default();
380            anyhow::bail!("Matrix typing failed {}: {}", status, body);
381        }
382
383        Ok(())
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn validate_config_rejects_empty_homeserver() {
393        let messenger = MatrixMessenger::new("matrix", "", "bot", "secret");
394        assert!(messenger.validate_config().is_err());
395    }
396
397    #[test]
398    fn validate_config_rejects_empty_username() {
399        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "", "secret");
400        assert!(messenger.validate_config().is_err());
401    }
402
403    #[test]
404    fn validate_config_rejects_empty_password() {
405        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "");
406        assert!(messenger.validate_config().is_err());
407    }
408
409    #[test]
410    fn validate_config_accepts_non_empty_values() {
411        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "secret");
412        assert!(messenger.validate_config().is_ok());
413    }
414}