oauth_axum/lib.rs
1//! # oauth-axum
2//!
3//! This crate is a wrapper of oauth2 lib, but it has all the provider configuration done, making it easy to implement in your Axum project.
4//! The intention is to add all providers from this list: https://en.wikipedia.org/wiki/List_of_OAuth_providers that have oauth2 available.
5//!
6//! # Usage
7//!
8//! To use it, it's very simple. Just create a new instance of some provider:
9//!
10//! - CustomProvider
11//! - GithubProvider
12//! - DiscordProvider
13//! - TwitterProvider
14//! - GoogleProvider
15//! - MicrosoftProvider
16//! - FacebookProvider
17//! - SpotifyProvider
18//!
19//! in your project, pass to the ```new``` function:
20//!
21//! - **client_id:** Unique ID from the app created in your provider
22//! - **secret_id:** Secret token from your app inside the provider, this token needs to be hidden from the users
23//! - **redirect_url:** URL from your backend that will accept the return from the provider
24//!
25//! If you are using **``CustomProvider``** you need to pass:
26//!
27//! - **auth_url:** URL from your provider that is used to get the permission of your app access user account
28//! - **token_url:** URL that is used to generate the auth token
29//!
30//! The structure of this project is separated into two steps:
31//!
32//! ### 1. Generate the URL
33//!
34//! This step will create a URL to redirect the user to the provider to execute the authorization of your app access to the user info.
35//!
36//! The URL has this format (Github example): https://github.com/login/oauth/authorize?response_type=code&client_id={CLIENT_ID}&state={RANDOM_STATE}&code_challenge={RANDOM_STATE}&code_challenge_method=S256&redirect_uri={REDIRECT_URL}&scope={SCOPES}
37//!
38//! This step is important because that will generate the VERIFIER field, it is needed to save in some place (memory, db...) with the state field, the state will be your ID to get the verifier in the second step.
39//!
40//! ### 2. Callback URL
41//!
42//! After the user accepts the auth from the provider, it will redirect the user to the specific URL that you added in the config of the provider ``redirect_url``, and is important to remember that the same URL should be set in the oauth-axum params, if it is not the same an error will happen.
43//! This redirect will have two query parameters, CODE and STATE, we need to generate a token from the code and verifier fields, which is the reason that in the first step, you need to save the verifier and state together.
44//! After that, you will have a token to access the API in the provider.
45//!
46//! ## Example
47//!
48//! This method is for a small project that will run in one unique instance of Axum. It saves the state and verifier in memory, which can be accessible in the callback URL call.
49//!
50//! ```rust
51//! mod utils;
52//! use std::sync::Arc;
53//!
54//! use axum::extract::Query;
55//! use axum::Router;
56//! use axum::{routing::get, Extension};
57//! use oauth_axum::providers::twitter::TwitterProvider;
58//! use oauth_axum::{CustomProvider, OAuthClient};
59//!
60//! use crate::utils::memory_db_util::AxumState;
61//!
62//! #[derive(Clone, serde::Deserialize)]
63//! pub struct QueryAxumCallback {
64//! pub code: String,
65//! pub state: String,
66//! }
67//!
68//! #[tokio::main]
69//! async fn main() {
70//! dotenv::from_filename("examples/.env").ok();
71//! println!("Starting server...");
72//!
73//! let state = Arc::new(AxumState::new());
74//! let app = Router::new()
75//! .route("/", get(create_url))
76//! .route("/api/v1/twitter/callback", get(callback))
77//! .layer(Extension(state.clone()));
78//!
79//! println!("🚀 Server started successfully");
80//! let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
81//! .await
82//! .unwrap();
83//! axum::serve(listener, app).await.unwrap();
84//! }
85//!
86//! fn get_client() -> CustomProvider {
87//! TwitterProvider::new(
88//! std::env::var("TWITTER_CLIENT_ID").expect("TWITTER_CLIENT_ID must be set"),
89//! std::env::var("TWITTER_SECRET").expect("TWITTER_SECRET must be set"),
90//! "http://localhost:3000/api/v1/twitter/callback".to_string(),
91//! )
92//! }
93//!
94//! pub async fn create_url(Extension(state): Extension<Arc<AxumState>>) -> String {
95//! let state_oauth = get_client()
96//! .generate_url(
97//! Vec::from(["users.read".to_string()]),
98//! |state_e| async move {
99//! //SAVE THE DATA IN THE DB OR MEMORY
100//! //state should be your ID
101//! state.set(state_e.state, state_e.verifier);
102//! },
103//! )
104//! .await
105//! .ok()
106//! .unwrap()
107//! .state
108//! .unwrap();
109//!
110//! state_oauth.url_generated.unwrap()
111//! }
112//!
113//! pub async fn callback(
114//! Extension(state): Extension<Arc<AxumState>>,
115//! Query(queries): Query<QueryAxumCallback>,
116//! ) -> String {
117//! println!("{:?}", state.clone().get_all_items());
118//! // GET DATA FROM DB OR MEMORY
119//! // get data using state as ID
120//! let item = state.get(queries.state.clone());
121//! get_client()
122//! .generate_token(queries.code, item.unwrap())
123//! .await
124//! .ok()
125//! .unwrap()
126//! }
127//! ```
128//!
129//! # Next Steps of Development
130//!
131//! - Add all tests
132//! - Add more Providers
133//!
134
135pub mod error;
136pub mod providers;
137
138use async_trait::async_trait;
139use error::OauthError;
140use std::future::Future;
141
142use oauth2::reqwest::async_http_client;
143use oauth2::{
144 basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
145 Scope, TokenUrl,
146};
147use oauth2::{AuthorizationCode, PkceCodeVerifier, TokenResponse};
148
149#[derive(Clone)]
150pub struct CustomProvider {
151 pub auth_url: String,
152 pub token_url: String,
153 pub client_id: String,
154 pub client_secret: String,
155 pub redirect_url: String,
156 pub state: Option<StateAuth>,
157}
158
159#[derive(Clone)]
160pub enum MethodExecute {
161 DB,
162 MEMORY,
163}
164
165#[derive(Clone, Debug)]
166pub struct StateAuth {
167 pub url_generated: Option<String>,
168 pub state: String,
169 pub verifier: String,
170}
171
172impl CustomProvider {
173 pub fn new(
174 auth_url: String,
175 token_url: String,
176 client_id: String,
177 client_secret: String,
178 redirect_url: String,
179 ) -> Self {
180 CustomProvider {
181 auth_url,
182 token_url,
183 client_id,
184 client_secret,
185 redirect_url,
186 state: None,
187 }
188 }
189}
190
191/// OAuthClient is the main struct of the lib, it will handle all the connection with the provider
192#[async_trait]
193pub trait OAuthClient {
194 fn get_client(&self) -> Result<BasicClient, OauthError>;
195
196 /// Get fields data from generated URL
197 /// # Return
198 /// StateAuth - The state, verifier and url_generated
199 fn get_state(&self) -> Option<StateAuth>;
200
201 /// Generate the URL to redirect the user to the provider
202 /// # Arguments
203 /// * `scopes` - Vec<String> - The scopes that you want to access in the provider
204 /// * `save` - F - The function that will use to save your state in the db/memory
205 async fn generate_url<F, Fut>(
206 mut self,
207 scopes: Vec<String>,
208 save: F,
209 ) -> Result<Box<Self>, OauthError>
210 where
211 F: FnOnce(StateAuth) -> Fut + Send,
212 Fut: Future<Output = ()> + Send;
213
214 /// Generate the token from the code and verifier
215 /// # Arguments
216 /// * `code` - String - The code that the provider will return after the user accept the auth
217 /// * `verifier` - String - The verifier that was generated in the first step
218 /// # Return
219 /// The token generated
220 async fn generate_token(&self, code: String, verifier: String) -> Result<String, OauthError>;
221}
222
223#[async_trait]
224impl OAuthClient for CustomProvider {
225 fn get_client(&self) -> Result<BasicClient, OauthError> {
226 Ok(BasicClient::new(
227 ClientId::new(self.client_id.clone()),
228 Some(ClientSecret::new(self.client_secret.clone())),
229 AuthUrl::new(self.auth_url.clone()).map_err(|_| OauthError::AuthUrlCreationFailed)?,
230 Some(TokenUrl::new(self.token_url.clone()).unwrap()),
231 )
232 .set_redirect_uri(RedirectUrl::new(self.redirect_url.clone()).unwrap()))
233 }
234
235 fn get_state(&self) -> Option<StateAuth> {
236 self.state.clone()
237 }
238
239 async fn generate_url<F, Fut>(
240 mut self,
241 scopes: Vec<String>,
242 save: F,
243 ) -> Result<Box<Self>, OauthError>
244 where
245 F: FnOnce(StateAuth) -> Fut + Send,
246 Fut: Future<Output = ()> + Send,
247 {
248 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
249
250 let binding = self.get_client();
251 let (auth_url, csrf_token) = binding?
252 .authorize_url(CsrfToken::new_random)
253 .add_scopes(scopes.into_iter().map(Scope::new).collect::<Vec<Scope>>())
254 .set_pkce_challenge(pkce_challenge)
255 .url();
256
257 let state = StateAuth {
258 url_generated: Some(auth_url.to_string()),
259 state: csrf_token.secret().to_string(),
260 verifier: pkce_verifier.secret().to_string(),
261 };
262
263 self.state = Some(state.clone());
264 save(state).await;
265
266 Ok(Box::new(self.clone()))
267 }
268
269 async fn generate_token(&self, code: String, verifier: String) -> Result<String, OauthError> {
270 let token = self
271 .get_client()?
272 .exchange_code(AuthorizationCode::new(code.clone()))
273 .set_pkce_verifier(PkceCodeVerifier::new(verifier.clone()))
274 .request_async(async_http_client)
275 .await
276 .map_err(|_| OauthError::TokenRequestFailed)?;
277 Ok(token.access_token().secret().to_string())
278 }
279}