mbta_rs/
client.rs

1//! The client for interacting with the V3 API.
2
3use std::collections::HashSet;
4
5use serde::de::DeserializeOwned;
6
7use super::*;
8
9/// Base url for client request endpoints.
10pub const BASE_URL: &str = "https://api-v3.mbta.com";
11
12/// Attribute macro for quickly implementing MBTA client endpoints with multiple return objects.
13#[doc(hidden)]
14#[macro_export]
15macro_rules! mbta_endpoint_multiple {
16    (model=$model:ident, func=$func:ident, allowed_query_params=$allowed_query_params:expr) => {
17        impl Client {
18            #[doc = concat!("Returns ", stringify!($func), " in the MBTA system.")]
19            ///
20            /// Consult the [API swagger docs](https://api-v3.mbta.com/docs/swagger/index.html) for each parameter's meaning and which are required,
21            /// but the request will fail if you include any that are *not* the ones specified below
22            /// (we limit them to avoid any return type behaviors that we currently can't support).
23            ///
24            /// # Allowed Query Parameters
25            ///
26            #[doc = concat!("`", stringify!($allowed_query_params), "`")]
27            ///
28            /// # Arguments
29            ///
30            /// * `query_params` - a slice of pairings of query parameter names to values
31            ///
32            /// ```
33            /// # use std::env;
34            /// # use mbta_rs::Client;
35            /// #
36            /// # let client = match env::var("MBTA_TOKEN") {
37            /// #     Ok(token) => Client::with_key(token),
38            /// #     Err(_) => Client::without_key()
39            /// # };
40            /// #
41            /// # let query_params = [
42            /// #     ("page[limit]", "3")
43            /// # ];
44            #[doc = concat!("let ", stringify!($func), "_response = client.", stringify!($func), "(&query_params);\n")]
45            #[doc = concat!("if let Ok(", stringify!($func), ") = ", stringify!($func), "_response {\n")]
46            #[doc = concat!("    for item in ", stringify!($func), ".data {\n")]
47            ///         println!("{}", item.id);
48            ///     }
49            /// }
50            /// ```
51            pub fn $func<K: AsRef<str>, V: AsRef<str>>(&self, query_params: &[(K, V)]) -> Result<Response<$model>, ClientError> {
52                let allowed_query_params: HashSet<String> = $allowed_query_params.into_iter().map(|s: &str| s.to_string()).collect();
53                for (k, v) in query_params {
54                    if !allowed_query_params.contains(k.as_ref()) {
55                        return Err(ClientError::InvalidQueryParam {
56                            name: k.as_ref().to_string(),
57                            value: v.as_ref().to_string(),
58                        });
59                    }
60                }
61                self.get(stringify!($func), query_params)
62            }
63        }
64    };
65}
66
67/// Attribute macro for quickly implementing MBTA client endpoints with single return objects.
68#[doc(hidden)]
69#[macro_export]
70macro_rules! mbta_endpoint_single {
71    (model=$model:ident, func=$func:ident, endpoint=$endpoint:expr, allowed_query_params=$allowed_query_params:expr) => {
72        impl Client {
73            #[doc = concat!("Returns a ", stringify!($func), " in the MBTA system given its id.")]
74            ///
75            /// # Arguments
76            #[doc = concat!("* `id` - the id of the ", stringify!($func), " to return")]
77            ///
78            /// ```
79            /// # use std::env;
80            /// # use mbta_rs::Client;
81            /// #
82            /// # let client = match env::var("MBTA_TOKEN") {
83            /// #     Ok(token) => Client::with_key(token),
84            /// #     Err(_) => Client::without_key()
85            /// # };
86            /// #
87            /// # let id = "";
88            #[doc = concat!("let ", stringify!($func), "_response = client.", stringify!($func), "(id);\n")]
89            #[doc = concat!("if let Ok(item) = ", stringify!($func), "_response {\n")]
90            ///     println!("{}", item.data.id);
91            /// }
92            /// ```
93            pub fn $func(&self, id: &str) -> Result<Response<$model>, ClientError> {
94                self.get::<$model, String, String>(&format!("{}/{}", $endpoint, id), &[])
95            }
96        }
97    };
98}
99
100mbta_endpoint_multiple!(
101    model = Alerts,
102    func = alerts,
103    allowed_query_params = [
104        "page[offset]",
105        "page[limit]",
106        "sort",
107        "filter[activity]",
108        "filter[route_type]",
109        "filter[direction_id]",
110        "filter[route]",
111        "filter[stop]",
112        "filter[trip]",
113        "filter[facility]",
114        "filter[id]",
115        "filter[banner]",
116        "filter[datetime]",
117        "filter[lifecycle]",
118        "filter[severity]",
119    ]
120);
121mbta_endpoint_multiple!(
122    model = Facilities,
123    func = facilities,
124    allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[stop]", "filter[type]",]
125);
126mbta_endpoint_multiple!(
127    model = Lines,
128    func = lines,
129    allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[id]",]
130);
131mbta_endpoint_multiple!(
132    model = LiveFacilities,
133    func = live_facilities,
134    allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[id]",]
135);
136mbta_endpoint_multiple!(
137    model = Predictions,
138    func = predictions,
139    allowed_query_params = [
140        "page[offset]",
141        "page[limit]",
142        "sort",
143        "filter[latitude]",
144        "filter[longitude]",
145        "filter[radius]",
146        "filter[direction_id]",
147        "filter[route_type]",
148        "filter[stop]",
149        "filter[route]",
150        "filter[trip]",
151        "filter[route_pattern]",
152    ]
153);
154mbta_endpoint_multiple!(
155    model = Routes,
156    func = routes,
157    allowed_query_params = [
158        "page[offset]",
159        "page[limit]",
160        "sort",
161        "filter[stop]",
162        "filter[type]",
163        "filter[direction_id]",
164        "filter[date]",
165        "filter[id]",
166    ]
167);
168mbta_endpoint_multiple!(
169    model = RoutePatterns,
170    func = route_patterns,
171    allowed_query_params = [
172        "page[offset]",
173        "page[limit]",
174        "sort",
175        "filter[id]",
176        "filter[route]",
177        "filter[direction_id]",
178        "filter[stop]",
179    ]
180);
181mbta_endpoint_multiple!(
182    model = Schedules,
183    func = schedules,
184    allowed_query_params = [
185        "page[offset]",
186        "page[limit]",
187        "sort",
188        "filter[date]",
189        "filter[direction_id]",
190        "filter[route_type]",
191        "filter[min_time]",
192        "filter[max_time]",
193        "filter[route]",
194        "filter[stop]",
195        "filter[trip]",
196        "filter[stop_sequence]",
197    ]
198);
199mbta_endpoint_multiple!(
200    model = Services,
201    func = services,
202    allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[id]", "filter[route]",]
203);
204mbta_endpoint_multiple!(
205    model = Shapes,
206    func = shapes,
207    allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[route]",]
208);
209mbta_endpoint_multiple!(
210    model = Stops,
211    func = stops,
212    allowed_query_params = [
213        "page[offset]",
214        "page[limit]",
215        "sort",
216        "filter[date]",
217        "filter[direction_id]",
218        "filter[latitude]",
219        "filter[longitude]",
220        "filter[radius]",
221        "filter[id]",
222        "filter[route_type]",
223        "filter[route]",
224        "filter[service]",
225        "filter[location_type]",
226    ]
227);
228mbta_endpoint_multiple!(
229    model = Trips,
230    func = trips,
231    allowed_query_params = [
232        "page[offset]",
233        "page[limit]",
234        "sort",
235        "filter[date]",
236        "filter[direction_id]",
237        "filter[route]",
238        "filter[route_pattern]",
239        "filter[id]",
240        "filter[name]",
241    ]
242);
243mbta_endpoint_multiple!(
244    model = Vehicles,
245    func = vehicles,
246    allowed_query_params = [
247        "page[offset]",
248        "page[limit]",
249        "sort",
250        "filter[id]",
251        "filter[trip]",
252        "filter[label]",
253        "filter[route]",
254        "filter[direction_id]",
255        "filter[route_type]",
256    ]
257);
258
259mbta_endpoint_single!(model = Alert, func = alert, endpoint = "alerts", allowed_query_params = []);
260mbta_endpoint_single!(model = Facility, func = facility, endpoint = "facilities", allowed_query_params = []);
261mbta_endpoint_single!(model = Line, func = line, endpoint = "lines", allowed_query_params = []);
262mbta_endpoint_single!(model = Route, func = route, endpoint = "routes", allowed_query_params = []);
263mbta_endpoint_single!(model = RoutePattern, func = route_pattern, endpoint = "route_patterns", allowed_query_params = []);
264mbta_endpoint_single!(model = Service, func = service, endpoint = "services", allowed_query_params = []);
265mbta_endpoint_single!(model = Shape, func = shape, endpoint = "shapes", allowed_query_params = []);
266mbta_endpoint_single!(model = Stop, func = stop, endpoint = "stops", allowed_query_params = []);
267mbta_endpoint_single!(model = Trip, func = trip, endpoint = "trips", allowed_query_params = []);
268mbta_endpoint_single!(model = Vehicle, func = vehicle, endpoint = "vehicles", allowed_query_params = []);
269
270/// Synchronous client for interacting with the MBTA V3 API.
271#[derive(Debug, Clone, PartialEq)]
272pub struct Client {
273    /// HTTP agent that does all the heavy lifting.
274    api_key: Option<String>,
275    /// API base URL.
276    base_url: String,
277}
278
279impl Client {
280    /// Create a [Client] without an API key.
281    ///
282    /// > "Without an api key in the query string or as a request header, requests will be tracked by IP address and have stricter rate limit." - Massachusetts Bay Transportation Authority
283    pub fn without_key() -> Self {
284        Self {
285            api_key: None,
286            base_url: BASE_URL.into(),
287        }
288    }
289
290    /// Create a [Client] with an API key.
291    ///
292    /// # Arguments
293    ///
294    /// * `api_key` - the API key to use
295    pub fn with_key<S: Into<String>>(api_key: S) -> Self {
296        Self {
297            api_key: Some(api_key.into()),
298            base_url: BASE_URL.into(),
299        }
300    }
301
302    /// Create a [Client] with a custom base URL and no API key.
303    /// This method should only be used for mocking/testing purposes.
304    ///
305    /// # Arguments
306    ///
307    /// * `base_url` - the base URL to use
308    pub fn with_url<S: Into<String>>(base_url: S) -> Self {
309        Self {
310            api_key: None,
311            base_url: base_url.into(),
312        }
313    }
314
315    /// Helper method for making generalized `GET` requests to any endpoint with any query parameters.
316    /// Presumes that all query parameters given are valid.
317    ///
318    /// # Arguments
319    ///
320    /// * query_params - a slice of pairings of query parameter names to values
321    fn get<T: DeserializeOwned, K: AsRef<str>, V: AsRef<str>>(
322        &self,
323        endpoint: &str,
324        query_params: &[(K, V)],
325    ) -> Result<Response<T>, ClientError> {
326        let path = format!("{}/{}", self.base_url, endpoint);
327        let request = ureq::get(&path);
328        let request = match &self.api_key {
329            Some(key) => request.set("x-api-key", key),
330            None => request,
331        };
332        let request = query_params.iter().fold(request, |r, (k, v)| r.query(k.as_ref(), v.as_ref()));
333        let response: Response<T> = request.call()?.into_json()?;
334        Ok(response)
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    use rstest::*;
343
344    #[rstest]
345    fn test_client_without_key() {
346        // Arrange
347        let expected = Client {
348            api_key: None,
349            base_url: "https://api-v3.mbta.com".into(),
350        };
351
352        // Act
353        let actual = Client::without_key();
354
355        // Assert
356        assert_eq!(actual, expected);
357    }
358
359    #[rstest]
360    fn test_client_with_key() {
361        // Arrange
362        let expected = Client {
363            api_key: Some("test key".into()),
364            base_url: "https://api-v3.mbta.com".into(),
365        };
366
367        // Act
368        let actual = Client::with_key("test key");
369
370        // Assert
371        assert_eq!(actual, expected);
372    }
373
374    #[rstest]
375    fn test_client_with_url() {
376        // Arrange
377        let expected = Client {
378            api_key: None,
379            base_url: "https://foobar.com".into(),
380        };
381
382        // Act
383        let actual = Client::with_url("https://foobar.com");
384
385        // Assert
386        assert_eq!(actual, expected);
387    }
388}