Skip to main content

mlbt_api/
client.rs

1use crate::live::LiveResponse;
2use crate::schedule::ScheduleResponse;
3use crate::season::{GameType, SeasonInfo, SeasonsResponse};
4use crate::standings::StandingsResponse;
5use crate::stats::StatsResponse;
6use crate::teams::{SportId, TeamsResponse};
7use crate::win_probability::WinProbabilityResponse;
8use std::fmt;
9use std::time::Duration;
10
11use chrono::{DateTime, Datelike, Local, NaiveDate};
12use derive_builder::Builder;
13use reqwest::Client;
14use serde::de::DeserializeOwned;
15
16pub type ApiResult<T> = Result<T, ApiError>;
17
18const BASE_URL: &str = "https://statsapi.mlb.com/api/";
19
20/// MLB API object
21#[derive(Builder, Debug, Clone)]
22#[allow(clippy::upper_case_acronyms)]
23pub struct MLBApi {
24    #[builder(default = "Client::new()")]
25    client: Client,
26    #[builder(default = "Duration::from_secs(10)")]
27    timeout: Duration,
28    #[builder(setter(into), default = "String::from(BASE_URL)")]
29    base_url: String,
30}
31
32#[derive(Debug)]
33pub enum ApiError {
34    Network(reqwest::Error, String),
35    API(reqwest::Error, String),
36    Parsing(reqwest::Error, String),
37}
38
39impl ApiError {
40    pub fn log(&self) -> String {
41        match self {
42            ApiError::Network(e, url) => format!("Network error for {url}: {e:?}"),
43            ApiError::API(e, url) => format!("API error for {url}: {e:?}"),
44            ApiError::Parsing(e, url) => format!("Parsing error for {url}: {e:?}"),
45        }
46    }
47}
48
49/// The available stat groups. These are taken from the "meta" endpoint:
50/// https://statsapi.mlb.com/api/v1/statGroups
51/// I only need to use Hitting and Pitching for now.
52#[derive(Clone, Copy, Debug)]
53pub enum StatGroup {
54    Hitting,
55    Pitching,
56    // Fielding,
57    // Catching,
58    // Running,
59    // Game,
60    // Team,
61    // Streak,
62}
63
64/// Display the StatGroup in all lowercase.
65impl fmt::Display for StatGroup {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        match self {
68            StatGroup::Hitting => write!(f, "hitting"),
69            StatGroup::Pitching => write!(f, "pitching"),
70        }
71    }
72}
73
74impl StatGroup {
75    /// The default sort stat for player leaderboards.
76    pub fn default_sort_stat(&self) -> &'static str {
77        match self {
78            StatGroup::Hitting => "plateAppearances",
79            StatGroup::Pitching => "inningsPitched",
80        }
81    }
82}
83
84impl MLBApi {
85    pub async fn get_todays_schedule(&self) -> ApiResult<ScheduleResponse> {
86        let url = format!(
87            "{}v1/schedule?sportId=1,51&hydrate=linescore",
88            self.base_url
89        );
90        self.get(url).await
91    }
92
93    pub async fn get_schedule_date(&self, date: NaiveDate) -> ApiResult<ScheduleResponse> {
94        let url = format!(
95            "{}v1/schedule?sportId=1,51&hydrate=linescore&date={}",
96            self.base_url,
97            date.format("%Y-%m-%d")
98        );
99        self.get(url).await
100    }
101
102    pub async fn get_live_data(&self, game_id: u64) -> ApiResult<LiveResponse> {
103        if game_id == 0 {
104            return Ok(LiveResponse::default());
105        }
106        let url = format!(
107            "{}v1.1/game/{}/feed/live?language=en",
108            self.base_url, game_id
109        );
110        self.get(url).await
111    }
112
113    pub async fn get_win_probability(&self, game_id: u64) -> ApiResult<WinProbabilityResponse> {
114        if game_id == 0 {
115            return Ok(WinProbabilityResponse::default());
116        }
117        let url = format!(
118            "{}v1/game/{}/winProbability?fields=homeTeamWinProbability&fields=awayTeamWinProbability&fields=homeTeamWinProbabilityAdded&fields=atBatIndex&fields=about&fields=inning&fields=isTopInning&fields=captivatingIndex&fields=leverageIndex",
119            self.base_url, game_id
120        );
121        self.get(url).await
122    }
123
124    /// Fetch season info from the MLB API for a given year.
125    pub async fn get_season_info(&self, year: i32) -> ApiResult<Option<SeasonInfo>> {
126        let url = format!("{}v1/seasons/{}?sportId=1", self.base_url, year);
127        let resp = self.get::<SeasonsResponse>(url).await?;
128        Ok(resp.seasons.into_iter().next())
129    }
130
131    pub async fn get_standings(
132        &self,
133        date: NaiveDate,
134        game_type: GameType,
135    ) -> ApiResult<StandingsResponse> {
136        let url = match game_type {
137            GameType::SpringTraining => format!(
138                "{}v1/standings?sportId=1&season={}&standingsType=springTraining&leagueId=103,104&hydrate=team",
139                self.base_url,
140                date.year(),
141            ),
142            GameType::RegularSeason => format!(
143                "{}v1/standings?sportId=1&season={}&date={}&leagueId=103,104&hydrate=team",
144                self.base_url,
145                date.year(),
146                date.format("%Y-%m-%d"),
147            ),
148        };
149        self.get(url).await
150    }
151
152    pub async fn get_team_stats(
153        &self,
154        group: StatGroup,
155        game_type: GameType,
156    ) -> ApiResult<StatsResponse> {
157        let local: DateTime<Local> = Local::now();
158        let mut url = format!(
159            "{}v1/teams/stats?sportId=1&stats=season&season={}&group={}",
160            self.base_url,
161            local.year(),
162            group
163        );
164        if game_type == GameType::SpringTraining {
165            url.push_str("&gameType=S");
166        }
167        self.get(url).await
168    }
169
170    pub async fn get_team_stats_on_date(
171        &self,
172        group: StatGroup,
173        date: NaiveDate,
174        game_type: GameType,
175    ) -> ApiResult<StatsResponse> {
176        let mut url = format!(
177            "{}v1/teams/stats?sportId=1&stats=byDateRange&season={}&endDate={}&group={}",
178            self.base_url,
179            date.year(),
180            date.format("%Y-%m-%d"),
181            group
182        );
183        if game_type == GameType::SpringTraining {
184            url.push_str("&gameType=S");
185        }
186        self.get(url).await
187    }
188
189    pub async fn get_player_stats(
190        &self,
191        group: StatGroup,
192        game_type: GameType,
193    ) -> ApiResult<StatsResponse> {
194        let local: DateTime<Local> = Local::now();
195        let sort = group.default_sort_stat();
196        let mut url = format!(
197            "{}v1/stats?sportId=1&stats=season&season={}&group={}&limit=300&sortStat={}&order=desc",
198            self.base_url,
199            local.year(),
200            group,
201            sort
202        );
203        if game_type == GameType::SpringTraining {
204            url.push_str("&gameType=S&playerPool=ALL");
205        }
206        self.get(url).await
207    }
208
209    pub async fn get_player_stats_on_date(
210        &self,
211        group: StatGroup,
212        date: NaiveDate,
213        game_type: GameType,
214    ) -> ApiResult<StatsResponse> {
215        let sort = group.default_sort_stat();
216        // Spring training doesn't work well with byDateRange, use season instead.
217        let url = match game_type {
218            GameType::SpringTraining => format!(
219                "{}v1/stats?sportId=1&stats=season&season={}&group={}&limit=300&sortStat={}&order=desc&gameType=S&playerPool=ALL",
220                self.base_url,
221                date.year(),
222                group,
223                sort
224            ),
225            GameType::RegularSeason => format!(
226                "{}v1/stats?sportId=1&stats=byDateRange&season={}&endDate={}&group={}&limit=300&sortStat={}&order=desc",
227                self.base_url,
228                date.year(),
229                date.format("%Y-%m-%d"),
230                group,
231                sort
232            ),
233        };
234        self.get(url).await
235    }
236
237    pub async fn get_teams(&self, sport_ids: &[SportId]) -> ApiResult<TeamsResponse> {
238        let ids: Vec<String> = sport_ids.iter().map(|id| id.to_string()).collect();
239        let url = format!(
240            "{}v1/teams?sportIds={}&fields=teams,id,name,division,teamName,abbreviation,sport",
241            self.base_url,
242            ids.join(",")
243        );
244        self.get(url).await
245    }
246
247    async fn get<T: Default + DeserializeOwned>(&self, url: String) -> ApiResult<T> {
248        let response = self
249            .client
250            .get(&url)
251            .timeout(self.timeout)
252            .send()
253            .await
254            .map_err(|err| ApiError::Network(err, url.clone()))?;
255
256        let status = response.status();
257        match response.error_for_status() {
258            Ok(res) => res
259                .json::<T>()
260                .await
261                .map_err(|err| ApiError::Parsing(err, url.clone())),
262            // 400-5xx returns errors
263            Err(err) => {
264                if status.is_client_error() {
265                    // just swallow 4xx responses
266                    Ok(T::default())
267                } else {
268                    Err(ApiError::API(err, url.clone()))
269                }
270            }
271        }
272    }
273}
274
275#[test]
276fn test_stat_group_lowercase() {
277    assert_eq!("hitting".to_string(), StatGroup::Hitting.to_string());
278    assert_eq!("pitching".to_string(), StatGroup::Pitching.to_string());
279}