rocket_versioning/
lib.rs

1#[cfg(test)]
2#[macro_use] extern crate rocket;
3
4use rocket::{http::Status, request::{self, Request, FromRequest}};
5
6#[derive(Debug)]
7pub struct Versioning<const MAJOR: u64, const MINOR: u64> {
8    major: u64,
9    minor: u64,
10}
11
12impl<const MAJOR: u64, const MINOR: u64> Versioning<MAJOR, MINOR> {
13    pub const fn new() -> Versioning<MAJOR, MINOR> {
14        Versioning {
15            major: MAJOR,
16            minor: MINOR,
17        }
18    }
19}
20
21#[derive(Debug)]
22pub enum VersionError {
23    SemverError(semver::Error),
24    NotExists,
25}
26
27#[rocket::async_trait]
28impl<'r, const MAJOR: u64, const MINOR: u64> FromRequest<'r> for Versioning<MAJOR, MINOR> {
29    type Error = &'r VersionError;
30
31    async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
32        let version = req.local_cache(|| {
33            let ver = req.headers().get_one("api-version").ok_or(VersionError::NotExists)?;
34            semver::Version::parse(ver).map_err(VersionError::SemverError)
35        });
36        match version {
37            Err(err) => request::Outcome::Failure((Status::NotFound, err)),
38            Ok(version) => {
39                if version.major == MAJOR && version.minor == MINOR {
40                    request::Outcome::Success(Self::new())
41                } else {
42                    request::Outcome::Forward(())
43                }
44            }
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rocket::local::blocking::Client;
52    use rocket::http::{Header, Status};
53
54    use super::Versioning;
55
56    #[get("/versioning", rank = 4)]
57    fn versioning(_v: Versioning<1, 0>) -> String {
58        "v1.0".to_string()
59    }
60    
61    #[get("/versioning", rank = 3)]
62    fn versioning_1_1(_v: Versioning<1, 1>) -> String {
63        "v1.1".to_string()
64    }
65    
66    #[get("/versioning", rank = 2)]
67    fn versioning_2_1(_v: Versioning<2, 1>) -> String {
68        "v2.1".to_string()
69    }
70    
71    #[launch]
72    fn rocket() -> _ {
73        rocket::build().mount("/", routes![versioning, versioning_1_1, versioning_2_1])
74    }
75
76    #[test]
77    fn test_versioning() {
78        let client = Client::tracked(rocket()).expect("invalid rocket instance");
79        let response = client.get("/versioning").header(Header::new("Api-Version", "1.0.0")).dispatch();
80        assert_eq!(response.status(), Status::Ok);
81        assert_eq!(response.into_string().unwrap(), "v1.0");
82
83        let response = client.get("/versioning").header(Header::new("Api-Version", "1.1.0")).dispatch();
84        assert_eq!(response.status(), Status::Ok);
85        assert_eq!(response.into_string().unwrap(), "v1.1");
86
87        let response = client.get("/versioning").header(Header::new("Api-Version", "2.1.0")).dispatch();
88        assert_eq!(response.status(), Status::Ok);
89        assert_eq!(response.into_string().unwrap(), "v2.1");
90
91        let response = client.get("/versioning").header(Header::new("Api-Version", "2.0.0")).dispatch();
92        assert_eq!(response.status(), Status::NotFound);
93    }
94}