axum_negotiate/
lib.rs

1#![deny(unsafe_code)]
2#![deny(clippy::unwrap_used)]
3
4use std::{borrow::Borrow, ops::Deref};
5
6use async_trait::async_trait;
7use libgssapi::{
8    context::{SecurityContext, ServerCtx},
9    credential::{Cred, CredUsage},
10    name::Name,
11    oid::{OidSet, GSS_MECH_SPNEGO, GSS_NT_KRB5_PRINCIPAL},
12};
13use log::{debug, error};
14
15use base64::{engine::general_purpose::STANDARD, Engine as _};
16
17use axum_core::{
18    extract::FromRequestParts,
19    response::{IntoResponse, Response},
20};
21use futures_util::future::BoxFuture;
22use http::{
23    header::{AUTHORIZATION, WWW_AUTHENTICATE},
24    request::Parts,
25    HeaderValue, Request, StatusCode,
26};
27use thiserror::Error;
28use tower_layer::Layer;
29use tower_service::Service;
30
31pub trait NextMiddlewareError: std::error::Error + IntoResponse + Send + Sync {
32    fn box_into_response(self: Box<Self>) -> Response;
33}
34impl<T: std::error::Error + IntoResponse + Send + Sync> NextMiddlewareError for T {
35    fn box_into_response(self: Box<Self>) -> Response {
36        self.into_response()
37    }
38}
39pub type NextMiddlewareBoxError = Box<dyn NextMiddlewareError>;
40
41#[derive(Error, Debug)]
42pub enum Error {
43    #[error("invalid characters in spn")]
44    InvalidSpn,
45
46    #[error("next middleware: {0}")]
47    NextMiddleware(NextMiddlewareBoxError),
48
49    #[error("libgssapi: {0}")]
50    GssApi(#[from] libgssapi::error::Error),
51
52    #[error("multistage spnego is requested but currently not supported")]
53    MultipassSpnego,
54
55    #[error("invalid authorization header")]
56    InvalidAuthorizationHeader,
57
58    #[error("invalid gssapi_data")]
59    InvalidGssapiData,
60
61    #[error("UPN extension not found in request")]
62    UpnExtensionNotFound,
63}
64impl IntoResponse for Error {
65    fn into_response(self) -> Response {
66        match self {
67            Self::InvalidSpn | Self::MultipassSpnego | Self::GssApi(_) => {
68                (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
69            }
70            Self::NextMiddleware(error) => error.box_into_response(),
71            Self::InvalidGssapiData => (StatusCode::BAD_REQUEST, "bad request").into_response(),
72            Self::UpnExtensionNotFound | Self::InvalidAuthorizationHeader => {
73                let mut response = (StatusCode::UNAUTHORIZED, "unauthorized").into_response();
74                response
75                    .headers_mut()
76                    .insert(WWW_AUTHENTICATE, HeaderValue::from_static("Negotiate"));
77                response
78            }
79        }
80    }
81}
82
83#[derive(Clone)]
84pub struct NegotiateAuthLayer {
85    spn: String,
86}
87
88impl NegotiateAuthLayer {
89    pub fn new(spn: String) -> Result<Self, Error> {
90        //TODO: check if libgssapi really can't handle utf16 characters. remove the ascii check if
91        //it does.
92        if spn.is_ascii() {
93            Ok(Self { spn })
94        } else {
95            Err(Error::InvalidSpn)
96        }
97    }
98}
99
100impl<I> Layer<I> for NegotiateAuthLayer {
101    type Service = NegotiateAuthLayerMiddleware<I>;
102
103    fn layer(&self, inner: I) -> Self::Service {
104        Self::Service {
105            inner,
106            spn: self.spn.to_owned(),
107        }
108    }
109}
110
111/// The user principal name of the user
112#[derive(Clone, Debug)]
113pub struct Upn(pub Box<str>);
114
115#[async_trait]
116impl<S: Send + Sync> FromRequestParts<S> for Upn {
117    type Rejection = Error;
118
119    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
120        parts
121            .extensions
122            .get::<Self>()
123            .cloned()
124            .ok_or(Error::UpnExtensionNotFound)
125    }
126}
127
128impl AsRef<str> for Upn {
129    fn as_ref(&self) -> &str {
130        &self.0
131    }
132}
133impl Borrow<str> for Upn {
134    fn borrow(&self) -> &str {
135        &self.0
136    }
137}
138impl Deref for Upn {
139    type Target = str;
140
141    fn deref(&self) -> &Self::Target {
142        &self.0
143    }
144}
145
146#[derive(Clone)]
147pub struct NegotiateAuthLayerMiddleware<I> {
148    inner: I,
149    spn: String,
150}
151
152impl<I, B> Service<Request<B>> for NegotiateAuthLayerMiddleware<I>
153where
154    I: Service<Request<B>, Response = Response> + Clone + Send + 'static,
155    I::Error: NextMiddlewareError,
156    I::Future: Send + 'static,
157    B: Send + 'static,
158{
159    type Response = I::Response;
160
161    type Error = Error;
162
163    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
164
165    fn poll_ready(
166        &mut self,
167        cx: &mut std::task::Context<'_>,
168    ) -> std::task::Poll<Result<(), Self::Error>> {
169        self.inner
170            .poll_ready(cx)
171            .map_err(|e| Error::NextMiddleware(Box::new(e)))
172    }
173
174    fn call(&mut self, mut req: Request<B>) -> Self::Future {
175        let inner = self.inner.clone();
176        let mut inner = std::mem::replace(&mut self.inner, inner);
177
178        let spn = self.spn.clone();
179
180        Box::pin(async move {
181            let Some(authorization_header) = req
182                .headers()
183                .get(AUTHORIZATION)
184                .and_then(|x| x.to_str().ok())
185            else {
186                debug!("authorization header not present");
187                return Err(Error::InvalidAuthorizationHeader);
188            };
189
190            let Some(gssapi_data) = authorization_header.strip_prefix("Negotiate ") else {
191                debug!("authorization header has no prefix \"Negotiate\"");
192                return Err(Error::InvalidAuthorizationHeader);
193            };
194
195            let Ok(gssapi_data) = STANDARD.decode(gssapi_data) else {
196                debug!("authorization header gssapi_data contains invalid base64");
197                return Err(Error::InvalidGssapiData);
198            };
199
200            let mut ctx = new_server_ctx(&spn)?;
201
202            let token = ctx.step(&gssapi_data)?;
203
204            if !ctx.is_complete() {
205                error!("currently only 2-pass SPNEGO is supported");
206                return Err(Error::MultipassSpnego);
207            };
208
209            let upn = ctx.source_name()?.to_string();
210            req.extensions_mut().insert(Upn(upn.into()));
211
212            let mut response = inner
213                .call(req)
214                .await
215                .map_err(|x| Error::NextMiddleware(Box::new(x)))?;
216
217            response.headers_mut().insert(
218                WWW_AUTHENTICATE,
219                format!(
220                    "Negotiate {}",
221                    token.map(|x| STANDARD.encode(&*x)).unwrap_or_default()
222                )
223                .parse()
224                .expect("base64 to be ascii"),
225            );
226
227            Ok(response)
228        })
229    }
230}
231
232fn new_server_ctx(principal: &str) -> Result<ServerCtx, Error> {
233    let name = Name::new(principal.as_bytes(), Some(&GSS_NT_KRB5_PRINCIPAL))?
234        .canonicalize(Some(&GSS_MECH_SPNEGO))?;
235    let cred = {
236        let mut s = OidSet::new()?;
237        s.add(&GSS_MECH_SPNEGO)?;
238        Cred::acquire(Some(&name), None, CredUsage::Accept, Some(&s))?
239    };
240    Ok(ServerCtx::new(cred))
241}