1use crate::{Message, Messenger};
4use anyhow::{anyhow, ensure, Context, Result};
5use async_trait::async_trait;
6use reqwest::{Client, Url};
7use serde::Deserialize;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU64, Ordering};
11use tokio::sync::Mutex;
12
13pub struct MatrixMessenger {
14 name: String,
15 homeserver: String,
16 username: String,
17 password: String,
18 client: Client,
19 access_token: Option<String>,
20 user_id: Option<String>,
21 sync_token: Mutex<Option<String>>,
22 txn_counter: AtomicU64,
23 connected: bool,
24}
25
26impl MatrixMessenger {
27 pub fn new(
28 name: impl Into<String>,
29 homeserver: impl Into<String>,
30 username: impl Into<String>,
31 password: impl Into<String>,
32 ) -> Self {
33 Self {
34 name: name.into(),
35 homeserver: homeserver.into(),
36 username: username.into(),
37 password: password.into(),
38 client: Client::new(),
39 access_token: None,
40 user_id: None,
41 sync_token: Mutex::new(None),
42 txn_counter: AtomicU64::new(1),
43 connected: false,
44 }
45 }
46
47 fn validate_config(&self) -> Result<()> {
48 ensure!(
49 !self.homeserver.trim().is_empty(),
50 "Matrix homeserver must not be empty"
51 );
52 ensure!(
53 !self.username.trim().is_empty(),
54 "Matrix username must not be empty"
55 );
56 ensure!(
57 !self.password.trim().is_empty(),
58 "Matrix password must not be empty"
59 );
60 Ok(())
61 }
62
63 fn access_token(&self) -> Result<&str> {
64 self.access_token
65 .as_deref()
66 .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
67 }
68
69 fn user_id(&self) -> Result<&str> {
70 self.user_id
71 .as_deref()
72 .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
73 }
74
75 fn url_for_segments(&self, segments: &[&str]) -> Result<Url> {
76 let mut url = Url::parse(self.homeserver.trim_end_matches('/'))
77 .with_context(|| format!("Invalid Matrix homeserver URL: {}", self.homeserver))?;
78 {
79 let mut path_segments = url
80 .path_segments_mut()
81 .map_err(|_| anyhow!("Matrix homeserver URL cannot be a base URL"))?;
82 path_segments.extend(segments.iter().copied());
83 }
84 Ok(url)
85 }
86
87 fn client_api_url(&self, path: &[&str]) -> Result<Url> {
88 let mut segments = vec!["_matrix", "client", "v3"];
89 segments.extend_from_slice(path);
90 self.url_for_segments(&segments)
91 }
92
93 async fn sync_once(&self) -> Result<Vec<Message>> {
94 #[derive(Debug, Deserialize)]
95 struct SyncResponse {
96 next_batch: String,
97 #[serde(default)]
98 rooms: SyncRooms,
99 }
100
101 #[derive(Debug, Default, Deserialize)]
102 struct SyncRooms {
103 #[serde(default)]
104 join: HashMap<String, JoinedRoom>,
105 }
106
107 #[derive(Debug, Default, Deserialize)]
108 struct JoinedRoom {
109 #[serde(default)]
110 timeline: Timeline,
111 }
112
113 #[derive(Debug, Default, Deserialize)]
114 struct Timeline {
115 #[serde(default)]
116 events: Vec<TimelineEvent>,
117 }
118
119 #[derive(Debug, Deserialize)]
120 struct TimelineEvent {
121 #[serde(rename = "type")]
122 event_type: String,
123 event_id: String,
124 sender: String,
125 origin_server_ts: i64,
126 #[serde(default)]
127 content: TimelineContent,
128 }
129
130 #[derive(Debug, Default, Deserialize)]
131 struct TimelineContent {
132 #[serde(default)]
133 body: String,
134 #[serde(default, rename = "m.relates_to")]
135 relates_to: Option<RelatesTo>,
136 }
137
138 #[derive(Debug, Deserialize)]
139 struct RelatesTo {
140 #[serde(default, rename = "m.in_reply_to")]
141 in_reply_to: Option<ReplyTo>,
142 }
143
144 #[derive(Debug, Deserialize)]
145 struct ReplyTo {
146 event_id: String,
147 }
148
149 let since = self.sync_token.lock().await.clone();
150 let mut url = self.client_api_url(&["sync"])?;
151 {
152 let mut query = url.query_pairs_mut();
153 query.append_pair("timeout", "1");
154 if let Some(since) = since {
155 query.append_pair("since", &since);
156 }
157 }
158
159 let response = self
160 .client
161 .get(url)
162 .bearer_auth(self.access_token()?)
163 .send()
164 .await
165 .context("Matrix sync request failed")?;
166
167 let status = response.status();
168 if !status.is_success() {
169 let body = response.text().await.unwrap_or_default();
170 anyhow::bail!("Matrix sync failed {}: {}", status, body);
171 }
172
173 let sync: SyncResponse = response.json().await.context("Invalid Matrix sync response")?;
174 *self.sync_token.lock().await = Some(sync.next_batch);
175
176 let mut messages = Vec::new();
177 for (room_id, joined_room) in sync.rooms.join {
178 for event in joined_room.timeline.events {
179 if event.event_type != "m.room.message" || event.content.body.is_empty() {
180 continue;
181 }
182
183 messages.push(Message {
184 id: event.event_id,
185 sender: event.sender,
186 content: event.content.body,
187 timestamp: event.origin_server_ts / 1000,
188 channel: Some(room_id.clone()),
189 reply_to: event
190 .content
191 .relates_to
192 .and_then(|r| r.in_reply_to)
193 .map(|r| r.event_id),
194 media: None,
195 is_direct: false,
196 reactions: None,
197 });
198 }
199 }
200
201 Ok(messages)
202 }
203
204 async fn join_room_if_needed(&self, recipient: &str) -> Result<String> {
205 if recipient.starts_with('!') {
206 return Ok(recipient.to_string());
207 }
208
209 let response = self
210 .client
211 .post(self.client_api_url(&["join", recipient])?)
212 .bearer_auth(self.access_token()?)
213 .send()
214 .await
215 .context("Matrix join request failed")?;
216
217 let status = response.status();
218 if !status.is_success() {
219 let body = response.text().await.unwrap_or_default();
220 anyhow::bail!("Matrix join failed {}: {}", status, body);
221 }
222
223 #[derive(Deserialize)]
224 struct JoinResponse {
225 room_id: String,
226 }
227
228 let join: JoinResponse = response.json().await.context("Invalid Matrix join response")?;
229 Ok(join.room_id)
230 }
231}
232
233#[async_trait]
234impl Messenger for MatrixMessenger {
235 fn name(&self) -> &str {
236 &self.name
237 }
238
239 fn messenger_type(&self) -> &str {
240 "matrix"
241 }
242
243 async fn initialize(&mut self) -> Result<()> {
244 #[derive(Deserialize)]
245 struct LoginResponse {
246 access_token: String,
247 user_id: String,
248 }
249
250 self.validate_config()?;
251
252 let response = self
253 .client
254 .post(self.client_api_url(&["login"])? )
255 .json(&json!({
256 "type": "m.login.password",
257 "identifier": {
258 "type": "m.id.user",
259 "user": self.username,
260 },
261 "password": self.password,
262 "initial_device_display_name": self.name,
263 }))
264 .send()
265 .await
266 .context("Matrix login request failed")?;
267
268 let status = response.status();
269 if !status.is_success() {
270 let body = response.text().await.unwrap_or_default();
271 anyhow::bail!("Matrix login failed {}: {}", status, body);
272 }
273
274 let login: LoginResponse = response.json().await.context("Invalid Matrix login response")?;
275 self.access_token = Some(login.access_token);
276 self.user_id = Some(login.user_id);
277
278 *self.sync_token.lock().await = None;
279 let _ = self.sync_once().await?;
280
281 self.connected = true;
282 Ok(())
283 }
284
285 async fn send_message(&self, recipient: &str, content: &str) -> Result<String> {
286 let room_id = self.join_room_if_needed(recipient).await?;
287 let txn_id = self.txn_counter.fetch_add(1, Ordering::Relaxed).to_string();
288
289 let response = self
290 .client
291 .put(self.client_api_url(&[
292 "rooms",
293 &room_id,
294 "send",
295 "m.room.message",
296 &txn_id,
297 ])?)
298 .bearer_auth(self.access_token()?)
299 .json(&json!({
300 "msgtype": "m.text",
301 "body": content,
302 }))
303 .send()
304 .await
305 .context("Matrix send request failed")?;
306
307 let status = response.status();
308 if !status.is_success() {
309 let body = response.text().await.unwrap_or_default();
310 anyhow::bail!("Matrix send failed {}: {}", status, body);
311 }
312
313 #[derive(Deserialize)]
314 struct SendResponse {
315 event_id: String,
316 }
317
318 let send: SendResponse = response.json().await.context("Invalid Matrix send response")?;
319 Ok(send.event_id)
320 }
321
322 async fn receive_messages(&self) -> Result<Vec<Message>> {
323 self.sync_once().await
324 }
325
326 fn is_connected(&self) -> bool {
327 self.connected
328 }
329
330 async fn disconnect(&mut self) -> Result<()> {
331 if let Some(token) = self.access_token.as_deref() {
332 let response = self
333 .client
334 .post(self.client_api_url(&["logout"])? )
335 .bearer_auth(token)
336 .send()
337 .await;
338
339 if let Err(error) = response {
340 tracing::warn!(messenger = %self.name, "Matrix logout failed: {error}");
341 }
342 }
343
344 self.access_token = None;
345 self.user_id = None;
346 *self.sync_token.lock().await = None;
347 self.connected = false;
348 Ok(())
349 }
350
351 async fn set_typing(&self, channel: &str, typing: bool) -> Result<()> {
352 let room_id = self.join_room_if_needed(channel).await?;
353 let mut payload = json!({ "typing": typing });
354 if typing {
355 payload["timeout"] = json!(30_000);
356 }
357
358 let response = self
359 .client
360 .put(self.client_api_url(&["rooms", &room_id, "typing", self.user_id()?])?)
361 .bearer_auth(self.access_token()?)
362 .json(&payload)
363 .send()
364 .await
365 .context("Matrix typing request failed")?;
366
367 let status = response.status();
368 if !status.is_success() {
369 let body = response.text().await.unwrap_or_default();
370 anyhow::bail!("Matrix typing failed {}: {}", status, body);
371 }
372
373 Ok(())
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn validate_config_rejects_empty_homeserver() {
383 let messenger = MatrixMessenger::new("matrix", "", "bot", "secret");
384 assert!(messenger.validate_config().is_err());
385 }
386
387 #[test]
388 fn validate_config_rejects_empty_username() {
389 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "", "secret");
390 assert!(messenger.validate_config().is_err());
391 }
392
393 #[test]
394 fn validate_config_rejects_empty_password() {
395 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "");
396 assert!(messenger.validate_config().is_err());
397 }
398
399 #[test]
400 fn validate_config_accepts_non_empty_values() {
401 let messenger = MatrixMessenger::new(
402 "matrix",
403 "https://matrix.example",
404 "bot",
405 "secret",
406 );
407 assert!(messenger.validate_config().is_ok());
408 }
409}