oxide_auth_db/primitives/
db_registrar.rs

1use std::borrow::Cow;
2use std::iter::Extend;
3use once_cell::sync::Lazy;
4use oxide_auth::primitives::registrar::{
5    Argon2, BoundClient, Client, EncodedClient, PasswordPolicy, RegisteredClient, Registrar,
6    RegistrarError,
7};
8use oxide_auth::primitives::prelude::{ClientUrl, PreGrant, Scope};
9use crate::db_service::DataSource;
10use r2d2_redis::redis::RedisError;
11
12/// A database client service which implemented Registrar.
13/// db: repository service to query stored clients or regist new client.
14/// password_policy: to encode client_secret.
15pub struct DBRegistrar {
16    pub repo: DataSource,
17    password_policy: Option<Box<dyn PasswordPolicy>>,
18}
19
20/// methods to search and regist clients from DataSource.
21/// which should be implemented for all DataSource type.
22pub trait OauthClientDBRepository {
23    fn list(&self) -> anyhow::Result<Vec<EncodedClient>>;
24
25    fn find_client_by_id(&self, id: &str) -> anyhow::Result<EncodedClient>;
26
27    fn regist_from_encoded_client(&self, client: EncodedClient) -> anyhow::Result<()>;
28}
29
30///////////////////////////////////////////////////////////////////////////////////////////////////
31//                             Implementations of DB Registrars                                  //
32///////////////////////////////////////////////////////////////////////////////////////////////////
33
34static DEFAULT_PASSWORD_POLICY: Lazy<Argon2> = Lazy::new(|| Argon2::default());
35
36impl DBRegistrar {
37    /// Create an DB connection recording to features.
38    pub fn new(url: String, max_pool_size: u32, client_prefix: String) -> Result<Self, RedisError> {
39        let repo = DataSource::new(url, max_pool_size, client_prefix)?;
40        Ok(DBRegistrar {
41            repo,
42            password_policy: None,
43        })
44    }
45
46    /// Insert or update the client record.
47    pub fn register_client(&mut self, client: Client) -> Result<(), RegistrarError> {
48        let password_policy = Self::current_policy(&self.password_policy);
49        let encoded_client = client.encode(password_policy);
50
51        self.repo
52            .regist_from_encoded_client(encoded_client)
53            .map_err(|_e| RegistrarError::Unspecified)
54    }
55
56    /// Change how passwords are encoded while stored.
57    pub fn set_password_policy<P: PasswordPolicy + 'static>(&mut self, new_policy: P) {
58        self.password_policy = Some(Box::new(new_policy))
59    }
60
61    // This is not an instance method because it needs to borrow the box but register needs &mut
62    fn current_policy<'a>(policy: &'a Option<Box<dyn PasswordPolicy>>) -> &'a dyn PasswordPolicy {
63        policy
64            .as_ref()
65            .map(|boxed| &**boxed)
66            .unwrap_or(&*DEFAULT_PASSWORD_POLICY)
67    }
68}
69
70impl Extend<Client> for DBRegistrar {
71    fn extend<I>(&mut self, iter: I)
72    where
73        I: IntoIterator<Item = Client>,
74    {
75        iter.into_iter().for_each(|client| {
76            let _ = self.register_client(client);
77        })
78    }
79}
80
81impl Registrar for DBRegistrar {
82    fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
83        let client = match self.repo.find_client_by_id(bound.client_id.as_ref()) {
84            Ok(detail) => detail,
85            _ => return Err(RegistrarError::Unspecified),
86        };
87        // Perform exact matching as motivated in the rfc
88        let registered_url = match bound.redirect_uri {
89            None => client.redirect_uri.clone(),
90            Some(ref url) => {
91                let original = std::iter::once(&client.redirect_uri);
92                let alternatives = client.additional_redirect_uris.iter();
93                if let Some(registered) = original
94                    .chain(alternatives)
95                    .find(|&registered| *registered == *url.as_ref())
96                {
97                    registered.clone()
98                } else {
99                    return Err(RegistrarError::Unspecified);
100                }
101            }
102        };
103        Ok(BoundClient {
104            client_id: bound.client_id,
105            redirect_uri: Cow::Owned(registered_url),
106        })
107    }
108
109    fn negotiate<'a>(
110        &self, bound: BoundClient<'a>, _scope: Option<Scope>,
111    ) -> Result<PreGrant, RegistrarError> {
112        let client = self
113            .repo
114            .find_client_by_id(&bound.client_id)
115            .map_err(|_e| RegistrarError::Unspecified)?;
116        Ok(PreGrant {
117            client_id: bound.client_id.into_owned(),
118            redirect_uri: bound.redirect_uri.into_owned(),
119            scope: client.default_scope,
120        })
121    }
122
123    fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
124        let password_policy = Self::current_policy(&self.password_policy);
125
126        let client = self
127            .repo
128            .find_client_by_id(client_id)
129            .map_err(|_e| RegistrarError::Unspecified);
130        client.and_then(|op_client| {
131            RegisteredClient::new(&op_client, password_policy).check_authentication(passphrase)
132        })?;
133        Ok(())
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use oxide_auth::primitives::registrar::{ExactUrl, RegisteredUrl};
141    use std::str::FromStr;
142
143    #[test]
144    fn public_client() {
145        let policy = Argon2::default();
146        let client = Client::public(
147            "ClientId",
148            RegisteredUrl::Exact(ExactUrl::from_str("https://example.com").unwrap()),
149            "default".parse().unwrap(),
150        )
151        .encode(&policy);
152        let client = RegisteredClient::new(&client, &policy);
153
154        // Providing no authentication data is ok
155        assert!(client.check_authentication(None).is_ok());
156        // Any authentication data is a fail
157        assert!(client.check_authentication(Some(b"")).is_err());
158    }
159
160    #[test]
161    fn confidential_client() {
162        let policy = Argon2::default();
163        let pass = b"AB3fAj6GJpdxmEVeNCyPoA==";
164        let client = Client::confidential(
165            "ClientId",
166            RegisteredUrl::Exact(ExactUrl::from_str("https://example.com").unwrap()),
167            "default".parse().unwrap(),
168            pass,
169        )
170        .encode(&policy);
171        let client = RegisteredClient::new(&client, &policy);
172        assert!(client.check_authentication(None).is_err());
173        assert!(client.check_authentication(Some(pass)).is_ok());
174        assert!(client.check_authentication(Some(b"not the passphrase")).is_err());
175        assert!(client.check_authentication(Some(b"")).is_err());
176    }
177
178    #[test]
179    fn with_additional_redirect_uris() {
180        if crate::requires_redis_and_should_skip() {
181            return;
182        }
183
184        let client_id = "ClientId";
185        let redirect_uri =
186            RegisteredUrl::from(ExactUrl::new("https://example.com/foo".parse().unwrap()).unwrap());
187        let additional_redirect_uris: Vec<RegisteredUrl> = vec![RegisteredUrl::from(
188            ExactUrl::new("https://example.com/bar".parse().unwrap()).unwrap(),
189        )];
190        let default_scope = "default-scope".parse().unwrap();
191        let client = Client::public(client_id, redirect_uri, default_scope)
192            .with_additional_redirect_uris(additional_redirect_uris);
193        let mut db_registrar = DBRegistrar::new(
194            "redis://localhost/3".parse().unwrap(),
195            32,
196            "client:".parse().unwrap(),
197        )
198        .unwrap();
199        db_registrar.register_client(client).unwrap();
200
201        assert_eq!(
202            db_registrar
203                .bound_redirect(ClientUrl {
204                    client_id: Cow::from(client_id),
205                    redirect_uri: Some(Cow::Borrowed(&"https://example.com/foo".parse().unwrap()))
206                })
207                .unwrap()
208                .redirect_uri,
209            Cow::Owned::<RegisteredUrl>(RegisteredUrl::from(
210                ExactUrl::new("https://example.com/foo".parse().unwrap()).unwrap()
211            ))
212        );
213
214        assert_eq!(
215            db_registrar
216                .bound_redirect(ClientUrl {
217                    client_id: Cow::from(client_id),
218                    redirect_uri: Some(Cow::Borrowed(&"https://example.com/bar".parse().unwrap()))
219                })
220                .unwrap()
221                .redirect_uri,
222            Cow::Owned::<RegisteredUrl>(RegisteredUrl::from(
223                ExactUrl::new("https://example.com/bar".parse().unwrap()).unwrap()
224            ))
225        );
226
227        assert!(db_registrar
228            .bound_redirect(ClientUrl {
229                client_id: Cow::from(client_id),
230                redirect_uri: Some(Cow::Borrowed(&"https://example.com/baz".parse().unwrap()))
231            })
232            .is_err());
233    }
234
235    #[test]
236    fn client_service() {
237        if crate::requires_redis_and_should_skip() {
238            return;
239        }
240
241        let mut oauth_service = DBRegistrar::new(
242            "redis://localhost/3".parse().unwrap(),
243            32,
244            "client:".parse().unwrap(),
245        )
246        .unwrap();
247        let public_id = "PrivateClientId";
248        let client_url = "https://example.com";
249
250        let private_id = "PublicClientId";
251        let private_passphrase = b"WOJJCcS8WyS2aGmJK6ZADg==";
252
253        let public_client = Client::public(
254            public_id,
255            RegisteredUrl::Exact(ExactUrl::new(client_url.parse().unwrap()).unwrap()),
256            "default".parse().unwrap(),
257        );
258
259        oauth_service.register_client(public_client).unwrap();
260        oauth_service
261            .check(public_id, None)
262            .expect("Authorization of public client has changed");
263        oauth_service
264            .check(public_id, Some(b""))
265            .err()
266            .expect("Authorization with password succeeded");
267
268        let private_client = Client::confidential(
269            private_id,
270            RegisteredUrl::Exact(ExactUrl::new(client_url.parse().unwrap()).unwrap()),
271            "default".parse().unwrap(),
272            private_passphrase,
273        );
274
275        oauth_service.register_client(private_client).unwrap();
276
277        oauth_service
278            .check(private_id, Some(private_passphrase))
279            .expect("Authorization with right password did not succeed");
280        oauth_service
281            .check(private_id, Some(b"Not the private passphrase"))
282            .err()
283            .expect("Authorization succeed with wrong password");
284    }
285}