oauth_axum/
lib.rs

1//! # oauth-axum
2//!
3//! This crate is a wrapper of oauth2 lib, but it has all the provider configuration done, making it easy to implement in your Axum project.
4//! The intention is to add all providers from this list: https://en.wikipedia.org/wiki/List_of_OAuth_providers that have oauth2 available.
5//!
6//! # Usage
7//!
8//! To use it, it's very simple. Just create a new instance of some provider:
9//!
10//! - CustomProvider
11//! - GithubProvider
12//! - DiscordProvider
13//! - TwitterProvider
14//! - GoogleProvider
15//! - MicrosoftProvider
16//! - FacebookProvider
17//! - SpotifyProvider
18//!
19//! in your project, pass to the ```new``` function:
20//!
21//!   - **client_id:** Unique ID from the app created in your provider
22//!   - **secret_id:** Secret token from your app inside the provider, this token needs to be hidden from the users
23//!   - **redirect_url:** URL from your backend that will accept the return from the provider
24//!
25//!   If you are using **``CustomProvider``** you need to pass:
26//!
27//!   - **auth_url:** URL from your provider that is used to get the permission of your app access user account
28//!   - **token_url:** URL that is used to generate the auth token
29//!
30//! The structure of this project is separated into two steps:
31//!
32//! ### 1. Generate the URL
33//!
34//! This step will create a URL to redirect the user to the provider to execute the authorization of your app access to the user info.
35//!
36//! The URL has this format (Github example): https://github.com/login/oauth/authorize?response_type=code&client_id={CLIENT_ID}&state={RANDOM_STATE}&code_challenge={RANDOM_STATE}&code_challenge_method=S256&redirect_uri={REDIRECT_URL}&scope={SCOPES}
37//!
38//! This step is important because that will generate the VERIFIER field, it is needed to save in some place (memory, db...) with the state field, the state will be your ID to get the verifier in the second step.
39//!
40//! ### 2. Callback URL
41//!
42//! After the user accepts the auth from the provider, it will redirect the user to the specific URL that you added in the config of the provider ``redirect_url``, and is important to remember that the same URL should be set in the oauth-axum params, if it is not the same an error will happen.
43//! This redirect will have two query parameters, CODE and STATE, we need to generate a token from the code and verifier fields, which is the reason that in the first step, you need to save the verifier and state together.
44//! After that, you will have a token to access the API in the provider.
45//!
46//! ## Example
47//!
48//! This method is for a small project that will run in one unique instance of Axum. It saves the state and verifier in memory, which can be accessible in the callback URL call.
49//!
50//! ```rust
51//! mod utils;
52//! use std::sync::Arc;
53//!
54//! use axum::extract::Query;
55//! use axum::Router;
56//! use axum::{routing::get, Extension};
57//! use oauth_axum::providers::twitter::TwitterProvider;
58//! use oauth_axum::{CustomProvider, OAuthClient};
59//!
60//! use crate::utils::memory_db_util::AxumState;
61//!
62//! #[derive(Clone, serde::Deserialize)]
63//! pub struct QueryAxumCallback {
64//!     pub code: String,
65//!     pub state: String,
66//! }
67//!
68//! #[tokio::main]
69//! async fn main() {
70//!     dotenv::from_filename("examples/.env").ok();
71//!     println!("Starting server...");
72//!
73//!     let state = Arc::new(AxumState::new());
74//!     let app = Router::new()
75//!         .route("/", get(create_url))
76//!         .route("/api/v1/twitter/callback", get(callback))
77//!         .layer(Extension(state.clone()));
78//!
79//!     println!("🚀 Server started successfully");
80//!     let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
81//!         .await
82//!         .unwrap();
83//!     axum::serve(listener, app).await.unwrap();
84//! }
85//!
86//! fn get_client() -> CustomProvider {
87//!     TwitterProvider::new(
88//!         std::env::var("TWITTER_CLIENT_ID").expect("TWITTER_CLIENT_ID must be set"),
89//!         std::env::var("TWITTER_SECRET").expect("TWITTER_SECRET must be set"),
90//!         "http://localhost:3000/api/v1/twitter/callback".to_string(),
91//!     )
92//! }
93//!
94//! pub async fn create_url(Extension(state): Extension<Arc<AxumState>>) -> String {
95//!     let state_oauth = get_client()
96//!         .generate_url(
97//!             Vec::from(["users.read".to_string()]),
98//!             |state_e| async move {
99//!                 //SAVE THE DATA IN THE DB OR MEMORY
100//!                 //state should be your ID
101//!                 state.set(state_e.state, state_e.verifier);
102//!             },
103//!         )
104//!         .await
105//!         .ok()
106//!         .unwrap()
107//!         .state
108//!         .unwrap();
109//!
110//!     state_oauth.url_generated.unwrap()
111//! }
112//!
113//! pub async fn callback(
114//!     Extension(state): Extension<Arc<AxumState>>,
115//!     Query(queries): Query<QueryAxumCallback>,
116//! ) -> String {
117//!     println!("{:?}", state.clone().get_all_items());
118//!     // GET DATA FROM DB OR MEMORY
119//!     // get data using state as ID
120//!     let item = state.get(queries.state.clone());
121//!     get_client()
122//!         .generate_token(queries.code, item.unwrap())
123//!         .await
124//!         .ok()
125//!        .unwrap()
126//! }
127//! ```
128//!
129//! # Next Steps of Development
130//!
131//! - Add all tests
132//! - Add more Providers
133//!
134
135pub mod error;
136pub mod providers;
137
138use async_trait::async_trait;
139use error::OauthError;
140use std::future::Future;
141
142use oauth2::reqwest::async_http_client;
143use oauth2::{
144    basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
145    Scope, TokenUrl,
146};
147use oauth2::{AuthorizationCode, PkceCodeVerifier, TokenResponse};
148
149#[derive(Clone)]
150pub struct CustomProvider {
151    pub auth_url: String,
152    pub token_url: String,
153    pub client_id: String,
154    pub client_secret: String,
155    pub redirect_url: String,
156    pub state: Option<StateAuth>,
157}
158
159#[derive(Clone)]
160pub enum MethodExecute {
161    DB,
162    MEMORY,
163}
164
165#[derive(Clone, Debug)]
166pub struct StateAuth {
167    pub url_generated: Option<String>,
168    pub state: String,
169    pub verifier: String,
170}
171
172impl CustomProvider {
173    pub fn new(
174        auth_url: String,
175        token_url: String,
176        client_id: String,
177        client_secret: String,
178        redirect_url: String,
179    ) -> Self {
180        CustomProvider {
181            auth_url,
182            token_url,
183            client_id,
184            client_secret,
185            redirect_url,
186            state: None,
187        }
188    }
189}
190
191/// OAuthClient is the main struct of the lib, it will handle all the connection with the provider
192#[async_trait]
193pub trait OAuthClient {
194    fn get_client(&self) -> Result<BasicClient, OauthError>;
195
196    /// Get fields data from generated URL
197    /// # Return
198    /// StateAuth - The state, verifier and url_generated
199    fn get_state(&self) -> Option<StateAuth>;
200
201    /// Generate the URL to redirect the user to the provider
202    /// # Arguments
203    /// * `scopes` - Vec<String> - The scopes that you want to access in the provider
204    /// * `save` - F - The function that will use to save your state in the db/memory
205    async fn generate_url<F, Fut>(
206        mut self,
207        scopes: Vec<String>,
208        save: F,
209    ) -> Result<Box<Self>, OauthError>
210    where
211        F: FnOnce(StateAuth) -> Fut + Send,
212        Fut: Future<Output = ()> + Send;
213
214    /// Generate the token from the code and verifier
215    /// # Arguments
216    /// * `code` - String - The code that the provider will return after the user accept the auth
217    /// * `verifier` - String - The verifier that was generated in the first step
218    /// # Return
219    /// The token generated
220    async fn generate_token(&self, code: String, verifier: String) -> Result<String, OauthError>;
221}
222
223#[async_trait]
224impl OAuthClient for CustomProvider {
225    fn get_client(&self) -> Result<BasicClient, OauthError> {
226        Ok(BasicClient::new(
227            ClientId::new(self.client_id.clone()),
228            Some(ClientSecret::new(self.client_secret.clone())),
229            AuthUrl::new(self.auth_url.clone()).map_err(|_| OauthError::AuthUrlCreationFailed)?,
230            Some(TokenUrl::new(self.token_url.clone()).unwrap()),
231        )
232        .set_redirect_uri(RedirectUrl::new(self.redirect_url.clone()).unwrap()))
233    }
234
235    fn get_state(&self) -> Option<StateAuth> {
236        self.state.clone()
237    }
238
239    async fn generate_url<F, Fut>(
240        mut self,
241        scopes: Vec<String>,
242        save: F,
243    ) -> Result<Box<Self>, OauthError>
244    where
245        F: FnOnce(StateAuth) -> Fut + Send,
246        Fut: Future<Output = ()> + Send,
247    {
248        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
249
250        let binding = self.get_client();
251        let (auth_url, csrf_token) = binding?
252            .authorize_url(CsrfToken::new_random)
253            .add_scopes(scopes.into_iter().map(Scope::new).collect::<Vec<Scope>>())
254            .set_pkce_challenge(pkce_challenge)
255            .url();
256
257        let state = StateAuth {
258            url_generated: Some(auth_url.to_string()),
259            state: csrf_token.secret().to_string(),
260            verifier: pkce_verifier.secret().to_string(),
261        };
262
263        self.state = Some(state.clone());
264        save(state).await;
265
266        Ok(Box::new(self.clone()))
267    }
268
269    async fn generate_token(&self, code: String, verifier: String) -> Result<String, OauthError> {
270        let token = self
271            .get_client()?
272            .exchange_code(AuthorizationCode::new(code.clone()))
273            .set_pkce_verifier(PkceCodeVerifier::new(verifier.clone()))
274            .request_async(async_http_client)
275            .await
276            .map_err(|_| OauthError::TokenRequestFailed)?;
277        Ok(token.access_token().secret().to_string())
278    }
279}