1use std::sync::Arc;
2
3use async_trait::async_trait;
4use shield::{Action, Method, ShieldError, User, erased_method};
5
6use crate::{
7 actions::{OauthSignInAction, OauthSignInCallbackAction, OauthSignOutAction},
8 options::OauthOptions,
9 provider::OauthProvider,
10 session::OauthSession,
11 storage::OauthStorage,
12};
13
14pub const OAUTH_METHOD_ID: &str = "oauth";
15
16pub struct OauthMethod<U: User> {
17 options: OauthOptions,
18 providers: Vec<OauthProvider>,
19 storage: Arc<dyn OauthStorage<U>>,
20}
21
22impl<U: User> OauthMethod<U> {
23 pub fn new<S: OauthStorage<U> + 'static>(storage: S) -> Self {
24 Self {
25 options: OauthOptions::default(),
26 providers: vec![],
27 storage: Arc::new(storage),
28 }
29 }
30
31 pub fn with_options(mut self, options: OauthOptions) -> Self {
32 self.options = options;
33 self
34 }
35
36 pub fn with_providers<I: IntoIterator<Item = OauthProvider>>(mut self, providers: I) -> Self {
37 self.providers = providers.into_iter().collect();
38 self
39 }
40
41 async fn oauth_provider_by_id_or_slug(
42 &self,
43 provider_id: &str,
44 ) -> Result<Option<OauthProvider>, ShieldError> {
45 if let Some(provider) = self
46 .providers
47 .iter()
48 .find(|provider| provider.id == provider_id)
49 {
50 return Ok(Some(provider.clone()));
51 }
52
53 if let Some(provider) = self
54 .storage
55 .oauth_provider_by_id_or_slug(provider_id)
56 .await?
57 {
58 return Ok(Some(provider));
59 }
60
61 Ok(None)
62 }
63}
64
65#[async_trait]
66impl<U: User + 'static> Method for OauthMethod<U> {
67 type Provider = OauthProvider;
68 type Session = OauthSession;
69
70 fn id(&self) -> String {
71 OAUTH_METHOD_ID.to_owned()
72 }
73
74 fn actions(&self) -> Vec<Box<dyn Action<Self::Provider, Self::Session>>> {
75 vec![
76 Box::new(OauthSignInAction),
77 Box::new(OauthSignInCallbackAction::new(
78 self.options.clone(),
79 self.storage.clone(),
80 )),
81 Box::new(OauthSignOutAction),
82 ]
83 }
84
85 async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
86 Ok(self
87 .providers
88 .iter()
89 .cloned()
90 .chain(self.storage.oauth_providers().await?)
91 .collect())
92 }
93
94 async fn provider_by_id(
95 &self,
96 provider_id: Option<&str>,
97 ) -> Result<Option<Self::Provider>, ShieldError> {
98 if let Some(provider_id) = provider_id {
99 self.oauth_provider_by_id_or_slug(provider_id).await
100 } else {
101 Ok(None)
102 }
103 }
104}
105
106erased_method!(OauthMethod, <U: User>);