1use crate::http::security::user::User;
26use std::collections::HashMap;
27use std::sync::Arc;
28use std::time::Duration;
29
30type MockUserEntry = (String, HashMap<String, Vec<String>>, Vec<String>);
32
33#[derive(Debug, Clone)]
35pub struct LdapConfig {
36 pub url: String,
38 pub base_dn: String,
40 pub user_search_base: String,
42 pub user_search_filter: String,
44 pub group_search_base: String,
46 pub group_search_filter: String,
48 pub group_role_attribute: String,
50 pub bind_dn: Option<String>,
52 pub bind_password: Option<String>,
54 pub user_dn_pattern: Option<String>,
56 pub connect_timeout: Duration,
58 pub operation_timeout: Duration,
60 pub use_starttls: bool,
62 pub role_prefix: String,
64 pub convert_to_uppercase: bool,
66 pub username_attribute: String,
68 pub email_attribute: String,
70 pub display_name_attribute: String,
72 pub attribute_mappings: HashMap<String, String>,
74}
75
76impl Default for LdapConfig {
77 fn default() -> Self {
78 Self {
79 url: "ldap://localhost:389".to_string(),
80 base_dn: String::new(),
81 user_search_base: "ou=users".to_string(),
82 user_search_filter: "(uid={0})".to_string(),
83 group_search_base: "ou=groups".to_string(),
84 group_search_filter: "(member={0})".to_string(),
85 group_role_attribute: "cn".to_string(),
86 bind_dn: None,
87 bind_password: None,
88 user_dn_pattern: None,
89 connect_timeout: Duration::from_secs(5),
90 operation_timeout: Duration::from_secs(10),
91 use_starttls: false,
92 role_prefix: "ROLE_".to_string(),
93 convert_to_uppercase: true,
94 username_attribute: "uid".to_string(),
95 email_attribute: "mail".to_string(),
96 display_name_attribute: "cn".to_string(),
97 attribute_mappings: HashMap::new(),
98 }
99 }
100}
101
102impl LdapConfig {
103 pub fn new(url: impl Into<String>) -> Self {
105 Self {
106 url: url.into(),
107 ..Default::default()
108 }
109 }
110
111 pub fn active_directory(url: impl Into<String>, domain: impl Into<String>) -> Self {
113 let domain = domain.into();
114 let base_dn = domain
115 .split('.')
116 .map(|part| format!("dc={}", part))
117 .collect::<Vec<_>>()
118 .join(",");
119
120 Self {
121 url: url.into(),
122 base_dn,
123 user_search_filter: "(sAMAccountName={0})".to_string(),
124 group_search_filter: "(member:1.2.840.113556.1.4.1941:={0})".to_string(),
125 username_attribute: "sAMAccountName".to_string(),
126 display_name_attribute: "displayName".to_string(),
127 ..Default::default()
128 }
129 }
130
131 pub fn base_dn(mut self, dn: impl Into<String>) -> Self {
133 self.base_dn = dn.into();
134 self
135 }
136
137 pub fn user_search_base(mut self, base: impl Into<String>) -> Self {
139 self.user_search_base = base.into();
140 self
141 }
142
143 pub fn user_search_filter(mut self, filter: impl Into<String>) -> Self {
145 self.user_search_filter = filter.into();
146 self
147 }
148
149 pub fn group_search_base(mut self, base: impl Into<String>) -> Self {
151 self.group_search_base = base.into();
152 self
153 }
154
155 pub fn group_search_filter(mut self, filter: impl Into<String>) -> Self {
157 self.group_search_filter = filter.into();
158 self
159 }
160
161 pub fn bind_dn(mut self, dn: impl Into<String>) -> Self {
163 self.bind_dn = Some(dn.into());
164 self
165 }
166
167 pub fn bind_password(mut self, password: impl Into<String>) -> Self {
169 self.bind_password = Some(password.into());
170 self
171 }
172
173 pub fn user_dn_pattern(mut self, pattern: impl Into<String>) -> Self {
175 self.user_dn_pattern = Some(pattern.into());
176 self
177 }
178
179 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
181 self.connect_timeout = timeout;
182 self
183 }
184
185 pub fn operation_timeout(mut self, timeout: Duration) -> Self {
187 self.operation_timeout = timeout;
188 self
189 }
190
191 pub fn use_starttls(mut self, use_tls: bool) -> Self {
193 self.use_starttls = use_tls;
194 self
195 }
196
197 pub fn role_prefix(mut self, prefix: impl Into<String>) -> Self {
199 self.role_prefix = prefix.into();
200 self
201 }
202
203 pub fn convert_to_uppercase(mut self, convert: bool) -> Self {
205 self.convert_to_uppercase = convert;
206 self
207 }
208
209 pub fn map_attribute(
211 mut self,
212 ldap_attr: impl Into<String>,
213 user_attr: impl Into<String>,
214 ) -> Self {
215 self.attribute_mappings
216 .insert(ldap_attr.into(), user_attr.into());
217 self
218 }
219
220 pub fn full_user_search_base(&self) -> String {
222 if self.user_search_base.is_empty() {
223 self.base_dn.clone()
224 } else {
225 format!("{},{}", self.user_search_base, self.base_dn)
226 }
227 }
228
229 pub fn full_group_search_base(&self) -> String {
231 if self.group_search_base.is_empty() {
232 self.base_dn.clone()
233 } else {
234 format!("{},{}", self.group_search_base, self.base_dn)
235 }
236 }
237
238 pub fn build_user_filter(&self, username: &str) -> String {
240 self.user_search_filter.replace("{0}", username)
241 }
242
243 pub fn build_group_filter(&self, user_dn: &str) -> String {
245 self.group_search_filter.replace("{0}", user_dn)
246 }
247
248 pub fn build_user_dn(&self, username: &str) -> Option<String> {
250 self.user_dn_pattern
251 .as_ref()
252 .map(|pattern| pattern.replace("{0}", username))
253 }
254}
255
256#[derive(Debug, Clone)]
258pub struct LdapAuthResult {
259 pub success: bool,
261 pub user_dn: Option<String>,
263 pub attributes: HashMap<String, Vec<String>>,
265 pub groups: Vec<String>,
267 pub error: Option<String>,
269}
270
271impl LdapAuthResult {
272 pub fn success(user_dn: String, attributes: HashMap<String, Vec<String>>) -> Self {
274 Self {
275 success: true,
276 user_dn: Some(user_dn),
277 attributes,
278 groups: Vec::new(),
279 error: None,
280 }
281 }
282
283 pub fn failure(error: impl Into<String>) -> Self {
285 Self {
286 success: false,
287 user_dn: None,
288 attributes: HashMap::new(),
289 groups: Vec::new(),
290 error: Some(error.into()),
291 }
292 }
293
294 pub fn with_groups(mut self, groups: Vec<String>) -> Self {
296 self.groups = groups;
297 self
298 }
299
300 pub fn get_attribute(&self, name: &str) -> Option<&str> {
302 self.attributes
303 .get(name)
304 .and_then(|values| values.first())
305 .map(|s| s.as_str())
306 }
307
308 pub fn get_attribute_values(&self, name: &str) -> Option<&Vec<String>> {
310 self.attributes.get(name)
311 }
312}
313
314#[derive(Debug, Clone)]
316pub enum LdapError {
317 ConnectionFailed(String),
319 BindFailed(String),
321 UserNotFound(String),
323 AuthenticationFailed(String),
325 SearchFailed(String),
327 ConfigurationError(String),
329 Timeout,
331 TlsError(String),
333}
334
335impl std::fmt::Display for LdapError {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 match self {
338 LdapError::ConnectionFailed(msg) => write!(f, "LDAP connection failed: {}", msg),
339 LdapError::BindFailed(msg) => write!(f, "LDAP bind failed: {}", msg),
340 LdapError::UserNotFound(msg) => write!(f, "User not found: {}", msg),
341 LdapError::AuthenticationFailed(msg) => write!(f, "Authentication failed: {}", msg),
342 LdapError::SearchFailed(msg) => write!(f, "LDAP search failed: {}", msg),
343 LdapError::ConfigurationError(msg) => write!(f, "Configuration error: {}", msg),
344 LdapError::Timeout => write!(f, "LDAP operation timed out"),
345 LdapError::TlsError(msg) => write!(f, "TLS error: {}", msg),
346 }
347 }
348}
349
350impl std::error::Error for LdapError {}
351
352#[cfg_attr(feature = "ldap", async_trait::async_trait)]
357pub trait LdapOperations: Send + Sync {
358 async fn connect(&self) -> Result<(), LdapError>;
360
361 async fn bind(&self, dn: &str, password: &str) -> Result<(), LdapError>;
363
364 async fn search(
366 &self,
367 base: &str,
368 filter: &str,
369 attrs: &[&str],
370 ) -> Result<Vec<LdapAuthResult>, LdapError>;
371
372 async fn authenticate(
374 &self,
375 username: &str,
376 password: &str,
377 ) -> Result<LdapAuthResult, LdapError>;
378}
379
380#[derive(Clone)]
385pub struct LdapAuthenticator {
386 config: Arc<LdapConfig>,
387 #[cfg(feature = "ldap")]
388 client: Arc<dyn LdapOperations>,
389}
390
391impl LdapAuthenticator {
392 #[cfg(feature = "ldap")]
394 pub fn new<C: LdapOperations + 'static>(config: LdapConfig, client: C) -> Self {
395 Self {
396 config: Arc::new(config),
397 client: Arc::new(client),
398 }
399 }
400
401 pub fn with_config(config: LdapConfig) -> Self {
403 Self {
404 config: Arc::new(config),
405 #[cfg(feature = "ldap")]
406 client: Arc::new(MockLdapClient::new()),
407 }
408 }
409
410 pub fn config(&self) -> &LdapConfig {
412 &self.config
413 }
414
415 #[cfg(feature = "ldap")]
417 pub async fn authenticate(&self, username: &str, password: &str) -> Result<User, LdapError> {
418 let result = self.client.authenticate(username, password).await?;
419
420 if !result.success {
421 return Err(LdapError::AuthenticationFailed(
422 result.error.unwrap_or_else(|| "Unknown error".to_string()),
423 ));
424 }
425
426 let user = self.build_user_from_result(username, &result);
428 Ok(user)
429 }
430
431 fn build_user_from_result(&self, username: &str, result: &LdapAuthResult) -> User {
433 let roles: Vec<String> = result
435 .groups
436 .iter()
437 .filter_map(|group_dn| {
438 group_dn
440 .split(',')
441 .next()
442 .and_then(|cn_part| cn_part.strip_prefix("cn=").or(cn_part.strip_prefix("CN=")))
443 .map(|cn| {
444 let role = if self.config.convert_to_uppercase {
445 cn.to_uppercase()
446 } else {
447 cn.to_string()
448 };
449 format!("{}{}", self.config.role_prefix, role)
450 })
451 })
452 .collect();
453
454 let _display_name = result
456 .get_attribute(&self.config.display_name_attribute)
457 .unwrap_or(username);
458
459 let mut user = User::new(username.to_string(), String::new());
461
462 if !roles.is_empty() {
464 user = user.roles(&roles);
465 }
466
467 if let Some(email) = result.get_attribute(&self.config.email_attribute) {
469 user = user.authorities(&[format!("email:{}", email)]);
470 }
471
472 if let Some(ref dn) = result.user_dn {
474 user = user.authorities(&[format!("dn:{}", dn)]);
475 }
476
477 user
478 }
479}
480
481#[derive(Default)]
483pub struct MockLdapClient {
484 users: std::sync::RwLock<HashMap<String, MockUserEntry>>,
485}
486
487impl MockLdapClient {
488 pub fn new() -> Self {
490 Self::default()
491 }
492
493 pub fn add_user(
495 &self,
496 username: &str,
497 password: &str,
498 attributes: HashMap<String, Vec<String>>,
499 groups: Vec<String>,
500 ) {
501 let mut users = self.users.write().unwrap();
502 users.insert(
503 username.to_string(),
504 (password.to_string(), attributes, groups),
505 );
506 }
507}
508
509#[cfg_attr(feature = "ldap", async_trait::async_trait)]
510impl LdapOperations for MockLdapClient {
511 async fn connect(&self) -> Result<(), LdapError> {
512 Ok(())
513 }
514
515 async fn bind(&self, _dn: &str, _password: &str) -> Result<(), LdapError> {
516 Ok(())
517 }
518
519 async fn search(
520 &self,
521 _base: &str,
522 _filter: &str,
523 _attrs: &[&str],
524 ) -> Result<Vec<LdapAuthResult>, LdapError> {
525 Ok(Vec::new())
526 }
527
528 async fn authenticate(
529 &self,
530 username: &str,
531 password: &str,
532 ) -> Result<LdapAuthResult, LdapError> {
533 let users = self.users.read().unwrap();
534
535 match users.get(username) {
536 Some((stored_password, attributes, groups)) if stored_password == password => {
537 Ok(LdapAuthResult::success(
538 format!("uid={},ou=users,dc=example,dc=com", username),
539 attributes.clone(),
540 )
541 .with_groups(groups.clone()))
542 }
543 Some(_) => Err(LdapError::AuthenticationFailed(
544 "Invalid password".to_string(),
545 )),
546 None => Err(LdapError::UserNotFound(username.to_string())),
547 }
548 }
549}
550
551pub trait LdapContextMapper: Send + Sync {
553 fn map_user(&self, username: &str, result: &LdapAuthResult, config: &LdapConfig) -> User;
555}
556
557#[derive(Default)]
559pub struct DefaultLdapContextMapper;
560
561impl LdapContextMapper for DefaultLdapContextMapper {
562 fn map_user(&self, username: &str, result: &LdapAuthResult, config: &LdapConfig) -> User {
563 let roles: Vec<String> = result
564 .groups
565 .iter()
566 .filter_map(|group_dn| {
567 group_dn
568 .split(',')
569 .next()
570 .and_then(|cn_part| cn_part.strip_prefix("cn=").or(cn_part.strip_prefix("CN=")))
571 .map(|cn| {
572 let role = if config.convert_to_uppercase {
573 cn.to_uppercase()
574 } else {
575 cn.to_string()
576 };
577 format!("{}{}", config.role_prefix, role)
578 })
579 })
580 .collect();
581
582 User::new(username.to_string(), String::new()).roles(&roles)
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_ldap_config_builder() {
592 let config = LdapConfig::new("ldap://localhost:389")
593 .base_dn("dc=example,dc=com")
594 .user_search_filter("(uid={0})")
595 .bind_dn("cn=admin,dc=example,dc=com")
596 .bind_password("secret");
597
598 assert_eq!(config.url, "ldap://localhost:389");
599 assert_eq!(config.base_dn, "dc=example,dc=com");
600 assert_eq!(
601 config.bind_dn,
602 Some("cn=admin,dc=example,dc=com".to_string())
603 );
604 }
605
606 #[test]
607 fn test_active_directory_config() {
608 let config = LdapConfig::active_directory("ldap://dc.example.com", "example.com");
609
610 assert_eq!(config.base_dn, "dc=example,dc=com");
611 assert_eq!(config.user_search_filter, "(sAMAccountName={0})");
612 assert_eq!(config.username_attribute, "sAMAccountName");
613 }
614
615 #[test]
616 fn test_build_user_filter() {
617 let config = LdapConfig::new("ldap://localhost").user_search_filter("(uid={0})");
618
619 assert_eq!(config.build_user_filter("john"), "(uid=john)");
620 }
621
622 #[test]
623 fn test_build_user_dn() {
624 let config = LdapConfig::new("ldap://localhost")
625 .user_dn_pattern("uid={0},ou=users,dc=example,dc=com");
626
627 assert_eq!(
628 config.build_user_dn("john"),
629 Some("uid=john,ou=users,dc=example,dc=com".to_string())
630 );
631 }
632
633 #[test]
634 fn test_ldap_auth_result() {
635 let mut attrs = HashMap::new();
636 attrs.insert("cn".to_string(), vec!["John Doe".to_string()]);
637 attrs.insert("mail".to_string(), vec!["john@example.com".to_string()]);
638
639 let result = LdapAuthResult::success("uid=john,dc=example,dc=com".to_string(), attrs)
640 .with_groups(vec!["cn=admins,ou=groups,dc=example,dc=com".to_string()]);
641
642 assert!(result.success);
643 assert_eq!(result.get_attribute("cn"), Some("John Doe"));
644 assert_eq!(result.get_attribute("mail"), Some("john@example.com"));
645 assert_eq!(result.groups.len(), 1);
646 }
647
648 #[tokio::test]
649 async fn test_mock_ldap_client() {
650 let client = MockLdapClient::new();
651
652 let mut attrs = HashMap::new();
653 attrs.insert("cn".to_string(), vec!["Test User".to_string()]);
654
655 client.add_user(
656 "testuser",
657 "password123",
658 attrs,
659 vec!["cn=users,ou=groups,dc=example,dc=com".to_string()],
660 );
661
662 let result = client.authenticate("testuser", "password123").await;
664 assert!(result.is_ok());
665
666 let result = client.authenticate("testuser", "wrongpass").await;
668 assert!(result.is_err());
669
670 let result = client.authenticate("unknown", "password").await;
672 assert!(result.is_err());
673 }
674
675 #[test]
676 fn test_ldap_error_display() {
677 let err = LdapError::ConnectionFailed("Connection refused".to_string());
678 assert!(err.to_string().contains("Connection refused"));
679
680 let err = LdapError::UserNotFound("john".to_string());
681 assert!(err.to_string().contains("john"));
682 }
683}