1#[cfg(test)]
2#[macro_use] extern crate rocket;
3
4use rocket::{http::Status, request::{self, Request, FromRequest}};
5
6#[derive(Debug)]
7pub struct Versioning<const MAJOR: u64, const MINOR: u64> {
8 major: u64,
9 minor: u64,
10}
11
12impl<const MAJOR: u64, const MINOR: u64> Versioning<MAJOR, MINOR> {
13 pub const fn new() -> Versioning<MAJOR, MINOR> {
14 Versioning {
15 major: MAJOR,
16 minor: MINOR,
17 }
18 }
19}
20
21#[derive(Debug)]
22pub enum VersionError {
23 SemverError(semver::Error),
24 NotExists,
25}
26
27#[rocket::async_trait]
28impl<'r, const MAJOR: u64, const MINOR: u64> FromRequest<'r> for Versioning<MAJOR, MINOR> {
29 type Error = &'r VersionError;
30
31 async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
32 let version = req.local_cache(|| {
33 let ver = req.headers().get_one("api-version").ok_or(VersionError::NotExists)?;
34 semver::Version::parse(ver).map_err(VersionError::SemverError)
35 });
36 match version {
37 Err(err) => request::Outcome::Failure((Status::NotFound, err)),
38 Ok(version) => {
39 if version.major == MAJOR && version.minor == MINOR {
40 request::Outcome::Success(Self::new())
41 } else {
42 request::Outcome::Forward(())
43 }
44 }
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use rocket::local::blocking::Client;
52 use rocket::http::{Header, Status};
53
54 use super::Versioning;
55
56 #[get("/versioning", rank = 4)]
57 fn versioning(_v: Versioning<1, 0>) -> String {
58 "v1.0".to_string()
59 }
60
61 #[get("/versioning", rank = 3)]
62 fn versioning_1_1(_v: Versioning<1, 1>) -> String {
63 "v1.1".to_string()
64 }
65
66 #[get("/versioning", rank = 2)]
67 fn versioning_2_1(_v: Versioning<2, 1>) -> String {
68 "v2.1".to_string()
69 }
70
71 #[launch]
72 fn rocket() -> _ {
73 rocket::build().mount("/", routes![versioning, versioning_1_1, versioning_2_1])
74 }
75
76 #[test]
77 fn test_versioning() {
78 let client = Client::tracked(rocket()).expect("invalid rocket instance");
79 let response = client.get("/versioning").header(Header::new("Api-Version", "1.0.0")).dispatch();
80 assert_eq!(response.status(), Status::Ok);
81 assert_eq!(response.into_string().unwrap(), "v1.0");
82
83 let response = client.get("/versioning").header(Header::new("Api-Version", "1.1.0")).dispatch();
84 assert_eq!(response.status(), Status::Ok);
85 assert_eq!(response.into_string().unwrap(), "v1.1");
86
87 let response = client.get("/versioning").header(Header::new("Api-Version", "2.1.0")).dispatch();
88 assert_eq!(response.status(), Status::Ok);
89 assert_eq!(response.into_string().unwrap(), "v2.1");
90
91 let response = client.get("/versioning").header(Header::new("Api-Version", "2.0.0")).dispatch();
92 assert_eq!(response.status(), Status::NotFound);
93 }
94}