oauth2_broker/flows/
common.rs1use crate::{
5 _prelude::*,
6 auth::{PrincipalId, ScopeSet, TenantId, TokenRecord, TokenRecordBuilderError},
7 error::ConfigError,
8 flows::Broker,
9 http::TokenHttpClient,
10 oauth::TransportErrorMapper,
11 store::StoreKey,
12};
13
14#[derive(Clone, Debug)]
17pub struct CachedTokenRequest {
18 pub tenant: TenantId,
20 pub principal: PrincipalId,
22 pub scope: ScopeSet,
24 pub force: bool,
26 pub preemptive_window: Duration,
28}
29impl CachedTokenRequest {
30 const DEFAULT_PREEMPTIVE_WINDOW: Duration = Duration::seconds(60);
31
32 pub fn new(tenant: TenantId, principal: PrincipalId, scope: ScopeSet) -> Self {
34 Self {
35 tenant,
36 principal,
37 scope,
38 force: false,
39 preemptive_window: Self::DEFAULT_PREEMPTIVE_WINDOW,
40 }
41 }
42
43 pub fn force_refresh(mut self) -> Self {
45 self.force = true;
46
47 self
48 }
49
50 pub fn with_force(mut self, force: bool) -> Self {
52 self.force = force;
53
54 self
55 }
56
57 pub fn with_preemptive_window(mut self, window: Duration) -> Self {
59 self.preemptive_window = if window.is_negative() { Duration::ZERO } else { window };
60
61 self
62 }
63
64 pub fn should_refresh(&self, record: &TokenRecord, now: OffsetDateTime) -> bool {
66 if self.force || record.is_revoked() || record.is_expired_at(now) {
67 return true;
68 }
69
70 let effective_window = self.effective_preemptive_window();
71
72 if effective_window.is_zero() {
73 return false;
74 }
75
76 let remaining = record.expires_at - now;
77
78 remaining <= effective_window
79 }
80
81 fn effective_preemptive_window(&self) -> Duration {
82 self.preemptive_window.checked_sub(self.preemptive_jitter()).unwrap_or(Duration::ZERO)
83 }
84
85 fn preemptive_jitter(&self) -> Duration {
86 let window_secs = self.preemptive_window.whole_seconds();
87
88 if window_secs <= 1 {
89 return Duration::ZERO;
90 }
91
92 let modulus = u64::try_from(window_secs).unwrap_or(u64::MAX);
93
94 if modulus == 0 {
95 return Duration::ZERO;
96 }
97
98 let jitter_secs = self.jitter_seed() % modulus;
99
100 if jitter_secs == 0 {
101 return Duration::ZERO;
102 }
103
104 let clamped = i64::try_from(jitter_secs).unwrap_or(i64::MAX);
105
106 Duration::seconds(clamped)
107 }
108
109 fn jitter_seed(&self) -> u64 {
110 let mut hasher = DefaultHasher::new();
111
112 self.tenant.hash(&mut hasher);
113 self.principal.hash(&mut hasher);
114 self.scope.hash(&mut hasher);
115
116 hasher.finish()
117 }
118}
119
120pub(crate) fn format_scope(scope: &ScopeSet, delimiter: char) -> Option<String> {
122 if scope.is_empty() {
123 return None;
124 }
125 if delimiter == ' ' {
126 return Some(scope.normalized());
127 }
128
129 let mut buf = String::new();
130
131 for (idx, value) in scope.iter().enumerate() {
132 if idx > 0 {
133 buf.push(delimiter);
134 }
135
136 buf.push_str(value);
137 }
138
139 Some(buf)
140}
141
142pub(crate) fn flow_guard<C, M>(broker: &Broker<C, M>, key: &StoreKey) -> Arc<AsyncMutex<()>>
144where
145 C: ?Sized + TokenHttpClient,
146 M: ?Sized + TransportErrorMapper<C::TransportError>,
147{
148 let mut guards = broker.flow_guards.lock();
149
150 guards.entry(key.clone()).or_insert_with(|| Arc::new(AsyncMutex::new(()))).clone()
151}
152
153pub(crate) fn map_token_builder_error(err: TokenRecordBuilderError) -> Error {
155 ConfigError::from(err).into()
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
162 use crate::auth::ScopeSet;
163
164 #[test]
165 fn scope_formatting_handles_custom_delimiters() {
166 let scope = ScopeSet::new(["email", "profile"]).expect("Failed to build test scope.");
167
168 assert_eq!(format_scope(&scope, ' '), Some("email profile".into()));
169 assert_eq!(format_scope(&scope, ','), Some("email,profile".into()));
170 }
171}