1use crate::message::MessageType;
4use crate::{Message, Messenger};
5use anyhow::{Context, Result, anyhow, ensure};
6use async_trait::async_trait;
7use reqwest::{Client, Url};
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use tokio::sync::Mutex;
13
14pub struct MatrixMessenger {
15 name: String,
16 homeserver: String,
17 username: String,
18 password: String,
19 client: Client,
20 access_token: Option<String>,
21 user_id: Option<String>,
22 sync_token: Mutex<Option<String>>,
23 txn_counter: AtomicU64,
24 connected: bool,
25}
26
27impl MatrixMessenger {
28 pub fn new(
29 name: impl Into<String>,
30 homeserver: impl Into<String>,
31 username: impl Into<String>,
32 password: impl Into<String>,
33 ) -> Self {
34 Self {
35 name: name.into(),
36 homeserver: homeserver.into(),
37 username: username.into(),
38 password: password.into(),
39 client: Client::new(),
40 access_token: None,
41 user_id: None,
42 sync_token: Mutex::new(None),
43 txn_counter: AtomicU64::new(1),
44 connected: false,
45 }
46 }
47
48 fn validate_config(&self) -> Result<()> {
49 ensure!(
50 !self.homeserver.trim().is_empty(),
51 "Matrix homeserver must not be empty"
52 );
53 ensure!(
54 !self.username.trim().is_empty(),
55 "Matrix username must not be empty"
56 );
57 ensure!(
58 !self.password.trim().is_empty(),
59 "Matrix password must not be empty"
60 );
61 Ok(())
62 }
63
64 fn access_token(&self) -> Result<&str> {
65 self.access_token
66 .as_deref()
67 .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
68 }
69
70 fn user_id(&self) -> Result<&str> {
71 self.user_id
72 .as_deref()
73 .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
74 }
75
76 fn url_for_segments(&self, segments: &[&str]) -> Result<Url> {
77 let mut url = Url::parse(self.homeserver.trim_end_matches('/'))
78 .with_context(|| format!("Invalid Matrix homeserver URL: {}", self.homeserver))?;
79 {
80 let mut path_segments = url
81 .path_segments_mut()
82 .map_err(|_| anyhow!("Matrix homeserver URL cannot be a base URL"))?;
83 path_segments.extend(segments.iter().copied());
84 }
85 Ok(url)
86 }
87
88 fn client_api_url(&self, path: &[&str]) -> Result<Url> {
89 let mut segments = vec!["_matrix", "client", "v3"];
90 segments.extend_from_slice(path);
91 self.url_for_segments(&segments)
92 }
93
94 async fn sync_once(&self) -> Result<Vec<Message>> {
95 #[derive(Debug, Deserialize)]
96 struct SyncResponse {
97 next_batch: String,
98 #[serde(default)]
99 rooms: SyncRooms,
100 }
101
102 #[derive(Debug, Default, Deserialize)]
103 struct SyncRooms {
104 #[serde(default)]
105 join: HashMap<String, JoinedRoom>,
106 }
107
108 #[derive(Debug, Default, Deserialize)]
109 struct JoinedRoom {
110 #[serde(default)]
111 timeline: Timeline,
112 }
113
114 #[derive(Debug, Default, Deserialize)]
115 struct Timeline {
116 #[serde(default)]
117 events: Vec<TimelineEvent>,
118 }
119
120 #[derive(Debug, Deserialize)]
121 struct TimelineEvent {
122 #[serde(rename = "type")]
123 event_type: String,
124 event_id: String,
125 sender: String,
126 origin_server_ts: i64,
127 #[serde(default)]
128 content: TimelineContent,
129 }
130
131 #[derive(Debug, Default, Deserialize)]
132 struct TimelineContent {
133 #[serde(default)]
134 body: String,
135 #[serde(default, rename = "m.relates_to")]
136 relates_to: Option<RelatesTo>,
137 }
138
139 #[derive(Debug, Deserialize)]
140 struct RelatesTo {
141 #[serde(default, rename = "m.in_reply_to")]
142 in_reply_to: Option<ReplyTo>,
143 }
144
145 #[derive(Debug, Deserialize)]
146 struct ReplyTo {
147 event_id: String,
148 }
149
150 let since = self.sync_token.lock().await.clone();
151 let mut url = self.client_api_url(&["sync"])?;
152 {
153 let mut query = url.query_pairs_mut();
154 query.append_pair("timeout", "1");
155 if let Some(since) = since {
156 query.append_pair("since", &since);
157 }
158 }
159
160 let response = self
161 .client
162 .get(url)
163 .bearer_auth(self.access_token()?)
164 .send()
165 .await
166 .context("Matrix sync request failed")?;
167
168 let status = response.status();
169 if !status.is_success() {
170 let body = response.text().await.unwrap_or_default();
171 anyhow::bail!("Matrix sync failed {}: {}", status, body);
172 }
173
174 let sync: SyncResponse = response
175 .json()
176 .await
177 .context("Invalid Matrix sync response")?;
178 *self.sync_token.lock().await = Some(sync.next_batch);
179
180 let mut messages = Vec::new();
181 for (room_id, joined_room) in sync.rooms.join {
182 for event in joined_room.timeline.events {
183 if event.event_type != "m.room.message" || event.content.body.is_empty() {
184 continue;
185 }
186
187 messages.push(Message {
188 id: event.event_id,
189 sender: event.sender,
190 content: event.content.body,
191 timestamp: event.origin_server_ts / 1000,
192 channel: Some(room_id.clone()),
193 reply_to: event
194 .content
195 .relates_to
196 .and_then(|r| r.in_reply_to)
197 .map(|r| r.event_id),
198 thread_id: None,
199 media: None,
200 is_direct: false,
201 message_type: MessageType::Text,
202 edited_timestamp: None,
203 reactions: None,
204 });
205 }
206 }
207
208 Ok(messages)
209 }
210
211 async fn join_room_if_needed(&self, recipient: &str) -> Result<String> {
212 if recipient.starts_with('!') {
213 return Ok(recipient.to_string());
214 }
215
216 let response = self
217 .client
218 .post(self.client_api_url(&["join", recipient])?)
219 .bearer_auth(self.access_token()?)
220 .send()
221 .await
222 .context("Matrix join request failed")?;
223
224 let status = response.status();
225 if !status.is_success() {
226 let body = response.text().await.unwrap_or_default();
227 anyhow::bail!("Matrix join failed {}: {}", status, body);
228 }
229
230 #[derive(Deserialize)]
231 struct JoinResponse {
232 room_id: String,
233 }
234
235 let join: JoinResponse = response
236 .json()
237 .await
238 .context("Invalid Matrix join response")?;
239 Ok(join.room_id)
240 }
241}
242
243#[async_trait]
244impl Messenger for MatrixMessenger {
245 fn name(&self) -> &str {
246 &self.name
247 }
248
249 fn messenger_type(&self) -> &str {
250 "matrix"
251 }
252
253 async fn initialize(&mut self) -> Result<()> {
254 #[derive(Deserialize)]
255 struct LoginResponse {
256 access_token: String,
257 user_id: String,
258 }
259
260 self.validate_config()?;
261
262 let response = self
263 .client
264 .post(self.client_api_url(&["login"])?)
265 .json(&json!({
266 "type": "m.login.password",
267 "identifier": {
268 "type": "m.id.user",
269 "user": self.username,
270 },
271 "password": self.password,
272 "initial_device_display_name": self.name,
273 }))
274 .send()
275 .await
276 .context("Matrix login request failed")?;
277
278 let status = response.status();
279 if !status.is_success() {
280 let body = response.text().await.unwrap_or_default();
281 anyhow::bail!("Matrix login failed {}: {}", status, body);
282 }
283
284 let login: LoginResponse = response
285 .json()
286 .await
287 .context("Invalid Matrix login response")?;
288 self.access_token = Some(login.access_token);
289 self.user_id = Some(login.user_id);
290
291 *self.sync_token.lock().await = None;
292 let _ = self.sync_once().await?;
293
294 self.connected = true;
295 Ok(())
296 }
297
298 async fn send_message(&self, recipient: &str, content: &str) -> Result<String> {
299 let room_id = self.join_room_if_needed(recipient).await?;
300 let txn_id = self.txn_counter.fetch_add(1, Ordering::Relaxed).to_string();
301
302 let response = self
303 .client
304 .put(self.client_api_url(&["rooms", &room_id, "send", "m.room.message", &txn_id])?)
305 .bearer_auth(self.access_token()?)
306 .json(&json!({
307 "msgtype": "m.text",
308 "body": content,
309 }))
310 .send()
311 .await
312 .context("Matrix send request failed")?;
313
314 let status = response.status();
315 if !status.is_success() {
316 let body = response.text().await.unwrap_or_default();
317 anyhow::bail!("Matrix send failed {}: {}", status, body);
318 }
319
320 #[derive(Deserialize)]
321 struct SendResponse {
322 event_id: String,
323 }
324
325 let send: SendResponse = response
326 .json()
327 .await
328 .context("Invalid Matrix send response")?;
329 Ok(send.event_id)
330 }
331
332 async fn receive_messages(&self) -> Result<Vec<Message>> {
333 self.sync_once().await
334 }
335
336 fn is_connected(&self) -> bool {
337 self.connected
338 }
339
340 async fn disconnect(&mut self) -> Result<()> {
341 if let Some(token) = self.access_token.as_deref() {
342 let response = self
343 .client
344 .post(self.client_api_url(&["logout"])?)
345 .bearer_auth(token)
346 .send()
347 .await;
348
349 if let Err(error) = response {
350 tracing::warn!(messenger = %self.name, "Matrix logout failed: {error}");
351 }
352 }
353
354 self.access_token = None;
355 self.user_id = None;
356 *self.sync_token.lock().await = None;
357 self.connected = false;
358 Ok(())
359 }
360
361 async fn set_typing(&self, channel: &str, typing: bool) -> Result<()> {
362 let room_id = self.join_room_if_needed(channel).await?;
363 let mut payload = json!({ "typing": typing });
364 if typing {
365 payload["timeout"] = json!(30_000);
366 }
367
368 let response = self
369 .client
370 .put(self.client_api_url(&["rooms", &room_id, "typing", self.user_id()?])?)
371 .bearer_auth(self.access_token()?)
372 .json(&payload)
373 .send()
374 .await
375 .context("Matrix typing request failed")?;
376
377 let status = response.status();
378 if !status.is_success() {
379 let body = response.text().await.unwrap_or_default();
380 anyhow::bail!("Matrix typing failed {}: {}", status, body);
381 }
382
383 Ok(())
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn validate_config_rejects_empty_homeserver() {
393 let messenger = MatrixMessenger::new("matrix", "", "bot", "secret");
394 assert!(messenger.validate_config().is_err());
395 }
396
397 #[test]
398 fn validate_config_rejects_empty_username() {
399 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "", "secret");
400 assert!(messenger.validate_config().is_err());
401 }
402
403 #[test]
404 fn validate_config_rejects_empty_password() {
405 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "");
406 assert!(messenger.validate_config().is_err());
407 }
408
409 #[test]
410 fn validate_config_accepts_non_empty_values() {
411 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "secret");
412 assert!(messenger.validate_config().is_ok());
413 }
414}