1use {
8 std::{
9 borrow::Cow,
10 collections::BTreeMap,
11 num::NonZeroU16,
12 },
13 collect_mac::collect,
14 itertools::Itertools as _,
15 lazy_regex::regex_captures,
16 serde::Deserialize,
17 tokio::net::ToSocketAddrs,
18 url::Url,
19};
20pub use crate::{
21 bot::Bot,
22 builder::BotBuilder,
23 handler::RaceHandler,
24};
25
26pub mod bot;
27mod builder;
28pub mod handler;
29pub mod model;
30
31const RACETIME_HOST: &str = "racetime.gg";
32
33pub type UDuration = std::time::Duration;
35
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38 #[error(transparent)] Custom(#[from] Box<dyn std::error::Error + Send + Sync>),
39 #[error(transparent)] HeaderToStr(#[from] reqwest::header::ToStrError),
40 #[error(transparent)] InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
41 #[error(transparent)] Io(#[from] std::io::Error),
42 #[error(transparent)] Json(#[from] serde_json::Error),
43 #[error(transparent)] Reqwest(#[from] reqwest::Error),
44 #[error(transparent)] Task(#[from] tokio::task::JoinError),
45 #[error(transparent)] UrlParse(#[from] url::ParseError),
46 #[error("websocket connection closed by the server")]
47 EndOfStream,
48 #[error("the startrace location did not match the input category")]
49 LocationCategory,
50 #[error("the startrace location header did not have the expected format")]
51 LocationFormat,
52 #[error("the startrace response did not include a location header")]
53 MissingLocationHeader,
54 #[error("{inner}, body:\n\n{}", .text.as_ref().map(|text| text.clone()).unwrap_or_else(|e| e.to_string()))]
55 ResponseStatus {
56 #[source]
57 inner: reqwest::Error,
58 headers: reqwest::header::HeaderMap,
59 text: reqwest::Result<String>,
60 },
61 #[error("server errors:{}", .0.into_iter().map(|msg| format!("\n• {msg}")).format(""))]
62 Server(Vec<String>),
63 #[error("WebSocket error: {0}")]
64 Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
65 #[error("expected text message from websocket, but received {0:?}")]
66 UnexpectedMessageType(tokio_tungstenite::tungstenite::Message),
67}
68
69trait ReqwestResponseExt: Sized {
70 async fn detailed_error_for_status(self) -> Result<Self, Error>;
72}
73
74impl ReqwestResponseExt for reqwest::Response {
75 async fn detailed_error_for_status(self) -> Result<Self, Error> {
76 match self.error_for_status_ref() {
77 Ok(_) => Ok(self),
78 Err(inner) => Err(Error::ResponseStatus {
79 headers: self.headers().clone(),
80 text: self.text().await,
81 inner,
82 }),
83 }
84 }
85}
86
87pub trait ResultExt {
89 type Ok;
90
91 fn to_racetime(self) -> Result<Self::Ok, Error>;
93}
94
95impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt for Result<T, E> {
96 type Ok = T;
97
98 fn to_racetime(self) -> Result<T, Error> {
99 self.map_err(|e| Error::Custom(Box::new(e)))
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct HostInfo {
105 pub hostname: Cow<'static, str>,
106 pub port: NonZeroU16,
107 pub secure: bool,
108}
109
110impl HostInfo {
111 pub fn new(hostname: impl Into<Cow<'static, str>>, port: NonZeroU16, secure: bool) -> Self {
112 Self {
113 hostname: hostname.into(),
114 secure, port,
115 }
116 }
117
118 fn http_protocol(&self) -> &'static str {
119 match self.secure {
120 true => "https",
121 false => "http",
122 }
123 }
124
125 fn websocket_protocol(&self) -> &'static str {
126 match self.secure {
127 true => "wss",
128 false => "ws",
129 }
130 }
131 fn http_uri(&self, url: &str) -> Result<Url, Error> {
132 uri(self.http_protocol(), &self.hostname, self.port, url)
133 }
134
135 fn websocket_uri(&self, url: &str) -> Result<Url, Error> {
136 uri(self.websocket_protocol(), &self.hostname, self.port, url)
137 }
138
139 fn websocket_socketaddrs(&self) -> impl ToSocketAddrs + '_ {
140 (&*self.hostname, self.port.get())
141 }
142}
143
144impl Default for HostInfo {
145 fn default() -> Self {
147 Self {
148 hostname: Cow::Borrowed(RACETIME_HOST),
149 port: NonZeroU16::new(443).unwrap(),
150 secure: true,
151 }
152 }
153}
154
155fn uri(proto: &str, host: &str, port: NonZeroU16, url: &str) -> Result<Url, Error> {
157 Ok(format!("{proto}://{host}:{port}{url}").parse()?)
158}
159
160pub async fn authorize(client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
162 authorize_with_host(&HostInfo::default(), client_id, client_secret, client).await
163}
164
165pub async fn authorize_with_host(host_info: &HostInfo, client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
166 #[derive(Deserialize)]
167 struct AuthResponse {
168 access_token: String,
169 expires_in: Option<u64>,
170 }
171
172 let data = client.post(host_info.http_uri("/o/token")?)
173 .form(&collect![as BTreeMap<_, _>:
174 "client_id" => client_id,
175 "client_secret" => client_secret,
176 "grant_type" => "client_credentials",
177 ])
178 .send().await?
179 .detailed_error_for_status().await?
180 .json::<AuthResponse>().await?;
181 Ok((
182 data.access_token,
183 UDuration::from_secs(data.expires_in.unwrap_or(36000)),
184 ))
185}
186
187fn form_bool(value: &bool) -> Cow<'static, str> {
188 Cow::Borrowed(if *value { "true" } else { "false" })
189}
190
191pub struct StartRace {
192 pub goal: String,
194 pub goal_is_custom: bool,
196 pub team_race: bool,
197 pub invitational: bool,
199 pub unlisted: bool,
201 pub partitionable: bool,
203 pub hide_entrants: bool,
205 pub ranked: bool,
206 pub info_user: String,
207 pub info_bot: String,
208 pub require_even_teams: bool,
209 pub start_delay: u8,
212 pub time_limit: u8,
215 pub time_limit_auto_complete: bool,
216 pub streaming_required: bool,
220 pub auto_start: bool,
222 pub allow_comments: bool,
223 pub hide_comments: bool,
224 pub allow_prerace_chat: bool,
225 pub allow_midrace_chat: bool,
226 pub allow_non_entrant_chat: bool,
227 pub chat_message_delay: u8,
229}
230
231impl StartRace {
232 fn form(&self) -> BTreeMap<&'static str, Cow<'_, str>> {
233 let Self {
234 goal,
235 goal_is_custom,
236 team_race,
237 invitational,
238 unlisted,
239 partitionable,
240 hide_entrants,
241 ranked,
242 info_user,
243 info_bot,
244 require_even_teams,
245 start_delay,
246 time_limit,
247 time_limit_auto_complete,
248 streaming_required,
249 auto_start,
250 allow_comments,
251 hide_comments,
252 allow_prerace_chat,
253 allow_midrace_chat,
254 allow_non_entrant_chat,
255 chat_message_delay,
256 } = self;
257 collect![
258 if *goal_is_custom { "custom_goal" } else { "goal" } => Cow::Borrowed(&**goal),
259 "team_race" => form_bool(team_race),
260 "invitational" => form_bool(invitational),
261 "unlisted" => form_bool(unlisted),
262 "partitionable" => form_bool(partitionable),
263 "hide_entrants" => form_bool(hide_entrants),
264 "ranked" => form_bool(ranked),
265 "info_user" => Cow::Borrowed(&**info_user),
266 "info_bot" => Cow::Borrowed(&**info_bot),
267 "require_even_teams" => form_bool(require_even_teams),
268 "start_delay" => Cow::Owned(start_delay.to_string()),
269 "time_limit" => Cow::Owned(time_limit.to_string()),
270 "time_limit_auto_complete" => form_bool(time_limit_auto_complete),
271 "streaming_required" => form_bool(streaming_required),
272 "auto_start" => form_bool(auto_start),
273 "allow_comments" => form_bool(allow_comments),
274 "hide_comments" => form_bool(hide_comments),
275 "allow_prerace_chat" => form_bool(allow_prerace_chat),
276 "allow_midrace_chat" => form_bool(allow_midrace_chat),
277 "allow_non_entrant_chat" => form_bool(allow_non_entrant_chat),
278 "chat_message_delay" => Cow::Owned(chat_message_delay.to_string()),
279 ]
280 }
281
282 pub async fn start(&self, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
286 self.start_with_host(&HostInfo::default(), access_token, client, category).await
287 }
288
289 pub async fn start_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
290 let response = client.post(host_info.http_uri(&format!("/o/{category}/startrace"))?)
291 .bearer_auth(access_token)
292 .form(&self.form())
293 .send().await?
294 .detailed_error_for_status().await?;
295 let location = response
296 .headers()
297 .get("location").ok_or(Error::MissingLocationHeader)?
298 .to_str()?;
299 let (_, location_category, slug) = regex_captures!("^/([^/]+)/([^/]+)$", location).ok_or(Error::LocationFormat)?;
300 if location_category != category { return Err(Error::LocationCategory) }
301 Ok(slug.to_owned())
302 }
303
304 pub async fn edit(&self, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
310 self.edit_with_host(&HostInfo::default(), access_token, client, category, race_slug).await
311 }
312
313 pub async fn edit_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
314 client.post(host_info.http_uri(&format!("/o/{category}/{race_slug}/edit"))?)
315 .bearer_auth(access_token)
316 .form(&self.form())
317 .send().await?
318 .detailed_error_for_status().await?;
319 Ok(())
320 }
321}