1use std::collections::HashSet;
4
5use serde::de::DeserializeOwned;
6
7use super::*;
8
9pub const BASE_URL: &str = "https://api-v3.mbta.com";
11
12#[doc(hidden)]
14#[macro_export]
15macro_rules! mbta_endpoint_multiple {
16 (model=$model:ident, func=$func:ident, allowed_query_params=$allowed_query_params:expr) => {
17 impl Client {
18 #[doc = concat!("Returns ", stringify!($func), " in the MBTA system.")]
19 #[doc = concat!("`", stringify!($allowed_query_params), "`")]
27 #[doc = concat!("let ", stringify!($func), "_response = client.", stringify!($func), "(&query_params);\n")]
45 #[doc = concat!("if let Ok(", stringify!($func), ") = ", stringify!($func), "_response {\n")]
46 #[doc = concat!(" for item in ", stringify!($func), ".data {\n")]
47 pub fn $func<K: AsRef<str>, V: AsRef<str>>(&self, query_params: &[(K, V)]) -> Result<Response<$model>, ClientError> {
52 let allowed_query_params: HashSet<String> = $allowed_query_params.into_iter().map(|s: &str| s.to_string()).collect();
53 for (k, v) in query_params {
54 if !allowed_query_params.contains(k.as_ref()) {
55 return Err(ClientError::InvalidQueryParam {
56 name: k.as_ref().to_string(),
57 value: v.as_ref().to_string(),
58 });
59 }
60 }
61 self.get(stringify!($func), query_params)
62 }
63 }
64 };
65}
66
67#[doc(hidden)]
69#[macro_export]
70macro_rules! mbta_endpoint_single {
71 (model=$model:ident, func=$func:ident, endpoint=$endpoint:expr, allowed_query_params=$allowed_query_params:expr) => {
72 impl Client {
73 #[doc = concat!("Returns a ", stringify!($func), " in the MBTA system given its id.")]
74 #[doc = concat!("* `id` - the id of the ", stringify!($func), " to return")]
77 #[doc = concat!("let ", stringify!($func), "_response = client.", stringify!($func), "(id);\n")]
89 #[doc = concat!("if let Ok(item) = ", stringify!($func), "_response {\n")]
90 pub fn $func(&self, id: &str) -> Result<Response<$model>, ClientError> {
94 self.get::<$model, String, String>(&format!("{}/{}", $endpoint, id), &[])
95 }
96 }
97 };
98}
99
100mbta_endpoint_multiple!(
101 model = Alerts,
102 func = alerts,
103 allowed_query_params = [
104 "page[offset]",
105 "page[limit]",
106 "sort",
107 "filter[activity]",
108 "filter[route_type]",
109 "filter[direction_id]",
110 "filter[route]",
111 "filter[stop]",
112 "filter[trip]",
113 "filter[facility]",
114 "filter[id]",
115 "filter[banner]",
116 "filter[datetime]",
117 "filter[lifecycle]",
118 "filter[severity]",
119 ]
120);
121mbta_endpoint_multiple!(
122 model = Facilities,
123 func = facilities,
124 allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[stop]", "filter[type]",]
125);
126mbta_endpoint_multiple!(
127 model = Lines,
128 func = lines,
129 allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[id]",]
130);
131mbta_endpoint_multiple!(
132 model = LiveFacilities,
133 func = live_facilities,
134 allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[id]",]
135);
136mbta_endpoint_multiple!(
137 model = Predictions,
138 func = predictions,
139 allowed_query_params = [
140 "page[offset]",
141 "page[limit]",
142 "sort",
143 "filter[latitude]",
144 "filter[longitude]",
145 "filter[radius]",
146 "filter[direction_id]",
147 "filter[route_type]",
148 "filter[stop]",
149 "filter[route]",
150 "filter[trip]",
151 "filter[route_pattern]",
152 ]
153);
154mbta_endpoint_multiple!(
155 model = Routes,
156 func = routes,
157 allowed_query_params = [
158 "page[offset]",
159 "page[limit]",
160 "sort",
161 "filter[stop]",
162 "filter[type]",
163 "filter[direction_id]",
164 "filter[date]",
165 "filter[id]",
166 ]
167);
168mbta_endpoint_multiple!(
169 model = RoutePatterns,
170 func = route_patterns,
171 allowed_query_params = [
172 "page[offset]",
173 "page[limit]",
174 "sort",
175 "filter[id]",
176 "filter[route]",
177 "filter[direction_id]",
178 "filter[stop]",
179 ]
180);
181mbta_endpoint_multiple!(
182 model = Schedules,
183 func = schedules,
184 allowed_query_params = [
185 "page[offset]",
186 "page[limit]",
187 "sort",
188 "filter[date]",
189 "filter[direction_id]",
190 "filter[route_type]",
191 "filter[min_time]",
192 "filter[max_time]",
193 "filter[route]",
194 "filter[stop]",
195 "filter[trip]",
196 "filter[stop_sequence]",
197 ]
198);
199mbta_endpoint_multiple!(
200 model = Services,
201 func = services,
202 allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[id]", "filter[route]",]
203);
204mbta_endpoint_multiple!(
205 model = Shapes,
206 func = shapes,
207 allowed_query_params = ["page[offset]", "page[limit]", "sort", "filter[route]",]
208);
209mbta_endpoint_multiple!(
210 model = Stops,
211 func = stops,
212 allowed_query_params = [
213 "page[offset]",
214 "page[limit]",
215 "sort",
216 "filter[date]",
217 "filter[direction_id]",
218 "filter[latitude]",
219 "filter[longitude]",
220 "filter[radius]",
221 "filter[id]",
222 "filter[route_type]",
223 "filter[route]",
224 "filter[service]",
225 "filter[location_type]",
226 ]
227);
228mbta_endpoint_multiple!(
229 model = Trips,
230 func = trips,
231 allowed_query_params = [
232 "page[offset]",
233 "page[limit]",
234 "sort",
235 "filter[date]",
236 "filter[direction_id]",
237 "filter[route]",
238 "filter[route_pattern]",
239 "filter[id]",
240 "filter[name]",
241 ]
242);
243mbta_endpoint_multiple!(
244 model = Vehicles,
245 func = vehicles,
246 allowed_query_params = [
247 "page[offset]",
248 "page[limit]",
249 "sort",
250 "filter[id]",
251 "filter[trip]",
252 "filter[label]",
253 "filter[route]",
254 "filter[direction_id]",
255 "filter[route_type]",
256 ]
257);
258
259mbta_endpoint_single!(model = Alert, func = alert, endpoint = "alerts", allowed_query_params = []);
260mbta_endpoint_single!(model = Facility, func = facility, endpoint = "facilities", allowed_query_params = []);
261mbta_endpoint_single!(model = Line, func = line, endpoint = "lines", allowed_query_params = []);
262mbta_endpoint_single!(model = Route, func = route, endpoint = "routes", allowed_query_params = []);
263mbta_endpoint_single!(model = RoutePattern, func = route_pattern, endpoint = "route_patterns", allowed_query_params = []);
264mbta_endpoint_single!(model = Service, func = service, endpoint = "services", allowed_query_params = []);
265mbta_endpoint_single!(model = Shape, func = shape, endpoint = "shapes", allowed_query_params = []);
266mbta_endpoint_single!(model = Stop, func = stop, endpoint = "stops", allowed_query_params = []);
267mbta_endpoint_single!(model = Trip, func = trip, endpoint = "trips", allowed_query_params = []);
268mbta_endpoint_single!(model = Vehicle, func = vehicle, endpoint = "vehicles", allowed_query_params = []);
269
270#[derive(Debug, Clone, PartialEq)]
272pub struct Client {
273 api_key: Option<String>,
275 base_url: String,
277}
278
279impl Client {
280 pub fn without_key() -> Self {
284 Self {
285 api_key: None,
286 base_url: BASE_URL.into(),
287 }
288 }
289
290 pub fn with_key<S: Into<String>>(api_key: S) -> Self {
296 Self {
297 api_key: Some(api_key.into()),
298 base_url: BASE_URL.into(),
299 }
300 }
301
302 pub fn with_url<S: Into<String>>(base_url: S) -> Self {
309 Self {
310 api_key: None,
311 base_url: base_url.into(),
312 }
313 }
314
315 fn get<T: DeserializeOwned, K: AsRef<str>, V: AsRef<str>>(
322 &self,
323 endpoint: &str,
324 query_params: &[(K, V)],
325 ) -> Result<Response<T>, ClientError> {
326 let path = format!("{}/{}", self.base_url, endpoint);
327 let request = ureq::get(&path);
328 let request = match &self.api_key {
329 Some(key) => request.set("x-api-key", key),
330 None => request,
331 };
332 let request = query_params.iter().fold(request, |r, (k, v)| r.query(k.as_ref(), v.as_ref()));
333 let response: Response<T> = request.call()?.into_json()?;
334 Ok(response)
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 use rstest::*;
343
344 #[rstest]
345 fn test_client_without_key() {
346 let expected = Client {
348 api_key: None,
349 base_url: "https://api-v3.mbta.com".into(),
350 };
351
352 let actual = Client::without_key();
354
355 assert_eq!(actual, expected);
357 }
358
359 #[rstest]
360 fn test_client_with_key() {
361 let expected = Client {
363 api_key: Some("test key".into()),
364 base_url: "https://api-v3.mbta.com".into(),
365 };
366
367 let actual = Client::with_key("test key");
369
370 assert_eq!(actual, expected);
372 }
373
374 #[rstest]
375 fn test_client_with_url() {
376 let expected = Client {
378 api_key: None,
379 base_url: "https://foobar.com".into(),
380 };
381
382 let actual = Client::with_url("https://foobar.com");
384
385 assert_eq!(actual, expected);
387 }
388}