1use std::{
16 collections::{HashMap, HashSet},
17 convert::TryFrom,
18};
19
20use anyhow::{anyhow, Context, Error};
21use data_encoding::HEXLOWER;
22use hmac::{Hmac, Mac};
23use log::debug;
24use matrix_sdk::{
25 ruma::{api::client::error::ErrorKind, RoomAliasId},
26 HttpError,
27};
28use reqwest::StatusCode;
29use serde::{Deserialize, Serialize};
30use sha1::Sha1;
31use typed_builder::TypedBuilder;
32
33use crate::util::{AsRumaError, Retry};
34
35type HmacSha1 = Hmac<Sha1>;
36
37const RETRY_ATTEMPTS: u64 = 10;
39const TIMEOUT_SEC: u64 = 15;
40
41#[derive(Clone, Debug, Deserialize)]
42pub enum RateLimit {
43 #[serde(alias = "default")]
45 Default,
46
47 #[serde(alias = "unlimited")]
49 Unlimited,
50}
51impl Default for RateLimit {
52 fn default() -> Self {
53 RateLimit::Default
54 }
55}
56
57#[derive(Clone, TypedBuilder, Debug, Deserialize)]
58pub struct User {
59 #[serde(default)]
61 #[builder(default = false)]
62 pub admin: bool,
63
64 pub localname: String,
65
66 #[serde(default = "User::default_password")]
68 #[builder(default = User::default_password())]
69 pub password: String,
70
71 #[serde(default)]
72 #[builder(default)]
73 pub rooms: Vec<Room>,
74
75 #[serde(default)]
78 #[builder(default)]
79 pub rate_limit: RateLimit,
80}
81
82impl User {
83 fn default_password() -> String {
84 "password".to_string()
85 }
86}
87
88#[derive(Clone, TypedBuilder, Debug, Deserialize)]
90pub struct Room {
91 #[serde(default)]
93 #[builder(default = false)]
94 pub public: bool,
95
96 #[serde(default)]
100 #[builder(default)]
101 pub members: Vec<String>,
102
103 #[serde(default)]
105 #[builder(default)]
106 pub name: Option<String>,
107
108 #[serde(default)]
110 #[builder(default)]
111 pub alias: Option<String>,
112
113 #[serde(default)]
115 #[builder(default)]
116 pub topic: Option<String>,
117}
118
119async fn register_user(
123 base_url: &str,
124 registration_shared_secret: &str,
125 user: &User,
126) -> Result<(), Error> {
127 #[derive(Debug, Deserialize)]
128 struct GetRegisterResponse {
129 nonce: String,
130 }
131 let registration_url = format!("{}/_synapse/admin/v1/register", base_url);
132 debug!(
133 "Registration shared secret: {}, url: {}, user: {:#?}",
134 registration_shared_secret, registration_url, user
135 );
136 let client = reqwest::Client::new();
137 let nonce = client
138 .get(®istration_url)
139 .auto_retry(RETRY_ATTEMPTS)
140 .await?
141 .json::<GetRegisterResponse>()
142 .await?
143 .nonce;
144 let mut mac =
146 HmacSha1::new_from_slice(registration_shared_secret.as_bytes()).map_err(|err| {
147 anyhow!(
148 "Couldn't use the provided registration shared secret to create a hmac: {}",
149 err
150 )
151 })?;
152 mac.update(
153 format!(
154 "{nonce}\0{username}\0{password}\0{admin}",
155 nonce = nonce,
156 username = user.localname,
157 password = user.password,
158 admin = if user.admin { "admin" } else { "notadmin" }
159 )
160 .as_bytes(),
161 );
162
163 #[derive(Debug, Serialize)]
164 struct RegistrationPayload {
165 nonce: String,
166 username: String,
167 displayname: String,
168 password: String,
169 admin: bool,
170 mac: String,
171 }
172
173 let registration_payload = RegistrationPayload {
174 nonce,
175 username: user.localname.to_string(),
176 displayname: user.localname.to_string(),
177 password: user.password.to_string(),
178 admin: user.admin,
179 mac: HEXLOWER.encode(&mac.finalize().into_bytes()),
180 };
181 debug!(
182 "Sending payload {:#?}",
183 serde_json::to_string_pretty(®istration_payload)
184 );
185
186 #[derive(Debug, Deserialize)]
187 struct ErrorResponse {
188 errcode: String,
189 error: String,
190 }
191 let client = reqwest::Client::new();
192 let response = client
193 .post(®istration_url)
194 .json(®istration_payload)
195 .auto_retry(RETRY_ATTEMPTS)
196 .await?;
197 match response.status() {
198 StatusCode::OK => Ok(()),
199 _ => {
200 let body = response.json::<ErrorResponse>().await?;
201 Err(anyhow!(
202 "Homeserver responded with errcode: {}, error: {}",
203 body.errcode,
204 body.error
205 ))
206 }
207 }
208}
209
210async fn ensure_user_exists(
213 base_url: &str,
214 registration_shared_secret: &str,
215 user: &User,
216) -> Result<matrix_sdk::Client, Error> {
217 debug!(
218 "ensure_user_exists at {}: user {} with password {}",
219 base_url, user.localname, user.password
220 );
221 use matrix_sdk::ruma::api::client::error::*;
222 let homeserver_url = reqwest::Url::parse(base_url)?;
223 let request_config = matrix_sdk::config::RequestConfig::new()
224 .retry_limit(RETRY_ATTEMPTS)
225 .retry_timeout(std::time::Duration::new(TIMEOUT_SEC, 0));
226 let client = matrix_sdk::Client::builder()
227 .request_config(request_config)
228 .homeserver_url(homeserver_url)
229 .build()
230 .await?;
231 match client
232 .login(&user.localname, &user.password, None, None)
233 .await
234 {
235 Ok(_) => return Ok(client),
236 Err(err) => {
237 match err.as_ruma_error() {
238 Some(err) if err.kind == ErrorKind::Forbidden => {
239 debug!("Could not authenticate {}", err);
240 }
242 _ => return Err(err).context("Error attempting to login"),
243 }
244 }
245 }
246 register_user(base_url, registration_shared_secret, user).await?;
247 client
248 .login(&user.localname, &user.password, None, None)
249 .await?;
250 Ok(client)
251}
252
253pub async fn handle_user_registration(config: &crate::Config) -> Result<(), Error> {
254 let admin = ensure_user_exists(
256 &config.homeserver.public_baseurl,
257 &config.homeserver.registration_shared_secret,
258 &User::builder()
259 .admin(true)
260 .localname("mx-tester-admin".to_string())
261 .build(),
262 )
263 .await?;
264
265 let mut clients = HashMap::new();
266 for user in &config.users {
268 let client = ensure_user_exists(
269 &config.homeserver.public_baseurl,
270 &config.homeserver.registration_shared_secret,
271 user,
272 )
273 .await
274 .with_context(|| format!("Could not setup user {}", user.localname))?;
275
276 if let RateLimit::Unlimited = user.rate_limit {
278 use override_rate_limits::*;
279 let user_id = client
280 .user_id()
281 .await
282 .expect("Client doesn't have a user id");
283 let request = Request::new(&user_id, Some(0), Some(0));
284 let _ = admin.send(request, None).await?;
285 }
286
287 clients.insert(user.localname.clone(), client);
288 }
289
290 let mut aliases = HashSet::new();
292 for user in &config.users {
293 if user.rooms.is_empty() {
294 continue;
295 }
296 let client = clients.get(&user.localname).unwrap(); let my_user_id = client.user_id().await.ok_or_else(|| {
298 anyhow!(
299 "Cannot determine full user id for own user {}.",
300 user.localname
301 )
302 })?;
303 for room in &user.rooms {
304 let mut request = matrix_sdk::ruma::api::client::room::create_room::v3::Request::new();
305 if room.public {
306 request.preset = Some(
307 matrix_sdk::ruma::api::client::room::create_room::v3::RoomPreset::PublicChat,
308 );
309 } else {
310 request.preset = Some(
311 matrix_sdk::ruma::api::client::room::create_room::v3::RoomPreset::PrivateChat,
312 );
313 }
314 if let Some(ref name) = room.name {
315 request.name = Some(TryFrom::<&str>::try_from(name.as_str())?);
316 }
317 if let Some(ref alias) = room.alias {
318 if !aliases.insert(alias) {
319 return Err(anyhow!(
320 "Attempting to create more than one room with alias {}",
321 alias
322 ));
323 }
324 request.room_alias_name = Some(alias.as_ref());
325 let full_alias = format!("#{}:{}", alias, config.homeserver.server_name);
327 debug!("Attempting to register alias {}, this may require unregistering previous instances first.", full_alias);
328 let room_alias_id = <&RoomAliasId as TryFrom<&str>>::try_from(full_alias.as_ref())?;
329 match client
330 .send(
331 matrix_sdk::ruma::api::client::alias::delete_alias::v3::Request::new(
332 &room_alias_id,
333 ),
334 None,
335 )
336 .await
337 {
338 Ok(_) => Ok(()),
340 Err(HttpError::Server(ref code)) if code.as_u16() == 404 => Ok(()),
342 Err(err) => {
343 match err.as_ruma_error() {
344 Some(err) if err.kind == ErrorKind::NotFound => Ok(()),
345 _ => Err(err),
347 }
348 }
349 }
350 .context("Error while attempting to unregister existing alias")?;
351 }
352 if let Some(ref topic) = room.topic {
353 request.topic = Some(topic.as_ref());
354 }
355
356 let mut invites = vec![];
358 for member in &room.members {
359 let member_client = clients.get(member).ok_or_else(|| {
360 anyhow!(
361 "Cannot invite user {}: we haven't created this user.",
362 member
363 )
364 })?;
365 let user_id = member_client
366 .user_id()
367 .await
368 .ok_or_else(|| anyhow!("Cannot determine full user id for user {}.", member))?;
369 if my_user_id == user_id {
370 continue;
372 }
373 invites.push(user_id);
374 }
375 request.invite = &invites;
376 let room_id = client.create_room(request).await?.room_id;
377
378 for member in &room.members {
380 let member_client = clients.get(member).unwrap(); member_client.join_room_by_id(&room_id).await?;
382 }
383 }
384 }
385 Ok(())
386}
387
388mod override_rate_limits {
389 use matrix_sdk::ruma::api::ruma_api;
390 use matrix_sdk::ruma::UserId;
391 use serde::{Deserialize, Serialize};
392
393 ruma_api! {
394 metadata: {
395 description: "Override rate limits",
396 method: POST,
397 name: "override_rate_limit",
398 unstable_path: "/_synapse/admin/v1/users/:user_id/override_ratelimit",
399 rate_limited: false,
400 authentication: AccessToken,
401 }
402
403 request: {
404 #[ruma_api(path)]
406 pub user_id: &'a UserId,
407
408 #[serde(default, skip_serializing_if = "Option::is_none")]
410 pub messages_per_second: Option<u32>,
411
412 #[serde(default, skip_serializing_if = "Option::is_none")]
414 pub burst_count: Option<u32>
415 }
416
417 response: {
418 #[ruma_api(body)]
420 pub limits: UserLimits,
421 }
422 }
423
424 #[derive(Serialize, Deserialize, Clone, Debug)]
425 pub struct UserLimits {
426 pub messages_per_second: u32,
427 pub burst_count: u32,
428 }
429
430 impl<'a> Request<'a> {
431 pub fn new(
433 user_id: &'a UserId,
434 messages_per_second: Option<u32>,
435 burst_count: Option<u32>,
436 ) -> Self {
437 Self {
438 user_id,
439 messages_per_second,
440 burst_count,
441 }
442 }
443 }
444
445 impl Response {
446 #[allow(dead_code)]
448 pub fn new(limits: UserLimits) -> Self {
449 Self { limits }
450 }
451 }
452}