axum_oidc/
lib.rs

1#![deny(unsafe_code)]
2#![deny(clippy::unwrap_used)]
3#![deny(warnings)]
4#![doc = include_str!("../README.md")]
5
6use crate::error::Error;
7use http::Uri;
8use openidconnect::{
9    core::{
10        CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod,
11        CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType,
12        CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm,
13        CoreJwsSigningAlgorithm, CoreResponseMode, CoreResponseType, CoreRevocableToken,
14        CoreRevocationErrorResponse, CoreSubjectIdentifierType, CoreTokenIntrospectionResponse,
15        CoreTokenType,
16    },
17    AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, HttpRequest,
18    HttpResponse, IdTokenFields, IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken,
19    StandardErrorResponse, StandardTokenResponse,
20};
21use serde::{de::DeserializeOwned, Deserialize, Serialize};
22
23pub mod error;
24mod extractor;
25mod middleware;
26
27pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout};
28pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
29
30const SESSION_KEY: &str = "axum-oidc";
31
32pub trait AdditionalClaims:
33    openidconnect::AdditionalClaims + Clone + Sync + Send + Serialize + DeserializeOwned
34{
35}
36
37type OidcTokenResponse<AC> = StandardTokenResponse<
38    IdTokenFields<
39        AC,
40        EmptyExtraTokenFields,
41        CoreGenderClaim,
42        CoreJweContentEncryptionAlgorithm,
43        CoreJwsSigningAlgorithm,
44        CoreJsonWebKeyType,
45    >,
46    CoreTokenType,
47>;
48
49pub type IdToken<AZ> = openidconnect::IdToken<
50    AZ,
51    CoreGenderClaim,
52    CoreJweContentEncryptionAlgorithm,
53    CoreJwsSigningAlgorithm,
54    CoreJsonWebKeyType,
55>;
56
57type Client<AC> = openidconnect::Client<
58    AC,
59    CoreAuthDisplay,
60    CoreGenderClaim,
61    CoreJweContentEncryptionAlgorithm,
62    CoreJwsSigningAlgorithm,
63    CoreJsonWebKeyType,
64    CoreJsonWebKeyUse,
65    CoreJsonWebKey,
66    CoreAuthPrompt,
67    StandardErrorResponse<CoreErrorResponseType>,
68    OidcTokenResponse<AC>,
69    CoreTokenType,
70    CoreTokenIntrospectionResponse,
71    CoreRevocableToken,
72    CoreRevocationErrorResponse,
73>;
74
75pub type ProviderMetadata = openidconnect::ProviderMetadata<
76    AdditionalProviderMetadata,
77    CoreAuthDisplay,
78    CoreClientAuthMethod,
79    CoreClaimName,
80    CoreClaimType,
81    CoreGrantType,
82    CoreJweContentEncryptionAlgorithm,
83    CoreJweKeyManagementAlgorithm,
84    CoreJwsSigningAlgorithm,
85    CoreJsonWebKeyType,
86    CoreJsonWebKeyUse,
87    CoreJsonWebKey,
88    CoreResponseMode,
89    CoreResponseType,
90    CoreSubjectIdentifierType,
91>;
92
93pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
94
95/// OpenID Connect Client
96#[derive(Clone)]
97pub struct OidcClient<AC: AdditionalClaims> {
98    scopes: Vec<String>,
99    client_id: String,
100    client: Client<AC>,
101    http_client: reqwest::Client,
102    application_base_url: Uri,
103    end_session_endpoint: Option<Uri>,
104}
105
106impl<AC: AdditionalClaims> OidcClient<AC> {
107    /// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
108    pub fn from_provider_metadata(
109        provider_metadata: ProviderMetadata,
110        application_base_url: Uri,
111        client_id: String,
112        client_secret: Option<String>,
113        scopes: Vec<String>,
114    ) -> Result<Self, Error> {
115        let end_session_endpoint = provider_metadata
116            .additional_metadata()
117            .end_session_endpoint
118            .clone()
119            .map(Uri::from_maybe_shared)
120            .transpose()
121            .map_err(Error::InvalidEndSessionEndpoint)?;
122        let client = Client::from_provider_metadata(
123            provider_metadata,
124            ClientId::new(client_id.clone()),
125            client_secret.map(ClientSecret::new),
126        );
127        Ok(Self {
128            scopes,
129            client,
130            client_id,
131            application_base_url,
132            end_session_endpoint,
133            http_client: reqwest::Client::default(),
134        })
135    }
136    /// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
137    pub fn from_provider_metadata_and_client(
138        provider_metadata: ProviderMetadata,
139        application_base_url: Uri,
140        client_id: String,
141        client_secret: Option<String>,
142        scopes: Vec<String>,
143        http_client: reqwest::Client,
144    ) -> Result<Self, Error> {
145        let end_session_endpoint = provider_metadata
146            .additional_metadata()
147            .end_session_endpoint
148            .clone()
149            .map(Uri::from_maybe_shared)
150            .transpose()
151            .map_err(Error::InvalidEndSessionEndpoint)?;
152        let client = Client::from_provider_metadata(
153            provider_metadata,
154            ClientId::new(client_id.clone()),
155            client_secret.map(ClientSecret::new),
156        );
157        Ok(Self {
158            scopes,
159            client,
160            client_id,
161            application_base_url,
162            end_session_endpoint,
163            http_client,
164        })
165    }
166
167    /// create a new [`OidcClient`] by fetching the required information from the
168    /// `/.well-known/openid-configuration` endpoint of the issuer.
169    pub async fn discover_new(
170        application_base_url: Uri,
171        issuer: String,
172        client_id: String,
173        client_secret: Option<String>,
174        scopes: Vec<String>,
175    ) -> Result<Self, Error> {
176        let client = reqwest::Client::default();
177        Self::discover_new_with_client(
178            application_base_url,
179            issuer,
180            client_id,
181            client_secret,
182            scopes,
183            &client,
184        )
185        .await
186    }
187
188    /// create a new [`OidcClient`] by fetching the required information from the
189    /// `/.well-known/openid-configuration` endpoint of the issuer using the provided
190    /// `reqwest::Client`.
191    pub async fn discover_new_with_client(
192        application_base_url: Uri,
193        issuer: String,
194        client_id: String,
195        client_secret: Option<String>,
196        scopes: Vec<String>,
197        //TODO remove borrow with next breaking version
198        client: &reqwest::Client,
199    ) -> Result<Self, Error> {
200        // modified version of `openidconnect::reqwest::async_client::async_http_client`.
201        let async_http_client = |request: HttpRequest| async move {
202            let mut request_builder = client
203                .request(request.method, request.url.as_str())
204                .body(request.body);
205            for (name, value) in &request.headers {
206                request_builder = request_builder.header(name.as_str(), value.as_bytes());
207            }
208            let request = request_builder
209                .build()
210                .map_err(openidconnect::reqwest::Error::Reqwest)?;
211
212            let response = client
213                .execute(request)
214                .await
215                .map_err(openidconnect::reqwest::Error::Reqwest)?;
216
217            let status_code = response.status();
218            let headers = response.headers().to_owned();
219            let chunks = response
220                .bytes()
221                .await
222                .map_err(openidconnect::reqwest::Error::Reqwest)?;
223            Ok(HttpResponse {
224                status_code,
225                headers,
226                body: chunks.to_vec(),
227            })
228        };
229
230        let provider_metadata =
231            ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
232        Self::from_provider_metadata_and_client(
233            provider_metadata,
234            application_base_url,
235            client_id,
236            client_secret,
237            scopes,
238            client.clone(),
239        )
240    }
241}
242
243/// an empty struct to be used as the default type for the additional claims generic
244#[derive(Deserialize, Serialize, Debug, Clone, Copy, Default)]
245pub struct EmptyAdditionalClaims {}
246impl AdditionalClaims for EmptyAdditionalClaims {}
247impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {}
248
249/// response data of the openid issuer after login
250#[derive(Debug, Deserialize)]
251struct OidcQuery {
252    code: String,
253    state: String,
254    #[allow(dead_code)]
255    session_state: Option<String>,
256}
257
258/// oidc session
259#[derive(Serialize, Deserialize, Debug)]
260#[serde(bound = "AC: Serialize + DeserializeOwned")]
261struct OidcSession<AC: AdditionalClaims> {
262    nonce: Nonce,
263    csrf_token: CsrfToken,
264    pkce_verifier: PkceCodeVerifier,
265    authenticated: Option<AuthenticatedSession<AC>>,
266    refresh_token: Option<RefreshToken>,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270#[serde(bound = "AC: Serialize + DeserializeOwned")]
271struct AuthenticatedSession<AC: AdditionalClaims> {
272    id_token: IdToken<AC>,
273    access_token: AccessToken,
274}
275
276/// additional metadata that is discovered on client creation via the
277/// `.well-knwon/openid-configuration` endpoint.
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct AdditionalProviderMetadata {
280    end_session_endpoint: Option<String>,
281}
282impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {}
283
284/// response extension flag to signal the [`OidcAuthLayer`] that the session should be cleared.
285#[derive(Clone, Copy)]
286pub struct ClearSessionFlag;