1use std::sync::Arc;
2
3use async_trait::async_trait;
4use shield::{Action, Method, ShieldError, User, erased_method};
5
6use crate::{
7 actions::{OidcSignInAction, OidcSignInCallbackAction, OidcSignOutAction},
8 options::OidcOptions,
9 provider::OidcProvider,
10 session::OidcSession,
11 storage::OidcStorage,
12};
13
14pub const OIDC_METHOD_ID: &str = "oidc";
15
16pub struct OidcMethod<U: User> {
17 options: OidcOptions,
18 providers: Vec<OidcProvider>,
19 storage: Arc<dyn OidcStorage<U>>,
20}
21
22impl<U: User> OidcMethod<U> {
23 pub fn new<S: OidcStorage<U> + 'static>(storage: S) -> Self {
24 Self {
25 options: OidcOptions::default(),
26 providers: vec![],
27 storage: Arc::new(storage),
28 }
29 }
30
31 pub fn with_options(mut self, options: OidcOptions) -> Self {
32 self.options = options;
33 self
34 }
35
36 pub fn with_providers<I: IntoIterator<Item = OidcProvider>>(mut self, providers: I) -> Self {
37 self.providers = providers.into_iter().collect();
38 self
39 }
40
41 async fn oidc_provider_by_id_or_slug(
42 &self,
43 provider_id: &str,
44 ) -> Result<Option<OidcProvider>, ShieldError> {
45 if let Some(provider) = self
46 .providers
47 .iter()
48 .find(|provider| provider.id == provider_id)
49 {
50 return Ok(Some(provider.clone()));
51 }
52
53 if let Some(provider) = self
54 .storage
55 .oidc_provider_by_id_or_slug(provider_id)
56 .await?
57 {
58 return Ok(Some(provider));
59 }
60
61 Ok(None)
62 }
63}
64
65#[async_trait]
66impl<U: User + 'static> Method for OidcMethod<U> {
67 type Provider = OidcProvider;
68 type Session = OidcSession;
69
70 fn id(&self) -> String {
71 OIDC_METHOD_ID.to_owned()
72 }
73
74 fn actions(&self) -> Vec<Box<dyn Action<Self::Provider, Self::Session>>> {
75 vec![
76 Box::new(OidcSignInAction),
77 Box::new(OidcSignInCallbackAction::new(
78 self.options.clone(),
79 self.storage.clone(),
80 )),
81 Box::new(OidcSignOutAction),
82 ]
83 }
84
85 async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
86 Ok(self
87 .providers
88 .iter()
89 .cloned()
90 .chain(self.storage.oidc_providers().await?)
91 .collect())
92 }
93
94 async fn provider_by_id(
95 &self,
96 provider_id: Option<&str>,
97 ) -> Result<Option<Self::Provider>, ShieldError> {
98 if let Some(provider_id) = provider_id {
99 self.oidc_provider_by_id_or_slug(provider_id).await
100 } else {
101 Ok(None)
102 }
103 }
104}
105
106erased_method!(OidcMethod, <U: User>);