racetime/
lib.rs

1//! Utilities for creating chat bots for [racetime.gg](https://racetime.gg/).
2//!
3//! The main entry point is [`Bot::run`]. You can also create new race rooms using [`StartRace::start`].
4//!
5//! For documentation, see also <https://github.com/racetimeGG/racetime-app/wiki/Category-bots>.
6
7use {
8    std::{
9        borrow::Cow,
10        collections::BTreeMap,
11        num::NonZeroU16,
12    },
13    collect_mac::collect,
14    itertools::Itertools as _,
15    lazy_regex::regex_captures,
16    serde::Deserialize,
17    tokio::net::ToSocketAddrs,
18    url::Url,
19};
20pub use crate::{
21    bot::Bot,
22    builder::BotBuilder,
23    handler::RaceHandler,
24};
25
26pub mod bot;
27mod builder;
28pub mod handler;
29pub mod model;
30
31const RACETIME_HOST: &str = "racetime.gg";
32
33/// An unsigned duration. This is a reexport of [`std::time::Duration`].
34pub type UDuration = std::time::Duration;
35
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38    #[error(transparent)] Custom(#[from] Box<dyn std::error::Error + Send + Sync>),
39    #[error(transparent)] HeaderToStr(#[from] reqwest::header::ToStrError),
40    #[error(transparent)] InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
41    #[error(transparent)] Io(#[from] std::io::Error),
42    #[error(transparent)] Json(#[from] serde_json::Error),
43    #[error(transparent)] Reqwest(#[from] reqwest::Error),
44    #[error(transparent)] Task(#[from] tokio::task::JoinError),
45    #[error(transparent)] UrlParse(#[from] url::ParseError),
46    #[error("websocket connection closed by the server")]
47    EndOfStream,
48    #[error("the startrace location did not match the input category")]
49    LocationCategory,
50    #[error("the startrace location header did not have the expected format")]
51    LocationFormat,
52    #[error("the startrace response did not include a location header")]
53    MissingLocationHeader,
54    #[error("{inner}, body:\n\n{}", .text.as_ref().map(|text| text.clone()).unwrap_or_else(|e| e.to_string()))]
55    ResponseStatus {
56        #[source]
57        inner: reqwest::Error,
58        headers: reqwest::header::HeaderMap,
59        text: reqwest::Result<String>,
60    },
61    #[error("server errors:{}", .0.into_iter().map(|msg| format!("\n• {msg}")).format(""))]
62    Server(Vec<String>),
63    #[error("WebSocket error: {0}")]
64    Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
65    #[error("expected text message from websocket, but received {0:?}")]
66    UnexpectedMessageType(tokio_tungstenite::tungstenite::Message),
67}
68
69trait ReqwestResponseExt: Sized {
70    /// Like `error_for_status` but includes response headers and text in the error.
71    async fn detailed_error_for_status(self) -> Result<Self, Error>;
72}
73
74impl ReqwestResponseExt for reqwest::Response {
75    async fn detailed_error_for_status(self) -> Result<Self, Error> {
76        match self.error_for_status_ref() {
77            Ok(_) => Ok(self),
78            Err(inner) => Err(Error::ResponseStatus {
79                headers: self.headers().clone(),
80                text: self.text().await,
81                inner,
82            }),
83        }
84    }
85}
86
87/// A convenience trait for converting results to use this crate's [`Error`] type.
88pub trait ResultExt {
89    type Ok;
90
91    /// Convert the error to this crate's [`Error`] type using the [`Error::Custom`] variant.
92    fn to_racetime(self) -> Result<Self::Ok, Error>;
93}
94
95impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt for Result<T, E> {
96    type Ok = T;
97
98    fn to_racetime(self) -> Result<T, Error> {
99        self.map_err(|e| Error::Custom(Box::new(e)))
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct HostInfo {
105    pub hostname: Cow<'static, str>,
106    pub port: NonZeroU16,
107    pub secure: bool,
108}
109
110impl HostInfo {
111    pub fn new(hostname: impl Into<Cow<'static, str>>, port: NonZeroU16, secure: bool) -> Self {
112        Self {
113            hostname: hostname.into(),
114            secure, port,
115        }
116    }
117
118    fn http_protocol(&self) -> &'static str {
119        match self.secure {
120            true => "https",
121            false => "http",
122        }
123    }
124
125    fn websocket_protocol(&self) -> &'static str {
126        match self.secure {
127            true => "wss",
128            false => "ws",
129        }
130    }
131    fn http_uri(&self, url: &str) -> Result<Url, Error>  {
132        uri(self.http_protocol(), &self.hostname, self.port, url)
133    }
134
135    fn websocket_uri(&self, url: &str) -> Result<Url, Error> {
136        uri(self.websocket_protocol(), &self.hostname, self.port, url)
137    }
138
139    fn websocket_socketaddrs(&self) -> impl ToSocketAddrs + '_ {
140        (&*self.hostname, self.port.get())
141    }
142}
143
144impl Default for HostInfo {
145    /// Returns the host info for racetime.gg.
146    fn default() -> Self {
147        Self {
148            hostname: Cow::Borrowed(RACETIME_HOST),
149            port: NonZeroU16::new(443).unwrap(),
150            secure: true,
151        }
152    }
153}
154
155/// Generate a URI from the given protocol and URL path fragment.
156fn uri(proto: &str, host: &str, port: NonZeroU16, url: &str) -> Result<Url, Error> {
157    Ok(format!("{proto}://{host}:{port}{url}").parse()?)
158}
159
160/// Get an OAuth2 token from the authentication server.
161pub async fn authorize(client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
162    authorize_with_host(&HostInfo::default(), client_id, client_secret, client).await
163}
164
165pub async fn authorize_with_host(host_info: &HostInfo, client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
166    #[derive(Deserialize)]
167    struct AuthResponse {
168        access_token: String,
169        expires_in: Option<u64>,
170    }
171
172    let data = client.post(host_info.http_uri("/o/token")?)
173        .form(&collect![as BTreeMap<_, _>:
174            "client_id" => client_id,
175            "client_secret" => client_secret,
176            "grant_type" => "client_credentials",
177        ])
178        .send().await?
179        .detailed_error_for_status().await?
180        .json::<AuthResponse>().await?;
181    Ok((
182        data.access_token,
183        UDuration::from_secs(data.expires_in.unwrap_or(36000)),
184    ))
185}
186
187fn form_bool(value: &bool) -> Cow<'static, str> {
188    Cow::Borrowed(if *value { "true" } else { "false" })
189}
190
191pub struct StartRace {
192    /// If the race has already started, this must match the current goal.
193    pub goal: String,
194    /// If the race has already started, this must match the current goal.
195    pub goal_is_custom: bool,
196    pub team_race: bool,
197    /// If editing the race, this must match the current state. Use [`RaceContext::set_invitational`](handler::RaceContext::set_invitational) or [`RaceContext::set_open`](handler::RaceContext::set_open) instead.
198    pub invitational: bool,
199    /// Bots always have permission to set this field.
200    pub unlisted: bool,
201    /// Only available if category is opted-in to the 1v1 ladder beta.
202    pub partitionable: bool,
203    /// Only available if category is opted-in to the anonymised races beta.
204    pub hide_entrants: bool,
205    pub ranked: bool,
206    pub info_user: String,
207    pub info_bot: String,
208    pub require_even_teams: bool,
209    /// Number of seconds the countdown should run for. Must be in `10..=60`.
210    /// If the race has already started, this must match the current delay.
211    pub start_delay: u8,
212    /// Maximum number of hours the race is allowed to run for. Must be in `1..=72`.
213    /// If the race has already started, this must match the current limit.
214    pub time_limit: u8,
215    pub time_limit_auto_complete: bool,
216    /// Bots always have permission to set this field.
217    ///
218    /// If the race has already started, this cannot be changed.
219    pub streaming_required: bool,
220    /// If the race has already started, this cannot be changed.
221    pub auto_start: bool,
222    pub allow_comments: bool,
223    pub hide_comments: bool,
224    pub allow_prerace_chat: bool,
225    pub allow_midrace_chat: bool,
226    pub allow_non_entrant_chat: bool,
227    /// Number of seconds to hold a message for before displaying it. Doesn't affect race monitors or moderators. Must be in `0..=90`.
228    pub chat_message_delay: u8,
229}
230
231impl StartRace {
232    fn form(&self) -> BTreeMap<&'static str, Cow<'_, str>> {
233        let Self {
234            goal,
235            goal_is_custom,
236            team_race,
237            invitational,
238            unlisted,
239            partitionable,
240            hide_entrants,
241            ranked,
242            info_user,
243            info_bot,
244            require_even_teams,
245            start_delay,
246            time_limit,
247            time_limit_auto_complete,
248            streaming_required,
249            auto_start,
250            allow_comments,
251            hide_comments,
252            allow_prerace_chat,
253            allow_midrace_chat,
254            allow_non_entrant_chat,
255            chat_message_delay,
256        } = self;
257        collect![
258            if *goal_is_custom { "custom_goal" } else { "goal" } => Cow::Borrowed(&**goal),
259            "team_race" => form_bool(team_race),
260            "invitational" => form_bool(invitational),
261            "unlisted" => form_bool(unlisted),
262            "partitionable" => form_bool(partitionable),
263            "hide_entrants" => form_bool(hide_entrants),
264            "ranked" => form_bool(ranked),
265            "info_user" => Cow::Borrowed(&**info_user),
266            "info_bot" => Cow::Borrowed(&**info_bot),
267            "require_even_teams" => form_bool(require_even_teams),
268            "start_delay" => Cow::Owned(start_delay.to_string()),
269            "time_limit" => Cow::Owned(time_limit.to_string()),
270            "time_limit_auto_complete" => form_bool(time_limit_auto_complete),
271            "streaming_required" => form_bool(streaming_required),
272            "auto_start" => form_bool(auto_start),
273            "allow_comments" => form_bool(allow_comments),
274            "hide_comments" => form_bool(hide_comments),
275            "allow_prerace_chat" => form_bool(allow_prerace_chat),
276            "allow_midrace_chat" => form_bool(allow_midrace_chat),
277            "allow_non_entrant_chat" => form_bool(allow_non_entrant_chat),
278            "chat_message_delay" => Cow::Owned(chat_message_delay.to_string()),
279        ]
280    }
281
282    /// Creates a race room with the specified configuration and returns its slug.
283    ///
284    /// An access token can be obtained using [`authorize`].
285    pub async fn start(&self, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
286        self.start_with_host(&HostInfo::default(), access_token, client, category).await
287    }
288
289    pub async fn start_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
290        let response = client.post(host_info.http_uri(&format!("/o/{category}/startrace"))?)
291            .bearer_auth(access_token)
292            .form(&self.form())
293            .send().await?
294            .detailed_error_for_status().await?;
295        let location = response
296            .headers()
297            .get("location").ok_or(Error::MissingLocationHeader)?
298            .to_str()?;
299        let (_, location_category, slug) = regex_captures!("^/([^/]+)/([^/]+)$", location).ok_or(Error::LocationFormat)?;
300        if location_category != category { return Err(Error::LocationCategory) }
301        Ok(slug.to_owned())
302    }
303
304    /// Edits the given race room.
305    ///
306    /// Due to a limitation of the racetime.gg API, all fields including ones that should remain the same must be specified.
307    ///
308    /// An access token can be obtained using [`authorize`].
309    pub async fn edit(&self, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
310        self.edit_with_host(&HostInfo::default(), access_token, client, category, race_slug).await
311    }
312
313    pub async fn edit_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
314        client.post(host_info.http_uri(&format!("/o/{category}/{race_slug}/edit"))?)
315            .bearer_auth(access_token)
316            .form(&self.form())
317            .send().await?
318            .detailed_error_for_status().await?;
319        Ok(())
320    }
321}