oauth2_broker/provider/
strategy.rs1use std::collections::BTreeMap;
8use crate::{_prelude::*, provider::descriptor::GrantType};
10
11pub trait ProviderStrategy: Send + Sync {
18 fn classify_token_error(&self, ctx: &ProviderErrorContext) -> ProviderErrorKind;
20
21 fn augment_token_request(&self, _grant: GrantType, _form: &mut BTreeMap<String, String>) {}
28}
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum ProviderErrorKind {
33 InvalidGrant,
35 InvalidClient,
37 InsufficientScope,
39 Transient,
41}
42
43#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct ProviderErrorContext {
51 pub grant_type: GrantType,
53 pub http_status: Option<u16>,
55 pub oauth_error: Option<String>,
57 pub error_description: Option<String>,
59 pub body_preview: Option<String>,
61 pub network_error: bool,
63}
64impl ProviderErrorContext {
65 const BODY_PREVIEW_LIMIT: usize = 256;
66
67 pub fn new(grant_type: GrantType) -> Self {
69 Self {
70 grant_type,
71 http_status: None,
72 oauth_error: None,
73 error_description: None,
74 body_preview: None,
75 network_error: false,
76 }
77 }
78
79 pub fn network_failure(grant_type: GrantType) -> Self {
81 let mut ctx = Self::new(grant_type);
82
83 ctx.network_error = true;
84
85 ctx
86 }
87
88 pub fn with_network_error(mut self, network_error: bool) -> Self {
90 self.network_error = network_error;
91
92 self
93 }
94
95 pub fn with_http_status(mut self, status: u16) -> Self {
97 self.http_status = Some(status);
98
99 self
100 }
101
102 pub fn with_oauth_error(mut self, error: impl Into<String>) -> Self {
104 self.oauth_error = Some(error.into());
105
106 self
107 }
108
109 pub fn with_error_description(mut self, description: impl Into<String>) -> Self {
111 self.error_description = Some(description.into());
112
113 self
114 }
115
116 pub fn with_body_preview(mut self, body: impl Into<String>) -> Self {
118 self.body_preview = Some(truncate_preview(body.into()));
119
120 self
121 }
122}
123
124#[derive(Debug, Default)]
130pub struct DefaultProviderStrategy;
131impl Display for DefaultProviderStrategy {
132 fn fmt(&self, f: &mut Formatter) -> FmtResult {
133 f.write_str("default-provider-strategy")
134 }
135}
136impl ProviderStrategy for DefaultProviderStrategy {
137 fn classify_token_error(&self, ctx: &ProviderErrorContext) -> ProviderErrorKind {
138 if ctx.network_error {
139 return ProviderErrorKind::Transient;
140 }
141
142 if let Some(kind) =
143 classify_oauth_error(ctx.oauth_error.as_deref(), ctx.error_description.as_deref())
144 {
145 return kind;
146 }
147 if let Some(kind) = classify_body(ctx.body_preview.as_deref()) {
148 return kind;
149 }
150
151 classify_status(ctx.http_status)
152 }
153}
154
155fn truncate_preview(body: String) -> String {
156 if body.chars().count() <= ProviderErrorContext::BODY_PREVIEW_LIMIT {
157 return body;
158 }
159
160 let mut buf = String::new();
161
162 for (idx, ch) in body.chars().enumerate() {
163 if idx >= ProviderErrorContext::BODY_PREVIEW_LIMIT {
164 buf.push('…');
165
166 break;
167 }
168 buf.push(ch);
169 }
170
171 buf
172}
173
174fn classify_oauth_error(
175 oauth_error: Option<&str>,
176 error_description: Option<&str>,
177) -> Option<ProviderErrorKind> {
178 oauth_error
179 .and_then(match_exact_value)
180 .or_else(|| error_description.and_then(match_exact_value))
181 .or_else(|| classify_body(error_description))
182}
183
184fn match_exact_value(value: &str) -> Option<ProviderErrorKind> {
185 if value.eq_ignore_ascii_case("invalid_grant") || value.eq_ignore_ascii_case("access_denied") {
186 Some(ProviderErrorKind::InvalidGrant)
187 } else if value.eq_ignore_ascii_case("invalid_client")
188 || value.eq_ignore_ascii_case("unauthorized_client")
189 {
190 Some(ProviderErrorKind::InvalidClient)
191 } else if value.eq_ignore_ascii_case("invalid_scope")
192 || value.eq_ignore_ascii_case("insufficient_scope")
193 {
194 Some(ProviderErrorKind::InsufficientScope)
195 } else if value.eq_ignore_ascii_case("temporarily_unavailable")
196 || value.eq_ignore_ascii_case("server_error")
197 {
198 Some(ProviderErrorKind::Transient)
199 } else {
200 None
201 }
202}
203
204fn classify_body(body: Option<&str>) -> Option<ProviderErrorKind> {
205 let body = body?;
206 let lowered = body.to_ascii_lowercase();
207
208 match lowered.as_str() {
209 text if text.contains("invalid_grant") => Some(ProviderErrorKind::InvalidGrant),
210 text if text.contains("invalid_client") => Some(ProviderErrorKind::InvalidClient),
211 text if text.contains("insufficient_scope") || text.contains("invalid_scope") =>
212 Some(ProviderErrorKind::InsufficientScope),
213 text if text.contains("temporarily_unavailable") || text.contains("retry") =>
214 Some(ProviderErrorKind::Transient),
215 _ => None,
216 }
217}
218
219fn classify_status(status: Option<u16>) -> ProviderErrorKind {
220 match status {
221 Some(400 | 404 | 410) => ProviderErrorKind::InvalidGrant,
222 Some(401) => ProviderErrorKind::InvalidClient,
223 Some(403) => ProviderErrorKind::InsufficientScope,
224 Some(429) => ProviderErrorKind::Transient,
225 Some(code) if code >= 500 => ProviderErrorKind::Transient,
226 _ => ProviderErrorKind::Transient,
227 }
228}