use crate::session::authenticator::Authenticator;
use super::*;
pub(crate) struct SessionSetup<'a, T>
where
T: SessionSetupProperties,
{
last_setup_response: Option<SessionSetupResponse>,
flags: Option<SessionFlags>,
handler: Option<ChannelMessageHandler>,
preauth_hash: Option<PreauthHashState>,
result: Option<Arc<RwLock<SessionAndChannel>>>,
authenticator: Authenticator,
upstream: &'a ChannelUpstream,
conn_info: &'a Arc<ConnectionInfo>,
channel: Option<ChannelInfo>,
new_channel_id: u32,
_phantom: std::marker::PhantomData<T>,
}
#[maybe_async]
impl<'a, T> SessionSetup<'a, T>
where
T: SessionSetupProperties,
{
pub async fn new(
identity: sspi::AuthIdentity,
upstream: &'a ChannelUpstream,
conn_info: &'a Arc<ConnectionInfo>,
new_channel_id: u32,
primary_session: Option<&Arc<RwLock<SessionAndChannel>>>,
) -> crate::Result<Self> {
let authenticator = Authenticator::build(identity, conn_info)?;
let mut result = Self {
last_setup_response: None,
flags: None,
result: None,
handler: None,
preauth_hash: Some(conn_info.preauth_hash.clone()),
authenticator,
upstream,
conn_info,
channel: None,
new_channel_id,
_phantom: std::marker::PhantomData,
};
if let Some(primary_session) = primary_session {
let primary_session = primary_session.read().await?;
let session = primary_session.session.clone();
let channel = primary_session
.channel
.as_ref()
.expect("A properly initialized session is expected in session setup.")
.clone();
#[cfg(feature = "ksmbd-multichannel-compat")]
let channel = channel.with_binding(true);
result.set_session(session).await?;
result
.result
.as_ref()
.expect("Should have been set up by set_session()")
.write()
.await?
.channel = Some(channel);
}
Ok(result)
}
pub(crate) async fn setup(&mut self) -> crate::Result<Arc<RwLock<SessionAndChannel>>> {
log::debug!(
"Setting up session for user {} (@{}).",
self.authenticator.user_name().account_name(),
self.authenticator.user_name().domain_name().unwrap_or("")
);
let result = self._setup_loop().await;
match result {
Ok(()) => Ok(self.result.take().unwrap()),
Err(e) => {
log::error!("Failed to setup session: {}", e);
if let Err(ce) = T::error_cleanup(self).await {
log::error!("Failed to cleanup after setup error: {}", ce);
}
Err(e)
}
}
}
async fn _setup_loop(&mut self) -> crate::Result<()> {
while !self.authenticator.is_authenticated()? {
let next_buf = match self.last_setup_response.as_ref() {
Some(response) => self.authenticator.next(&response.buffer).await?,
None => self.authenticator.next(&[]).await?,
};
let is_auth_done = self.authenticator.is_authenticated()?;
let request = self.send_setup_request(next_buf).await?;
if is_auth_done {
self.preauth_hash = self.preauth_hash.take().unwrap().finish().into();
self.make_channel().await?;
}
let response = self.receive_setup_response(request.msg_id).await?;
let message_form = response.form;
let session_id = response.message.header.session_id;
let session_setup_response = response.message.content.to_sessionsetup()?;
if self.result.is_none() {
log::trace!("Creating session state with id {session_id}.");
self.set_session(T::init_session(self, session_id).await?)
.await?;
}
if is_auth_done {
if !session_setup_response
.session_flags
.is_guest_or_null_session()
&& !message_form.signed_or_encrypted()
{
return Err(Error::InvalidMessage(
"Expected a signed message!".to_string(),
));
}
} else {
self.next_preauth_hash(&response.raw);
}
self.flags = Some(session_setup_response.session_flags);
self.last_setup_response = Some(session_setup_response)
}
self.flags.ok_or(Error::InvalidState(
"Failed to complete authentication properly.".to_string(),
))?;
log::trace!("setup success, finishing up.");
T::on_setup_success(self).await?;
Ok(())
}
async fn set_session(&mut self, session: Arc<RwLock<SessionInfo>>) -> crate::Result<()> {
let session_id = session.read().await?.id();
let result = SessionAndChannel::new(session_id, session);
let session = Arc::new(RwLock::new(result));
let setup_handler = ChannelMessageHandler::make_for_setup(&session, self.upstream).await?;
self.handler = Some(setup_handler);
self.upstream
.worker()
.ok_or_else(|| Error::InvalidState("Worker not available!".to_string()))
.unwrap()
.session_started(&session)
.await?;
self.result = Some(session);
Ok(())
}
async fn receive_setup_response(&mut self, for_msg_id: u64) -> crate::Result<IncomingMessage> {
let is_auth_done = self.authenticator.is_authenticated()?;
let expected_status = if is_auth_done {
&[Status::Success]
} else {
&[Status::MoreProcessingRequired]
};
let roptions = ReceiveOptions::new()
.with_status(expected_status)
.with_msg_id_filter(for_msg_id);
let channel_set_up = self.result.is_some()
&& self
.result
.as_ref()
.unwrap()
.read()
.await?
.channel
.is_some();
let skip_security_validation = !is_auth_done && !channel_set_up;
if let Some(handler) = &self.handler {
log::trace!(
"setup loop: receiving with channel handler; skip_security_validation={skip_security_validation}"
);
handler
.recvo_internal(roptions, skip_security_validation)
.await
} else {
assert!(skip_security_validation);
log::trace!("setup loop: receiving with upstream handler");
self.upstream.handler.recvo(roptions).await
}
}
async fn send_setup_request(&mut self, buf: Vec<u8>) -> crate::Result<SendMessageResult> {
let request = T::make_request(self, buf).await?;
let send_result = if let Some(handler) = self.handler.as_ref() {
log::trace!("setup loop: sending with channel handler");
handler.sendo(request).await?
} else {
log::trace!("setup loop: sending with upstream handler");
self.upstream.sendo(request).await?
};
self.next_preauth_hash(send_result.raw.as_ref().unwrap());
Ok(send_result)
}
async fn make_channel(&mut self) -> crate::Result<()> {
T::on_session_key_exchanged(self).await?;
log::trace!("Session keys are set.");
let channel_info = ChannelInfo::new(
self.new_channel_id,
&self.session_key()?,
&self.preauth_hash_value(),
self.conn_info,
)?;
self.channel = Some(channel_info);
let mut session_lock = self.result.as_ref().unwrap().write().await?;
session_lock.set_channel(self.channel.take().unwrap());
log::trace!("Channel for current setup has been initialized");
Ok(())
}
fn session_key(&self) -> crate::Result<KeyToDerive> {
self.authenticator.session_key()
}
fn preauth_hash_value(&self) -> Option<PreauthHashValue> {
self.preauth_hash
.as_ref()
.unwrap()
.unwrap_final_hash()
.copied()
}
fn next_preauth_hash(&mut self, data: &IoVec) -> &PreauthHashState {
if let Some(ref mut hash) = self.preauth_hash {
*hash = hash.clone().next(data);
}
self.preauth_hash.as_ref().unwrap()
}
pub fn upstream(&self) -> &'a ChannelUpstream {
self.upstream
}
pub fn conn_info(&self) -> &'a Arc<ConnectionInfo> {
self.conn_info
}
}
#[maybe_async(AFIT)]
pub(crate) trait SessionSetupProperties {
async fn error_cleanup<T>(setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties;
fn _make_default_request(buffer: Vec<u8>, dfs: bool) -> OutgoingMessage {
OutgoingMessage::new(
SessionSetupRequest::new(
buffer,
SessionSecurityMode::new().with_signing_enabled(true),
SetupRequestFlags::new(),
NegotiateCapabilities::new().with_dfs(dfs),
)
.into(),
)
.with_return_raw_data(true)
}
async fn make_request<T>(
_setup: &mut SessionSetup<'_, T>,
buffer: Vec<u8>,
) -> crate::Result<OutgoingMessage>
where
T: SessionSetupProperties,
{
let has_dfs = _setup.conn_info().negotiation.caps.dfs();
Ok(Self::_make_default_request(buffer, has_dfs))
}
async fn init_session<T>(
_setup: &'_ SessionSetup<'_, T>,
_session_id: u64,
) -> crate::Result<Arc<RwLock<SessionInfo>>>
where
T: SessionSetupProperties;
async fn on_session_key_exchanged<T>(_setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties,
{
Ok(())
}
async fn on_setup_success<T>(_setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties;
}
pub(crate) struct SmbSessionBind;
#[maybe_async(AFIT)]
impl SessionSetupProperties for SmbSessionBind {
async fn make_request<T>(
_setup: &mut SessionSetup<'_, T>,
buffer: Vec<u8>,
) -> crate::Result<OutgoingMessage>
where
T: SessionSetupProperties,
{
let has_dfs = _setup.conn_info().negotiation.caps.dfs();
let mut request = Self::_make_default_request(buffer, has_dfs);
request
.message
.content
.as_mut_sessionsetup()
.unwrap()
.flags
.set_binding(true);
Ok(request)
}
async fn error_cleanup<T>(setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties,
{
if setup.result.is_none() {
log::warn!("No session to cleanup in binding.");
return Ok(());
}
setup
.upstream
.worker()
.ok_or_else(|| Error::InvalidState("Worker not available!".to_string()))?
.session_ended(setup.result.as_ref().unwrap())
.await
}
async fn init_session<T>(
_setup: &SessionSetup<'_, T>,
_session_id: u64,
) -> crate::Result<Arc<RwLock<SessionInfo>>>
where
T: SessionSetupProperties,
{
panic!("(Primary) Session should be provided in construction, rather than during setup!");
}
async fn on_setup_success<T>(_setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties,
{
Ok(())
}
}
pub(crate) struct SmbSessionNew;
#[maybe_async(AFIT)]
impl SessionSetupProperties for SmbSessionNew {
async fn error_cleanup<T>(setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties,
{
if setup.result.is_none() {
log::trace!("No session to cleanup in setup.");
return Ok(());
}
log::trace!("Invalidating session before cleanup.");
let session = setup.result.as_ref().unwrap();
{
let session_lock = session.read().await?;
session_lock.session.write().await?.invalidate();
}
setup
.upstream
.worker()
.ok_or_else(|| Error::InvalidState("Worker not available!".to_string()))?
.session_ended(setup.result.as_ref().unwrap())
.await
}
async fn on_session_key_exchanged<T>(setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties,
{
log::trace!("Session keys exchanged. Setting up session state.");
setup
.result
.as_ref()
.unwrap()
.read()
.await?
.session
.write()
.await?
.setup(
&setup.session_key()?,
&setup.preauth_hash_value(),
setup.conn_info,
)
}
async fn on_setup_success<T>(setup: &mut SessionSetup<'_, T>) -> crate::Result<()>
where
T: SessionSetupProperties,
{
log::trace!("Session setup successful");
let result = setup.result.as_ref().unwrap().read().await?;
let mut session = result.session.write().await?;
session.ready(setup.flags.unwrap(), setup.conn_info)
}
async fn init_session<T>(
_setup: &SessionSetup<'_, T>,
session_id: u64,
) -> crate::Result<Arc<RwLock<SessionInfo>>>
where
T: SessionSetupProperties,
{
let session_info = SessionInfo::new(session_id);
let session_info = Arc::new(RwLock::new(session_info));
Ok(session_info)
}
}