mx_tester/
registration.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{HashMap, HashSet},
17    convert::TryFrom,
18};
19
20use anyhow::{anyhow, Context, Error};
21use data_encoding::HEXLOWER;
22use hmac::{Hmac, Mac};
23use log::debug;
24use matrix_sdk::{
25    ruma::{api::client::error::ErrorKind, RoomAliasId},
26    HttpError,
27};
28use reqwest::StatusCode;
29use serde::{Deserialize, Serialize};
30use sha1::Sha1;
31use typed_builder::TypedBuilder;
32
33use crate::util::{AsRumaError, Retry};
34
35type HmacSha1 = Hmac<Sha1>;
36
37/// The maximal number of attempts when registering a user..
38const RETRY_ATTEMPTS: u64 = 10;
39const TIMEOUT_SEC: u64 = 15;
40
41#[derive(Clone, Debug, Deserialize)]
42pub enum RateLimit {
43    /// Leave the rate limit unchanged.
44    #[serde(alias = "default")]
45    Default,
46
47    /// Specify that the user shouldn't be rate-limited.
48    #[serde(alias = "unlimited")]
49    Unlimited,
50}
51impl Default for RateLimit {
52    fn default() -> Self {
53        RateLimit::Default
54    }
55}
56
57#[derive(Clone, TypedBuilder, Debug, Deserialize)]
58pub struct User {
59    /// Create user as admin?
60    #[serde(default)]
61    #[builder(default = false)]
62    pub admin: bool,
63
64    pub localname: String,
65
66    /// The password for this user. If unspecified, we use `"password"` as password.
67    #[serde(default = "User::default_password")]
68    #[builder(default = User::default_password())]
69    pub password: String,
70
71    #[serde(default)]
72    #[builder(default)]
73    pub rooms: Vec<Room>,
74
75    /// If specified, override the maximal number of messages per second
76    /// that this user can send.
77    #[serde(default)]
78    #[builder(default)]
79    pub rate_limit: RateLimit,
80}
81
82impl User {
83    fn default_password() -> String {
84        "password".to_string()
85    }
86}
87
88/// Instructions for creating a room.
89#[derive(Clone, TypedBuilder, Debug, Deserialize)]
90pub struct Room {
91    /// Whether the room should be public.
92    #[serde(default)]
93    #[builder(default = false)]
94    pub public: bool,
95
96    /// A list of room members.
97    ///
98    /// These must have been created by mx-tester.
99    #[serde(default)]
100    #[builder(default)]
101    pub members: Vec<String>,
102
103    /// A name for the room.
104    #[serde(default)]
105    #[builder(default)]
106    pub name: Option<String>,
107
108    /// A public alias for the room.
109    #[serde(default)]
110    #[builder(default)]
111    pub alias: Option<String>,
112
113    /// A topic for the room.
114    #[serde(default)]
115    #[builder(default)]
116    pub topic: Option<String>,
117}
118
119/// Register a user using the admin api and a registration shared secret.
120/// The base_url is the Scheme and Authority of the URL to access synapse via.
121/// Returns a RegistrationResponse if registration succeeded, otherwise returns an error.
122async fn register_user(
123    base_url: &str,
124    registration_shared_secret: &str,
125    user: &User,
126) -> Result<(), Error> {
127    #[derive(Debug, Deserialize)]
128    struct GetRegisterResponse {
129        nonce: String,
130    }
131    let registration_url = format!("{}/_synapse/admin/v1/register", base_url);
132    debug!(
133        "Registration shared secret: {}, url: {}, user: {:#?}",
134        registration_shared_secret, registration_url, user
135    );
136    let client = reqwest::Client::new();
137    let nonce = client
138        .get(&registration_url)
139        .auto_retry(RETRY_ATTEMPTS)
140        .await?
141        .json::<GetRegisterResponse>()
142        .await?
143        .nonce;
144    // We use map_err here because Hmac::InvalidKeyLength doesn't implement the std::error::Error trait.
145    let mut mac =
146        HmacSha1::new_from_slice(registration_shared_secret.as_bytes()).map_err(|err| {
147            anyhow!(
148                "Couldn't use the provided registration shared secret to create a hmac: {}",
149                err
150            )
151        })?;
152    mac.update(
153        format!(
154            "{nonce}\0{username}\0{password}\0{admin}",
155            nonce = nonce,
156            username = user.localname,
157            password = user.password,
158            admin = if user.admin { "admin" } else { "notadmin" }
159        )
160        .as_bytes(),
161    );
162
163    #[derive(Debug, Serialize)]
164    struct RegistrationPayload {
165        nonce: String,
166        username: String,
167        displayname: String,
168        password: String,
169        admin: bool,
170        mac: String,
171    }
172
173    let registration_payload = RegistrationPayload {
174        nonce,
175        username: user.localname.to_string(),
176        displayname: user.localname.to_string(),
177        password: user.password.to_string(),
178        admin: user.admin,
179        mac: HEXLOWER.encode(&mac.finalize().into_bytes()),
180    };
181    debug!(
182        "Sending payload {:#?}",
183        serde_json::to_string_pretty(&registration_payload)
184    );
185
186    #[derive(Debug, Deserialize)]
187    struct ErrorResponse {
188        errcode: String,
189        error: String,
190    }
191    let client = reqwest::Client::new();
192    let response = client
193        .post(&registration_url)
194        .json(&registration_payload)
195        .auto_retry(RETRY_ATTEMPTS)
196        .await?;
197    match response.status() {
198        StatusCode::OK => Ok(()),
199        _ => {
200            let body = response.json::<ErrorResponse>().await?;
201            Err(anyhow!(
202                "Homeserver responded with errcode: {}, error: {}",
203                body.errcode,
204                body.error
205            ))
206        }
207    }
208}
209
210/// Try to login with the user details provided. If login fails, try to register that user.
211/// If registration then fails, returns an error explaining why, otherwise returns the login details.
212async fn ensure_user_exists(
213    base_url: &str,
214    registration_shared_secret: &str,
215    user: &User,
216) -> Result<matrix_sdk::Client, Error> {
217    debug!(
218        "ensure_user_exists at {}: user {} with password {}",
219        base_url, user.localname, user.password
220    );
221    use matrix_sdk::ruma::api::client::error::*;
222    let homeserver_url = reqwest::Url::parse(base_url)?;
223    let request_config = matrix_sdk::config::RequestConfig::new()
224        .retry_limit(RETRY_ATTEMPTS)
225        .retry_timeout(std::time::Duration::new(TIMEOUT_SEC, 0));
226    let client = matrix_sdk::Client::builder()
227        .request_config(request_config)
228        .homeserver_url(homeserver_url)
229        .build()
230        .await?;
231    match client
232        .login(&user.localname, &user.password, None, None)
233        .await
234    {
235        Ok(_) => return Ok(client),
236        Err(err) => {
237            match err.as_ruma_error() {
238                Some(err) if err.kind == ErrorKind::Forbidden => {
239                    debug!("Could not authenticate {}", err);
240                    // Proceed with registration.
241                }
242                _ => return Err(err).context("Error attempting to login"),
243            }
244        }
245    }
246    register_user(base_url, registration_shared_secret, user).await?;
247    client
248        .login(&user.localname, &user.password, None, None)
249        .await?;
250    Ok(client)
251}
252
253pub async fn handle_user_registration(config: &crate::Config) -> Result<(), Error> {
254    // Create an admin user. We'll need it later to unthrottle users.
255    let admin = ensure_user_exists(
256        &config.homeserver.public_baseurl,
257        &config.homeserver.registration_shared_secret,
258        &User::builder()
259            .admin(true)
260            .localname("mx-tester-admin".to_string())
261            .build(),
262    )
263    .await?;
264
265    let mut clients = HashMap::new();
266    // Create users
267    for user in &config.users {
268        let client = ensure_user_exists(
269            &config.homeserver.public_baseurl,
270            &config.homeserver.registration_shared_secret,
271            user,
272        )
273        .await
274        .with_context(|| format!("Could not setup user {}", user.localname))?;
275
276        // If the user is not rate limited, remove the rate limit.
277        if let RateLimit::Unlimited = user.rate_limit {
278            use override_rate_limits::*;
279            let user_id = client
280                .user_id()
281                .await
282                .expect("Client doesn't have a user id");
283            let request = Request::new(&user_id, Some(0), Some(0));
284            let _ = admin.send(request, None).await?;
285        }
286
287        clients.insert(user.localname.clone(), client);
288    }
289
290    // Create rooms
291    let mut aliases = HashSet::new();
292    for user in &config.users {
293        if user.rooms.is_empty() {
294            continue;
295        }
296        let client = clients.get(&user.localname).unwrap(); // We just inserted it.
297        let my_user_id = client.user_id().await.ok_or_else(|| {
298            anyhow!(
299                "Cannot determine full user id for own user {}.",
300                user.localname
301            )
302        })?;
303        for room in &user.rooms {
304            let mut request = matrix_sdk::ruma::api::client::room::create_room::v3::Request::new();
305            if room.public {
306                request.preset = Some(
307                    matrix_sdk::ruma::api::client::room::create_room::v3::RoomPreset::PublicChat,
308                );
309            } else {
310                request.preset = Some(
311                    matrix_sdk::ruma::api::client::room::create_room::v3::RoomPreset::PrivateChat,
312                );
313            }
314            if let Some(ref name) = room.name {
315                request.name = Some(TryFrom::<&str>::try_from(name.as_str())?);
316            }
317            if let Some(ref alias) = room.alias {
318                if !aliases.insert(alias) {
319                    return Err(anyhow!(
320                        "Attempting to create more than one room with alias {}",
321                        alias
322                    ));
323                }
324                request.room_alias_name = Some(alias.as_ref());
325                // If the alias is already taken, we may need to remove it.
326                let full_alias = format!("#{}:{}", alias, config.homeserver.server_name);
327                debug!("Attempting to register alias {}, this may require unregistering previous instances first.", full_alias);
328                let room_alias_id = <&RoomAliasId as TryFrom<&str>>::try_from(full_alias.as_ref())?;
329                match client
330                    .send(
331                        matrix_sdk::ruma::api::client::alias::delete_alias::v3::Request::new(
332                            &room_alias_id,
333                        ),
334                        None,
335                    )
336                    .await
337                {
338                    // Room alias was successfully removed.
339                    Ok(_) => Ok(()),
340                    // Room alias wasn't removed because it didn't exist.
341                    Err(HttpError::Server(ref code)) if code.as_u16() == 404 => Ok(()),
342                    Err(err) => {
343                        match err.as_ruma_error() {
344                            Some(err) if err.kind == ErrorKind::NotFound => Ok(()),
345                            // Room alias wasn't removed for any other reason.
346                            _ => Err(err),
347                        }
348                    }
349                }
350                .context("Error while attempting to unregister existing alias")?;
351            }
352            if let Some(ref topic) = room.topic {
353                request.topic = Some(topic.as_ref());
354            }
355
356            // Place invites.
357            let mut invites = vec![];
358            for member in &room.members {
359                let member_client = clients.get(member).ok_or_else(|| {
360                    anyhow!(
361                        "Cannot invite user {}: we haven't created this user.",
362                        member
363                    )
364                })?;
365                let user_id = member_client
366                    .user_id()
367                    .await
368                    .ok_or_else(|| anyhow!("Cannot determine full user id for user {}.", member))?;
369                if my_user_id == user_id {
370                    // Don't invite oneself.
371                    continue;
372                }
373                invites.push(user_id);
374            }
375            request.invite = &invites;
376            let room_id = client.create_room(request).await?.room_id;
377
378            // Respond to invites.
379            for member in &room.members {
380                let member_client = clients.get(member).unwrap(); // We checked this a few lines ago.
381                member_client.join_room_by_id(&room_id).await?;
382            }
383        }
384    }
385    Ok(())
386}
387
388mod override_rate_limits {
389    use matrix_sdk::ruma::api::ruma_api;
390    use matrix_sdk::ruma::UserId;
391    use serde::{Deserialize, Serialize};
392
393    ruma_api! {
394        metadata: {
395            description: "Override rate limits",
396            method: POST,
397            name: "override_rate_limit",
398            unstable_path: "/_synapse/admin/v1/users/:user_id/override_ratelimit",
399            rate_limited: false,
400            authentication: AccessToken,
401        }
402
403        request: {
404            /// user ID
405            #[ruma_api(path)]
406            pub user_id: &'a UserId,
407
408            /// The number of actions that can be performed in a second. Defaults to 0.
409            #[serde(default, skip_serializing_if = "Option::is_none")]
410            pub messages_per_second: Option<u32>,
411
412            /// How many actions that can be performed before being limited. Defaults to 0.
413            #[serde(default, skip_serializing_if = "Option::is_none")]
414            pub burst_count: Option<u32>
415        }
416
417        response: {
418            /// Details about the user.
419            #[ruma_api(body)]
420            pub limits: UserLimits,
421        }
422    }
423
424    #[derive(Serialize, Deserialize, Clone, Debug)]
425    pub struct UserLimits {
426        pub messages_per_second: u32,
427        pub burst_count: u32,
428    }
429
430    impl<'a> Request<'a> {
431        /// Creates an `Request` with the given user ID.
432        pub fn new(
433            user_id: &'a UserId,
434            messages_per_second: Option<u32>,
435            burst_count: Option<u32>,
436        ) -> Self {
437            Self {
438                user_id,
439                messages_per_second,
440                burst_count,
441            }
442        }
443    }
444
445    impl Response {
446        /// Creates a new `Response` with all parameters defaulted.
447        #[allow(dead_code)]
448        pub fn new(limits: UserLimits) -> Self {
449            Self { limits }
450        }
451    }
452}