reqwest_negotiate/
lib.rs

1//! Kerberos/SPNEGO Negotiate authentication for reqwest.
2//!
3//! This crate provides an extension trait for [`reqwest::RequestBuilder`] that adds
4//! Kerberos SPNEGO (Negotiate) authentication support using the system's GSSAPI library.
5//!
6//! # Prerequisites
7//!
8//! - A valid Kerberos ticket (obtained via `kinit` or similar)
9//! - GSSAPI libraries installed on your system (`libgssapi_krb5` on Linux, Heimdal on macOS)
10//!
11//! # Basic Example
12//!
13//! ```no_run
14//! use reqwest_negotiate::NegotiateAuthExt;
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//!     let client = reqwest::Client::new();
19//!
20//!     let response = client
21//!         .get("https://api.example.com/protected")
22//!         .negotiate_auth()? // Uses default credentials and derives SPN from URL
23//!         .send()
24//!         .await?;
25//!
26//!     println!("Status: {}", response.status());
27//!     Ok(())
28//! }
29//! ```
30//!
31//! # Mutual Authentication
32//!
33//! For high-security environments, you can verify the server's identity:
34//!
35//! ```no_run
36//! use reqwest_negotiate::NegotiateAuthExt;
37//!
38//! #[tokio::main]
39//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
40//!     let client = reqwest::Client::new();
41//!
42//!     let (builder, mut ctx) = client
43//!         .get("https://api.example.com/protected")
44//!         .negotiate_auth_mutual()?;
45//!
46//!     let response = builder.send().await?;
47//!
48//!     // Verify the server proved its identity
49//!     ctx.verify_response(&response)?;
50//!
51//!     println!("Status: {}", response.status());
52//!     Ok(())
53//! }
54//! ```
55//!
56//! # Custom Service Principal
57//!
58//! If the service principal name (SPN) differs from the standard `HTTP/<hostname>`:
59//!
60//! ```no_run
61//! use reqwest_negotiate::NegotiateAuthExt;
62//!
63//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
64//! let client = reqwest::Client::new();
65//!
66//! let response = client
67//!     .get("https://api.example.com/protected")
68//!     .negotiate_auth_with_spn("HTTP/custom.principal@REALM.COM")?
69//!     .send()
70//!     .await?;
71//! # Ok(())
72//! # }
73//! ```
74
75use base64::Engine;
76use base64::engine::general_purpose::STANDARD as BASE64;
77use cross_krb5::{ClientCtx, InitiateFlags, PendingClientCtx, Step};
78use reqwest::header::{AUTHORIZATION, HeaderValue, WWW_AUTHENTICATE};
79use reqwest::{RequestBuilder, Response};
80
81/// Extracts the service principal name from a URL host.
82///
83/// Returns the SPN in the format `HTTP/<hostname>`.
84fn spn_from_host(host: &str) -> String {
85    format!("HTTP/{}", host)
86}
87
88/// Parses a Negotiate token from a WWW-Authenticate header value.
89///
90/// Returns the decoded token bytes, or an error if the header is malformed.
91fn parse_negotiate_header(header_value: &str) -> Result<Vec<u8>, NegotiateError> {
92    let token_b64 = header_value
93        .strip_prefix("Negotiate ")
94        .ok_or(NegotiateError::InvalidTokenFormat)?;
95
96    if token_b64.is_empty() {
97        return Err(NegotiateError::MissingMutualAuthToken);
98    }
99
100    BASE64
101        .decode(token_b64)
102        .map_err(|_| NegotiateError::InvalidTokenFormat)
103}
104
105/// Errors that can occur during Negotiate authentication.
106#[derive(Debug, thiserror::Error)]
107pub enum NegotiateError {
108    /// Failed to create the service principal name.
109    #[error("failed to create service name: {0}")]
110    NameError(String),
111
112    /// Failed to acquire Kerberos credentials.
113    #[error("failed to acquire credentials: {0}")]
114    CredentialError(String),
115
116    /// Failed to initialize or step the security context.
117    #[error("failed to initialize security context: {0}")]
118    ContextError(String),
119
120    /// The request URL is missing a host component.
121    #[error("request URL is missing host")]
122    MissingHost,
123
124    /// The request could not be built.
125    #[error("failed to build request: {0}")]
126    BuildError(#[from] reqwest::Error),
127
128    /// Server did not provide a mutual authentication token.
129    #[error("server did not provide mutual authentication token")]
130    MissingMutualAuthToken,
131
132    /// Failed to verify the server's authentication token.
133    #[error("failed to verify server token: {0}")]
134    MutualAuthFailed(String),
135
136    /// Invalid token format in server response.
137    #[error("invalid token format in WWW-Authenticate header")]
138    InvalidTokenFormat,
139}
140
141/// Internal state for the context - either pending or complete.
142enum ContextState {
143    Pending(PendingClientCtx),
144    Complete(ClientCtx),
145}
146
147/// Holds the GSSAPI context for mutual authentication verification.
148///
149/// After sending a request with [`NegotiateAuthExt::negotiate_auth_mutual`],
150/// use this context to verify the server's response token.
151pub struct NegotiateContext {
152    state: Option<ContextState>,
153}
154
155impl NegotiateContext {
156    /// Verifies the server's mutual authentication token from the response.
157    ///
158    /// Call this after receiving a response to confirm the server's identity.
159    /// The server's token is extracted from the `WWW-Authenticate: Negotiate <token>` header.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if:
164    /// - The response doesn't contain a `WWW-Authenticate: Negotiate` header
165    /// - The token is malformed
166    /// - The server's identity cannot be verified
167    ///
168    /// # Example
169    ///
170    /// ```no_run
171    /// use reqwest_negotiate::NegotiateAuthExt;
172    ///
173    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
174    /// let client = reqwest::Client::new();
175    ///
176    /// let (builder, mut ctx) = client
177    ///     .get("https://api.example.com/protected")
178    ///     .negotiate_auth_mutual()?;
179    ///
180    /// let response = builder.send().await?;
181    /// ctx.verify_response(&response)?;
182    /// # Ok(())
183    /// # }
184    /// ```
185    pub fn verify_response(&mut self, response: &Response) -> Result<(), NegotiateError> {
186        let header = response
187            .headers()
188            .get(WWW_AUTHENTICATE)
189            .ok_or(NegotiateError::MissingMutualAuthToken)?;
190
191        let header_str = header
192            .to_str()
193            .map_err(|_| NegotiateError::InvalidTokenFormat)?;
194
195        let token = parse_negotiate_header(header_str)?;
196
197        // Take ownership of the state to step it
198        let current_state = self.state.take().ok_or(NegotiateError::ContextError(
199            "context already consumed".into(),
200        ))?;
201
202        match current_state {
203            ContextState::Pending(pending) => match pending.step(&token) {
204                Ok(Step::Continue((new_pending, _))) => {
205                    self.state = Some(ContextState::Pending(new_pending));
206                    Ok(())
207                }
208                Ok(Step::Finished((ctx, _))) => {
209                    self.state = Some(ContextState::Complete(ctx));
210                    Ok(())
211                }
212                Err(e) => Err(NegotiateError::MutualAuthFailed(e.to_string())),
213            },
214            ContextState::Complete(ctx) => {
215                // Already complete, restore state
216                self.state = Some(ContextState::Complete(ctx));
217                Ok(())
218            }
219        }
220    }
221
222    /// Checks if the security context is fully established.
223    ///
224    /// Returns `true` if mutual authentication is complete.
225    pub fn is_complete(&self) -> bool {
226        matches!(self.state, Some(ContextState::Complete(_)))
227    }
228}
229
230/// Extension trait that adds Negotiate authentication to [`reqwest::RequestBuilder`].
231pub trait NegotiateAuthExt {
232    /// Adds Negotiate authentication using the default Kerberos credentials.
233    ///
234    /// The service principal name (SPN) is derived from the request URL as `HTTP/<hostname>`.
235    ///
236    /// This method does not verify the server's identity. For mutual authentication,
237    /// use [`negotiate_auth_mutual`](Self::negotiate_auth_mutual) instead.
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if:
242    /// - The URL has no host component
243    /// - No valid Kerberos credentials are available
244    /// - The GSSAPI context initialization fails
245    fn negotiate_auth(self) -> Result<RequestBuilder, NegotiateError>;
246
247    /// Adds Negotiate authentication with a custom service principal name.
248    ///
249    /// Use this when the service is registered with a non-standard SPN.
250    ///
251    /// This method does not verify the server's identity. For mutual authentication,
252    /// use [`negotiate_auth_mutual_with_spn`](Self::negotiate_auth_mutual_with_spn) instead.
253    ///
254    /// # Arguments
255    ///
256    /// * `spn` - The service principal name (e.g., `HTTP/service.example.com@REALM.COM`)
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if:
261    /// - No valid Kerberos credentials are available
262    /// - The GSSAPI context initialization fails
263    fn negotiate_auth_with_spn(self, spn: &str) -> Result<RequestBuilder, NegotiateError>;
264
265    /// Adds Negotiate authentication and returns a context for mutual authentication.
266    ///
267    /// The service principal name (SPN) is derived from the request URL as `HTTP/<hostname>`.
268    ///
269    /// After sending the request, call [`NegotiateContext::verify_response`] to verify
270    /// the server's identity.
271    ///
272    /// # Errors
273    ///
274    /// Returns an error if:
275    /// - The URL has no host component
276    /// - No valid Kerberos credentials are available
277    /// - The GSSAPI context initialization fails
278    ///
279    /// # Example
280    ///
281    /// ```no_run
282    /// use reqwest_negotiate::NegotiateAuthExt;
283    ///
284    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
285    /// let client = reqwest::Client::new();
286    ///
287    /// let (builder, mut ctx) = client
288    ///     .get("https://api.example.com/protected")
289    ///     .negotiate_auth_mutual()?;
290    ///
291    /// let response = builder.send().await?;
292    /// ctx.verify_response(&response)?;
293    /// # Ok(())
294    /// # }
295    /// ```
296    fn negotiate_auth_mutual(self) -> Result<(RequestBuilder, NegotiateContext), NegotiateError>;
297
298    /// Adds Negotiate authentication with a custom SPN and returns a context for mutual authentication.
299    ///
300    /// After sending the request, call [`NegotiateContext::verify_response`] to verify
301    /// the server's identity.
302    ///
303    /// # Arguments
304    ///
305    /// * `spn` - The service principal name (e.g., `HTTP/service.example.com@REALM.COM`)
306    ///
307    /// # Errors
308    ///
309    /// Returns an error if:
310    /// - No valid Kerberos credentials are available
311    /// - The GSSAPI context initialization fails
312    fn negotiate_auth_mutual_with_spn(
313        self,
314        spn: &str,
315    ) -> Result<(RequestBuilder, NegotiateContext), NegotiateError>;
316}
317
318impl NegotiateAuthExt for RequestBuilder {
319    fn negotiate_auth(self) -> Result<RequestBuilder, NegotiateError> {
320        let (builder, _ctx) = self.negotiate_auth_mutual()?;
321        Ok(builder)
322    }
323
324    fn negotiate_auth_with_spn(self, spn: &str) -> Result<RequestBuilder, NegotiateError> {
325        let (builder, _ctx) = self.negotiate_auth_mutual_with_spn(spn)?;
326        Ok(builder)
327    }
328
329    fn negotiate_auth_mutual(self) -> Result<(RequestBuilder, NegotiateContext), NegotiateError> {
330        // Build a temporary copy to inspect the URL
331        let request = self
332            .try_clone()
333            .ok_or_else(|| NegotiateError::ContextError("request body not clonable".into()))?
334            .build()?;
335
336        let host = request
337            .url()
338            .host_str()
339            .ok_or(NegotiateError::MissingHost)?;
340        let spn = spn_from_host(host);
341
342        add_negotiate_header_with_ctx(self, &spn)
343    }
344
345    fn negotiate_auth_mutual_with_spn(
346        self,
347        spn: &str,
348    ) -> Result<(RequestBuilder, NegotiateContext), NegotiateError> {
349        add_negotiate_header_with_ctx(self, spn)
350    }
351}
352
353fn add_negotiate_header_with_ctx(
354    builder: RequestBuilder,
355    spn: &str,
356) -> Result<(RequestBuilder, NegotiateContext), NegotiateError> {
357    let (token, ctx) = generate_negotiate_token_with_ctx(spn)?;
358    let header_value = format!("Negotiate {}", BASE64.encode(&token));
359
360    let builder = builder.header(
361        AUTHORIZATION,
362        HeaderValue::from_str(&header_value)
363            .map_err(|e| NegotiateError::ContextError(e.to_string()))?,
364    );
365
366    Ok((builder, ctx))
367}
368
369fn generate_negotiate_token_with_ctx(
370    spn: &str,
371) -> Result<(Vec<u8>, NegotiateContext), NegotiateError> {
372    // Initialize the client context - cross-krb5 handles credential acquisition
373    // ClientCtx::new returns (PendingClientCtx, initial_token)
374    let (pending_ctx, token) = ClientCtx::new(InitiateFlags::empty(), None, spn, None)
375        .map_err(|e| NegotiateError::ContextError(e.to_string()))?;
376
377    let ctx = NegotiateContext {
378        state: Some(ContextState::Pending(pending_ctx)),
379    };
380
381    Ok((token.to_vec(), ctx))
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    mod spn_derivation {
389        use super::*;
390
391        #[test]
392        fn simple_hostname() {
393            assert_eq!(spn_from_host("api.example.com"), "HTTP/api.example.com");
394        }
395
396        #[test]
397        fn hostname_with_subdomain() {
398            assert_eq!(
399                spn_from_host("service.internal.example.com"),
400                "HTTP/service.internal.example.com"
401            );
402        }
403
404        #[test]
405        fn localhost() {
406            assert_eq!(spn_from_host("localhost"), "HTTP/localhost");
407        }
408
409        #[test]
410        fn ip_address() {
411            assert_eq!(spn_from_host("192.168.1.1"), "HTTP/192.168.1.1");
412        }
413    }
414
415    mod negotiate_header_parsing {
416        use super::*;
417
418        #[test]
419        fn valid_token() {
420            // "hello world" base64 encoded
421            let header = "Negotiate aGVsbG8gd29ybGQ=";
422            let result = parse_negotiate_header(header).unwrap();
423            assert_eq!(result, b"hello world");
424        }
425
426        #[test]
427        fn valid_empty_looking_but_valid_base64() {
428            // Single byte base64 encoded
429            let header = "Negotiate QQ==";
430            let result = parse_negotiate_header(header).unwrap();
431            assert_eq!(result, b"A");
432        }
433
434        #[test]
435        fn missing_negotiate_prefix() {
436            let header = "Basic dXNlcjpwYXNz";
437            let result = parse_negotiate_header(header);
438            assert!(matches!(result, Err(NegotiateError::InvalidTokenFormat)));
439        }
440
441        #[test]
442        fn wrong_prefix_case() {
443            // GSSAPI is case-sensitive; "negotiate" != "Negotiate"
444            let header = "negotiate aGVsbG8gd29ybGQ=";
445            let result = parse_negotiate_header(header);
446            assert!(matches!(result, Err(NegotiateError::InvalidTokenFormat)));
447        }
448
449        #[test]
450        fn empty_token_after_prefix() {
451            let header = "Negotiate ";
452            let result = parse_negotiate_header(header);
453            assert!(matches!(
454                result,
455                Err(NegotiateError::MissingMutualAuthToken)
456            ));
457        }
458
459        #[test]
460        fn invalid_base64() {
461            let header = "Negotiate !!!not-valid-base64!!!";
462            let result = parse_negotiate_header(header);
463            assert!(matches!(result, Err(NegotiateError::InvalidTokenFormat)));
464        }
465
466        #[test]
467        fn negotiate_only_no_space() {
468            let header = "Negotiate";
469            let result = parse_negotiate_header(header);
470            assert!(matches!(result, Err(NegotiateError::InvalidTokenFormat)));
471        }
472
473        #[test]
474        fn binary_token_roundtrip() {
475            // Simulate a realistic SPNEGO token (just random bytes for testing)
476            let original_bytes: Vec<u8> = vec![0x60, 0x82, 0x01, 0x00, 0x06, 0x09, 0x2a];
477            let encoded = BASE64.encode(&original_bytes);
478            let header = format!("Negotiate {}", encoded);
479
480            let result = parse_negotiate_header(&header).unwrap();
481            assert_eq!(result, original_bytes);
482        }
483    }
484
485    mod error_display {
486        use super::*;
487
488        #[test]
489        fn name_error_displays_context() {
490            let err = NegotiateError::NameError("invalid principal".to_string());
491            let display = format!("{}", err);
492            assert!(display.contains("service name"));
493            assert!(display.contains("invalid principal"));
494        }
495
496        #[test]
497        fn credential_error_displays_context() {
498            let err = NegotiateError::CredentialError("no credentials".to_string());
499            let display = format!("{}", err);
500            assert!(display.contains("credentials"));
501            assert!(display.contains("no credentials"));
502        }
503
504        #[test]
505        fn missing_host_is_descriptive() {
506            let err = NegotiateError::MissingHost;
507            let display = format!("{}", err);
508            assert!(display.contains("host"));
509        }
510
511        #[test]
512        fn mutual_auth_failed_includes_reason() {
513            let err = NegotiateError::MutualAuthFailed("token expired".to_string());
514            let display = format!("{}", err);
515            assert!(display.contains("token expired"));
516        }
517    }
518
519    mod error_traits {
520        use super::*;
521
522        #[test]
523        fn errors_are_send() {
524            fn assert_send<T: Send>() {}
525            assert_send::<NegotiateError>();
526        }
527
528        #[test]
529        fn errors_are_sync() {
530            fn assert_sync<T: Sync>() {}
531            assert_sync::<NegotiateError>();
532        }
533
534        #[test]
535        fn errors_implement_std_error() {
536            fn assert_error<T: std::error::Error>() {}
537            assert_error::<NegotiateError>();
538        }
539    }
540}