use std::sync::{Arc, Mutex};
use crate::error::Result;
use crate::hostkey::HostKey;
use super::client::Agent;
use super::protocol::{SSH_AGENT_RSA_SHA2_256, SSH_AGENT_RSA_SHA2_512};
pub struct AgentHostKey {
agent: Arc<Mutex<Agent>>,
key_blob: Vec<u8>,
algorithm: &'static str,
flags: u32,
}
impl AgentHostKey {
pub fn from_identity(agent: Arc<Mutex<Agent>>, key_blob: Vec<u8>) -> Result<Self> {
let algo = first_string(&key_blob).ok_or(crate::error::Error::Format(
"agent identity blob: missing algorithm prefix",
))?;
let (algorithm, flags): (&'static str, u32) = match algo.as_str() {
"ssh-ed25519" => ("ssh-ed25519", 0),
"ecdsa-sha2-nistp256" => ("ecdsa-sha2-nistp256", 0),
"ecdsa-sha2-nistp384" => ("ecdsa-sha2-nistp384", 0),
"ecdsa-sha2-nistp521" => ("ecdsa-sha2-nistp521", 0),
"ssh-rsa" => ("rsa-sha2-256", SSH_AGENT_RSA_SHA2_256),
"rsa-sha2-256" => ("rsa-sha2-256", SSH_AGENT_RSA_SHA2_256),
"rsa-sha2-512" => ("rsa-sha2-512", SSH_AGENT_RSA_SHA2_512),
other => {
return Err(crate::error::Error::Format({
let _ = other;
"agent identity blob: unsupported algorithm"
}));
}
};
Ok(Self {
agent,
key_blob,
algorithm,
flags,
})
}
pub fn with_rsa_hash(mut self, hash: RsaHash) -> Self {
match self.algorithm {
"rsa-sha2-256" | "rsa-sha2-512" => {
let (algo, flags) = match hash {
RsaHash::Sha256 => ("rsa-sha2-256", SSH_AGENT_RSA_SHA2_256),
RsaHash::Sha512 => ("rsa-sha2-512", SSH_AGENT_RSA_SHA2_512),
};
self.algorithm = algo;
self.flags = flags;
}
_ => {}
}
self
}
}
#[derive(Debug, Clone, Copy)]
pub enum RsaHash {
Sha256,
Sha512,
}
impl HostKey for AgentHostKey {
fn algorithm(&self) -> &'static str {
self.algorithm
}
fn public_blob(&self) -> Vec<u8> {
self.key_blob.clone()
}
fn sign(&self, msg: &[u8]) -> Result<Vec<u8>> {
let sig_blob = {
let mut agent = self
.agent
.lock()
.map_err(|_| crate::error::Error::Protocol("agent mutex poisoned"))?;
agent.sign(&self.key_blob, msg, self.flags)?
};
if let Some(algo) = first_string(&sig_blob) {
if algo != self.algorithm {
return Err(crate::error::Error::Protocol(
"agent: signature algorithm mismatch",
));
}
}
Ok(sig_blob)
}
}
fn first_string(buf: &[u8]) -> Option<String> {
if buf.len() < 4 {
return None;
}
let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if buf.len() < 4 + len {
return None;
}
Some(String::from_utf8_lossy(&buf[4..4 + len]).into_owned())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_blob(algo: &str) -> Vec<u8> {
let mut v = Vec::new();
v.extend_from_slice(&(algo.len() as u32).to_be_bytes());
v.extend_from_slice(algo.as_bytes());
v.extend_from_slice(b"key-body-placeholder");
v
}
#[test]
fn first_string_handles_short_blob() {
assert_eq!(first_string(b""), None);
assert_eq!(first_string(b"\x00\x00\x00\x05ab"), None); assert_eq!(
first_string(b"\x00\x00\x00\x03foo+padding").as_deref(),
Some("foo"),
);
}
#[test]
fn from_identity_picks_correct_algorithm() {
let blob = make_blob("ssh-ed25519");
assert_eq!(first_string(&blob).as_deref(), Some("ssh-ed25519"));
let blob = make_blob("rsa-sha2-512");
assert_eq!(first_string(&blob).as_deref(), Some("rsa-sha2-512"));
}
}