1use std::collections::{HashMap, HashSet};
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub struct AuthConfig {
13 pub enabled: bool,
15
16 pub jwt: Option<JwtConfig>,
18
19 pub oauth: Option<OAuthConfig>,
21
22 pub ldap: Option<LdapConfig>,
24
25 pub api_keys: Option<ApiKeyConfig>,
27
28 pub role_mapping: Vec<RoleMappingRule>,
30
31 pub default_role: Option<String>,
33
34 pub credentials: CredentialConfig,
36
37 pub session: SessionConfig,
39
40 pub rate_limit: AuthRateLimitConfig,
42
43 pub auth_methods: Vec<AuthMethod>,
45}
46
47impl Default for AuthConfig {
48 fn default() -> Self {
49 Self {
50 enabled: false,
51 jwt: None,
52 oauth: None,
53 ldap: None,
54 api_keys: None,
55 role_mapping: Vec::new(),
56 default_role: Some("db_minimal".to_string()),
57 credentials: CredentialConfig::default(),
58 session: SessionConfig::default(),
59 rate_limit: AuthRateLimitConfig::default(),
60 auth_methods: Vec::new(),
61 }
62 }
63}
64
65impl AuthConfig {
66 pub fn jwt(jwks_url: impl Into<String>) -> Self {
68 Self {
69 enabled: true,
70 jwt: Some(JwtConfig::new(jwks_url)),
71 ..Default::default()
72 }
73 }
74
75 pub fn api_keys() -> Self {
77 Self {
78 enabled: true,
79 api_keys: Some(ApiKeyConfig::default()),
80 ..Default::default()
81 }
82 }
83
84 pub fn builder() -> AuthConfigBuilder {
86 AuthConfigBuilder::new()
87 }
88}
89
90#[derive(Default)]
92pub struct AuthConfigBuilder {
93 config: AuthConfig,
94}
95
96impl AuthConfigBuilder {
97 pub fn new() -> Self {
98 Self {
99 config: AuthConfig {
100 enabled: true,
101 ..Default::default()
102 },
103 }
104 }
105
106 pub fn jwt(mut self, config: JwtConfig) -> Self {
107 self.config.jwt = Some(config);
108 self
109 }
110
111 pub fn oauth(mut self, config: OAuthConfig) -> Self {
112 self.config.oauth = Some(config);
113 self
114 }
115
116 pub fn ldap(mut self, config: LdapConfig) -> Self {
117 self.config.ldap = Some(config);
118 self
119 }
120
121 pub fn api_keys(mut self, config: ApiKeyConfig) -> Self {
122 self.config.api_keys = Some(config);
123 self
124 }
125
126 pub fn add_role_mapping(mut self, rule: RoleMappingRule) -> Self {
127 self.config.role_mapping.push(rule);
128 self
129 }
130
131 pub fn default_role(mut self, role: impl Into<String>) -> Self {
132 self.config.default_role = Some(role.into());
133 self
134 }
135
136 pub fn credentials(mut self, config: CredentialConfig) -> Self {
137 self.config.credentials = config;
138 self
139 }
140
141 pub fn session(mut self, config: SessionConfig) -> Self {
142 self.config.session = config;
143 self
144 }
145
146 pub fn build(self) -> AuthConfig {
147 self.config
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct JwtConfig {
154 pub jwks_url: String,
156
157 pub jwks_refresh_interval: Duration,
159
160 pub allowed_issuers: HashSet<String>,
162
163 pub required_audience: Option<String>,
165
166 pub clock_skew: Duration,
168
169 pub user_id_claim: String,
171
172 pub roles_claim: Option<String>,
174
175 pub allowed_algorithms: Vec<String>,
177}
178
179impl Default for JwtConfig {
180 fn default() -> Self {
181 Self {
182 jwks_url: String::new(),
183 jwks_refresh_interval: Duration::from_secs(3600),
184 allowed_issuers: HashSet::new(),
185 required_audience: None,
186 clock_skew: Duration::from_secs(60),
187 user_id_claim: "sub".to_string(),
188 roles_claim: Some("roles".to_string()),
189 allowed_algorithms: vec!["RS256".to_string(), "ES256".to_string()],
190 }
191 }
192}
193
194impl JwtConfig {
195 pub fn new(jwks_url: impl Into<String>) -> Self {
196 Self {
197 jwks_url: jwks_url.into(),
198 ..Default::default()
199 }
200 }
201
202 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
203 self.allowed_issuers.insert(issuer.into());
204 self
205 }
206
207 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
208 self.required_audience = Some(audience.into());
209 self
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct OAuthConfig {
216 pub introspection_url: String,
218
219 pub client_id: String,
221
222 pub client_secret: String,
224
225 pub token_url: Option<String>,
227
228 pub scopes: Vec<String>,
230
231 pub cache_ttl: Duration,
233
234 pub required_scopes: Vec<String>,
236
237 pub issuer: String,
239
240 pub authorization_url: Option<String>,
242
243 pub audience: Option<String>,
245}
246
247impl Default for OAuthConfig {
248 fn default() -> Self {
249 Self {
250 introspection_url: String::new(),
251 client_id: String::new(),
252 client_secret: String::new(),
253 token_url: None,
254 scopes: Vec::new(),
255 cache_ttl: Duration::from_secs(60),
256 required_scopes: Vec::new(),
257 issuer: String::new(),
258 authorization_url: None,
259 audience: None,
260 }
261 }
262}
263
264impl OAuthConfig {
265 pub fn new(
266 introspection_url: impl Into<String>,
267 client_id: impl Into<String>,
268 client_secret: impl Into<String>,
269 ) -> Self {
270 Self {
271 introspection_url: introspection_url.into(),
272 client_id: client_id.into(),
273 client_secret: client_secret.into(),
274 ..Default::default()
275 }
276 }
277}
278
279#[derive(Debug, Clone)]
281pub struct LdapConfig {
282 pub server_url: String,
284
285 pub bind_dn: String,
287
288 pub bind_password: String,
290
291 pub user_search_base: String,
293
294 pub user_filter: String,
296
297 pub group_search_base: Option<String>,
299
300 pub group_attribute: String,
302
303 pub timeout: Duration,
305
306 pub starttls: bool,
308}
309
310impl Default for LdapConfig {
311 fn default() -> Self {
312 Self {
313 server_url: "ldap://localhost:389".to_string(),
314 bind_dn: String::new(),
315 bind_password: String::new(),
316 user_search_base: String::new(),
317 user_filter: "(uid={0})".to_string(),
318 group_search_base: None,
319 group_attribute: "memberOf".to_string(),
320 timeout: Duration::from_secs(10),
321 starttls: false,
322 }
323 }
324}
325
326#[derive(Debug, Clone)]
328pub struct ApiKeyConfig {
329 pub header_name: String,
331
332 pub query_param: Option<String>,
334
335 pub prefix: Option<String>,
337
338 pub hash_algorithm: String,
340}
341
342impl Default for ApiKeyConfig {
343 fn default() -> Self {
344 Self {
345 header_name: "X-API-Key".to_string(),
346 query_param: None,
347 prefix: Some("hpk_".to_string()),
348 hash_algorithm: "sha256".to_string(),
349 }
350 }
351}
352
353#[derive(Debug, Clone)]
355pub struct RoleMappingRule {
356 pub name: String,
358
359 pub condition: RoleCondition,
361
362 pub db_role: String,
364
365 pub priority: i32,
367
368 pub assign_roles: Vec<String>,
370
371 pub permissions: Vec<String>,
373
374 pub conditions: Vec<RoleMappingCondition>,
376}
377
378impl RoleMappingRule {
379 pub fn new(condition: RoleCondition, db_role: impl Into<String>) -> Self {
380 Self {
381 name: String::new(),
382 condition,
383 db_role: db_role.into(),
384 priority: 0,
385 assign_roles: Vec::new(),
386 permissions: Vec::new(),
387 conditions: Vec::new(),
388 }
389 }
390
391 pub fn with_priority(mut self, priority: i32) -> Self {
392 self.priority = priority;
393 self
394 }
395}
396
397#[derive(Debug, Clone)]
399pub enum RoleCondition {
400 JwtClaim { name: String, value: String },
402
403 JwtClaimAny { name: String, values: Vec<String> },
405
406 OAuthScope(String),
408
409 Group(String),
411
412 EmailDomain(String),
414
415 TenantId(String),
417
418 And(Vec<RoleCondition>),
420
421 Or(Vec<RoleCondition>),
423
424 Always,
426}
427
428#[derive(Debug, Clone)]
431pub enum RoleMappingCondition {
432 HasClaim {
434 claim: String,
435 value: Option<String>,
436 },
437
438 InGroup { group: String },
440
441 HasRole { role: String },
443
444 FromTenant { tenant_id: String },
446
447 AuthMethod { method: String },
449
450 EmailDomain { domain: String },
452
453 UsernamePattern { pattern: String },
455
456 And {
458 conditions: Vec<RoleMappingCondition>,
459 },
460
461 Or {
463 conditions: Vec<RoleMappingCondition>,
464 },
465
466 Not {
468 condition: Box<RoleMappingCondition>,
469 },
470}
471
472impl RoleMappingCondition {
473 pub fn has_claim(claim: impl Into<String>, value: Option<String>) -> Self {
475 Self::HasClaim {
476 claim: claim.into(),
477 value,
478 }
479 }
480
481 pub fn in_group(group: impl Into<String>) -> Self {
483 Self::InGroup {
484 group: group.into(),
485 }
486 }
487
488 pub fn has_role(role: impl Into<String>) -> Self {
490 Self::HasRole { role: role.into() }
491 }
492
493 pub fn auth_method(method: impl Into<String>) -> Self {
495 Self::AuthMethod {
496 method: method.into(),
497 }
498 }
499}
500
501impl RoleCondition {
502 pub fn jwt_claim(name: impl Into<String>, value: impl Into<String>) -> Self {
503 Self::JwtClaim {
504 name: name.into(),
505 value: value.into(),
506 }
507 }
508
509 pub fn group(name: impl Into<String>) -> Self {
510 Self::Group(name.into())
511 }
512
513 pub fn email_domain(domain: impl Into<String>) -> Self {
514 Self::EmailDomain(domain.into())
515 }
516}
517
518#[derive(Debug, Clone)]
520pub struct CredentialConfig {
521 pub default_provider: CredentialProvider,
523
524 pub static_credentials: HashMap<String, Credentials>,
526
527 pub vault: Option<VaultConfig>,
529
530 pub aws_secrets: Option<AwsSecretsConfig>,
532
533 pub cache_ttl: Duration,
535}
536
537impl Default for CredentialConfig {
538 fn default() -> Self {
539 Self {
540 default_provider: CredentialProvider::Static,
541 static_credentials: HashMap::new(),
542 vault: None,
543 aws_secrets: None,
544 cache_ttl: Duration::from_secs(300),
545 }
546 }
547}
548
549#[derive(Debug, Clone, Copy, PartialEq, Eq)]
551pub enum CredentialProvider {
552 Static,
553 Vault,
554 AwsSecrets,
555}
556
557#[derive(Debug, Clone)]
559pub struct Credentials {
560 pub username: String,
562
563 pub password: String,
565
566 pub ttl: Option<Duration>,
568
569 pub options: HashMap<String, String>,
571}
572
573impl Credentials {
574 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
575 Self {
576 username: username.into(),
577 password: password.into(),
578 ttl: None,
579 options: HashMap::new(),
580 }
581 }
582
583 pub fn with_ttl(mut self, ttl: Duration) -> Self {
584 self.ttl = Some(ttl);
585 self
586 }
587}
588
589#[derive(Debug, Clone)]
591pub struct VaultConfig {
592 pub address: String,
594
595 pub auth_method: VaultAuthMethod,
597
598 pub role: String,
600
601 pub secret_path: String,
603
604 pub tls_verify: bool,
606}
607
608#[derive(Debug, Clone)]
610pub enum VaultAuthMethod {
611 Token(String),
612 Kubernetes { role: String },
613 AppRole { role_id: String, secret_id: String },
614}
615
616#[derive(Debug, Clone)]
618pub struct AwsSecretsConfig {
619 pub region: String,
621
622 pub secret_prefix: String,
624
625 pub use_iam_role: bool,
627}
628
629#[derive(Debug, Clone)]
631pub struct SessionConfig {
632 pub timeout: Duration,
634
635 pub max_sessions_per_identity: usize,
637
638 pub max_sessions_per_user: usize,
640
641 pub idle_timeout: Duration,
643
644 pub absolute_timeout: Duration,
646
647 pub secure_cookies: bool,
649
650 pub session_vars: HashMap<String, String>,
652
653 pub extend_on_activity: bool,
655}
656
657impl Default for SessionConfig {
658 fn default() -> Self {
659 Self {
660 timeout: Duration::from_secs(3600),
661 max_sessions_per_identity: 10,
662 max_sessions_per_user: 10,
663 idle_timeout: Duration::from_secs(1800),
664 absolute_timeout: Duration::from_secs(86400),
665 secure_cookies: true,
666 session_vars: HashMap::new(),
667 extend_on_activity: true,
668 }
669 }
670}
671
672#[derive(Debug, Clone)]
674pub struct AuthRateLimitConfig {
675 pub enabled: bool,
677
678 pub max_attempts_per_ip: u32,
680
681 pub max_failures_per_ip: u32,
683
684 pub lockout_duration: Duration,
686
687 pub window_seconds: u64,
689
690 pub max_requests_per_user: u32,
692
693 pub max_requests_per_ip: u32,
695}
696
697impl Default for AuthRateLimitConfig {
698 fn default() -> Self {
699 Self {
700 enabled: true,
701 max_attempts_per_ip: 60,
702 max_failures_per_ip: 10,
703 lockout_duration: Duration::from_secs(300),
704 window_seconds: 60,
705 max_requests_per_user: 120,
706 max_requests_per_ip: 60,
707 }
708 }
709}
710
711#[derive(Debug, Clone)]
713pub enum AuthType {
714 Jwt(String),
716
717 OAuth(String),
719
720 Basic { username: String, password: String },
722
723 ApiKey(String),
725
726 None,
728}
729
730#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
733pub enum AuthMethod {
734 Jwt,
736
737 OAuth,
739
740 Ldap,
742
743 ApiKey,
745
746 Basic,
748
749 Trust,
751
752 AgentToken,
754
755 Session,
757
758 Anonymous,
760}
761
762impl std::fmt::Display for AuthMethod {
763 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
764 match self {
765 Self::Jwt => write!(f, "jwt"),
766 Self::OAuth => write!(f, "oauth"),
767 Self::Ldap => write!(f, "ldap"),
768 Self::ApiKey => write!(f, "api_key"),
769 Self::Basic => write!(f, "basic"),
770 Self::Trust => write!(f, "trust"),
771 Self::AgentToken => write!(f, "agent_token"),
772 Self::Session => write!(f, "session"),
773 Self::Anonymous => write!(f, "anonymous"),
774 }
775 }
776}
777
778impl AuthMethod {
779 #[allow(clippy::should_implement_trait)]
781 pub fn from_str(s: &str) -> Option<Self> {
782 match s.to_lowercase().as_str() {
783 "jwt" => Some(Self::Jwt),
784 "oauth" => Some(Self::OAuth),
785 "ldap" => Some(Self::Ldap),
786 "api_key" | "apikey" => Some(Self::ApiKey),
787 "basic" => Some(Self::Basic),
788 "trust" => Some(Self::Trust),
789 "agent_token" | "agent" => Some(Self::AgentToken),
790 "session" => Some(Self::Session),
791 "anonymous" | "none" => Some(Self::Anonymous),
792 _ => None,
793 }
794 }
795}
796
797#[derive(Debug, Clone, Serialize, Deserialize)]
799pub struct Identity {
800 pub user_id: String,
802
803 pub name: Option<String>,
805
806 pub email: Option<String>,
808
809 pub roles: Vec<String>,
811
812 pub groups: Vec<String>,
814
815 pub tenant_id: Option<String>,
817
818 pub claims: HashMap<String, serde_json::Value>,
820
821 pub auth_method: String,
823
824 pub authenticated_at: chrono::DateTime<chrono::Utc>,
826}
827
828impl Identity {
829 pub fn new(user_id: impl Into<String>, auth_method: impl Into<String>) -> Self {
831 Self {
832 user_id: user_id.into(),
833 name: None,
834 email: None,
835 roles: Vec::new(),
836 groups: Vec::new(),
837 tenant_id: None,
838 claims: HashMap::new(),
839 auth_method: auth_method.into(),
840 authenticated_at: chrono::Utc::now(),
841 }
842 }
843
844 pub fn from_jwt_claims(claims: &JwtClaims) -> Self {
846 let mut identity = Self::new(&claims.sub, "jwt");
847 identity.name = claims.name.clone();
848 identity.email = claims.email.clone();
849 identity.roles = claims.roles.clone();
850 identity.tenant_id = claims.tenant_id.clone();
851 identity.claims = claims.custom.clone();
852 identity
853 }
854
855 pub fn has_role(&self, role: &str) -> bool {
857 self.roles.iter().any(|r| r == role)
858 }
859
860 pub fn in_group(&self, group: &str) -> bool {
862 self.groups.iter().any(|g| g == group)
863 }
864
865 pub fn is_admin(&self) -> bool {
867 self.has_role("admin") || self.has_role("db_admin")
868 }
869
870 pub fn get_claim(&self, name: &str) -> Option<&serde_json::Value> {
872 self.claims.get(name)
873 }
874
875 pub fn email_domain(&self) -> Option<&str> {
877 self.email.as_ref().and_then(|e| e.split('@').nth(1))
878 }
879
880 pub fn anonymous() -> Self {
882 Self {
883 user_id: "anonymous".to_string(),
884 name: None,
885 email: None,
886 roles: Vec::new(),
887 groups: Vec::new(),
888 tenant_id: None,
889 claims: HashMap::new(),
890 auth_method: "anonymous".to_string(),
891 authenticated_at: chrono::Utc::now(),
892 }
893 }
894}
895
896#[derive(Debug, Clone, Serialize, Deserialize)]
898pub struct JwtClaims {
899 pub sub: String,
901
902 pub iss: String,
904
905 pub aud: Option<Vec<String>>,
907
908 pub exp: i64,
910
911 pub iat: i64,
913
914 pub nbf: Option<i64>,
916
917 pub jti: Option<String>,
919
920 pub name: Option<String>,
922
923 pub email: Option<String>,
925
926 #[serde(default)]
928 pub roles: Vec<String>,
929
930 pub tenant_id: Option<String>,
932
933 #[serde(flatten)]
935 pub custom: HashMap<String, serde_json::Value>,
936}
937
938#[derive(Debug, Clone, Serialize, Deserialize)]
940pub struct AgentIdentity {
941 pub id: String,
943
944 pub agent_type: String,
946
947 pub allowed_tools: Vec<String>,
949
950 pub quota: AgentQuota,
952
953 pub conversation_id: Option<String>,
955
956 pub parent_identity: Option<String>,
958}
959
960#[derive(Debug, Clone, Serialize, Deserialize)]
962pub struct AgentQuota {
963 pub max_queries_per_conversation: u32,
965
966 pub max_rows_per_query: u32,
968
969 pub token_budget: u64,
971
972 pub allowed_tables: Option<Vec<String>>,
974}
975
976impl Default for AgentQuota {
977 fn default() -> Self {
978 Self {
979 max_queries_per_conversation: 100,
980 max_rows_per_query: 1000,
981 token_budget: 100000,
982 allowed_tables: None,
983 }
984 }
985}
986
987#[derive(Debug, Clone)]
989pub struct ToolPermission {
990 pub db_role: String,
992
993 pub allowed_tables: Vec<String>,
995
996 pub read_only: bool,
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use super::*;
1003
1004 #[test]
1005 fn test_auth_config_builder() {
1006 let config = AuthConfig::builder()
1007 .jwt(JwtConfig::new(
1008 "https://auth.example.com/.well-known/jwks.json",
1009 ))
1010 .add_role_mapping(
1011 RoleMappingRule::new(RoleCondition::jwt_claim("role", "admin"), "db_admin")
1012 .with_priority(100),
1013 )
1014 .default_role("db_minimal")
1015 .build();
1016
1017 assert!(config.enabled);
1018 assert!(config.jwt.is_some());
1019 assert_eq!(config.role_mapping.len(), 1);
1020 }
1021
1022 #[test]
1023 fn test_identity() {
1024 let identity = Identity::new("user123", "jwt");
1025 assert_eq!(identity.user_id, "user123");
1026 assert_eq!(identity.auth_method, "jwt");
1027 assert!(!identity.is_admin());
1028 }
1029
1030 #[test]
1031 fn test_identity_roles() {
1032 let mut identity = Identity::new("admin123", "jwt");
1033 identity.roles = vec!["admin".to_string(), "db_readwrite".to_string()];
1034
1035 assert!(identity.is_admin());
1036 assert!(identity.has_role("admin"));
1037 assert!(identity.has_role("db_readwrite"));
1038 assert!(!identity.has_role("superuser"));
1039 }
1040
1041 #[test]
1042 fn test_email_domain() {
1043 let mut identity = Identity::new("user", "jwt");
1044 identity.email = Some("alice@example.com".to_string());
1045
1046 assert_eq!(identity.email_domain(), Some("example.com"));
1047 }
1048
1049 #[test]
1050 fn test_credentials() {
1051 let creds = Credentials::new("dbuser", "password123").with_ttl(Duration::from_secs(3600));
1052
1053 assert_eq!(creds.username, "dbuser");
1054 assert!(creds.ttl.is_some());
1055 }
1056
1057 #[test]
1058 fn test_role_mapping() {
1059 let rule = RoleMappingRule::new(RoleCondition::group("developers"), "db_readwrite")
1060 .with_priority(50);
1061
1062 assert_eq!(rule.db_role, "db_readwrite");
1063 assert_eq!(rule.priority, 50);
1064 }
1065}