auth_framework/server/core/
client_registry.rs1use crate::errors::{AuthError, Result};
7use crate::storage::core::AuthStorage;
8use crate::storage::memory::InMemoryStorage;
9use std::sync::Arc;
10
11pub use crate::client::{ClientConfig, ClientType};
14
15#[derive(Clone)]
17pub struct ClientRegistry {
18 storage: Arc<dyn AuthStorage>,
19}
20
21impl ClientRegistry {
22 pub async fn new(storage: Arc<dyn AuthStorage>) -> Result<Self> {
24 Ok(Self { storage })
25 }
26
27 pub async fn register_client(&self, config: ClientConfig) -> Result<ClientConfig> {
29 self.validate_client_config(&config)?;
31
32 let client_key = format!("oauth_client:{}", config.client_id);
34 let client_data = serde_json::to_string(&config)
35 .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
36
37 self.storage
38 .store_kv(&client_key, client_data.as_bytes(), None)
39 .await?;
40
41 Ok(config)
42 }
43
44 pub async fn get_client(&self, client_id: &str) -> Result<Option<ClientConfig>> {
46 let client_key = format!("oauth_client:{}", client_id);
47
48 if let Some(client_data) = self.storage.get_kv(&client_key).await? {
49 let client_str = std::str::from_utf8(&client_data)
50 .map_err(|e| AuthError::internal(format!("Invalid UTF-8 in client data: {}", e)))?;
51 let config: ClientConfig = serde_json::from_str(client_str)
52 .map_err(|e| AuthError::internal(format!("Failed to deserialize client: {}", e)))?;
53 Ok(Some(config))
54 } else {
55 Ok(None)
56 }
57 }
58
59 pub async fn update_client(&self, client_id: &str, config: ClientConfig) -> Result<()> {
61 if config.client_id != client_id {
63 return Err(AuthError::validation("Client ID mismatch"));
64 }
65
66 self.validate_client_config(&config)?;
68
69 let client_key = format!("oauth_client:{}", client_id);
71 let client_data = serde_json::to_string(&config)
72 .map_err(|e| AuthError::internal(format!("Failed to serialize client: {}", e)))?;
73
74 self.storage
75 .store_kv(&client_key, client_data.as_bytes(), None)
76 .await?;
77
78 Ok(())
79 }
80
81 pub async fn delete_client(&self, client_id: &str) -> Result<()> {
83 let client_key = format!("oauth_client:{}", client_id);
84 self.storage.delete_kv(&client_key).await?;
85 Ok(())
86 }
87
88 pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result<bool> {
90 if let Some(client) = self.get_client(client_id).await? {
91 Ok(client.redirect_uris.contains(&redirect_uri.to_string()))
92 } else {
93 Ok(false)
94 }
95 }
96
97 pub async fn validate_scope(&self, client_id: &str, scope: &str) -> Result<bool> {
99 if let Some(client) = self.get_client(client_id).await? {
100 Ok(client.authorized_scopes.contains(&scope.to_string()))
101 } else {
102 Ok(false)
103 }
104 }
105
106 pub async fn validate_grant_type(&self, client_id: &str, grant_type: &str) -> Result<bool> {
108 if let Some(client) = self.get_client(client_id).await? {
109 Ok(client
110 .authorized_grant_types
111 .contains(&grant_type.to_string()))
112 } else {
113 Ok(false)
114 }
115 }
116
117 pub async fn authenticate_client(&self, client_id: &str, client_secret: &str) -> Result<bool> {
119 if let Some(client) = self.get_client(client_id).await? {
120 match (&client.client_type, &client.client_secret) {
121 (ClientType::Confidential, Some(stored_secret)) => {
122 Ok(crate::security::secure_utils::constant_time_compare(
124 client_secret.as_bytes(),
125 stored_secret.as_bytes(),
126 ))
127 }
128 (ClientType::Public, None) => {
129 Ok(true)
131 }
132 _ => Ok(false),
133 }
134 } else {
135 Ok(false)
136 }
137 }
138
139 fn validate_client_config(&self, config: &ClientConfig) -> Result<()> {
141 if config.client_id.is_empty() {
143 return Err(AuthError::validation("Client ID cannot be empty"));
144 }
145
146 if config.client_type == ClientType::Confidential && config.client_secret.is_none() {
148 return Err(AuthError::validation(
149 "Confidential clients must have a client secret",
150 ));
151 }
152
153 if config.client_type == ClientType::Public && config.client_secret.is_some() {
155 return Err(AuthError::validation(
156 "Public clients must not have a client secret",
157 ));
158 }
159
160 if config.redirect_uris.is_empty() {
162 return Err(AuthError::validation(
163 "At least one redirect URI must be provided",
164 ));
165 }
166
167 for uri in &config.redirect_uris {
169 if uri.is_empty() {
170 return Err(AuthError::validation("Redirect URI cannot be empty"));
171 }
172
173 if !uri.starts_with("https://") && !uri.starts_with("http://localhost") {
175 return Err(AuthError::validation(
176 "Redirect URIs must use HTTPS (except localhost)",
177 ));
178 }
179 }
180
181 if config.authorized_scopes.is_empty() {
183 return Err(AuthError::validation(
184 "At least one authorized scope must be provided",
185 ));
186 }
187
188 if config.authorized_grant_types.is_empty() {
190 return Err(AuthError::validation(
191 "At least one authorized grant type must be provided",
192 ));
193 }
194
195 if config.authorized_response_types.is_empty() {
197 return Err(AuthError::validation(
198 "At least one authorized response type must be provided",
199 ));
200 }
201
202 Ok(())
203 }
204}
205
206impl Default for ClientRegistry {
207 fn default() -> Self {
208 let storage =
210 if std::env::var("CLIENT_REGISTRY_STORAGE").unwrap_or_default() == "persistent" {
211 Arc::new(InMemoryStorage::new())
213 } else {
214 Arc::new(InMemoryStorage::new())
216 };
217
218 Self { storage }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::storage::memory::InMemoryStorage;
226
227 #[tokio::test]
228 async fn test_client_registry_operations() {
229 let storage = Arc::new(InMemoryStorage::new());
230 let registry = ClientRegistry::new(storage).await.unwrap();
231
232 let client_config = ClientConfig {
234 client_id: "test_client".to_string(),
235 client_type: ClientType::Confidential,
236 client_secret: Some("test_secret".to_string()),
237 redirect_uris: vec!["https://example.com/callback".to_string()].into(),
238 ..Default::default()
239 };
240
241 let registered_client = registry
243 .register_client(client_config.clone())
244 .await
245 .unwrap();
246 assert_eq!(registered_client.client_id, "test_client");
247
248 let retrieved_client = registry.get_client("test_client").await.unwrap().unwrap();
250 assert_eq!(retrieved_client.client_id, "test_client");
251 assert_eq!(retrieved_client.client_type, ClientType::Confidential);
252
253 let auth_result = registry
255 .authenticate_client("test_client", "test_secret")
256 .await
257 .unwrap();
258 assert!(auth_result);
259
260 let auth_fail = registry
261 .authenticate_client("test_client", "wrong_secret")
262 .await
263 .unwrap();
264 assert!(!auth_fail);
265
266 let valid_uri = registry
268 .validate_redirect_uri("test_client", "https://example.com/callback")
269 .await
270 .unwrap();
271 assert!(valid_uri);
272
273 let invalid_uri = registry
274 .validate_redirect_uri("test_client", "https://malicious.com/callback")
275 .await
276 .unwrap();
277 assert!(!invalid_uri);
278
279 registry.delete_client("test_client").await.unwrap();
281 let deleted_client = registry.get_client("test_client").await.unwrap();
282 assert!(deleted_client.is_none());
283 }
284
285 #[tokio::test]
286 async fn test_client_validation() {
287 let storage = Arc::new(InMemoryStorage::new());
288 let registry = ClientRegistry::new(storage).await.unwrap();
289
290 let invalid_config = ClientConfig {
292 client_id: "".to_string(),
293 ..Default::default()
294 };
295 assert!(registry.register_client(invalid_config).await.is_err());
296
297 let invalid_config = ClientConfig {
299 client_type: ClientType::Confidential,
300 client_secret: None,
301 ..Default::default()
302 };
303 assert!(registry.register_client(invalid_config).await.is_err());
304
305 let invalid_config = ClientConfig {
307 client_type: ClientType::Public,
308 client_secret: Some("secret".to_string()),
309 ..Default::default()
310 };
311 assert!(registry.register_client(invalid_config).await.is_err());
312
313 let invalid_config = ClientConfig {
315 redirect_uris: vec![].into(),
316 ..Default::default()
317 };
318 assert!(registry.register_client(invalid_config).await.is_err());
319
320 let invalid_config = ClientConfig {
322 redirect_uris: vec!["http://example.com/callback".to_string()].into(),
323 ..Default::default()
324 };
325 assert!(registry.register_client(invalid_config).await.is_err());
326 }
327}