use alloc::boxed::Box;
use alloc::collections::VecDeque;
use alloc::string::String;
use alloc::vec::Vec;
use crate::error::{Error, Result};
use super::message::{
AuthMethodPayload, ServiceAccept, ServiceRequest, UserauthBanner, UserauthFailure,
UserauthInfoRequest, UserauthInfoResponse, UserauthPkOk, UserauthRequest,
SSH_MSG_SERVICE_ACCEPT, SSH_MSG_USERAUTH_BANNER, SSH_MSG_USERAUTH_FAILURE,
SSH_MSG_USERAUTH_PK_OK, SSH_MSG_USERAUTH_SUCCESS,
};
pub trait KeyboardInteractiveResponder: Send {
fn respond(&mut self, name: &str, instruction: &str, prompts: &[(String, bool)])
-> Vec<String>;
}
pub enum ClientCredential {
None,
Password(String),
PublicKey(Box<dyn crate::hostkey::HostKey>),
KeyboardInteractive(Box<dyn KeyboardInteractiveResponder>),
}
impl ClientCredential {
fn method_name(&self) -> &'static str {
match self {
ClientCredential::None => "none",
ClientCredential::Password(_) => "password",
ClientCredential::PublicKey(_) => "publickey",
ClientCredential::KeyboardInteractive(_) => "keyboard-interactive",
}
}
}
pub enum ClientStep {
Send(Vec<u8>),
Success,
Failed {
continuations: Vec<String>,
partial_success: bool,
},
Banner {
message: String,
language: String,
},
Idle,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Initial,
AwaitingServiceAccept,
AwaitingPkOk,
AwaitingPkResult,
AwaitingPasswordResult,
AwaitingNoneResult,
AwaitingKbdintResult,
Done,
}
pub struct ClientAuth {
user: String,
service: &'static str,
session_id: Vec<u8>,
credentials: VecDeque<ClientCredential>,
current: Option<ClientCredential>,
server_continuations: Vec<String>,
last_partial_success: bool,
state: State,
}
impl ClientAuth {
pub fn new(user: impl Into<String>, session_id: Vec<u8>) -> Self {
Self {
user: user.into(),
service: "ssh-connection",
session_id,
credentials: VecDeque::new(),
current: None,
server_continuations: Vec::new(),
last_partial_success: false,
state: State::Initial,
}
}
pub fn add_credential(&mut self, cred: ClientCredential) {
self.credentials.push_back(cred);
}
pub fn start(&mut self) -> Vec<u8> {
self.state = State::AwaitingServiceAccept;
ServiceRequest {
service: "ssh-userauth".into(),
}
.encode()
}
pub fn on_packet(&mut self, payload: &[u8]) -> Result<ClientStep> {
if payload.is_empty() {
return Err(Error::Format("auth: empty payload"));
}
let msg_type = payload[0];
if msg_type == SSH_MSG_USERAUTH_BANNER {
let banner = UserauthBanner::decode(payload)?;
return Ok(ClientStep::Banner {
message: banner.message,
language: banner.language,
});
}
match self.state {
State::Initial => Err(Error::Protocol("auth: client not started")),
State::AwaitingServiceAccept => {
if msg_type != SSH_MSG_SERVICE_ACCEPT {
return Err(Error::Protocol("auth: expected SERVICE_ACCEPT"));
}
let accept = ServiceAccept::decode(payload)?;
if accept.service != "ssh-userauth" {
return Err(Error::Protocol("auth: wrong service accepted"));
}
self.advance_to_next_credential()
}
State::AwaitingPkOk => self.on_pk_probe_reply(payload),
State::AwaitingPkResult => self.on_auth_result(payload),
State::AwaitingPasswordResult => self.on_auth_result(payload),
State::AwaitingNoneResult => self.on_auth_result(payload),
State::AwaitingKbdintResult => self.on_kbdint_reply(payload),
State::Done => Ok(ClientStep::Idle),
}
}
fn advance_to_next_credential(&mut self) -> Result<ClientStep> {
loop {
let cred = match self.credentials.pop_front() {
Some(c) => c,
None => {
self.state = State::Done;
return Ok(ClientStep::Failed {
continuations: core::mem::take(&mut self.server_continuations),
partial_success: self.last_partial_success,
});
}
};
if !self.server_allows(&cred) {
continue;
}
self.current = Some(cred);
return self.emit_current_request();
}
}
fn server_allows(&self, cred: &ClientCredential) -> bool {
if self.server_continuations.is_empty() {
return true;
}
let name = cred.method_name();
self.server_continuations.iter().any(|m| m == name)
}
fn emit_current_request(&mut self) -> Result<ClientStep> {
let cred = self
.current
.as_ref()
.ok_or(Error::Protocol("auth: no current credential"))?;
let (method, next_state) = match cred {
ClientCredential::None => (AuthMethodPayload::None, State::AwaitingNoneResult),
ClientCredential::Password(pw) => (
AuthMethodPayload::Password {
new_password: None,
password: pw.clone(),
},
State::AwaitingPasswordResult,
),
ClientCredential::PublicKey(hk) => (
AuthMethodPayload::PublicKey {
signature_present: false,
algorithm: hk.algorithm().into(),
public_blob: hk.public_blob(),
signature: None,
},
State::AwaitingPkOk,
),
ClientCredential::KeyboardInteractive(_) => (
AuthMethodPayload::KeyboardInteractive {
language_tag: String::new(),
submethods: String::new(),
},
State::AwaitingKbdintResult,
),
};
let req = UserauthRequest {
user: self.user.clone(),
service: self.service.into(),
method,
};
self.state = next_state;
Ok(ClientStep::Send(req.encode()))
}
fn on_pk_probe_reply(&mut self, payload: &[u8]) -> Result<ClientStep> {
let msg_type = payload[0];
if msg_type == SSH_MSG_USERAUTH_PK_OK {
let pk_ok = UserauthPkOk::decode(payload)?;
self.send_pk_signed(&pk_ok)
} else if msg_type == SSH_MSG_USERAUTH_FAILURE {
self.on_auth_result(payload)
} else if msg_type == SSH_MSG_USERAUTH_SUCCESS {
self.state = State::Done;
self.current = None;
Ok(ClientStep::Success)
} else {
Err(Error::Protocol(
"auth: unexpected packet after publickey probe",
))
}
}
fn send_pk_signed(&mut self, pk_ok: &UserauthPkOk) -> Result<ClientStep> {
let cred = self
.current
.as_ref()
.ok_or(Error::Protocol("auth: pk-ok without current credential"))?;
let hk = match cred {
ClientCredential::PublicKey(hk) => hk,
_ => return Err(Error::Protocol("auth: pk-ok for non-publickey credential")),
};
if hk.algorithm() != pk_ok.algorithm {
return Err(Error::Protocol("auth: pk-ok algorithm mismatch"));
}
let public_blob = hk.public_blob();
if public_blob != pk_ok.public_blob {
return Err(Error::Protocol("auth: pk-ok public-key mismatch"));
}
let signed = super::message::publickey_signed_data(
&self.session_id,
&self.user,
self.service,
hk.algorithm(),
&public_blob,
);
let signature = hk.sign(&signed)?;
let req = UserauthRequest {
user: self.user.clone(),
service: self.service.into(),
method: AuthMethodPayload::PublicKey {
signature_present: true,
algorithm: hk.algorithm().into(),
public_blob,
signature: Some(signature),
},
};
self.state = State::AwaitingPkResult;
Ok(ClientStep::Send(req.encode()))
}
fn on_auth_result(&mut self, payload: &[u8]) -> Result<ClientStep> {
let msg_type = payload[0];
if msg_type == SSH_MSG_USERAUTH_SUCCESS {
super::message::decode_success(payload)?;
self.state = State::Done;
self.current = None;
Ok(ClientStep::Success)
} else if msg_type == SSH_MSG_USERAUTH_FAILURE {
let failure = UserauthFailure::decode(payload)?;
self.server_continuations = failure.continuations;
self.last_partial_success = failure.partial_success;
self.current = None;
self.advance_to_next_credential()
} else {
Err(Error::Protocol("auth: unexpected packet for auth result"))
}
}
fn on_kbdint_reply(&mut self, payload: &[u8]) -> Result<ClientStep> {
let msg_type = payload[0];
match msg_type {
SSH_MSG_USERAUTH_SUCCESS => {
super::message::decode_success(payload)?;
self.state = State::Done;
self.current = None;
Ok(ClientStep::Success)
}
SSH_MSG_USERAUTH_FAILURE => {
let failure = UserauthFailure::decode(payload)?;
self.server_continuations = failure.continuations;
self.last_partial_success = failure.partial_success;
self.current = None;
self.advance_to_next_credential()
}
60 => {
let info = UserauthInfoRequest::decode(payload)?;
let cred = self
.current
.as_mut()
.ok_or(Error::Protocol("auth: kbdint without current credential"))?;
let responder = match cred {
ClientCredential::KeyboardInteractive(r) => r,
_ => return Err(Error::Protocol("auth: kbdint reply on wrong credential")),
};
let responses = responder.respond(&info.name, &info.instruction, &info.prompts);
if responses.len() != info.prompts.len() {
return Err(Error::Protocol("auth: wrong number of kbdint responses"));
}
let resp = UserauthInfoResponse { responses };
Ok(ClientStep::Send(resp.encode()))
}
_ => Err(Error::Protocol("auth: unexpected packet in kbdint")),
}
}
}