Skip to main content

openauth_core/
context.rs

1//! Request and runtime context contracts.
2
3pub mod request_state;
4
5mod builder;
6mod origins;
7mod plugins;
8mod secrets;
9
10use crate::auth::trusted_origins::{matches_origin_pattern, OriginMatchSettings};
11use crate::cookies::AuthCookies;
12use crate::db::{DbAdapter, DbSchema};
13use crate::env::logger::Logger;
14use crate::error::OpenAuthError;
15use crate::options::{
16    BackgroundTaskFuture, BackgroundTaskRunner, DynamicRateLimitPathRule, HybridRateLimitOptions,
17    OpenAuthOptions, RateLimitPathRule, RateLimitStorageOption, RateLimitStore, SecondaryStorage,
18};
19use crate::plugin::{AuthPlugin, PluginErrorCode};
20use crate::rate_limit::GovernorMemoryRateLimitStore;
21use http::Request;
22use openauth_oauth::oauth2::SocialOAuthProvider;
23use std::collections::BTreeMap;
24use std::fmt;
25use std::sync::Arc;
26use std::time::Duration;
27
28pub use builder::{
29    create_auth_context, create_auth_context_with_adapter, create_auth_context_with_environment,
30    create_auth_context_with_environment_and_adapter,
31};
32pub use secrets::SecretMaterial;
33
34use origins::push_trusted_origin;
35
36#[derive(Clone)]
37pub struct AuthContext {
38    pub app_name: String,
39    pub base_url: String,
40    pub base_path: String,
41    pub options: OpenAuthOptions,
42    pub auth_cookies: AuthCookies,
43    pub session_config: SessionConfig,
44    pub secret: String,
45    pub secret_config: SecretMaterial,
46    pub password: PasswordContext,
47    pub rate_limit: RateLimitContext,
48    pub trusted_origins: Vec<String>,
49    pub disabled_paths: Vec<String>,
50    pub plugins: Vec<AuthPlugin>,
51    pub adapter: Option<Arc<dyn DbAdapter>>,
52    pub secondary_storage: Option<Arc<dyn SecondaryStorage>>,
53    pub background_tasks: Option<Arc<dyn BackgroundTaskRunner>>,
54    pub social_providers: BTreeMap<String, Arc<dyn SocialOAuthProvider>>,
55    pub db_schema: DbSchema,
56    pub plugin_error_codes: BTreeMap<String, PluginErrorCode>,
57    pub plugin_database_hooks: Vec<crate::plugin::PluginDatabaseHook>,
58    pub plugin_migrations: Vec<crate::plugin::PluginMigration>,
59    pub logger: Logger,
60}
61
62/// Environment values used by context initialization.
63#[derive(Clone, Default, PartialEq, Eq)]
64pub struct AuthEnvironment {
65    pub openauth_secret: Option<String>,
66    pub openauth_secrets: Option<String>,
67}
68
69impl fmt::Debug for AuthEnvironment {
70    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
71        formatter
72            .debug_struct("AuthEnvironment")
73            .field(
74                "openauth_secret",
75                &self.openauth_secret.as_ref().map(|_| "<redacted>"),
76            )
77            .field(
78                "openauth_secrets",
79                &self.openauth_secrets.as_ref().map(|_| "<redacted>"),
80            )
81            .finish()
82    }
83}
84
85impl AuthEnvironment {
86    pub fn from_process() -> Self {
87        Self {
88            openauth_secret: std::env::var("OPENAUTH_SECRET").ok(),
89            openauth_secrets: std::env::var("OPENAUTH_SECRETS").ok(),
90        }
91    }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct SessionConfig {
96    pub update_age: u64,
97    pub expires_in: u64,
98    pub fresh_age: u64,
99    pub cookie_refresh_cache: bool,
100}
101
102#[derive(Clone)]
103pub struct PasswordContext {
104    pub config: PasswordPolicy,
105    pub hash: fn(&str) -> Result<String, OpenAuthError>,
106    pub verify: fn(&str, &str) -> Result<bool, OpenAuthError>,
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct PasswordPolicy {
111    pub min_password_length: usize,
112    pub max_password_length: usize,
113}
114
115#[derive(Clone)]
116pub struct RateLimitContext {
117    pub enabled: bool,
118    pub window: u64,
119    pub max: u64,
120    pub storage: RateLimitStorageOption,
121    pub custom_rules: Vec<RateLimitPathRule>,
122    pub dynamic_rules: Vec<DynamicRateLimitPathRule>,
123    pub plugin_rules: Vec<crate::plugin::PluginRateLimitRule>,
124    pub custom_store: Option<Arc<dyn RateLimitStore>>,
125    pub hybrid: HybridRateLimitOptions,
126    pub memory_cleanup_interval: Option<Duration>,
127    pub memory_store: Arc<GovernorMemoryRateLimitStore>,
128}
129
130impl AuthContext {
131    pub fn adapter(&self) -> Option<Arc<dyn DbAdapter>> {
132        self.adapter.clone()
133    }
134
135    pub fn secondary_storage(&self) -> Option<Arc<dyn SecondaryStorage>> {
136        self.secondary_storage.clone()
137    }
138
139    pub fn run_background_task(&self, task: BackgroundTaskFuture) -> bool {
140        let Some(runner) = &self.background_tasks else {
141            return false;
142        };
143        runner.spawn(task);
144        true
145    }
146
147    pub fn social_provider(&self, id: &str) -> Option<Arc<dyn SocialOAuthProvider>> {
148        self.social_providers.get(id).cloned()
149    }
150
151    pub fn has_plugin(&self, id: &str) -> bool {
152        self.plugins.iter().any(|plugin| plugin.id == id)
153    }
154
155    pub fn is_trusted_origin(&self, url: &str, settings: Option<OriginMatchSettings>) -> bool {
156        self.trusted_origins
157            .iter()
158            .any(|origin| matches_origin_pattern(url, origin, settings))
159    }
160
161    pub fn trusted_origins_for_request(
162        &self,
163        request: Option<&Request<Vec<u8>>>,
164    ) -> Result<Vec<String>, OpenAuthError> {
165        let mut origins = self.trusted_origins.clone();
166        if let Some(provider) = self.options.trusted_origins.provider() {
167            for origin in provider.trusted_origins(request)? {
168                push_trusted_origin(&mut origins, origin);
169            }
170        }
171        Ok(origins)
172    }
173
174    pub fn is_trusted_origin_for_request(
175        &self,
176        url: &str,
177        settings: Option<OriginMatchSettings>,
178        request: Option<&Request<Vec<u8>>>,
179    ) -> Result<bool, OpenAuthError> {
180        Ok(self
181            .trusted_origins_for_request(request)?
182            .iter()
183            .any(|origin| matches_origin_pattern(url, origin, settings)))
184    }
185}