1use axum::{
56 extract::{ConnectInfo, FromRequestParts, Request, connect_info::Connected},
57 http::{
58 HeaderMap, HeaderValue, StatusCode,
59 header::{AUTHORIZATION, CONNECTION, WWW_AUTHENTICATE},
60 request::Parts,
61 },
62 response::{IntoResponse, Response},
63};
64use base64::{Engine, prelude::BASE64_STANDARD};
65use futures_util::future::BoxFuture;
66use kenobi::{
67 channel_bindings::Channel,
68 cred::{Credentials, Inbound},
69 mech::Mechanism,
70 server::{PendingServerContext, ServerBuilder, ServerContext},
71};
72use sspi::handle_sspi;
73use std::{
74 convert::Infallible,
75 fmt::Debug,
76 ops::{Deref, DerefMut},
77 sync::{Arc, RwLock},
78 task::Poll,
79};
80use tower::{Layer, Service};
81
82#[cfg(feature = "http1")]
83mod listener;
84mod sspi;
85#[cfg(feature = "http1")]
86pub use listener::{HasNegotiateInfo, Negotiator, WithNegotiateInfo};
87
88#[derive(Default)]
89enum NegotiateState {
90 #[default]
91 Unauthorized,
92 Pending(PendingServerContext<Inbound>),
93 Authenticated(ServerContext<Inbound>),
94}
95impl NegotiateState {
96 fn is_authenticated(&self) -> bool {
97 matches!(self, Self::Authenticated(_))
98 }
99}
100impl Debug for NegotiateState {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 match self {
103 Self::Authenticated(_) => f.write_str("Authenticated"),
104 Self::Pending(_) => f.write_str("Pending"),
105 Self::Unauthorized => f.write_str("Unauthenticated"),
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
114pub struct Authenticated(Arc<RwLock<NegotiateState>>);
115impl Authenticated {
116 fn call<T>(&self, f: impl Fn(&mut ServerContext<Inbound>) -> T) -> T {
117 let mut guard = self.0.write().unwrap();
118 match guard.deref_mut() {
119 NegotiateState::Authenticated(x) => f(x),
120 _ => unreachable!("Authenticated only exists after successful authentication"),
121 }
122 }
123 pub fn client(&mut self) -> Option<String> {
124 self.call(|x| Some(x.client_name().to_string()))
125 }
126}
127impl<S: Sync> FromRequestParts<S> for Authenticated {
128 type Rejection = Infallible;
129 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
130 let auth = get_state_from_extension(parts).0;
131 let Ok(au) = auth.read() else {
132 #[cfg(feature = "tracing")]
133 tracing::error!("Concurrency error, multiple threads accessing the same NegotiateInfo");
134 panic!()
135 };
136 if au.is_authenticated() {
137 Ok(Authenticated(auth.clone()))
138 } else {
139 #[cfg(feature = "tracing")]
140 tracing::error!(r#"NegotiateInfo not authorized. Probably extracted "Authenticated" outside of layer"#);
141 panic!("NegotiateInfo was not authorized. you may have extracted `Authenticated` outside of the layer")
142 }
143 }
144}
145
146fn get_state_from_extension(parts: &Parts) -> (Arc<RwLock<NegotiateState>>, Option<ChannelBindings>) {
147 match parts.extensions.get::<ConnectInfo<NegotiateInfo>>().cloned() {
148 Some(ConnectInfo(NegotiateInfo { auth, channel })) => (auth, channel),
149 None => {
150 #[cfg(feature = "tracing")]
151 tracing::error!("Panicking due to no ConnectInfo given");
152 panic!(
153 "No NegotiateInfo ConnectInfo was given. you may have forgotten to use into_make_service_with_connect_info"
154 )
155 }
156 }
157}
158#[derive(Clone, Debug, Default)]
162pub struct NegotiateInfo {
163 auth: Arc<RwLock<NegotiateState>>,
164 channel: Option<ChannelBindings>,
165}
166impl Connected<NegotiateInfo> for NegotiateInfo {
167 fn connect_info(value: NegotiateInfo) -> Self {
168 value
169 }
170}
171impl NegotiateInfo {
172 #[must_use]
173 pub fn new() -> Self {
175 Self::default()
176 }
177 pub fn with_channel<C: Channel>(self, c: &C) -> Result<NegotiateInfo, C::Error> {
178 let channel = match c.channel_bindings() {
179 Err(e) => return Err(e),
180 Ok(bindings) => ChannelBindings(bindings.map(|ar| ar.into())),
181 };
182 Ok(NegotiateInfo {
183 auth: self.auth,
184 channel: Some(channel),
185 })
186 }
187}
188
189#[derive(Debug, Clone)]
190pub struct ChannelBindings(Option<Arc<[u8]>>);
191impl Channel for ChannelBindings {
192 type Error = Infallible;
193 fn channel_bindings(&self) -> Result<Option<Vec<u8>>, Self::Error> {
194 Ok(self.0.as_ref().map(|ar| ar.to_vec()))
195 }
196}
197
198#[derive(Clone)]
204pub struct NegotiateLayer {
205 spn: Option<String>,
206}
207impl NegotiateLayer {
208 #[must_use]
209 pub fn new(spn: Option<&str>) -> Self {
210 Self {
211 spn: spn.map(ToOwned::to_owned),
212 }
213 }
214}
215impl<S> Layer<S> for NegotiateLayer {
216 type Service = NegotiateMiddleware<S>;
217
218 fn layer(&self, inner: S) -> Self::Service {
219 NegotiateMiddleware::new(inner, self.spn.as_deref())
220 }
221}
222#[derive(Clone)]
223pub struct NegotiateMiddleware<S> {
230 inner: S,
231 spn: Option<String>,
232}
233impl<S> NegotiateMiddleware<S> {
234 #[must_use]
235 pub fn new(service: S, spn: Option<&str>) -> NegotiateMiddleware<S> {
236 let spn = spn.map(ToOwned::to_owned);
237 NegotiateMiddleware { inner: service, spn }
238 }
239}
240impl<S> Service<Request> for NegotiateMiddleware<S>
241where
242 S: Service<Request, Response = Response> + Send + 'static,
243 S::Future: Send + 'static,
244{
245 type Response = S::Response;
246 type Error = S::Error;
247 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
248 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
249 self.inner.poll_ready(cx)
250 }
251 fn call(&mut self, req: Request) -> Self::Future {
252 let (mut parts, body) = req.into_parts();
253 let (auth, channel) = get_state_from_extension(&parts);
254 if auth.read().unwrap().deref().is_authenticated() {
257 let request = Request::from_parts(parts, body);
258 return Box::pin(self.inner.call(request));
259 }
260 let token = match extract_token(&parts.headers) {
261 Ok(token) => token,
262 Err(response) => {
263 return Box::pin(async { Ok(response) });
264 }
265 };
266 let mut lock = auth.write().unwrap();
267 let step_result = match std::mem::take(&mut *lock) {
268 NegotiateState::Authenticated(_) => unreachable!(),
269 NegotiateState::Pending(context) => handle_sspi(context, token),
270 NegotiateState::Unauthorized => {
271 #[cfg(feature = "tracing")]
272 tracing::debug!(spn = self.spn.as_deref(), "Getting local SPNEGO credentials");
273 let cred = match Credentials::inbound(self.spn.as_deref(), Mechanism::Spnego) {
274 Ok(cred) => cred,
275 Err(_e) => {
276 #[cfg(feature = "tracing")]
277 tracing::error!(error = ?_e, "Failed to create credentials handle");
278 let response = failed_to_create_context().into_response();
279 return Box::pin(async move { Ok(response) });
280 }
281 };
282 let builder = ServerBuilder::new_from_credentials(cred).with_mutual_auth();
283 let builder_with_bindings = if let Some(channel) = channel {
284 #[cfg(feature = "tracing")]
285 if channel.0.is_some() {
286 tracing::trace!("Adding channel bindings");
287 } else {
288 tracing::warn!("channel bindings provided but were empty");
289 }
290 builder.bind_to_channel(&channel).expect("infallible")
291 } else {
292 builder
293 };
294 handle_sspi(builder_with_bindings, token)
295 }
296 };
297 match step_result {
298 StepResult::Finished(f, maybe_token) => {
299 parts.extensions.insert(Authenticated(auth.clone()));
300 let request = Request::from_parts(parts, body);
301 let next_future = self.inner.call(request);
302 *lock = NegotiateState::Authenticated(f);
303 Box::pin(async move {
304 let mut response = next_future.await?;
305 if let Some(token) = maybe_token {
306 response
307 .headers_mut()
308 .append(WWW_AUTHENTICATE, to_negotiate_header(&token));
309 }
310 Ok(response)
311 })
312 }
313 StepResult::ContinueWith(server_context, response) => {
314 *lock = NegotiateState::Pending(server_context);
315 Box::pin(async move { Ok(response) })
316 }
317 StepResult::Error(response) => {
318 *lock = NegotiateState::Unauthorized;
319 Box::pin(async { Ok(response) })
320 }
321 }
322 }
323}
324
325fn to_negotiate_header(token_bytes: &[u8]) -> HeaderValue {
326 let encoded = BASE64_STANDARD.encode(token_bytes.as_ref());
327 HeaderValue::from_str(&format!("Negotiate {encoded}")).expect("Base64-string should be valid header material")
328}
329
330enum StepResult {
331 Finished(ServerContext<Inbound>, Option<Box<[u8]>>),
332 ContinueWith(PendingServerContext<Inbound>, Response),
333 Error(Response),
334}
335
336#[allow(clippy::result_large_err)]
337fn extract_token(headers: &HeaderMap) -> Result<&str, Response> {
338 let Some(authorization) = headers.get(AUTHORIZATION) else {
339 return Err(unauthorized("No Authorization given"));
340 };
341 let Some(token) = authorization
342 .to_str()
343 .ok()
344 .and_then(|with_prefix| with_prefix.strip_prefix("Negotiate "))
345 else {
346 return Err(unauthorized("Invalid Authorization Header"));
347 };
348 Ok(token)
349}
350
351fn www_authenticate_map() -> HeaderMap {
352 let mut map = HeaderMap::new();
353 map.insert(WWW_AUTHENTICATE, HeaderValue::from_static("Negotiate"));
354 map.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
355 map
356}
357
358fn unauthorized(message: &str) -> Response {
359 (StatusCode::UNAUTHORIZED, www_authenticate_map(), message.to_owned()).into_response()
360}
361
362fn failed_to_create_context() -> Response {
363 (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
364}