1use crate::message::MessageType;
4use crate::{Message, Messenger};
5use anyhow::{Context, Result, anyhow, ensure};
6use async_trait::async_trait;
7use reqwest::{Client, Url};
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use tokio::sync::Mutex;
13
14pub struct MatrixMessenger {
15 name: String,
16 homeserver: String,
17 username: String,
18 password: String,
19 client: Client,
20 access_token: Option<String>,
21 user_id: Option<String>,
22 sync_token: Mutex<Option<String>>,
23 txn_counter: AtomicU64,
24 connected: bool,
25}
26
27impl MatrixMessenger {
28 pub fn new(
29 name: impl Into<String>,
30 homeserver: impl Into<String>,
31 username: impl Into<String>,
32 password: impl Into<String>,
33 ) -> Self {
34 Self {
35 name: name.into(),
36 homeserver: homeserver.into(),
37 username: username.into(),
38 password: password.into(),
39 client: Client::new(),
40 access_token: None,
41 user_id: None,
42 sync_token: Mutex::new(None),
43 txn_counter: AtomicU64::new(1),
44 connected: false,
45 }
46 }
47
48 pub fn with_access_token(
61 name: impl Into<String>,
62 homeserver: impl Into<String>,
63 user_id: impl Into<String>,
64 access_token: impl Into<String>,
65 _device_id: Option<String>,
66 ) -> Self {
67 let user_id_str = user_id.into();
68 Self {
69 name: name.into(),
70 homeserver: homeserver.into(),
71 username: user_id_str.clone(),
72 password: String::new(), client: Client::new(),
74 access_token: Some(access_token.into()),
75 user_id: Some(user_id_str),
76 sync_token: Mutex::new(None),
77 txn_counter: AtomicU64::new(1),
78 connected: false,
79 }
80 }
81
82 fn validate_config(&self) -> Result<()> {
83 ensure!(
84 !self.homeserver.trim().is_empty(),
85 "Matrix homeserver must not be empty"
86 );
87 ensure!(
88 !self.username.trim().is_empty(),
89 "Matrix username must not be empty"
90 );
91 if self.access_token.is_none() {
93 ensure!(
94 !self.password.trim().is_empty(),
95 "Matrix password must not be empty (unless using access_token auth)"
96 );
97 }
98 Ok(())
99 }
100
101 fn access_token(&self) -> Result<&str> {
102 self.access_token
103 .as_deref()
104 .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
105 }
106
107 fn user_id(&self) -> Result<&str> {
108 self.user_id
109 .as_deref()
110 .ok_or_else(|| anyhow!("Matrix messenger is not initialized"))
111 }
112
113 fn url_for_segments(&self, segments: &[&str]) -> Result<Url> {
114 let mut url = Url::parse(self.homeserver.trim_end_matches('/'))
115 .with_context(|| format!("Invalid Matrix homeserver URL: {}", self.homeserver))?;
116 {
117 let mut path_segments = url
118 .path_segments_mut()
119 .map_err(|_| anyhow!("Matrix homeserver URL cannot be a base URL"))?;
120 path_segments.extend(segments.iter().copied());
121 }
122 Ok(url)
123 }
124
125 fn client_api_url(&self, path: &[&str]) -> Result<Url> {
126 let mut segments = vec!["_matrix", "client", "v3"];
127 segments.extend_from_slice(path);
128 self.url_for_segments(&segments)
129 }
130
131 async fn sync_once(&self) -> Result<Vec<Message>> {
132 #[derive(Debug, Deserialize)]
133 struct SyncResponse {
134 next_batch: String,
135 #[serde(default)]
136 rooms: SyncRooms,
137 }
138
139 #[derive(Debug, Default, Deserialize)]
140 struct SyncRooms {
141 #[serde(default)]
142 join: HashMap<String, JoinedRoom>,
143 }
144
145 #[derive(Debug, Default, Deserialize)]
146 struct JoinedRoom {
147 #[serde(default)]
148 timeline: Timeline,
149 }
150
151 #[derive(Debug, Default, Deserialize)]
152 struct Timeline {
153 #[serde(default)]
154 events: Vec<TimelineEvent>,
155 }
156
157 #[derive(Debug, Deserialize)]
158 struct TimelineEvent {
159 #[serde(rename = "type")]
160 event_type: String,
161 event_id: String,
162 sender: String,
163 origin_server_ts: i64,
164 #[serde(default)]
165 content: TimelineContent,
166 }
167
168 #[derive(Debug, Default, Deserialize)]
169 struct TimelineContent {
170 #[serde(default)]
171 body: String,
172 #[serde(default, rename = "m.relates_to")]
173 relates_to: Option<RelatesTo>,
174 }
175
176 #[derive(Debug, Deserialize)]
177 struct RelatesTo {
178 #[serde(default, rename = "m.in_reply_to")]
179 in_reply_to: Option<ReplyTo>,
180 }
181
182 #[derive(Debug, Deserialize)]
183 struct ReplyTo {
184 event_id: String,
185 }
186
187 let since = self.sync_token.lock().await.clone();
188 let mut url = self.client_api_url(&["sync"])?;
189 {
190 let mut query = url.query_pairs_mut();
191 query.append_pair("timeout", "1");
192 if let Some(since) = since {
193 query.append_pair("since", &since);
194 }
195 }
196
197 let response = self
198 .client
199 .get(url)
200 .bearer_auth(self.access_token()?)
201 .send()
202 .await
203 .context("Matrix sync request failed")?;
204
205 let status = response.status();
206 if !status.is_success() {
207 let body = response.text().await.unwrap_or_default();
208 anyhow::bail!("Matrix sync failed {}: {}", status, body);
209 }
210
211 let sync: SyncResponse = response
212 .json()
213 .await
214 .context("Invalid Matrix sync response")?;
215 *self.sync_token.lock().await = Some(sync.next_batch);
216
217 let mut messages = Vec::new();
218 for (room_id, joined_room) in sync.rooms.join {
219 for event in joined_room.timeline.events {
220 if event.event_type != "m.room.message" || event.content.body.is_empty() {
221 continue;
222 }
223
224 messages.push(Message {
225 id: event.event_id,
226 sender: event.sender,
227 content: event.content.body,
228 timestamp: event.origin_server_ts / 1000,
229 channel: Some(room_id.clone()),
230 reply_to: event
231 .content
232 .relates_to
233 .and_then(|r| r.in_reply_to)
234 .map(|r| r.event_id),
235 thread_id: None,
236 media: None,
237 is_direct: false,
238 message_type: MessageType::Text,
239 edited_timestamp: None,
240 reactions: None,
241 });
242 }
243 }
244
245 Ok(messages)
246 }
247
248 async fn join_room_if_needed(&self, recipient: &str) -> Result<String> {
249 if recipient.starts_with('!') {
250 return Ok(recipient.to_string());
251 }
252
253 let response = self
254 .client
255 .post(self.client_api_url(&["join", recipient])?)
256 .bearer_auth(self.access_token()?)
257 .send()
258 .await
259 .context("Matrix join request failed")?;
260
261 let status = response.status();
262 if !status.is_success() {
263 let body = response.text().await.unwrap_or_default();
264 anyhow::bail!("Matrix join failed {}: {}", status, body);
265 }
266
267 #[derive(Deserialize)]
268 struct JoinResponse {
269 room_id: String,
270 }
271
272 let join: JoinResponse = response
273 .json()
274 .await
275 .context("Invalid Matrix join response")?;
276 Ok(join.room_id)
277 }
278}
279
280#[async_trait]
281impl Messenger for MatrixMessenger {
282 fn name(&self) -> &str {
283 &self.name
284 }
285
286 fn messenger_type(&self) -> &str {
287 "matrix"
288 }
289
290 async fn initialize(&mut self) -> Result<()> {
291 if self.access_token.is_some() {
293 *self.sync_token.lock().await = None;
295 let _ = self.sync_once().await?;
296 self.connected = true;
297 return Ok(());
298 }
299
300 #[derive(Deserialize)]
301 struct LoginResponse {
302 access_token: String,
303 user_id: String,
304 }
305
306 self.validate_config()?;
307
308 let response = self
309 .client
310 .post(self.client_api_url(&["login"])?)
311 .json(&json!({
312 "type": "m.login.password",
313 "identifier": {
314 "type": "m.id.user",
315 "user": self.username,
316 },
317 "password": self.password,
318 "initial_device_display_name": self.name,
319 }))
320 .send()
321 .await
322 .context("Matrix login request failed")?;
323
324 let status = response.status();
325 if !status.is_success() {
326 let body = response.text().await.unwrap_or_default();
327 anyhow::bail!("Matrix login failed {}: {}", status, body);
328 }
329
330 let login: LoginResponse = response
331 .json()
332 .await
333 .context("Invalid Matrix login response")?;
334 self.access_token = Some(login.access_token);
335 self.user_id = Some(login.user_id);
336
337 *self.sync_token.lock().await = None;
338 let _ = self.sync_once().await?;
339
340 self.connected = true;
341 Ok(())
342 }
343
344 async fn send_message(&self, recipient: &str, content: &str) -> Result<String> {
345 let room_id = self.join_room_if_needed(recipient).await?;
346 let txn_id = self.txn_counter.fetch_add(1, Ordering::Relaxed).to_string();
347
348 let response = self
349 .client
350 .put(self.client_api_url(&["rooms", &room_id, "send", "m.room.message", &txn_id])?)
351 .bearer_auth(self.access_token()?)
352 .json(&json!({
353 "msgtype": "m.text",
354 "body": content,
355 }))
356 .send()
357 .await
358 .context("Matrix send request failed")?;
359
360 let status = response.status();
361 if !status.is_success() {
362 let body = response.text().await.unwrap_or_default();
363 anyhow::bail!("Matrix send failed {}: {}", status, body);
364 }
365
366 #[derive(Deserialize)]
367 struct SendResponse {
368 event_id: String,
369 }
370
371 let send: SendResponse = response
372 .json()
373 .await
374 .context("Invalid Matrix send response")?;
375 Ok(send.event_id)
376 }
377
378 async fn receive_messages(&self) -> Result<Vec<Message>> {
379 self.sync_once().await
380 }
381
382 fn is_connected(&self) -> bool {
383 self.connected
384 }
385
386 async fn disconnect(&mut self) -> Result<()> {
387 if let Some(token) = self.access_token.as_deref() {
388 let response = self
389 .client
390 .post(self.client_api_url(&["logout"])?)
391 .bearer_auth(token)
392 .send()
393 .await;
394
395 if let Err(error) = response {
396 tracing::warn!(messenger = %self.name, "Matrix logout failed: {error}");
397 }
398 }
399
400 self.access_token = None;
401 self.user_id = None;
402 *self.sync_token.lock().await = None;
403 self.connected = false;
404 Ok(())
405 }
406
407 async fn set_typing(&self, channel: &str, typing: bool) -> Result<()> {
408 let room_id = self.join_room_if_needed(channel).await?;
409 let mut payload = json!({ "typing": typing });
410 if typing {
411 payload["timeout"] = json!(30_000);
412 }
413
414 let response = self
415 .client
416 .put(self.client_api_url(&["rooms", &room_id, "typing", self.user_id()?])?)
417 .bearer_auth(self.access_token()?)
418 .json(&payload)
419 .send()
420 .await
421 .context("Matrix typing request failed")?;
422
423 let status = response.status();
424 if !status.is_success() {
425 let body = response.text().await.unwrap_or_default();
426 anyhow::bail!("Matrix typing failed {}: {}", status, body);
427 }
428
429 Ok(())
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn validate_config_rejects_empty_homeserver() {
439 let messenger = MatrixMessenger::new("matrix", "", "bot", "secret");
440 assert!(messenger.validate_config().is_err());
441 }
442
443 #[test]
444 fn validate_config_rejects_empty_username() {
445 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "", "secret");
446 assert!(messenger.validate_config().is_err());
447 }
448
449 #[test]
450 fn validate_config_rejects_empty_password() {
451 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "");
452 assert!(messenger.validate_config().is_err());
453 }
454
455 #[test]
456 fn validate_config_accepts_non_empty_values() {
457 let messenger = MatrixMessenger::new("matrix", "https://matrix.example", "bot", "secret");
458 assert!(messenger.validate_config().is_ok());
459 }
460}