1use {
8 std::{
9 borrow::Cow,
10 collections::BTreeMap,
11 num::NonZeroU16,
12 },
13 collect_mac::collect,
14 itertools::Itertools as _,
15 lazy_regex::regex_captures,
16 serde::Deserialize,
17 tokio::net::ToSocketAddrs,
18 url::Url,
19};
20pub use crate::{
21 bot::Bot,
22 builder::BotBuilder,
23 handler::RaceHandler,
24};
25
26pub mod bot;
27mod builder;
28pub mod handler;
29pub mod model;
30
31const RACETIME_HOST: &str = "racetime.gg";
32
33pub type UDuration = std::time::Duration;
35
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38 #[error(transparent)] Custom(#[from] Box<dyn std::error::Error + Send + Sync>),
39 #[error(transparent)] HeaderToStr(#[from] reqwest::header::ToStrError),
40 #[error(transparent)] InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
41 #[error(transparent)] Io(#[from] std::io::Error),
42 #[error(transparent)] Json(#[from] serde_json::Error),
43 #[error(transparent)] Reqwest(#[from] reqwest::Error),
44 #[error(transparent)] Task(#[from] tokio::task::JoinError),
45 #[error(transparent)] UrlParse(#[from] url::ParseError),
46 #[error("websocket connection closed by the server")]
47 EndOfStream,
48 #[error("the startrace location did not match the input category")]
49 LocationCategory,
50 #[error("the startrace location header did not have the expected format")]
51 LocationFormat,
52 #[error("the startrace response did not include a location header")]
53 MissingLocationHeader,
54 #[error("{inner}, body:\n\n{}", .text.as_ref().map(|text| text.clone()).unwrap_or_else(|e| e.to_string()))]
55 ResponseStatus {
56 #[source]
57 inner: reqwest::Error,
58 headers: reqwest::header::HeaderMap,
59 text: reqwest::Result<String>,
60 },
61 #[error("server errors:{}", .0.into_iter().map(|msg| format!("\n• {msg}")).format(""))]
62 Server(Vec<String>),
63 #[error("WebSocket error: {0}")]
64 Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
65 #[error("expected text message from websocket, but received {0:?}")]
66 UnexpectedMessageType(tokio_tungstenite::tungstenite::Message),
67}
68
69trait ReqwestResponseExt: Sized {
70 async fn detailed_error_for_status(self) -> Result<Self, Error>;
72}
73
74impl ReqwestResponseExt for reqwest::Response {
75 async fn detailed_error_for_status(self) -> Result<Self, Error> {
76 match self.error_for_status_ref() {
77 Ok(_) => Ok(self),
78 Err(inner) => Err(Error::ResponseStatus {
79 headers: self.headers().clone(),
80 text: self.text().await,
81 inner,
82 }),
83 }
84 }
85}
86
87pub trait ResultExt {
89 type Ok;
90
91 fn to_racetime(self) -> Result<Self::Ok, Error>;
93}
94
95impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt for Result<T, E> {
96 type Ok = T;
97
98 fn to_racetime(self) -> Result<T, Error> {
99 self.map_err(|e| Error::Custom(Box::new(e)))
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct HostInfo {
105 pub hostname: Cow<'static, str>,
106 pub port: NonZeroU16,
107 pub secure: bool,
108}
109
110impl HostInfo {
111 pub fn new(hostname: impl Into<Cow<'static, str>>, port: NonZeroU16, secure: bool) -> Self {
112 Self {
113 hostname: hostname.into(),
114 secure, port,
115 }
116 }
117
118 fn http_protocol(&self) -> &'static str {
119 match self.secure {
120 true => "https",
121 false => "http",
122 }
123 }
124
125 fn websocket_protocol(&self) -> &'static str {
126 match self.secure {
127 true => "wss",
128 false => "ws",
129 }
130 }
131 fn http_uri(&self, url: &str) -> Result<Url, Error> {
132 uri(self.http_protocol(), &self.hostname, self.port, url)
133 }
134
135 fn websocket_uri(&self, url: &str) -> Result<Url, Error> {
136 uri(self.websocket_protocol(), &self.hostname, self.port, url)
137 }
138
139 fn websocket_socketaddrs(&self) -> impl ToSocketAddrs + '_ {
140 (&*self.hostname, self.port.get())
141 }
142}
143
144impl Default for HostInfo {
145 fn default() -> Self {
147 Self {
148 hostname: Cow::Borrowed(RACETIME_HOST),
149 port: NonZeroU16::new(443).unwrap(),
150 secure: true,
151 }
152 }
153}
154
155fn uri(proto: &str, host: &str, port: NonZeroU16, url: &str) -> Result<Url, Error> {
157 Ok(format!("{proto}://{host}:{port}{url}").parse()?)
158}
159
160pub async fn authorize(client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
162 authorize_with_host(&HostInfo::default(), client_id, client_secret, client).await
163}
164
165pub async fn authorize_with_host(host_info: &HostInfo, client_id: &str, client_secret: &str, client: &reqwest::Client) -> Result<(String, UDuration), Error> {
166 #[derive(Deserialize)]
167 struct AuthResponse {
168 access_token: String,
169 expires_in: Option<u64>,
170 }
171
172 let data = client.post(host_info.http_uri("/o/token")?)
173 .form(&collect![as BTreeMap<_, _>:
174 "client_id" => client_id,
175 "client_secret" => client_secret,
176 "grant_type" => "client_credentials",
177 ])
178 .send().await?
179 .detailed_error_for_status().await?
180 .json::<AuthResponse>().await?;
181 Ok((
182 data.access_token,
183 UDuration::from_secs(data.expires_in.unwrap_or(36000)),
184 ))
185}
186
187fn form_bool(value: bool) -> Cow<'static, str> {
188 Cow::Borrowed(if value { "true" } else { "false" })
189}
190
191pub struct StartRace {
192 pub goal: String,
194 pub goal_is_custom: bool,
196 pub team_race: bool,
197 pub invitational: bool,
199 pub unlisted: bool,
201 pub ranked: bool,
202 pub info_user: String,
203 pub info_bot: String,
204 pub require_even_teams: bool,
205 pub start_delay: u8,
208 pub time_limit: u8,
211 pub time_limit_auto_complete: bool,
212 pub streaming_required: bool,
216 pub auto_start: bool,
218 pub allow_comments: bool,
219 pub hide_comments: bool,
220 pub allow_prerace_chat: bool,
221 pub allow_midrace_chat: bool,
222 pub allow_non_entrant_chat: bool,
223 pub chat_message_delay: u8,
225}
226
227impl StartRace {
228 fn form(&self) -> BTreeMap<&'static str, Cow<'_, str>> {
229 let start_delay = self.start_delay.to_string();
230 let time_limit = self.time_limit.to_string();
231 let chat_message_delay = self.chat_message_delay.to_string();
232 collect![
233 if self.goal_is_custom { "custom_goal" } else { "goal" } => Cow::Borrowed(&*self.goal),
234 "team_race" => form_bool(self.team_race),
235 "invitational" => form_bool(self.invitational),
236 "unlisted" => form_bool(self.unlisted),
237 "ranked" => form_bool(self.ranked),
238 "info_user" => Cow::Borrowed(&*self.info_user),
239 "info_bot" => Cow::Borrowed(&*self.info_bot),
240 "require_even_teams" => form_bool(self.require_even_teams),
241 "start_delay" => Cow::Owned(start_delay),
242 "time_limit" => Cow::Owned(time_limit),
243 "time_limit_auto_complete" => form_bool(self.time_limit_auto_complete),
244 "streaming_required" => form_bool(self.streaming_required),
245 "auto_start" => form_bool(self.auto_start),
246 "allow_comments" => form_bool(self.allow_comments),
247 "hide_comments" => form_bool(self.hide_comments),
248 "allow_prerace_chat" => form_bool(self.allow_prerace_chat),
249 "allow_midrace_chat" => form_bool(self.allow_midrace_chat),
250 "allow_non_entrant_chat" => form_bool(self.allow_non_entrant_chat),
251 "chat_message_delay" => Cow::Owned(chat_message_delay),
252 ]
253 }
254
255 pub async fn start(&self, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
259 self.start_with_host(&HostInfo::default(), access_token, client, category).await
260 }
261
262 pub async fn start_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str) -> Result<String, Error> {
263 let response = client.post(host_info.http_uri(&format!("/o/{category}/startrace"))?)
264 .bearer_auth(access_token)
265 .form(&self.form())
266 .send().await?
267 .detailed_error_for_status().await?;
268 let location = response
269 .headers()
270 .get("location").ok_or(Error::MissingLocationHeader)?
271 .to_str()?;
272 let (_, location_category, slug) = regex_captures!("^/([^/]+)/([^/]+)$", location).ok_or(Error::LocationFormat)?;
273 if location_category != category { return Err(Error::LocationCategory) }
274 Ok(slug.to_owned())
275 }
276
277 pub async fn edit(&self, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
283 self.edit_with_host(&HostInfo::default(), access_token, client, category, race_slug).await
284 }
285
286 pub async fn edit_with_host(&self, host_info: &HostInfo, access_token: &str, client: &reqwest::Client, category: &str, race_slug: &str) -> Result<(), Error> {
287 client.post(host_info.http_uri(&format!("/o/{category}/{race_slug}/edit"))?)
288 .bearer_auth(access_token)
289 .form(&self.form())
290 .send().await?
291 .detailed_error_for_status().await?;
292 Ok(())
293 }
294}