1#![deny(unsafe_code)]
2#![deny(clippy::unwrap_used)]
3
4use std::{borrow::Borrow, ops::Deref};
5
6use async_trait::async_trait;
7use libgssapi::{
8 context::{SecurityContext, ServerCtx},
9 credential::{Cred, CredUsage},
10 name::Name,
11 oid::{OidSet, GSS_MECH_SPNEGO, GSS_NT_KRB5_PRINCIPAL},
12};
13use log::{debug, error};
14
15use base64::{engine::general_purpose::STANDARD, Engine as _};
16
17use axum_core::{
18 extract::FromRequestParts,
19 response::{IntoResponse, Response},
20};
21use futures_util::future::BoxFuture;
22use http::{
23 header::{AUTHORIZATION, WWW_AUTHENTICATE},
24 request::Parts,
25 HeaderValue, Request, StatusCode,
26};
27use thiserror::Error;
28use tower_layer::Layer;
29use tower_service::Service;
30
31pub trait NextMiddlewareError: std::error::Error + IntoResponse + Send + Sync {
32 fn box_into_response(self: Box<Self>) -> Response;
33}
34impl<T: std::error::Error + IntoResponse + Send + Sync> NextMiddlewareError for T {
35 fn box_into_response(self: Box<Self>) -> Response {
36 self.into_response()
37 }
38}
39pub type NextMiddlewareBoxError = Box<dyn NextMiddlewareError>;
40
41#[derive(Error, Debug)]
42pub enum Error {
43 #[error("invalid characters in spn")]
44 InvalidSpn,
45
46 #[error("next middleware: {0}")]
47 NextMiddleware(NextMiddlewareBoxError),
48
49 #[error("libgssapi: {0}")]
50 GssApi(#[from] libgssapi::error::Error),
51
52 #[error("multistage spnego is requested but currently not supported")]
53 MultipassSpnego,
54
55 #[error("invalid authorization header")]
56 InvalidAuthorizationHeader,
57
58 #[error("invalid gssapi_data")]
59 InvalidGssapiData,
60
61 #[error("UPN extension not found in request")]
62 UpnExtensionNotFound,
63}
64impl IntoResponse for Error {
65 fn into_response(self) -> Response {
66 match self {
67 Self::InvalidSpn | Self::MultipassSpnego | Self::GssApi(_) => {
68 (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
69 }
70 Self::NextMiddleware(error) => error.box_into_response(),
71 Self::InvalidGssapiData => (StatusCode::BAD_REQUEST, "bad request").into_response(),
72 Self::UpnExtensionNotFound | Self::InvalidAuthorizationHeader => {
73 let mut response = (StatusCode::UNAUTHORIZED, "unauthorized").into_response();
74 response
75 .headers_mut()
76 .insert(WWW_AUTHENTICATE, HeaderValue::from_static("Negotiate"));
77 response
78 }
79 }
80 }
81}
82
83#[derive(Clone)]
84pub struct NegotiateAuthLayer {
85 spn: String,
86}
87
88impl NegotiateAuthLayer {
89 pub fn new(spn: String) -> Result<Self, Error> {
90 if spn.is_ascii() {
93 Ok(Self { spn })
94 } else {
95 Err(Error::InvalidSpn)
96 }
97 }
98}
99
100impl<I> Layer<I> for NegotiateAuthLayer {
101 type Service = NegotiateAuthLayerMiddleware<I>;
102
103 fn layer(&self, inner: I) -> Self::Service {
104 Self::Service {
105 inner,
106 spn: self.spn.to_owned(),
107 }
108 }
109}
110
111#[derive(Clone, Debug)]
113pub struct Upn(pub Box<str>);
114
115#[async_trait]
116impl<S: Send + Sync> FromRequestParts<S> for Upn {
117 type Rejection = Error;
118
119 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
120 parts
121 .extensions
122 .get::<Self>()
123 .cloned()
124 .ok_or(Error::UpnExtensionNotFound)
125 }
126}
127
128impl AsRef<str> for Upn {
129 fn as_ref(&self) -> &str {
130 &self.0
131 }
132}
133impl Borrow<str> for Upn {
134 fn borrow(&self) -> &str {
135 &self.0
136 }
137}
138impl Deref for Upn {
139 type Target = str;
140
141 fn deref(&self) -> &Self::Target {
142 &self.0
143 }
144}
145
146#[derive(Clone)]
147pub struct NegotiateAuthLayerMiddleware<I> {
148 inner: I,
149 spn: String,
150}
151
152impl<I, B> Service<Request<B>> for NegotiateAuthLayerMiddleware<I>
153where
154 I: Service<Request<B>, Response = Response> + Clone + Send + 'static,
155 I::Error: NextMiddlewareError,
156 I::Future: Send + 'static,
157 B: Send + 'static,
158{
159 type Response = I::Response;
160
161 type Error = Error;
162
163 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
164
165 fn poll_ready(
166 &mut self,
167 cx: &mut std::task::Context<'_>,
168 ) -> std::task::Poll<Result<(), Self::Error>> {
169 self.inner
170 .poll_ready(cx)
171 .map_err(|e| Error::NextMiddleware(Box::new(e)))
172 }
173
174 fn call(&mut self, mut req: Request<B>) -> Self::Future {
175 let inner = self.inner.clone();
176 let mut inner = std::mem::replace(&mut self.inner, inner);
177
178 let spn = self.spn.clone();
179
180 Box::pin(async move {
181 let Some(authorization_header) = req
182 .headers()
183 .get(AUTHORIZATION)
184 .and_then(|x| x.to_str().ok())
185 else {
186 debug!("authorization header not present");
187 return Err(Error::InvalidAuthorizationHeader);
188 };
189
190 let Some(gssapi_data) = authorization_header.strip_prefix("Negotiate ") else {
191 debug!("authorization header has no prefix \"Negotiate\"");
192 return Err(Error::InvalidAuthorizationHeader);
193 };
194
195 let Ok(gssapi_data) = STANDARD.decode(gssapi_data) else {
196 debug!("authorization header gssapi_data contains invalid base64");
197 return Err(Error::InvalidGssapiData);
198 };
199
200 let mut ctx = new_server_ctx(&spn)?;
201
202 let token = ctx.step(&gssapi_data)?;
203
204 if !ctx.is_complete() {
205 error!("currently only 2-pass SPNEGO is supported");
206 return Err(Error::MultipassSpnego);
207 };
208
209 let upn = ctx.source_name()?.to_string();
210 req.extensions_mut().insert(Upn(upn.into()));
211
212 let mut response = inner
213 .call(req)
214 .await
215 .map_err(|x| Error::NextMiddleware(Box::new(x)))?;
216
217 response.headers_mut().insert(
218 WWW_AUTHENTICATE,
219 format!(
220 "Negotiate {}",
221 token.map(|x| STANDARD.encode(&*x)).unwrap_or_default()
222 )
223 .parse()
224 .expect("base64 to be ascii"),
225 );
226
227 Ok(response)
228 })
229 }
230}
231
232fn new_server_ctx(principal: &str) -> Result<ServerCtx, Error> {
233 let name = Name::new(principal.as_bytes(), Some(&GSS_NT_KRB5_PRINCIPAL))?
234 .canonicalize(Some(&GSS_MECH_SPNEGO))?;
235 let cred = {
236 let mut s = OidSet::new()?;
237 s.add(&GSS_MECH_SPNEGO)?;
238 Cred::acquire(Some(&name), None, CredUsage::Accept, Some(&s))?
239 };
240 Ok(ServerCtx::new(cred))
241}