Skip to main content

faucet_source_rest/auth/
mod.rs

1//! Authentication strategies for REST APIs.
2
3pub mod api_key;
4pub mod basic;
5pub mod bearer;
6pub mod custom;
7pub mod oauth2;
8pub mod token_endpoint;
9
10use faucet_core::FaucetError;
11use reqwest::header::HeaderMap;
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15pub use token_endpoint::ResponseValidator;
16
17/// Supported authentication methods.
18#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
19#[serde(tag = "type", content = "config", rename_all = "snake_case")]
20pub enum Auth {
21    None,
22    /// Bearer token in the `Authorization` header.
23    Bearer {
24        token: String,
25    },
26    Basic {
27        username: String,
28        password: String,
29    },
30    /// API key sent in a request header.
31    ApiKey {
32        header: String,
33        value: String,
34    },
35    /// API key sent as a query parameter (e.g. `?api_key=secret`).
36    ///
37    /// Some APIs require the key in the URL rather than a header. The `param`
38    /// field is the query parameter name, and `value` is the key itself.
39    ApiKeyQuery {
40        param: String,
41        value: String,
42    },
43    #[serde(rename = "oauth2")]
44    OAuth2 {
45        token_url: String,
46        client_id: String,
47        client_secret: String,
48        scopes: Vec<String>,
49        /// Fraction of `expires_in` after which the cached token is considered
50        /// expired and a new one is fetched. Must be in `(0.0, 1.0]`.
51        /// Defaults to `0.9` (refresh after 90 % of the token lifetime).
52        expiry_ratio: f64,
53    },
54    /// Fetch a token from an arbitrary HTTP endpoint.
55    ///
56    /// The endpoint is called, the token is extracted from the JSON response
57    /// using `token_path` (a JSONPath expression), and then used as a Bearer
58    /// token (or in a custom header if `header_name` is set).
59    ///
60    /// Tokens are cached and refreshed automatically when `expiry_path`
61    /// is provided and the server returns an expiry value.
62    TokenEndpoint {
63        /// URL of the token endpoint.
64        url: String,
65        /// HTTP method for the token request (e.g. GET, POST).
66        #[serde(with = "crate::serde_helpers::http_method")]
67        #[schemars(with = "String")]
68        method: reqwest::Method,
69        /// Headers to send with the token request (e.g. API keys, content type).
70        #[serde(skip, default)]
71        headers: HeaderMap,
72        /// Optional JSON body for the token request.
73        body: Option<serde_json::Value>,
74        /// JSONPath expression to extract the token string from the response.
75        token_path: String,
76        /// Optional JSONPath expression to extract the expiry (in seconds)
77        /// from the response. When absent, the token is cached indefinitely.
78        expiry_path: Option<String>,
79        /// Fraction of the expiry after which the token is proactively refreshed.
80        /// Must be in `(0.0, 1.0]`. Defaults to `0.9`.
81        expiry_ratio: f64,
82        /// Optional callback to decide whether the token endpoint response is
83        /// successful. Receives the HTTP status code. When `None`, defaults to
84        /// `status.is_success()` (2xx).
85        #[serde(skip, default)]
86        response_validator: Option<ResponseValidator>,
87    },
88    /// Arbitrary headers attached to every request (e.g. multi-tenant routing,
89    /// API keys split across several headers).
90    Custom {
91        headers: HashMap<String, String>,
92    },
93}
94
95impl Auth {
96    /// Apply header-based auth to the request headers.
97    ///
98    /// `ApiKeyQuery` is a no-op here — it is applied as a query parameter by
99    /// `RestStream::execute_request` instead.
100    pub fn apply(&self, headers: &mut HeaderMap) -> Result<(), FaucetError> {
101        match self {
102            Auth::None | Auth::ApiKeyQuery { .. } => Ok(()),
103            Auth::Bearer { token } => bearer::apply(headers, token),
104            Auth::Basic { username, password } => basic::apply(headers, username, password),
105            Auth::ApiKey { header, value } => api_key::apply(headers, header, value),
106            // OAuth2 is resolved to Auth::Bearer by RestStream before apply() is called.
107            // If apply() is reached with an OAuth2 variant, it means the caller bypassed
108            // RestStream — return a clear error rather than silently sending no auth.
109            Auth::OAuth2 { .. } => Err(FaucetError::Auth(
110                "OAuth2 auth must be resolved to a bearer token before applying; \
111                 use RestStream (which resolves it automatically) or call \
112                 fetch_oauth2_token() and construct Auth::Bearer { token } directly"
113                    .into(),
114            )),
115            // TokenEndpoint is resolved to Auth::Bearer by RestStream before apply().
116            Auth::TokenEndpoint { .. } => Err(FaucetError::Auth(
117                "TokenEndpoint auth must be resolved to a bearer token before applying; \
118                 use RestStream (which resolves it automatically) or call \
119                 fetch_token_from_endpoint() and construct Auth::Bearer { token } directly"
120                    .into(),
121            )),
122            Auth::Custom { headers: extra } => custom::apply(headers, extra),
123        }
124    }
125}
126
127pub use oauth2::fetch_oauth2_token;
128pub use token_endpoint::fetch_token_from_endpoint;
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn auth_serializes_as_type_config() {
136        let a = Auth::Bearer { token: "t".into() };
137        let v = serde_json::to_value(&a).unwrap();
138        assert_eq!(
139            v,
140            serde_json::json!({"type": "bearer", "config": {"token": "t"}})
141        );
142        let back: Auth = serde_json::from_value(v).unwrap();
143        assert!(matches!(back, Auth::Bearer { token } if token == "t"));
144    }
145
146    #[test]
147    fn auth_unit_variant_has_no_config() {
148        let v = serde_json::to_value(Auth::None).unwrap();
149        assert_eq!(v, serde_json::json!({"type": "none"}));
150    }
151
152    #[test]
153    fn auth_snake_case_discriminators() {
154        let a = Auth::ApiKey {
155            header: "X-Key".into(),
156            value: "v".into(),
157        };
158        let v = serde_json::to_value(&a).unwrap();
159        assert_eq!(v["type"], "api_key");
160        assert_eq!(v["config"]["header"], "X-Key");
161    }
162
163    #[test]
164    fn auth_none_is_noop() {
165        let mut headers = HeaderMap::new();
166        Auth::None.apply(&mut headers).unwrap();
167        assert!(headers.is_empty());
168    }
169
170    #[test]
171    fn auth_bearer_sets_authorization_header() {
172        let mut headers = HeaderMap::new();
173        Auth::Bearer {
174            token: "my-token".into(),
175        }
176        .apply(&mut headers)
177        .unwrap();
178        assert_eq!(headers.get("authorization").unwrap(), "Bearer my-token");
179    }
180
181    #[test]
182    fn auth_basic_sets_authorization_header() {
183        let mut headers = HeaderMap::new();
184        Auth::Basic {
185            username: "user".into(),
186            password: "pass".into(),
187        }
188        .apply(&mut headers)
189        .unwrap();
190        let value = headers.get("authorization").unwrap().to_str().unwrap();
191        assert!(value.starts_with("Basic "));
192    }
193
194    #[test]
195    fn auth_api_key_sets_custom_header() {
196        let mut headers = HeaderMap::new();
197        Auth::ApiKey {
198            header: "X-Api-Key".into(),
199            value: "secret".into(),
200        }
201        .apply(&mut headers)
202        .unwrap();
203        assert_eq!(headers.get("x-api-key").unwrap(), "secret");
204    }
205
206    #[test]
207    fn auth_api_key_query_is_noop_on_apply() {
208        let mut headers = HeaderMap::new();
209        Auth::ApiKeyQuery {
210            param: "api_key".into(),
211            value: "secret".into(),
212        }
213        .apply(&mut headers)
214        .unwrap();
215        assert!(headers.is_empty());
216    }
217
218    #[test]
219    fn auth_oauth2_errors_on_direct_apply() {
220        let mut headers = HeaderMap::new();
221        let result = Auth::OAuth2 {
222            token_url: "https://auth.example.com/token".into(),
223            client_id: "id".into(),
224            client_secret: "secret".into(),
225            scopes: vec![],
226            expiry_ratio: 0.9,
227        }
228        .apply(&mut headers);
229        assert!(result.is_err());
230        assert!(matches!(result, Err(FaucetError::Auth(_))));
231    }
232
233    #[test]
234    fn auth_token_endpoint_errors_on_direct_apply() {
235        let mut headers = HeaderMap::new();
236        let result = Auth::TokenEndpoint {
237            url: "https://auth.example.com/token".into(),
238            method: reqwest::Method::POST,
239            headers: HeaderMap::new(),
240            body: None,
241            token_path: "$.token".into(),
242            expiry_path: None,
243            expiry_ratio: 0.9,
244            response_validator: None,
245        }
246        .apply(&mut headers);
247        assert!(result.is_err());
248        assert!(matches!(result, Err(FaucetError::Auth(_))));
249    }
250
251    #[test]
252    fn auth_custom_headers() {
253        let mut headers = HeaderMap::new();
254        let custom = Auth::Custom {
255            headers: [("x-custom".to_string(), "value".to_string())]
256                .into_iter()
257                .collect(),
258        };
259        custom.apply(&mut headers).unwrap();
260        assert_eq!(headers.get("x-custom").unwrap(), "value");
261    }
262
263    #[test]
264    fn auth_custom_round_trips_through_json() {
265        let auth = Auth::Custom {
266            headers: [
267                ("x-tenant".to_string(), "acme".to_string()),
268                ("x-region".to_string(), "us".to_string()),
269            ]
270            .into_iter()
271            .collect(),
272        };
273        let json = serde_json::to_value(&auth).unwrap();
274        let restored: Auth = serde_json::from_value(json).unwrap();
275        let mut headers = HeaderMap::new();
276        restored.apply(&mut headers).unwrap();
277        assert_eq!(headers.get("x-tenant").unwrap(), "acme");
278        assert_eq!(headers.get("x-region").unwrap(), "us");
279    }
280
281    #[test]
282    fn auth_bearer_round_trips_through_json() {
283        let auth = Auth::Bearer {
284            token: "tok".into(),
285        };
286        let json = serde_json::to_value(&auth).unwrap();
287        let restored: Auth = serde_json::from_value(json).unwrap();
288        let mut headers = HeaderMap::new();
289        restored.apply(&mut headers).unwrap();
290        assert_eq!(headers.get("authorization").unwrap(), "Bearer tok");
291    }
292
293    #[test]
294    fn auth_debug_format() {
295        let auth = Auth::None;
296        assert_eq!(format!("{auth:?}"), "None");
297
298        let auth = Auth::Bearer {
299            token: "tok".into(),
300        };
301        let debug = format!("{auth:?}");
302        assert!(debug.contains("Bearer"));
303    }
304
305    #[test]
306    fn auth_clone() {
307        let auth = Auth::Bearer {
308            token: "token".into(),
309        };
310        let cloned = auth.clone();
311        let mut h = HeaderMap::new();
312        cloned.apply(&mut h).unwrap();
313        assert_eq!(h.get("authorization").unwrap(), "Bearer token");
314    }
315}