oxide_auth_db/db_service/
redis.rs

1use crate::primitives::db_registrar::OauthClientDBRepository;
2
3use oxide_auth::primitives::prelude::Scope;
4use oxide_auth::primitives::registrar::{ClientType, EncodedClient, RegisteredUrl, ExactUrl};
5
6use r2d2_redis::r2d2::Pool;
7use r2d2_redis::redis::{Commands, RedisError, ErrorKind};
8use r2d2_redis::RedisConnectionManager;
9use std::str::FromStr;
10use serde::{Serialize, Deserialize};
11use url::Url;
12
13// // TODO 参数化
14// pub const CLIENT_PREFIX: &str = "client:";
15
16/// redis datasource to Client entries.
17#[derive(Debug, Clone)]
18pub struct RedisDataSource {
19    url: String,
20    pool: Pool<RedisConnectionManager>,
21    client_prefix: String,
22}
23
24/// A client whose credentials have been wrapped by a password policy.
25///
26/// This provides a standard encoding for `Registrars` who wish to store their clients and makes it
27/// possible to test password policies.
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct StringfiedEncodedClient {
30    /// The id of this client. If this is was registered at a `Registrar`, this should be a key
31    /// to the instance.
32    pub client_id: String,
33
34    /// The registered redirect uri.
35    /// Unlike `additional_redirect_uris`, this is registered as the default redirect uri
36    /// and will be replaced if, for example, no `redirect_uri` is specified in the request parameter.
37    pub redirect_uri: String,
38
39    /// The redirect uris that can be registered in addition to the `redirect_uri`.
40    /// If you want to register multiple redirect uris, register them together with `redirect_uri`.
41    pub additional_redirect_uris: Vec<String>,
42
43    /// The scope the client gets if none was given.
44    pub default_scope: Option<String>,
45
46    /// client_secret, for authentication.
47    pub client_secret: Option<String>,
48}
49
50impl StringfiedEncodedClient {
51    pub fn to_encoded_client(&self) -> anyhow::Result<EncodedClient> {
52        let redirect_uri = RegisteredUrl::from(ExactUrl::from_str(&self.redirect_uri)?);
53        let uris = &self.additional_redirect_uris;
54        let additional_redirect_uris = uris.iter().fold(vec![], |mut us, u| {
55            us.push(RegisteredUrl::from(ExactUrl::from_str(u).unwrap()));
56            us
57        });
58
59        let client_type = match &self.client_secret {
60            None => ClientType::Public,
61            Some(secret) => ClientType::Confidential {
62                passdata: secret.to_owned().into_bytes(),
63            },
64        };
65
66        Ok(EncodedClient {
67            client_id: (&self.client_id).parse().unwrap(),
68            redirect_uri,
69            additional_redirect_uris,
70            default_scope: Scope::from_str(
71                self.default_scope.as_ref().unwrap_or(&"".to_string()).as_ref(),
72            )
73            .unwrap(),
74            encoded_client: client_type,
75        })
76    }
77
78    pub fn from_encoded_client(encoded_client: &EncodedClient) -> Self {
79        let additional_redirect_uris = encoded_client
80            .additional_redirect_uris
81            .iter()
82            .map(|u| u.to_owned().as_str().parse().unwrap())
83            .collect();
84        let default_scope = Some(encoded_client.default_scope.to_string());
85        let client_secret = match &encoded_client.encoded_client {
86            ClientType::Public => None,
87            ClientType::Confidential { passdata } => Some(String::from_utf8(passdata.to_vec()).unwrap()),
88        };
89        StringfiedEncodedClient {
90            client_id: encoded_client.client_id.to_owned(),
91            redirect_uri: encoded_client.redirect_uri.to_owned().as_str().parse().unwrap(),
92            additional_redirect_uris,
93            default_scope,
94            client_secret,
95        }
96    }
97}
98
99impl RedisDataSource {
100    pub fn new(url: String, max_pool_size: u32, client_prefix: String) -> Result<Self, RedisError> {
101        let manager = r2d2_redis::RedisConnectionManager::new(url.as_str())?;
102        let pool = Pool::builder().max_size(max_pool_size).build(manager);
103        match pool {
104            Ok(pool) => Ok(RedisDataSource {
105                url,
106                pool,
107                client_prefix,
108            }),
109            Err(_e) => Err(RedisError::from((ErrorKind::ClientError, "Build pool error."))),
110        }
111    }
112
113    pub fn new_with_url(
114        url: Url, max_pool_size: u32, client_prefix: String,
115    ) -> Result<Self, RedisError> {
116        RedisDataSource::new(url.into(), max_pool_size, client_prefix)
117    }
118
119    pub fn get_url(&self) -> String {
120        self.url.clone()
121    }
122    pub fn get_pool(&self) -> Pool<RedisConnectionManager> {
123        self.pool.clone()
124    }
125}
126
127impl RedisDataSource {
128    /// users can regist to redis a custom client struct which can be Serialized and Deserialized.
129    pub fn regist(&self, detail: &StringfiedEncodedClient) -> anyhow::Result<()> {
130        let mut pool = self.pool.get()?;
131        let client_str = serde_json::to_string(&detail)?;
132        pool.set(&(self.client_prefix.to_owned() + &detail.client_id), client_str)?;
133        Ok(())
134    }
135}
136
137impl OauthClientDBRepository for RedisDataSource {
138    fn list(&self) -> anyhow::Result<Vec<EncodedClient>> {
139        let mut encoded_clients: Vec<EncodedClient> = vec![];
140        let mut r = self.pool.get()?;
141        let keys = r.keys::<&str, Vec<String>>(&self.client_prefix)?;
142        for key in keys {
143            let clients_str = r.get::<String, String>(key)?;
144            let stringfied_client = serde_json::from_str::<StringfiedEncodedClient>(&clients_str)?;
145            encoded_clients.push(stringfied_client.to_encoded_client()?);
146        }
147        Ok(encoded_clients)
148    }
149
150    fn find_client_by_id(&self, id: &str) -> anyhow::Result<EncodedClient> {
151        let mut r = self.pool.get()?;
152        let client_str = r.get::<&str, String>(&(self.client_prefix.to_owned() + id))?;
153        let stringfied_client = serde_json::from_str::<StringfiedEncodedClient>(&client_str)?;
154        Ok(stringfied_client.to_encoded_client()?)
155    }
156
157    fn regist_from_encoded_client(&self, client: EncodedClient) -> anyhow::Result<()> {
158        let detail = StringfiedEncodedClient::from_encoded_client(&client);
159        self.regist(&detail)
160    }
161}