1use crate::error::{Error, Result};
25use url::Url;
26
27#[cfg(feature = "api-key")]
28use api_keys_simplified::{
29 ApiKey, ApiKeyManagerV0, Environment, ExposeSecret, HashConfig, KeyConfig, KeyStatus,
30 SecureString,
31};
32
33#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
35pub struct OAuthConfig {
36 pub enabled: bool,
38 pub client_id: Option<String>,
40 pub client_secret: Option<String>,
42 pub redirect_uri: Option<String>,
44 pub authorization_endpoint: Option<String>,
46 pub token_endpoint: Option<String>,
48 pub scopes: Vec<String>,
50 pub provider: OAuthProvider,
52}
53
54#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, PartialEq)]
56pub enum OAuthProvider {
57 Custom,
59 GitHub,
61 Google,
63 Keycloak,
65}
66
67impl Default for OAuthConfig {
68 fn default() -> Self {
69 Self {
70 enabled: false,
71 client_id: None,
72 client_secret: None,
73 redirect_uri: None,
74 authorization_endpoint: None,
75 token_endpoint: None,
76 scopes: vec![
77 "openid".to_string(),
78 "profile".to_string(),
79 "email".to_string(),
80 ],
81 provider: OAuthProvider::Custom,
82 }
83 }
84}
85
86impl OAuthConfig {
87 #[must_use]
89 pub fn github(client_id: String, client_secret: String, redirect_uri: String) -> Self {
90 Self {
91 enabled: true,
92 client_id: Some(client_id),
93 client_secret: Some(client_secret),
94 redirect_uri: Some(redirect_uri),
95 authorization_endpoint: Some("https://github.com/login/oauth/authorize".to_string()),
96 token_endpoint: Some("https://github.com/login/oauth/access_token".to_string()),
97 scopes: vec!["read:user".to_string(), "user:email".to_string()],
98 provider: OAuthProvider::GitHub,
99 }
100 }
101
102 #[must_use]
104 pub fn google(client_id: String, client_secret: String, redirect_uri: String) -> Self {
105 Self {
106 enabled: true,
107 client_id: Some(client_id),
108 client_secret: Some(client_secret),
109 redirect_uri: Some(redirect_uri),
110 authorization_endpoint: Some(
111 "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
112 ),
113 token_endpoint: Some("https://oauth2.googleapis.com/token".to_string()),
114 scopes: vec![
115 "openid".to_string(),
116 "https://www.googleapis.com/auth/userinfo.profile".to_string(),
117 "https://www.googleapis.com/auth/userinfo.email".to_string(),
118 ],
119 provider: OAuthProvider::Google,
120 }
121 }
122
123 #[must_use]
125 pub fn keycloak(
126 client_id: String,
127 client_secret: String,
128 redirect_uri: String,
129 base_url: &str,
130 realm: &str,
131 ) -> Self {
132 let base = base_url.trim_end_matches('/');
133 Self {
134 enabled: true,
135 client_id: Some(client_id),
136 client_secret: Some(client_secret),
137 redirect_uri: Some(redirect_uri),
138 authorization_endpoint: Some(format!(
139 "{base}/realms/{realm}/protocol/openid-connect/auth"
140 )),
141 token_endpoint: Some(format!(
142 "{base}/realms/{realm}/protocol/openid-connect/token"
143 )),
144 scopes: vec![
145 "openid".to_string(),
146 "profile".to_string(),
147 "email".to_string(),
148 ],
149 provider: OAuthProvider::Keycloak,
150 }
151 }
152
153 pub fn validate(&self) -> Result<()> {
155 if !self.enabled {
156 return Ok(());
157 }
158
159 if self.client_id.is_none() {
160 return Err(Error::config("client_id", "is required"));
161 }
162
163 if self.client_secret.is_none() {
164 return Err(Error::config("client_secret", "is required"));
165 }
166
167 if self.redirect_uri.is_none() {
168 return Err(Error::config("redirect_uri", "is required"));
169 }
170
171 if self.authorization_endpoint.is_none() {
172 return Err(Error::config("authorization_endpoint", "is required"));
173 }
174
175 if self.token_endpoint.is_none() {
176 return Err(Error::config("token_endpoint", "is required"));
177 }
178
179 if let Some(uri) = &self.redirect_uri {
181 Url::parse(uri)
182 .map_err(|e| Error::config("redirect_uri", format!("Invalid URL: {e}")))?;
183 }
184
185 if let Some(endpoint) = &self.authorization_endpoint {
186 Url::parse(endpoint).map_err(|e| {
187 Error::config("authorization_endpoint", format!("Invalid URL: {e}"))
188 })?;
189 }
190
191 if let Some(endpoint) = &self.token_endpoint {
192 Url::parse(endpoint)
193 .map_err(|e| Error::config("token_endpoint", format!("Invalid URL: {e}")))?;
194 }
195
196 Ok(())
197 }
198
199 #[cfg(feature = "auth")]
201 pub fn to_mcp_config(&self) -> Result<()> {
202 if !self.enabled {
203 return Err(Error::config("oauth", "is not enabled"));
204 }
205
206 Ok(())
208 }
209
210 #[cfg(not(feature = "auth"))]
212 pub fn to_mcp_config(&self) -> Result<()> {
213 Err(Error::config("oauth", "feature is not enabled"))
214 }
215}
216
217#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
222#[cfg(feature = "api-key")]
223pub struct GeneratedApiKey {
224 pub key: String,
226 pub key_id: String,
228 pub hash: String,
230}
231
232#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
234#[cfg(feature = "api-key")]
235pub struct ApiKeyConfig {
236 pub enabled: bool,
238 pub keys: Vec<String>,
245 #[serde(default = "default_header_name")]
247 pub header_name: String,
248 #[serde(default = "default_query_param_name")]
250 pub query_param_name: String,
251 #[serde(default)]
253 pub allow_query_param: bool,
254 #[serde(default = "default_key_prefix")]
256 pub key_prefix: String,
257}
258
259#[cfg(feature = "api-key")]
260fn default_header_name() -> String {
261 "X-API-Key".to_string()
262}
263
264#[cfg(feature = "api-key")]
265fn default_query_param_name() -> String {
266 "api_key".to_string()
267}
268
269#[cfg(feature = "api-key")]
270fn default_key_prefix() -> String {
271 "sk".to_string()
272}
273
274#[cfg(feature = "api-key")]
275impl Default for ApiKeyConfig {
276 fn default() -> Self {
277 Self {
278 enabled: false,
279 keys: Vec::new(),
280 header_name: default_header_name(),
281 query_param_name: default_query_param_name(),
282 allow_query_param: false,
283 key_prefix: default_key_prefix(),
284 }
285 }
286}
287
288#[cfg(feature = "api-key")]
289impl ApiKeyConfig {
290 fn manager(&self) -> Result<ApiKeyManagerV0> {
291 ApiKeyManagerV0::init_default_config(self.key_prefix.clone())
292 .map_err(|e| Error::initialization("api_key_manager", e.to_string()))
293 }
294
295 fn legacy_manager(&self) -> Result<ApiKeyManagerV0> {
296 ApiKeyManagerV0::init(
297 self.key_prefix.clone(),
298 KeyConfig::default().disable_checksum(),
299 HashConfig::default(),
300 std::time::Duration::from_secs(10),
301 )
302 .map_err(|e| Error::initialization("api_key_manager", e.to_string()))
303 }
304
305 fn looks_like_hash(value: &str) -> bool {
306 value.starts_with("$argon2")
307 }
308
309 fn looks_like_legacy_hash(value: &str) -> bool {
310 value.starts_with("legacy:$argon2")
311 }
312
313 fn verify_plaintext_fallback(key: &str, stored_key: &str) -> bool {
314 use api_keys_simplified::SecureStringExt;
315
316 let provided = SecureString::from(key.to_string());
317 let expected = SecureString::from(stored_key.to_string());
318
319 provided.eq(&expected)
320 }
321
322 fn hash_legacy_key(&self, key: &str) -> Result<String> {
323 let manager = self.legacy_manager()?;
324 let seed = self.generate_key()?;
325 let secure = SecureString::from(key.to_string());
326 let hasher = manager.hasher();
327 let api_key = ApiKey::new(secure)
328 .into_hashed_with_phc(hasher, &seed.hash)
329 .map_err(|e| Error::initialization("api_key_hashing", e.to_string()))?;
330 Ok(format!("legacy:{}", api_key.expose_hash().hash()))
331 }
332
333 pub fn validate(&self) -> Result<()> {
335 if !self.enabled {
336 return Ok(());
337 }
338
339 if self.keys.is_empty() {
340 tracing::warn!("API key authentication is enabled but no keys are configured");
341 }
342
343 if self.header_name.is_empty() {
344 return Err(Error::config("header_name", "cannot be empty"));
345 }
346
347 if self.allow_query_param && self.query_param_name.is_empty() {
348 return Err(Error::config(
349 "query_param_name",
350 "cannot be empty when allow_query_param is true",
351 ));
352 }
353
354 if self.key_prefix.is_empty() {
355 return Err(Error::config("key_prefix", "cannot be empty"));
356 }
357
358 let _ = self.manager()?;
359
360 Ok(())
361 }
362
363 #[must_use]
365 pub fn is_valid_key(&self, key: &str) -> bool {
366 if !self.enabled {
367 return true;
368 }
369
370 let manager = self.manager().ok();
371 let legacy_manager = self.legacy_manager().ok();
372 let provided_key = SecureString::from(key.to_string());
373
374 self.keys.iter().any(|stored| {
375 if Self::looks_like_legacy_hash(stored) {
376 if let Some(legacy_manager) = &legacy_manager {
377 let stored_hash = stored.trim_start_matches("legacy:");
378 matches!(
379 legacy_manager.verify(&provided_key, stored_hash),
380 Ok(KeyStatus::Valid)
381 )
382 } else {
383 false
384 }
385 } else if Self::looks_like_hash(stored) {
386 if let Some(manager) = &manager {
387 matches!(manager.verify(&provided_key, stored), Ok(KeyStatus::Valid))
388 } else {
389 false
390 }
391 } else {
392 Self::verify_plaintext_fallback(key, stored)
393 }
394 })
395 }
396
397 pub fn generate_key(&self) -> Result<GeneratedApiKey> {
406 let manager = self.manager()?;
407
408 let key = manager
409 .generate(Environment::production())
410 .map_err(|e| Error::initialization("api_key_generation", e.to_string()))?;
411
412 Ok(GeneratedApiKey {
413 key: key.key().expose_secret().to_string(),
414 key_id: key.expose_hash().key_id().to_owned(),
415 hash: key.expose_hash().hash().to_owned(),
416 })
417 }
418
419 pub fn normalize_key_material(&self, key: &str) -> Result<String> {
429 if Self::looks_like_hash(key) || Self::looks_like_legacy_hash(key) {
430 Ok(key.to_string())
431 } else {
432 self.hash_legacy_key(key)
433 }
434 }
435}
436
437#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
439pub struct AuthConfig {
440 pub oauth: OAuthConfig,
442 #[cfg(feature = "api-key")]
444 pub api_key: ApiKeyConfig,
445}
446
447impl AuthConfig {
448 pub fn validate(&self) -> Result<()> {
450 self.oauth.validate()?;
451 #[cfg(feature = "api-key")]
452 self.api_key.validate()?;
453 Ok(())
454 }
455
456 #[must_use]
458 #[cfg(feature = "api-key")]
459 pub fn is_enabled(&self) -> bool {
460 self.oauth.enabled || self.api_key.enabled
461 }
462
463 #[must_use]
465 #[cfg(not(feature = "api-key"))]
466 pub fn is_enabled(&self) -> bool {
467 self.oauth.enabled
468 }
469}
470
471#[derive(Debug, Clone, Copy, PartialEq, Eq)]
473pub enum AuthProvider {
474 None,
476 OAuth,
478 #[cfg(feature = "api-key")]
480 ApiKey,
481}
482
483#[derive(Debug, Clone)]
485pub struct AuthContext {
486 pub provider: AuthProvider,
488 pub user_id: Option<String>,
490 pub user_email: Option<String>,
492 #[cfg(feature = "api-key")]
494 pub api_key_id: Option<String>,
495}
496
497impl AuthContext {
498 #[must_use]
500 pub fn new(provider: AuthProvider) -> Self {
501 Self {
502 provider,
503 user_id: None,
504 user_email: None,
505 #[cfg(feature = "api-key")]
506 api_key_id: None,
507 }
508 }
509
510 #[must_use]
512 pub fn is_authenticated(&self) -> bool {
513 !matches!(self.provider, AuthProvider::None)
514 }
515}
516
517#[derive(Default)]
519pub struct AuthManager {
520 config: OAuthConfig,
521 #[cfg(feature = "api-key")]
522 api_key_config: ApiKeyConfig,
523}
524
525impl AuthManager {
526 pub fn new(config: OAuthConfig) -> Result<Self> {
528 config.validate()?;
529 Ok(Self {
530 config,
531 #[cfg(feature = "api-key")]
532 api_key_config: ApiKeyConfig::default(),
533 })
534 }
535
536 #[cfg(feature = "api-key")]
538 pub fn with_config(config: AuthConfig) -> Result<Self> {
539 config.validate()?;
540 Ok(Self {
541 config: config.oauth,
542 api_key_config: config.api_key,
543 })
544 }
545
546 #[must_use]
548 #[cfg(feature = "api-key")]
549 pub fn is_enabled(&self) -> bool {
550 self.config.enabled || self.api_key_config.enabled
551 }
552
553 #[must_use]
555 #[cfg(not(feature = "api-key"))]
556 pub fn is_enabled(&self) -> bool {
557 self.config.enabled
558 }
559
560 #[must_use]
562 pub fn config(&self) -> &OAuthConfig {
563 &self.config
564 }
565
566 #[cfg(feature = "api-key")]
568 #[must_use]
569 pub fn api_key_config(&self) -> &ApiKeyConfig {
570 &self.api_key_config
571 }
572
573 #[cfg(feature = "api-key")]
575 #[must_use]
576 pub fn validate_api_key(&self, key: &str) -> bool {
577 self.api_key_config.is_valid_key(key)
578 }
579
580 #[cfg(feature = "api-key")]
586 pub fn generate_api_key(&self) -> Result<GeneratedApiKey> {
587 self.api_key_config.generate_key()
588 }
589
590 #[cfg(feature = "api-key")]
592 #[must_use]
593 pub fn extract_api_key_from_headers(
594 &self,
595 headers: &std::collections::HashMap<String, String>,
596 ) -> Option<String> {
597 headers.get(&self.api_key_config.header_name).cloned()
598 }
599}
600
601#[derive(Default)]
603pub struct TokenStore {
604 tokens: std::sync::RwLock<std::collections::HashMap<String, TokenInfo>>,
605}
606
607#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
609pub struct TokenInfo {
610 pub access_token: String,
612 pub refresh_token: Option<String>,
614 pub expires_at: chrono::DateTime<chrono::Utc>,
616 pub scopes: Vec<String>,
618 pub user_id: Option<String>,
620 pub user_email: Option<String>,
622}
623
624impl TokenStore {
625 #[must_use]
627 pub fn new() -> Self {
628 Self::default()
629 }
630
631 pub fn store_token(&self, key: String, token: TokenInfo) {
633 let mut tokens = self.tokens.write().unwrap();
634 tokens.insert(key, token);
635 }
636
637 pub fn get_token(&self, key: &str) -> Option<TokenInfo> {
639 let tokens = self.tokens.read().unwrap();
640 tokens.get(key).cloned()
641 }
642
643 pub fn remove_token(&self, key: &str) {
645 let mut tokens = self.tokens.write().unwrap();
646 tokens.remove(key);
647 }
648
649 pub fn cleanup_expired(&self) {
651 let now = chrono::Utc::now();
652 let mut tokens = self.tokens.write().unwrap();
653 tokens.retain(|_, token| token.expires_at > now);
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 #[test]
662 fn test_oauth_config_default() {
663 let config = OAuthConfig::default();
664 assert!(!config.enabled);
665 assert!(config.client_id.is_none());
666 }
667
668 #[test]
669 fn test_oauth_config_github() {
670 let config = OAuthConfig::github(
671 "client_id".to_string(),
672 "client_secret".to_string(),
673 "http://localhost:8080/callback".to_string(),
674 );
675 assert!(config.enabled);
676 assert_eq!(config.provider, OAuthProvider::GitHub);
677 }
678
679 #[test]
680 fn test_oauth_config_validate() {
681 let config = OAuthConfig::default();
682 assert!(config.validate().is_ok());
683
684 let config = OAuthConfig {
685 enabled: true,
686 ..Default::default()
687 };
688 assert!(config.validate().is_err());
689 }
690
691 #[cfg(feature = "api-key")]
692 #[test]
693 fn test_api_key_config_default() {
694 let config = ApiKeyConfig::default();
695 assert!(!config.enabled);
696 assert!(config.keys.is_empty());
697 assert_eq!(config.header_name, "X-API-Key");
698 assert_eq!(config.key_prefix, "sk");
699 }
700
701 #[cfg(feature = "api-key")]
702 #[test]
703 fn test_api_key_config_validate() {
704 let config = ApiKeyConfig::default();
705 assert!(config.validate().is_ok());
706
707 let config = ApiKeyConfig {
708 enabled: true,
709 header_name: String::new(),
710 ..Default::default()
711 };
712 assert!(config.validate().is_err());
713 }
714
715 #[cfg(feature = "api-key")]
716 #[test]
717 fn test_api_key_is_valid() {
718 let key_config = ApiKeyConfig {
719 enabled: true,
720 ..Default::default()
721 };
722 let generated = key_config.generate_key().unwrap();
723
724 let config = ApiKeyConfig {
725 enabled: true,
726 keys: vec![generated.hash.clone()],
727 ..Default::default()
728 };
729
730 assert!(config.is_valid_key(&generated.key));
731 assert!(!config.is_valid_key("invalid_key"));
732 }
733
734 #[cfg(feature = "api-key")]
735 #[test]
736 fn test_api_key_disabled_allows_all() {
737 let config = ApiKeyConfig::default();
738 assert!(!config.enabled);
739
740 assert!(config.is_valid_key("any_key"));
742 }
743
744 #[cfg(feature = "api-key")]
745 #[test]
746 fn test_api_key_plaintext_fallback() {
747 let config = ApiKeyConfig {
748 enabled: true,
749 keys: vec!["legacy_plaintext_key".to_string()],
750 ..Default::default()
751 };
752
753 assert!(config.is_valid_key("legacy_plaintext_key"));
754 assert!(!config.is_valid_key("legacy_plaintext_key_2"));
755 }
756
757 #[cfg(feature = "api-key")]
758 #[test]
759 fn test_api_key_legacy_hashed_verification() {
760 let config = ApiKeyConfig::default();
761 let legacy_hash = config
762 .normalize_key_material("legacy_plaintext_key")
763 .unwrap();
764
765 let enabled_config = ApiKeyConfig {
766 enabled: true,
767 keys: vec![legacy_hash],
768 ..Default::default()
769 };
770
771 assert!(enabled_config.is_valid_key("legacy_plaintext_key"));
772 assert!(!enabled_config.is_valid_key("legacy_plaintext_key_2"));
773 }
774
775 #[cfg(feature = "api-key")]
776 #[test]
777 fn test_api_key_generate_key_returns_hash_and_key() {
778 let config = ApiKeyConfig::default();
779 let generated = config.generate_key().unwrap();
780
781 assert!(
782 generated.key.starts_with("sk-")
783 || generated.key.starts_with("sk_")
784 || generated.key.starts_with("sk")
785 );
786 assert!(!generated.key_id.is_empty());
787 assert!(generated.hash.starts_with("$argon2"));
788 }
789
790 #[test]
791 fn test_auth_config_default() {
792 let config = AuthConfig::default();
793 assert!(!config.is_enabled());
794 }
795
796 #[test]
797 fn test_auth_context() {
798 let ctx = AuthContext::new(AuthProvider::None);
799 assert!(!ctx.is_authenticated());
800
801 let ctx = AuthContext::new(AuthProvider::OAuth);
802 assert!(ctx.is_authenticated());
803 }
804
805 #[test]
806 fn test_token_store() {
807 let store = TokenStore::new();
808 let token = TokenInfo {
809 access_token: "test_token".to_string(),
810 refresh_token: None,
811 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
812 scopes: vec!["read".to_string()],
813 user_id: None,
814 user_email: None,
815 };
816
817 store.store_token("key".to_string(), token.clone());
818 assert!(store.get_token("key").is_some());
819 assert!(store.get_token("nonexistent").is_none());
820
821 store.remove_token("key");
822 assert!(store.get_token("key").is_none());
823 }
824}