1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//! There are two key structs in this module: `OAuth2Flow` and `OAuth2`.
//! OAuth2Flow brings the user through the OAuth2 flow, and OAuth2
//! is a middleware used to authorize requests.
use std::sync::RwLock;
pub use middleware::{OAuth2, TokenType, RefreshData};
pub use refresh::RefreshConfig;

mod middleware;
mod refresh;
mod step2_exchange;
mod step1_init;

pub use step1_init::{Initialize, AccessType, PromptType};
use step1_init::{InitializeParams};
use httpclient::{Uri, client, Result, InMemoryResult};
pub use crate::step2_exchange::{ExchangeData, ExchangeResponse, RedirectData};
use httpclient::InMemoryResponseExt;

/// The main entry point for taking the user through OAuth2 flow.
pub struct OAuth2Flow {
    pub client_id: String,
    pub client_secret: String,

    /// The endpoint to initialize the flow. (Step 1)
    pub init_endpoint: String,
    /// The endpoint to exchange the code for an access token. (Step 2)
    pub exchange_endpoint: String,
    /// The endpoint to refresh the access token.
    pub refresh_endpoint: String,

    pub redirect_uri: String,
}

impl OAuth2Flow {
    /// Step 1: Send the user to the authorization URL.
    ///
    /// After performing the exchange, you will get an [`ExchangeResponse`]. Depending on the [`PromptType`]
    /// provided here, that response may not contain a refresh_token.
    ///
    /// If the value is select_account, it will have a refresh_token only on the first exchange. Afterward, it will be missing.
    ///
    /// If the value is consent, the response will always have a refresh_token. The reason to avoid consent
    /// except when necessary is because it will require the user to re-accept the permissions (i.e. longer user flow, causing drop-off).
    ///
    pub fn create_authorization_url(&self, init: Initialize) -> Uri {
        let params = InitializeParams {
            client_id: &self.client_id,
            redirect_uri: &self.redirect_uri,
            response_type: "code",
            scope: init.scope,
            access_type: init.access_type,
            state: init.state,
            prompt: init.prompt,
        };
        let params = serde_qs::to_string(&params).unwrap();
        let endpoint = self.init_endpoint.as_str();
        let uri = format!("{endpoint}?{params}");
        uri.parse().unwrap()
    }

    /// Step 2a: Extract the code from the redirect URL.
    /// `url` can either be the full url, or a path_and_query string, e.g. "/foo?code=abc&state=def"
    /// The input will be url-decoded (percent-decoded).
    pub fn extract_code(&self, url: &str) -> Result<RedirectData> {
        let uri: Uri = url.parse().unwrap();
        let query = uri.query().unwrap();
        let params = serde_qs::from_str::<RedirectData>(query).unwrap();
        Ok(params)
    }

    pub fn create_exchange_data(&self, code: String) -> ExchangeData {
        ExchangeData {
            code,
            client_id: &self.client_id,
            redirect_uri: &self.redirect_uri,
            client_secret: &self.client_secret,
            grant_type: "authorization_code",
        }
    }

    /// Step 2b: Using RedirectedParams.code, POST to the exchange_endpoint to get the access token.
    pub async fn exchange(&self, code: String) -> InMemoryResult<ExchangeResponse> {
        let data = self.create_exchange_data(code);
        let res = client().post(&self.exchange_endpoint)
            .form(data)
            .await?;
        Ok(res.json()?)
    }

    /// Step 3: Use the exchange response to create a middleware. You can also use `bearer_middleware`.
    /// This method can fail if the ExchangeResponse is missing the refresh token. This will happen in "re-auth"
    /// situations when prompt="consent" was not used. See [`Self::create_authorization_url`] docs for more.
    ///
    /// As the middleware makes requests, the access_token will be refreshed automatically when it expires.
    /// If you want to store the updated access_token (recommended), set the [`OAuth2`] `callback` field.
    pub fn middleware_from_exchange(&self, exchange: ExchangeResponse) -> Result<OAuth2, MissingRefreshToken> {
        let refresh_token = exchange.refresh_token.ok_or(MissingRefreshToken)?;
        Ok(OAuth2 {
            refresh_endpoint: self.refresh_endpoint.clone(),
            client_id: self.client_id.clone(),
            client_secret: self.client_secret.clone(),
            token_type: exchange.token_type,
            access_token: RwLock::new(exchange.access_token),
            refresh_token,
            callback: None,
        })
    }

    pub fn bearer_middleware(&self, access: String, refresh: String) -> OAuth2 {
        OAuth2 {
            refresh_endpoint: self.refresh_endpoint.clone(),
            client_id: self.client_id.clone(),
            client_secret: self.client_secret.clone(),
            token_type: TokenType::Bearer,
            access_token: RwLock::new(access),
            refresh_token: refresh,
            callback: None,
        }
    }
}

pub struct MissingRefreshToken;

impl std::fmt::Display for MissingRefreshToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self)
    }
}

impl std::fmt::Debug for MissingRefreshToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "`refresh_token` missing from ExchangeResponse. This will happen on re-authorization if you did not use `prompt=consent`. See `OAuth2Flow::create_authorization_url` docs for more.")
    }
}

impl std::error::Error for MissingRefreshToken {}


#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_extract_code() {
        let flow = OAuth2Flow {
            client_id: "".to_string(),
            client_secret: "".to_string(),
            init_endpoint: "".to_string(),
            exchange_endpoint: "".to_string(),
            refresh_endpoint: "".to_string(),
            redirect_uri: "".to_string(),
        };
        let url = "http://localhost:3000/?code=4%2F0AY0";
        let code = flow.extract_code(url).unwrap();
        assert_eq!(code.code, "4/0AY0");
        let url = "/foo?code=4%2F0AY0&state=123";
        let code = flow.extract_code(url).unwrap();
        assert_eq!(code.code, "4/0AY0");
        assert_eq!(code.state, Some("123".to_string()));
    }
}