Skip to main content

axum_negotiate_layer/
lib.rs

1//! axum-negotiate-layer provides middleware for authenticating connections over the Microsoft "HTTP Negotiate" extension.
2//!
3//! # Features
4//!
5//! - [`NegotiateMiddleware`]: A [`tower::Service`] object that uses the [`NegotiateInfo`] attached to the connection to authenticate that connection
6//! - [`NegotiateLayer`]: A [`tower::Layer`] for the above mentioned service
7//! - A [`Authenticated`] request extension object to get information about authenticated clients (so far only the user identity)
8//! - An extension to the standard [`axum::serve::Listener`] (with feature `http1`) to add negotiation info to every connection.
9//!   As SPNEGO is a non-http standard authentication method authenticating by connection, the negotiation info has to be included in every
10//!   connection given to axum, either via this struct or by manually providing it as a `ConnectInfo` extension when driving the routing loop yourself.
11//!
12//! # Usage
13//! The middleware and layer require the Kerberos SPN for the Router in question.
14//!
15//! ```rust
16//! use axum::{routing::get, Extension, Router};
17//! use axum_negotiate_layer::{Authenticated, NegotiateInfo, NegotiateLayer, AddNegotiateInfo};
18//! use tokio::net::TcpListener;
19//!
20//! #[tokio::main]
21//! async fn main() {
22//!     let router = Router::new()
23//!         .route("/", get(hello))
24//!         .layer(NegotiateLayer::new(Some("HTTP/example.com")))
25//!         .into_make_service_with_connect_info::<NegotiateInfo>();
26//!     let listener = TcpListener::bind("127.0.0.1:80").await.unwrap();
27//! }
28//! # async fn hello() {}
29//! ```
30//!
31//! The most convenient use case shown above will use the layer object to verify all routes above it are authenticated.
32//! The [`Router::into_make_service_with_connect_info`](axum::Router::into_make_service_with_connect_info) call is mandatory for this layer to work
33//! on the used Router, otherwise the layer will panic.
34//!
35//! ## Axum handler usage example
36//!
37//! ```rust
38//! # use axum_negotiate_layer::Authenticated;
39//! async fn hello(a: Authenticated) -> String {
40//!     format!("Hello, {}!", a.client().unwrap_or("whoever".to_owned()))
41//! }
42//! ```
43//!
44//! Alternatively, this works:
45//! ```rust
46//! # use axum::Extension;
47//! # use axum_negotiate_layer::Authenticated;
48//! async fn hello(Extension(a): Extension<Authenticated>) -> String {
49//!     format!("Hello, {}!", a.client().unwrap_or("whoever".to_owned()))
50//! }
51//! ```
52//!
53//! When getting the [`Authenticated`] object from the request extension or extracting it directly, the authentication can be guaranteed for this route, as this object can
54//! only be set by a middleware of this crate.
55use axum::{
56    extract::{ConnectInfo, FromRequestParts, Request, connect_info::Connected},
57    http::{
58        HeaderMap, HeaderValue, StatusCode,
59        header::{AUTHORIZATION, CONNECTION, WWW_AUTHENTICATE},
60        request::Parts,
61    },
62    response::{IntoResponse, Response},
63};
64use base64::{Engine, prelude::BASE64_STANDARD};
65use futures_util::future::BoxFuture;
66use kenobi::{
67    channel_bindings::Channel,
68    cred::{Credentials, Inbound},
69    mech::Mechanism,
70    server::{PendingServerContext, ServerBuilder, ServerContext},
71};
72use sspi::handle_sspi;
73use std::{
74    convert::Infallible,
75    fmt::Debug,
76    ops::{Deref, DerefMut},
77    sync::{Arc, RwLock},
78    task::Poll,
79};
80use tower::{Layer, Service};
81
82#[cfg(feature = "http1")]
83mod listener;
84mod sspi;
85#[cfg(feature = "http1")]
86pub use listener::{HasNegotiateInfo, Negotiator, WithNegotiateInfo};
87
88#[derive(Default)]
89enum NegotiateState {
90    #[default]
91    Unauthorized,
92    Pending(PendingServerContext<Inbound>),
93    Authenticated(ServerContext<Inbound>),
94}
95impl NegotiateState {
96    fn is_authenticated(&self) -> bool {
97        matches!(self, Self::Authenticated(_))
98    }
99}
100impl Debug for NegotiateState {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        match self {
103            Self::Authenticated(_) => f.write_str("Authenticated"),
104            Self::Pending(_) => f.write_str("Pending"),
105            Self::Unauthorized => f.write_str("Unauthenticated"),
106        }
107    }
108}
109
110/// [`Extension`](axum::Extension) or Extractor type that gets set after successful Authentication
111// This struct can only be created by the middleware in this crate or cloned from an
112// existing one. Extracting it directly panics when the Layer has not been applied yet.
113#[derive(Debug, Clone)]
114pub struct Authenticated(Arc<RwLock<NegotiateState>>);
115impl Authenticated {
116    fn call<T>(&self, f: impl Fn(&mut ServerContext<Inbound>) -> T) -> T {
117        let mut guard = self.0.write().unwrap();
118        match guard.deref_mut() {
119            NegotiateState::Authenticated(x) => f(x),
120            _ => unreachable!("Authenticated only exists after successful authentication"),
121        }
122    }
123    pub fn client(&mut self) -> Option<String> {
124        self.call(|x| Some(x.client_name().to_string()))
125    }
126}
127impl<S: Sync> FromRequestParts<S> for Authenticated {
128    type Rejection = Infallible;
129    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
130        let auth = get_state_from_extension(parts).0;
131        let Ok(au) = auth.read() else {
132            #[cfg(feature = "tracing")]
133            tracing::error!("Concurrency error, multiple threads accessing the same NegotiateInfo");
134            panic!()
135        };
136        if au.is_authenticated() {
137            Ok(Authenticated(auth.clone()))
138        } else {
139            #[cfg(feature = "tracing")]
140            tracing::error!(r#"NegotiateInfo not authorized. Probably extracted "Authenticated" outside of layer"#);
141            panic!("NegotiateInfo was not authorized. you may have extracted `Authenticated` outside of the layer")
142        }
143    }
144}
145
146fn get_state_from_extension(parts: &Parts) -> (Arc<RwLock<NegotiateState>>, Option<ChannelBindings>) {
147    match parts.extensions.get::<ConnectInfo<NegotiateInfo>>().cloned() {
148        Some(ConnectInfo(NegotiateInfo { auth, channel })) => (auth, channel),
149        None => {
150            #[cfg(feature = "tracing")]
151            tracing::error!("Panicking due to no ConnectInfo given");
152            panic!(
153                "No NegotiateInfo ConnectInfo was given. you may have forgotten to use into_make_service_with_connect_info"
154            )
155        }
156    }
157}
158/// Type that must be set via [`Router::into_make_service_with_connect_info`](axum::Router::into_make_service_with_connect_info).
159///
160/// Without this, the [`NegotiateLayer`] will not work
161#[derive(Clone, Debug, Default)]
162pub struct NegotiateInfo {
163    auth: Arc<RwLock<NegotiateState>>,
164    channel: Option<ChannelBindings>,
165}
166impl Connected<NegotiateInfo> for NegotiateInfo {
167    fn connect_info(value: NegotiateInfo) -> Self {
168        value
169    }
170}
171impl NegotiateInfo {
172    #[must_use]
173    /// You should probably only have to use this if you drive the IO loop yourself instead of using [`axum::serve()`]
174    pub fn new() -> Self {
175        Self::default()
176    }
177    pub fn with_channel<C: Channel>(self, c: &C) -> Result<NegotiateInfo, C::Error> {
178        let channel = match c.channel_bindings() {
179            Err(e) => return Err(e),
180            Ok(bindings) => ChannelBindings(bindings.map(|ar| ar.into())),
181        };
182        Ok(NegotiateInfo {
183            auth: self.auth,
184            channel: Some(channel),
185        })
186    }
187}
188
189#[derive(Debug, Clone)]
190pub struct ChannelBindings(Option<Arc<[u8]>>);
191impl Channel for ChannelBindings {
192    type Error = Infallible;
193    fn channel_bindings(&self) -> Result<Option<Vec<u8>>, Self::Error> {
194        Ok(self.0.as_ref().map(|ar| ar.to_vec()))
195    }
196}
197
198/// [`Layer`] which will enforce authentication
199///
200/// The SPN must be correctly installed in the local realm
201///
202/// Also a [`ConnectInfo`] extension must have been set on the router.
203#[derive(Clone)]
204pub struct NegotiateLayer {
205    spn: Option<String>,
206}
207impl NegotiateLayer {
208    #[must_use]
209    pub fn new(spn: Option<&str>) -> Self {
210        Self {
211            spn: spn.map(ToOwned::to_owned),
212        }
213    }
214}
215impl<S> Layer<S> for NegotiateLayer {
216    type Service = NegotiateMiddleware<S>;
217
218    fn layer(&self, inner: S) -> Self::Service {
219        NegotiateMiddleware::new(inner, self.spn.as_deref())
220    }
221}
222#[derive(Clone)]
223/// Middleware to enforce authentication
224///
225/// A layer may be made from this via [`NegotiateLayer::new`]
226///
227/// This middleware will not work without the [`NegotiateInfo`] [`ConnectInfo`] object.
228/// If there is no such connection information set, this middleware will panic.
229pub struct NegotiateMiddleware<S> {
230    inner: S,
231    spn: Option<String>,
232}
233impl<S> NegotiateMiddleware<S> {
234    #[must_use]
235    pub fn new(service: S, spn: Option<&str>) -> NegotiateMiddleware<S> {
236        let spn = spn.map(ToOwned::to_owned);
237        NegotiateMiddleware { inner: service, spn }
238    }
239}
240impl<S> Service<Request> for NegotiateMiddleware<S>
241where
242    S: Service<Request, Response = Response> + Send + 'static,
243    S::Future: Send + 'static,
244{
245    type Response = S::Response;
246    type Error = S::Error;
247    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
248    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
249        self.inner.poll_ready(cx)
250    }
251    fn call(&mut self, req: Request) -> Self::Future {
252        let (mut parts, body) = req.into_parts();
253        let (auth, channel) = get_state_from_extension(&parts);
254        // If anyone moves this .read() call around remember to not accidentally deadlock
255        // with the write() call below
256        if auth.read().unwrap().deref().is_authenticated() {
257            let request = Request::from_parts(parts, body);
258            return Box::pin(self.inner.call(request));
259        }
260        let token = match extract_token(&parts.headers) {
261            Ok(token) => token,
262            Err(response) => {
263                return Box::pin(async { Ok(response) });
264            }
265        };
266        let mut lock = auth.write().unwrap();
267        let step_result = match std::mem::take(&mut *lock) {
268            NegotiateState::Authenticated(_) => unreachable!(),
269            NegotiateState::Pending(context) => handle_sspi(context, token),
270            NegotiateState::Unauthorized => {
271                #[cfg(feature = "tracing")]
272                tracing::debug!(spn = self.spn.as_deref(), "Getting local SPNEGO credentials");
273                let cred = match Credentials::inbound(self.spn.as_deref(), Mechanism::Spnego) {
274                    Ok(cred) => cred,
275                    Err(_e) => {
276                        #[cfg(feature = "tracing")]
277                        tracing::error!(error = ?_e, "Failed to create credentials handle");
278                        let response = failed_to_create_context().into_response();
279                        return Box::pin(async move { Ok(response) });
280                    }
281                };
282                let builder = ServerBuilder::new_from_credentials(cred).with_mutual_auth();
283                let builder_with_bindings = if let Some(channel) = channel {
284                    #[cfg(feature = "tracing")]
285                    if channel.0.is_some() {
286                        tracing::trace!("Adding channel bindings");
287                    } else {
288                        tracing::warn!("channel bindings provided but were empty");
289                    }
290                    builder.bind_to_channel(&channel).expect("infallible")
291                } else {
292                    builder
293                };
294                handle_sspi(builder_with_bindings, token)
295            }
296        };
297        match step_result {
298            StepResult::Finished(f, maybe_token) => {
299                parts.extensions.insert(Authenticated(auth.clone()));
300                let request = Request::from_parts(parts, body);
301                let next_future = self.inner.call(request);
302                *lock = NegotiateState::Authenticated(f);
303                Box::pin(async move {
304                    let mut response = next_future.await?;
305                    if let Some(token) = maybe_token {
306                        response
307                            .headers_mut()
308                            .append(WWW_AUTHENTICATE, to_negotiate_header(&token));
309                    }
310                    Ok(response)
311                })
312            }
313            StepResult::ContinueWith(server_context, response) => {
314                *lock = NegotiateState::Pending(server_context);
315                Box::pin(async move { Ok(response) })
316            }
317            StepResult::Error(response) => {
318                *lock = NegotiateState::Unauthorized;
319                Box::pin(async { Ok(response) })
320            }
321        }
322    }
323}
324
325fn to_negotiate_header(token_bytes: &[u8]) -> HeaderValue {
326    let encoded = BASE64_STANDARD.encode(token_bytes.as_ref());
327    HeaderValue::from_str(&format!("Negotiate {encoded}")).expect("Base64-string should be valid header material")
328}
329
330enum StepResult {
331    Finished(ServerContext<Inbound>, Option<Box<[u8]>>),
332    ContinueWith(PendingServerContext<Inbound>, Response),
333    Error(Response),
334}
335
336#[allow(clippy::result_large_err)]
337fn extract_token(headers: &HeaderMap) -> Result<&str, Response> {
338    let Some(authorization) = headers.get(AUTHORIZATION) else {
339        return Err(unauthorized("No Authorization given"));
340    };
341    let Some(token) = authorization
342        .to_str()
343        .ok()
344        .and_then(|with_prefix| with_prefix.strip_prefix("Negotiate "))
345    else {
346        return Err(unauthorized("Invalid Authorization Header"));
347    };
348    Ok(token)
349}
350
351fn www_authenticate_map() -> HeaderMap {
352    let mut map = HeaderMap::new();
353    map.insert(WWW_AUTHENTICATE, HeaderValue::from_static("Negotiate"));
354    map.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
355    map
356}
357
358fn unauthorized(message: &str) -> Response {
359    (StatusCode::UNAUTHORIZED, www_authenticate_map(), message.to_owned()).into_response()
360}
361
362fn failed_to_create_context() -> Response {
363    (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
364}