use std::fs;
use std::io;
use std::path::Path;
use purecrypto::rng::{OsRng, RngCore};
use super::format::{
format_entry, format_host_pattern, parse_line, patterns_match, Entry, HostSpec, Marker,
ParsedLine,
};
use super::hash::{check_hashed, hash_new, parse_hashed};
#[derive(Debug)]
pub enum LookupResult {
Match,
Mismatch {
expected: Vec<(String, Vec<u8>)>,
},
Unknown,
}
pub struct KnownHosts {
lines: Vec<Slot>,
}
enum Slot {
Verbatim(String),
Entry(Entry),
Removed,
}
impl Default for KnownHosts {
fn default() -> Self {
Self::new()
}
}
impl KnownHosts {
pub fn new() -> Self {
Self { lines: Vec::new() }
}
pub fn from_bytes(data: &[u8]) -> Self {
let mut out = Self::new();
for raw in std::str::from_utf8(data).unwrap_or("").lines() {
match parse_line(raw) {
ParsedLine::Entry(e) => out.lines.push(Slot::Entry(e)),
ParsedLine::Verbatim(s) => out.lines.push(Slot::Verbatim(s)),
}
}
out
}
pub fn load(path: impl AsRef<Path>) -> io::Result<Self> {
match fs::read(path) {
Ok(bytes) => Ok(Self::from_bytes(&bytes)),
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(Self::new()),
Err(e) => Err(e),
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::new();
for slot in &self.lines {
match slot {
Slot::Verbatim(s) => {
out.extend_from_slice(s.as_bytes());
out.push(b'\n');
}
Slot::Entry(e) => {
out.extend_from_slice(format_entry(e).as_bytes());
out.push(b'\n');
}
Slot::Removed => {}
}
}
out
}
pub fn save(&self, path: impl AsRef<Path>) -> io::Result<()> {
let path = path.as_ref();
let tmp = path.with_extension({
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
if ext.is_empty() {
"tmp".to_string()
} else {
format!("{ext}.tmp")
}
});
write_private_file(&tmp, &self.to_bytes())?;
fs::rename(&tmp, path)?;
Ok(())
}
pub fn lookup(&self, host: &str, port: u16, key_type: &str, key_blob: &[u8]) -> LookupResult {
for slot in &self.lines {
if let Slot::Entry(e) = slot {
if e.marker == Some(Marker::Revoked)
&& host_field_matches(&e.host_spec, host, port)
&& e.key_type == key_type
&& e.key_blob == key_blob
{
return LookupResult::Mismatch {
expected: vec![(e.key_type.clone(), e.key_blob.clone())],
};
}
}
}
let mut host_matched = false;
let mut expected: Vec<(String, Vec<u8>)> = Vec::new();
for slot in &self.lines {
let e = match slot {
Slot::Entry(e) => e,
_ => continue,
};
if !host_field_matches(&e.host_spec, host, port) {
continue;
}
if e.marker == Some(Marker::Revoked) {
continue;
}
host_matched = true;
if e.key_type == key_type && e.key_blob == key_blob {
return LookupResult::Match;
}
expected.push((e.key_type.clone(), e.key_blob.clone()));
}
if host_matched {
LookupResult::Mismatch { expected }
} else {
LookupResult::Unknown
}
}
pub fn add(&mut self, host: &str, port: u16, key_type: &str, key_blob: &[u8], hashed: bool) {
let host_spec = if hashed {
let mut rng = OsRng;
let (_salt, token) = hash_new(&mut rng, host, port);
HostSpec::Hashed(token)
} else {
HostSpec::Patterns(vec![format_host_pattern(host, port)])
};
self.lines.push(Slot::Entry(Entry {
marker: None,
host_spec,
key_type: key_type.to_string(),
key_blob: key_blob.to_vec(),
comment: String::new(),
}));
}
pub fn remove(&mut self, host: &str, port: u16) -> usize {
let mut removed = 0usize;
for slot in self.lines.iter_mut() {
if let Slot::Entry(e) = slot {
if host_field_matches(&e.host_spec, host, port) {
removed += 1;
*slot = Slot::Removed;
}
}
}
removed
}
pub fn find(&self, host: &str, port: u16) -> Vec<&Entry> {
self.lines
.iter()
.filter_map(|s| match s {
Slot::Entry(e) if host_field_matches(&e.host_spec, host, port) => Some(e),
_ => None,
})
.collect()
}
pub fn hash_in_place(&mut self) {
let mut rng = OsRng;
let mut new_lines: Vec<Slot> = Vec::with_capacity(self.lines.len());
for slot in self.lines.drain(..) {
match slot {
Slot::Entry(e) => match e.host_spec {
HostSpec::Hashed(_) => new_lines.push(Slot::Entry(e)),
HostSpec::Patterns(pats) => {
for pat in pats {
let (host, port) = split_host_port(&pat);
let mut salt = [0u8; super::hash::SALT_LEN];
rng.fill_bytes(&mut salt);
let token = super::hash::encode_hashed(
&salt,
&super::hash::format_host(&host, port),
);
new_lines.push(Slot::Entry(Entry {
marker: e.marker,
host_spec: HostSpec::Hashed(token),
key_type: e.key_type.clone(),
key_blob: e.key_blob.clone(),
comment: e.comment.clone(),
}));
}
}
},
other => new_lines.push(other),
}
}
self.lines = new_lines;
}
}
fn write_private_file(path: &Path, data: &[u8]) -> io::Result<()> {
#[cfg(unix)]
{
use std::io::Write as _;
use std::os::unix::fs::OpenOptionsExt as _;
let mut f = fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(path)?;
f.write_all(data)?;
f.sync_all()?;
Ok(())
}
#[cfg(not(unix))]
{
fs::write(path, data)
}
}
fn host_field_matches(spec: &HostSpec, host: &str, port: u16) -> bool {
match spec {
HostSpec::Patterns(pats) => patterns_match(pats, host, port),
HostSpec::Hashed(token) => match parse_hashed(token) {
Some((salt, hash)) => check_hashed(&salt, &hash, host, port),
None => false,
},
}
}
fn split_host_port(pat: &str) -> (String, u16) {
if let Some(stripped) = pat.strip_prefix('[') {
if let Some(idx) = stripped.rfind(']') {
let host = stripped[..idx].to_string();
if let Some(rest) = stripped[idx + 1..].strip_prefix(':') {
if let Ok(p) = rest.parse::<u16>() {
return (host, p);
}
}
return (host, 22);
}
}
(pat.to_string(), 22)
}