racetime/
lib.rs

1//! Utilities for creating chat bots for [racetime.gg](https://racetime.gg/).
2//!
3//! The main entry point is [`Bot::run`]. You can also create new race rooms using [`StartRace::start`].
4//!
5//! For documentation, see also <https://github.com/racetimeGG/racetime-app/wiki/Category-bots>.
6
7use {
8    std::{
9        borrow::Cow,
10        collections::BTreeMap,
11        num::NonZeroU16,
12    },
13    collect_mac::collect,
14    itertools::Itertools as _,
15    lazy_regex::regex_captures,
16    serde::Deserialize,
17    tokio::net::ToSocketAddrs,
18    url::Url,
19};
20pub use crate::{
21    bot::Bot,
22    builder::BotBuilder,
23    handler::RaceHandler,
24};
25
26pub mod bot;
27mod builder;
28pub mod handler;
29pub mod model;
30
31const RACETIME_HOST: &str = "racetime.gg";
32
33/// An unsigned duration. This is a reexport of [`std::time::Duration`].
34pub type UDuration = std::time::Duration;
35
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38    #[error(transparent)] Custom(#[from] Box<dyn std::error::Error + Send + Sync>),
39    #[error(transparent)] HeaderToStr(#[from] reqwest::header::ToStrError),
40    #[error(transparent)] InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
41    #[error(transparent)] Io(#[from] std::io::Error),
42    #[error(transparent)] Json(#[from] serde_json::Error),
43    #[error(transparent)] Reqwest(#[from] reqwest::Error),
44    #[error(transparent)] Task(#[from] tokio::task::JoinError),
45    #[error(transparent)] UrlParse(#[from] url::ParseError),
46    #[error("websocket connection closed by the server")]
47    EndOfStream,
48    #[error("the startrace location did not match the input category")]
49    LocationCategory,
50    #[error("the startrace location header did not have the expected format")]
51    LocationFormat,
52    #[error("the startrace response did not include a location header")]
53    MissingLocationHeader,
54    #[error("{inner}, body:\n\n{}", .text.as_ref().map(|text| text.clone()).unwrap_or_else(|e| e.to_string()))]
55    ResponseStatus {
56        #[source]
57        inner: reqwest::Error,
58        headers: reqwest::header::HeaderMap,
59        text: reqwest::Result<String>,
60    },
61    #[error("server errors:{}", .0.into_iter().map(|msg| format!("\n• {msg}")).format(""))]
62    Server(Vec<String>),
63    #[error("WebSocket error: {0}")]
64    Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
65    #[error("expected text message from websocket, but received {0:?}")]
66    UnexpectedMessageType(tokio_tungstenite::tungstenite::Message),
67}
68
69trait ReqwestResponseExt: Sized {
70    /// Like `error_for_status` but includes response headers and text in the error.
71    async fn detailed_error_for_status(self) -> Result<Self, Error>;
72}
73
74impl ReqwestResponseExt for reqwest::Response {
75    async fn detailed_error_for_status(self) -> Result<Self, Error> {
76        match self.error_for_status_ref() {
77            Ok(_) => Ok(self),
78            Err(inner) => Err(Error::ResponseStatus {
79                headers: self.headers().clone(),
80                text: self.text().await,
81                inner,
82            }),
83        }
84    }
85}
86
87/// A convenience trait for converting results to use this crate's [`Error`] type.
88pub trait ResultExt {
89    type Ok;
90
91    /// Convert the error to this crate's [`Error`] type using the [`Error::Custom`] variant.
92    fn to_racetime(self) -> Result<Self::Ok, Error>;
93}
94
95impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt for Result<T, E> {
96    type Ok = T;
97
98    fn to_racetime(self) -> Result<T, Error> {
99        self.map_err(|e| Error::Custom(Box::new(e)))
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct HostInfo {
105    pub hostname: Cow<'static, str>,
106    pub port: NonZeroU16,
107    pub secure: bool,
108}
109
110impl HostInfo {
111    pub fn new(hostname: impl Into<Cow<'static, str>>, port: NonZeroU16, secure: bool) -> Self {
112        Self {
113            hostname: hostname.into(),
114            secure, port,
115        }
116    }
117
118    fn http_protocol(&self) -> &'static str {
119        match self.secure {
120            true => "https",
121            false => "http",
122        }
123    }
124
125    fn websocket_protocol(&self) -> &'static str {
126        match self.secure {
127            true => "wss",
128            false => "ws",
129        }
130    }
131    fn http_uri(&self, url: &str) -> Result<Url, Error>  {
132        uri(self.http_protocol(), &self.hostname, self.port, url)
133    }
134
135    fn websocket_uri(&self, url: &str) -> Result<Url, Error> {
136        uri(self.websocket_protocol(), &self.hostname, self.port, url)
137    }
138
139    fn websocket_socketaddrs(&self) -> impl ToSocketAddrs + '_ {
140        (&*self.hostname, self.port.get())
141    }
142}
143
144impl Default for HostInfo {
145    /// Returns the host info for racetime.gg.
146    fn default() -> Self {
147        Self {
148            hostname: Cow::Borrowed(RACETIME_HOST),
149            port: NonZeroU16::new(443).unwrap(),
150            secure: true,
151        }
152    }
153}
154
155/// Generate a URI from the given protocol and URL path fragment.
156fn uri(proto: &str, host: &str, port: NonZeroU16, url: &str) -> Result<Url, Error> {
157    Ok(format!("{proto}://{host}:{port}{url}").parse()?)
158}
159
160/// Get an OAuth2 token from the authentication server.
161pub async fn authorize(client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
162    authorize_with_host(&HostInfo::default(), client_id, client_secret, client).await
163}
164
165pub async fn authorize_with_host(host_info: &HostInfo, client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
166    #[derive(Deserialize)]
167    struct AuthResponse {
168        access_token: String,
169        expires_in: Option<u64>,
170    }
171
172    let data = client.post(host_info.http_uri("/o/token")?)
173        .form(&collect![as BTreeMap<_, _>:
174            "client_id" => client_id,
175            "client_secret" => client_secret,
176            "grant_type" => "client_credentials",
177        ])
178        .send().await?
179        .detailed_error_for_status().await?
180        .json::<AuthResponse>().await?;
181    Ok((
182        data.access_token,
183        UDuration::from_secs(data.expires_in.unwrap_or(36000)),
184    ))
185}
186
187fn form_bool(value: bool) -> Cow<'static, str> {
188    Cow::Borrowed(if value { "true" } else { "false" })
189}
190
191pub struct StartRace {
192    /// If the race has already started, this must match the current goal.
193    pub goal: String,
194    /// If the race has already started, this must match the current goal.
195    pub goal_is_custom: bool,
196    pub team_race: bool,
197    /// If editing the race, this must match the current state. Use [`RaceContext::set_invitational`](handler::RaceContext::set_invitational) or [`RaceContext::set_open`](handler::RaceContext::set_open) instead.
198    pub invitational: bool,
199    /// Bots always have permission to set this field.
200    pub unlisted: bool,
201    pub ranked: bool,
202    pub info_user: String,
203    pub info_bot: String,
204    pub require_even_teams: bool,
205    /// Number of seconds the countdown should run for. Must be in `10..=60`.
206    /// If the race has already started, this must match the current delay.
207    pub start_delay: u8,
208    /// Maximum number of hours the race is allowed to run for. Must be in `1..=72`.
209    /// If the race has already started, this must match the current limit.
210    pub time_limit: u8,
211    pub time_limit_auto_complete: bool,
212    /// Bots always have permission to set this field.
213    ///
214    /// If the race has already started, this cannot be changed.
215    pub streaming_required: bool,
216    /// If the race has already started, this cannot be changed.
217    pub auto_start: bool,
218    pub allow_comments: bool,
219    pub hide_comments: bool,
220    pub allow_prerace_chat: bool,
221    pub allow_midrace_chat: bool,
222    pub allow_non_entrant_chat: bool,
223    /// Number of seconds to hold a message for before displaying it. Doesn't affect race monitors or moderators. Must be in `0..=90`.
224    pub chat_message_delay: u8,
225}
226
227impl StartRace {
228    fn form(&self) -> BTreeMap<&'static str, Cow<'_, str>> {
229        let start_delay = self.start_delay.to_string();
230        let time_limit = self.time_limit.to_string();
231        let chat_message_delay = self.chat_message_delay.to_string();
232        collect![
233            if self.goal_is_custom { "custom_goal" } else { "goal" } => Cow::Borrowed(&*self.goal),
234            "team_race" => form_bool(self.team_race),
235            "invitational" => form_bool(self.invitational),
236            "unlisted" => form_bool(self.unlisted),
237            "ranked" => form_bool(self.ranked),
238            "info_user" => Cow::Borrowed(&*self.info_user),
239            "info_bot" => Cow::Borrowed(&*self.info_bot),
240            "require_even_teams" => form_bool(self.require_even_teams),
241            "start_delay" => Cow::Owned(start_delay),
242            "time_limit" => Cow::Owned(time_limit),
243            "time_limit_auto_complete" => form_bool(self.time_limit_auto_complete),
244            "streaming_required" => form_bool(self.streaming_required),
245            "auto_start" => form_bool(self.auto_start),
246            "allow_comments" => form_bool(self.allow_comments),
247            "hide_comments" => form_bool(self.hide_comments),
248            "allow_prerace_chat" => form_bool(self.allow_prerace_chat),
249            "allow_midrace_chat" => form_bool(self.allow_midrace_chat),
250            "allow_non_entrant_chat" => form_bool(self.allow_non_entrant_chat),
251            "chat_message_delay" => Cow::Owned(chat_message_delay),
252        ]
253    }
254
255    /// Creates a race room with the specified configuration and returns its slug.
256    ///
257    /// An access token can be obtained using [`authorize`].
258    pub async fn start(&self, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
259        self.start_with_host(&HostInfo::default(), access_token, client, category).await
260    }
261
262    pub async fn start_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
263        let response = client.post(host_info.http_uri(&format!("/o/{category}/startrace"))?)
264            .bearer_auth(access_token)
265            .form(&self.form())
266            .send().await?
267            .detailed_error_for_status().await?;
268        let location = response
269            .headers()
270            .get("location").ok_or(Error::MissingLocationHeader)?
271            .to_str()?;
272        let (_, location_category, slug) = regex_captures!("^/([^/]+)/([^/]+)$", location).ok_or(Error::LocationFormat)?;
273        if location_category != category { return Err(Error::LocationCategory) }
274        Ok(slug.to_owned())
275    }
276
277    /// Edits the given race room.
278    ///
279    /// Due to a limitation of the racetime.gg API, all fields including ones that should remain the same must be specified.
280    ///
281    /// An access token can be obtained using [`authorize`].
282    pub async fn edit(&self, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
283        self.edit_with_host(&HostInfo::default(), access_token, client, category, race_slug).await
284    }
285
286    pub async fn edit_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
287        client.post(host_info.http_uri(&format!("/o/{category}/{race_slug}/edit"))?)
288            .bearer_auth(access_token)
289            .form(&self.form())
290            .send().await?
291            .detailed_error_for_status().await?;
292        Ok(())
293    }
294}