Skip to main content

chat_system/messengers/
matrix.rs

1//! Matrix messenger backed by the Matrix Client-Server HTTP API.
2
3use crate::message::MessageType;
4use crate::{Message, Messenger};
5use anyhow::{Context, Result, anyhow, ensure};
6use async_trait::async_trait;
7use reqwest::{Client, Url};
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use tokio::sync::Mutex;
13
14pub struct MatrixMessenger {
15    name: String,
16    homeserver: String,
17    username: String,
18    password: String,
19    client: Client,
20    access_token: Option<String>,
21    user_id: Option<String>,
22    sync_token: Mutex<Option<String>>,
23    txn_counter: AtomicU64,
24    connected: bool,
25}
26
27impl MatrixMessenger {
28    pub fn new(
29        name: impl Into<String>,
30        homeserver: impl Into<String>,
31        username: impl Into<String>,
32        password: impl Into<String>,
33    ) -> Self {
34        Self {
35            name: name.into(),
36            homeserver: homeserver.into(),
37            username: username.into(),
38            password: password.into(),
39            client: Client::new(),
40            access_token: None,
41            user_id: None,
42            sync_token: Mutex::new(None),
43            txn_counter: AtomicU64::new(1),
44            connected: false,
45        }
46    }
47
48    /// Create a Matrix messenger using a pre-existing access token.
49    ///
50    /// This skips the password login flow and uses the provided token directly.
51    /// Useful when you already have an access token from a previous session or
52    /// from an external authentication flow.
53    ///
54    /// # Arguments
55    /// * `name` - Messenger instance name
56    /// * `homeserver` - Matrix homeserver URL (e.g., "https://matrix.org")
57    /// * `user_id` - Full Matrix user ID (e.g., "@user:matrix.org")
58    /// * `access_token` - Pre-existing access token
59    /// * `device_id` - Optional device ID (for E2EE tracking, not used in this implementation)
60    pub fn with_access_token(
61        name: impl Into<String>,
62        homeserver: impl Into<String>,
63        user_id: impl Into<String>,
64        access_token: impl Into<String>,
65        _device_id: Option<String>,
66    ) -> Self {
67        let user_id_str = user_id.into();
68        Self {
69            name: name.into(),
70            homeserver: homeserver.into(),
71            username: user_id_str.clone(),
72            password: String::new(), // Not needed for token auth
73            client: Client::new(),
74            access_token: Some(access_token.into()),
75            user_id: Some(user_id_str),
76            sync_token: Mutex::new(None),
77            txn_counter: AtomicU64::new(1),
78            connected: false,
79        }
80    }
81
82    fn validate_config(&self) -> Result<()> {
83        ensure!(
84            !self.homeserver.trim().is_empty(),
85            "Matrix homeserver must not be empty"
86        );
87        ensure!(
88            !self.username.trim().is_empty(),
89            "Matrix username must not be empty"
90        );
91        // Password is only required if we don't have a pre-existing access token
92        if self.access_token.is_none() {
93            ensure!(
94                !self.password.trim().is_empty(),
95                "Matrix password must not be empty (unless using access_token auth)"
96            );
97        }
98        Ok(())
99    }
100
101    fn access_token(&self) -> Result<&str> {
102        self.access_token
103            .as_deref()
104            .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
105    }
106
107    fn user_id(&self) -> Result<&str> {
108        self.user_id
109            .as_deref()
110            .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
111    }
112
113    fn url_for_segments(&self, segments: &[&str]) -> Result<Url> {
114        let mut url = Url::parse(self.homeserver.trim_end_matches('/'))
115            .with_context(|| format!("Invalid Matrix homeserver URL: {}", self.homeserver))?;
116        {
117            let mut path_segments = url
118                .path_segments_mut()
119                .map_err(|_| anyhow!("Matrix homeserver URL cannot be a base URL"))?;
120            path_segments.extend(segments.iter().copied());
121        }
122        Ok(url)
123    }
124
125    fn client_api_url(&self, path: &[&str]) -> Result<Url> {
126        let mut segments = vec!["_matrix", "client", "v3"];
127        segments.extend_from_slice(path);
128        self.url_for_segments(&segments)
129    }
130
131    async fn sync_once(&self) -> Result<Vec<Message>> {
132        #[derive(Debug, Deserialize)]
133        struct SyncResponse {
134            next_batch: String,
135            #[serde(default)]
136            rooms: SyncRooms,
137        }
138
139        #[derive(Debug, Default, Deserialize)]
140        struct SyncRooms {
141            #[serde(default)]
142            join: HashMap<String, JoinedRoom>,
143        }
144
145        #[derive(Debug, Default, Deserialize)]
146        struct JoinedRoom {
147            #[serde(default)]
148            timeline: Timeline,
149        }
150
151        #[derive(Debug, Default, Deserialize)]
152        struct Timeline {
153            #[serde(default)]
154            events: Vec<TimelineEvent>,
155        }
156
157        #[derive(Debug, Deserialize)]
158        struct TimelineEvent {
159            #[serde(rename = "type")]
160            event_type: String,
161            event_id: String,
162            sender: String,
163            origin_server_ts: i64,
164            #[serde(default)]
165            content: TimelineContent,
166        }
167
168        #[derive(Debug, Default, Deserialize)]
169        struct TimelineContent {
170            #[serde(default)]
171            body: String,
172            #[serde(default, rename = "m.relates_to")]
173            relates_to: Option<RelatesTo>,
174        }
175
176        #[derive(Debug, Deserialize)]
177        struct RelatesTo {
178            #[serde(default, rename = "m.in_reply_to")]
179            in_reply_to: Option<ReplyTo>,
180        }
181
182        #[derive(Debug, Deserialize)]
183        struct ReplyTo {
184            event_id: String,
185        }
186
187        let since = self.sync_token.lock().await.clone();
188        let mut url = self.client_api_url(&["sync"])?;
189        {
190            let mut query = url.query_pairs_mut();
191            query.append_pair("timeout", "1");
192            if let Some(since) = since {
193                query.append_pair("since", &since);
194            }
195        }
196
197        let response = self
198            .client
199            .get(url)
200            .bearer_auth(self.access_token()?)
201            .send()
202            .await
203            .context("Matrix sync request failed")?;
204
205        let status = response.status();
206        if !status.is_success() {
207            let body = response.text().await.unwrap_or_default();
208            anyhow::bail!("Matrix sync failed {}: {}", status, body);
209        }
210
211        let sync: SyncResponse = response
212            .json()
213            .await
214            .context("Invalid Matrix sync response")?;
215        *self.sync_token.lock().await = Some(sync.next_batch);
216
217        let mut messages = Vec::new();
218        for (room_id, joined_room) in sync.rooms.join {
219            for event in joined_room.timeline.events {
220                if event.event_type != "m.room.message" || event.content.body.is_empty() {
221                    continue;
222                }
223
224                messages.push(Message {
225                    id: event.event_id,
226                    sender: event.sender,
227                    content: event.content.body,
228                    timestamp: event.origin_server_ts / 1000,
229                    channel: Some(room_id.clone()),
230                    reply_to: event
231                        .content
232                        .relates_to
233                        .and_then(|r| r.in_reply_to)
234                        .map(|r| r.event_id),
235                    thread_id: None,
236                    media: None,
237                    is_direct: false,
238                    message_type: MessageType::Text,
239                    edited_timestamp: None,
240                    reactions: None,
241                });
242            }
243        }
244
245        Ok(messages)
246    }
247
248    async fn join_room_if_needed(&self, recipient: &str) -> Result<String> {
249        if recipient.starts_with('!') {
250            return Ok(recipient.to_string());
251        }
252
253        let response = self
254            .client
255            .post(self.client_api_url(&["join", recipient])?)
256            .bearer_auth(self.access_token()?)
257            .send()
258            .await
259            .context("Matrix join request failed")?;
260
261        let status = response.status();
262        if !status.is_success() {
263            let body = response.text().await.unwrap_or_default();
264            anyhow::bail!("Matrix join failed {}: {}", status, body);
265        }
266
267        #[derive(Deserialize)]
268        struct JoinResponse {
269            room_id: String,
270        }
271
272        let join: JoinResponse = response
273            .json()
274            .await
275            .context("Invalid Matrix join response")?;
276        Ok(join.room_id)
277    }
278}
279
280#[async_trait]
281impl Messenger for MatrixMessenger {
282    fn name(&self) -> &str {
283        &self.name
284    }
285
286    fn messenger_type(&self) -> &str {
287        "matrix"
288    }
289
290    async fn initialize(&mut self) -> Result<()> {
291        // If we already have an access token (from with_access_token), skip login
292        if self.access_token.is_some() {
293            // Validate token by doing an initial sync
294            *self.sync_token.lock().await = None;
295            let _ = self.sync_once().await?;
296            self.connected = true;
297            return Ok(());
298        }
299
300        #[derive(Deserialize)]
301        struct LoginResponse {
302            access_token: String,
303            user_id: String,
304        }
305
306        self.validate_config()?;
307
308        let response = self
309            .client
310            .post(self.client_api_url(&["login"])?)
311            .json(&json!({
312                "type": "m.login.password",
313                "identifier": {
314                    "type": "m.id.user",
315                    "user": self.username,
316                },
317                "password": self.password,
318                "initial_device_display_name": self.name,
319            }))
320            .send()
321            .await
322            .context("Matrix login request failed")?;
323
324        let status = response.status();
325        if !status.is_success() {
326            let body = response.text().await.unwrap_or_default();
327            anyhow::bail!("Matrix login failed {}: {}", status, body);
328        }
329
330        let login: LoginResponse = response
331            .json()
332            .await
333            .context("Invalid Matrix login response")?;
334        self.access_token = Some(login.access_token);
335        self.user_id = Some(login.user_id);
336
337        *self.sync_token.lock().await = None;
338        let _ = self.sync_once().await?;
339
340        self.connected = true;
341        Ok(())
342    }
343
344    async fn send_message(&self, recipient: &str, content: &str) -> Result<String> {
345        let room_id = self.join_room_if_needed(recipient).await?;
346        let txn_id = self.txn_counter.fetch_add(1, Ordering::Relaxed).to_string();
347
348        let response = self
349            .client
350            .put(self.client_api_url(&["rooms", &room_id, "send", "m.room.message", &txn_id])?)
351            .bearer_auth(self.access_token()?)
352            .json(&json!({
353                "msgtype": "m.text",
354                "body": content,
355            }))
356            .send()
357            .await
358            .context("Matrix send request failed")?;
359
360        let status = response.status();
361        if !status.is_success() {
362            let body = response.text().await.unwrap_or_default();
363            anyhow::bail!("Matrix send failed {}: {}", status, body);
364        }
365
366        #[derive(Deserialize)]
367        struct SendResponse {
368            event_id: String,
369        }
370
371        let send: SendResponse = response
372            .json()
373            .await
374            .context("Invalid Matrix send response")?;
375        Ok(send.event_id)
376    }
377
378    async fn receive_messages(&self) -> Result<Vec<Message>> {
379        self.sync_once().await
380    }
381
382    fn is_connected(&self) -> bool {
383        self.connected
384    }
385
386    async fn disconnect(&mut self) -> Result<()> {
387        if let Some(token) = self.access_token.as_deref() {
388            let response = self
389                .client
390                .post(self.client_api_url(&["logout"])?)
391                .bearer_auth(token)
392                .send()
393                .await;
394
395            if let Err(error) = response {
396                tracing::warn!(messenger = %self.name, "Matrix logout failed: {error}");
397            }
398        }
399
400        self.access_token = None;
401        self.user_id = None;
402        *self.sync_token.lock().await = None;
403        self.connected = false;
404        Ok(())
405    }
406
407    async fn set_typing(&self, channel: &str, typing: bool) -> Result<()> {
408        let room_id = self.join_room_if_needed(channel).await?;
409        let mut payload = json!({ "typing": typing });
410        if typing {
411            payload["timeout"] = json!(30_000);
412        }
413
414        let response = self
415            .client
416            .put(self.client_api_url(&["rooms", &room_id, "typing", self.user_id()?])?)
417            .bearer_auth(self.access_token()?)
418            .json(&payload)
419            .send()
420            .await
421            .context("Matrix typing request failed")?;
422
423        let status = response.status();
424        if !status.is_success() {
425            let body = response.text().await.unwrap_or_default();
426            anyhow::bail!("Matrix typing failed {}: {}", status, body);
427        }
428
429        Ok(())
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn validate_config_rejects_empty_homeserver() {
439        let messenger = MatrixMessenger::new("matrix", "", "bot", "secret");
440        assert!(messenger.validate_config().is_err());
441    }
442
443    #[test]
444    fn validate_config_rejects_empty_username() {
445        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "", "secret");
446        assert!(messenger.validate_config().is_err());
447    }
448
449    #[test]
450    fn validate_config_rejects_empty_password() {
451        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "");
452        assert!(messenger.validate_config().is_err());
453    }
454
455    #[test]
456    fn validate_config_accepts_non_empty_values() {
457        let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "secret");
458        assert!(messenger.validate_config().is_ok());
459    }
460}