oauth2_broker/provider/
strategy.rs

1//! Provider strategy hooks that customize token exchanges.
2//!
3//! Implementations decorate outgoing token requests and normalize error mapping
4//! without tying flows to any particular HTTP client.
5
6// std
7use std::collections::BTreeMap;
8// self
9use crate::{_prelude::*, provider::descriptor::GrantType};
10
11/// Strategy hook that allows providers to decorate requests and classify errors.
12///
13/// Implementors are required to be `Send + Sync`, and the hooks intentionally use
14/// crate-owned data types so downstream crates never depend on reqwest-specific
15/// structures.  Override only what you need—`augment_token_request` has a default
16/// no-op implementation.
17pub trait ProviderStrategy: Send + Sync {
18	/// Maps low-level HTTP/JSON errors into the broker taxonomy for a token request.
19	fn classify_token_error(&self, ctx: &ProviderErrorContext) -> ProviderErrorKind;
20
21	/// Gives providers a chance to add custom form parameters before dispatching.
22	///
23	/// The default implementation does nothing, which is enough for most providers.
24	/// Override the hook when a provider requires extra fields (audience, resource,
25	/// etc.).  The method works on a plain `BTreeMap` so implementations remain HTTP
26	/// client agnostic.
27	fn augment_token_request(&self, _grant: GrantType, _form: &mut BTreeMap<String, String>) {}
28}
29
30/// Canonical provider error categories used by strategies.
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum ProviderErrorKind {
33	/// Provider rejected the authorization grant (bad code/refresh token).
34	InvalidGrant,
35	/// Client authentication failed.
36	InvalidClient,
37	/// Requested scopes exceed what the token covers.
38	InsufficientScope,
39	/// Failure is temporary and should be retried.
40	Transient,
41}
42
43/// Context passed to provider strategies when classifying token errors.
44///
45/// The struct intentionally keeps only primitive data (status codes, OAuth fields,
46/// body preview) so strategies stay completely decoupled from any HTTP client
47/// (e.g., reqwest).  Builders on the flows side populate the context before
48/// invoking [`ProviderStrategy::classify_token_error`].
49#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct ProviderErrorContext {
51	/// Grant type associated with the failing request.
52	pub grant_type: GrantType,
53	/// HTTP status code returned by the provider, when available.
54	pub http_status: Option<u16>,
55	/// Provider-supplied OAuth `error` field.
56	pub oauth_error: Option<String>,
57	/// Provider-supplied OAuth `error_description` field.
58	pub error_description: Option<String>,
59	/// Preview of the response body for non-JSON payloads.
60	pub body_preview: Option<String>,
61	/// Indicates whether the failure originated from the network/transport layer.
62	pub network_error: bool,
63}
64impl ProviderErrorContext {
65	const BODY_PREVIEW_LIMIT: usize = 256;
66
67	/// Creates a new context scoped to the provided grant type.
68	pub fn new(grant_type: GrantType) -> Self {
69		Self {
70			grant_type,
71			http_status: None,
72			oauth_error: None,
73			error_description: None,
74			body_preview: None,
75			network_error: false,
76		}
77	}
78
79	/// Convenience constructor for transport-level/network failures.
80	pub fn network_failure(grant_type: GrantType) -> Self {
81		let mut ctx = Self::new(grant_type);
82
83		ctx.network_error = true;
84
85		ctx
86	}
87
88	/// Overrides the network error flag.
89	pub fn with_network_error(mut self, network_error: bool) -> Self {
90		self.network_error = network_error;
91
92		self
93	}
94
95	/// Adds an HTTP status code (e.g., 400, 401, 500).
96	pub fn with_http_status(mut self, status: u16) -> Self {
97		self.http_status = Some(status);
98
99		self
100	}
101
102	/// Adds the OAuth error code string returned by the provider.
103	pub fn with_oauth_error(mut self, error: impl Into<String>) -> Self {
104		self.oauth_error = Some(error.into());
105
106		self
107	}
108
109	/// Adds the OAuth `error_description` field.
110	pub fn with_error_description(mut self, description: impl Into<String>) -> Self {
111		self.error_description = Some(description.into());
112
113		self
114	}
115
116	/// Adds a body preview for providers that return non-JSON payloads.
117	pub fn with_body_preview(mut self, body: impl Into<String>) -> Self {
118		self.body_preview = Some(truncate_preview(body.into()));
119
120		self
121	}
122}
123
124/// Default strategy that applies RFC-guided heuristics.
125///
126/// It prioritizes structured OAuth fields (`error`, `error_description`), then
127/// falls back to body text hints, and finally the HTTP status code.  Network
128/// failures are always treated as transient.
129#[derive(Debug, Default)]
130pub struct DefaultProviderStrategy;
131impl Display for DefaultProviderStrategy {
132	fn fmt(&self, f: &mut Formatter) -> FmtResult {
133		f.write_str("default-provider-strategy")
134	}
135}
136impl ProviderStrategy for DefaultProviderStrategy {
137	fn classify_token_error(&self, ctx: &ProviderErrorContext) -> ProviderErrorKind {
138		if ctx.network_error {
139			return ProviderErrorKind::Transient;
140		}
141
142		if let Some(kind) =
143			classify_oauth_error(ctx.oauth_error.as_deref(), ctx.error_description.as_deref())
144		{
145			return kind;
146		}
147		if let Some(kind) = classify_body(ctx.body_preview.as_deref()) {
148			return kind;
149		}
150
151		classify_status(ctx.http_status)
152	}
153}
154
155fn truncate_preview(body: String) -> String {
156	if body.chars().count() <= ProviderErrorContext::BODY_PREVIEW_LIMIT {
157		return body;
158	}
159
160	let mut buf = String::new();
161
162	for (idx, ch) in body.chars().enumerate() {
163		if idx >= ProviderErrorContext::BODY_PREVIEW_LIMIT {
164			buf.push('…');
165
166			break;
167		}
168		buf.push(ch);
169	}
170
171	buf
172}
173
174fn classify_oauth_error(
175	oauth_error: Option<&str>,
176	error_description: Option<&str>,
177) -> Option<ProviderErrorKind> {
178	oauth_error
179		.and_then(match_exact_value)
180		.or_else(|| error_description.and_then(match_exact_value))
181		.or_else(|| classify_body(error_description))
182}
183
184fn match_exact_value(value: &str) -> Option<ProviderErrorKind> {
185	if value.eq_ignore_ascii_case("invalid_grant") || value.eq_ignore_ascii_case("access_denied") {
186		Some(ProviderErrorKind::InvalidGrant)
187	} else if value.eq_ignore_ascii_case("invalid_client")
188		|| value.eq_ignore_ascii_case("unauthorized_client")
189	{
190		Some(ProviderErrorKind::InvalidClient)
191	} else if value.eq_ignore_ascii_case("invalid_scope")
192		|| value.eq_ignore_ascii_case("insufficient_scope")
193	{
194		Some(ProviderErrorKind::InsufficientScope)
195	} else if value.eq_ignore_ascii_case("temporarily_unavailable")
196		|| value.eq_ignore_ascii_case("server_error")
197	{
198		Some(ProviderErrorKind::Transient)
199	} else {
200		None
201	}
202}
203
204fn classify_body(body: Option<&str>) -> Option<ProviderErrorKind> {
205	let body = body?;
206	let lowered = body.to_ascii_lowercase();
207
208	match lowered.as_str() {
209		text if text.contains("invalid_grant") => Some(ProviderErrorKind::InvalidGrant),
210		text if text.contains("invalid_client") => Some(ProviderErrorKind::InvalidClient),
211		text if text.contains("insufficient_scope") || text.contains("invalid_scope") =>
212			Some(ProviderErrorKind::InsufficientScope),
213		text if text.contains("temporarily_unavailable") || text.contains("retry") =>
214			Some(ProviderErrorKind::Transient),
215		_ => None,
216	}
217}
218
219fn classify_status(status: Option<u16>) -> ProviderErrorKind {
220	match status {
221		Some(400 | 404 | 410) => ProviderErrorKind::InvalidGrant,
222		Some(401) => ProviderErrorKind::InvalidClient,
223		Some(403) => ProviderErrorKind::InsufficientScope,
224		Some(429) => ProviderErrorKind::Transient,
225		Some(code) if code >= 500 => ProviderErrorKind::Transient,
226		_ => ProviderErrorKind::Transient,
227	}
228}