use core::ffi::{c_char, c_int};
use core::ptr;
use std::ffi::c_void;
use std::path::PathBuf;
use std::slice;
use std::sync::{Arc, Mutex};
use super::client::PcSshClient;
use super::common::{
catch, cstr_to_str, map_error, PCSSH_ERR_BUFFER_TOO_SMALL, PCSSH_ERR_CONNECT,
PCSSH_ERR_GENERIC, PCSSH_ERR_INVALID_ARGUMENT, PCSSH_ERR_IO, PCSSH_OK,
};
use crate::client::{Client, Config, HostKeyPolicy, KnownHostsPolicy, TofuAction};
use crate::error::Error;
use crate::known_hosts::{KnownHosts, LookupResult};
use crate::shared::SharedClient;
pub const PCSSH_KH_MATCH: c_int = 0;
pub const PCSSH_KH_MISMATCH: c_int = 1;
pub const PCSSH_KH_UNKNOWN: c_int = 2;
pub const PCSSH_TOFU_REJECT: c_int = 0;
pub const PCSSH_TOFU_ACCEPT: c_int = 1;
pub const PCSSH_TOFU_PROMPT: c_int = 2;
pub struct PcSshKnownHosts {
pub(crate) inner: Arc<Mutex<KnownHosts>>,
}
pub type PcSshTofuPromptCb = Option<
unsafe extern "C" fn(
ctx: *mut c_void,
host: *const c_char,
port: u16,
algorithm: *const c_char,
key_blob: *const u8,
key_blob_len: usize,
) -> c_int,
>;
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_new(out: *mut *mut PcSshKnownHosts) -> c_int {
catch(|| {
if out.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let kh = PcSshKnownHosts {
inner: Arc::new(Mutex::new(KnownHosts::new())),
};
unsafe { *out = Box::into_raw(Box::new(kh)) };
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_load(
path: *const c_char,
out: *mut *mut PcSshKnownHosts,
) -> c_int {
catch(|| {
if out.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
unsafe { *out = ptr::null_mut() };
let path_s = match unsafe { cstr_to_str(path) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let kh = match KnownHosts::load(path_s) {
Ok(k) => k,
Err(_) => return PCSSH_ERR_IO,
};
let boxed = PcSshKnownHosts {
inner: Arc::new(Mutex::new(kh)),
};
unsafe { *out = Box::into_raw(Box::new(boxed)) };
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_save(
kh: *const PcSshKnownHosts,
path: *const c_char,
) -> c_int {
catch(|| {
if kh.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let path_s = match unsafe { cstr_to_str(path) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let h = unsafe { &*kh };
let g = match h.inner.lock() {
Ok(g) => g,
Err(_) => return PCSSH_ERR_GENERIC,
};
match g.save(path_s) {
Ok(()) => PCSSH_OK,
Err(_) => PCSSH_ERR_IO,
}
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_from_bytes(
buf: *const u8,
len: usize,
out: *mut *mut PcSshKnownHosts,
) -> c_int {
catch(|| {
if out.is_null() || (buf.is_null() && len != 0) {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let bytes = if len == 0 {
&[][..]
} else {
unsafe { slice::from_raw_parts(buf, len) }
};
let kh = KnownHosts::from_bytes(bytes);
let boxed = PcSshKnownHosts {
inner: Arc::new(Mutex::new(kh)),
};
unsafe { *out = Box::into_raw(Box::new(boxed)) };
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_to_bytes(
kh: *const PcSshKnownHosts,
buf: *mut u8,
cap: usize,
out_len: *mut usize,
) -> c_int {
catch(|| {
if kh.is_null() || out_len.is_null() || (buf.is_null() && cap != 0) {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let h = unsafe { &*kh };
let g = match h.inner.lock() {
Ok(g) => g,
Err(_) => return PCSSH_ERR_GENERIC,
};
let data = g.to_bytes();
let need = data.len();
unsafe { *out_len = need };
if need > cap {
return PCSSH_ERR_BUFFER_TOO_SMALL;
}
if need > 0 {
unsafe { ptr::copy_nonoverlapping(data.as_ptr(), buf, need) };
}
PCSSH_OK
})
}
#[allow(clippy::too_many_arguments)]
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_lookup(
kh: *const PcSshKnownHosts,
host: *const c_char,
port: u16,
algorithm: *const c_char,
key_blob: *const u8,
key_blob_len: usize,
out_result: *mut c_int,
) -> c_int {
catch(|| {
if kh.is_null() || out_result.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
if key_blob.is_null() && key_blob_len != 0 {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let host_s = match unsafe { cstr_to_str(host) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let alg_s = match unsafe { cstr_to_str(algorithm) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let blob = if key_blob_len == 0 {
&[][..]
} else {
unsafe { slice::from_raw_parts(key_blob, key_blob_len) }
};
let h = unsafe { &*kh };
let g = match h.inner.lock() {
Ok(g) => g,
Err(_) => return PCSSH_ERR_GENERIC,
};
let r = match g.lookup(host_s, port, alg_s, blob) {
LookupResult::Match => PCSSH_KH_MATCH,
LookupResult::Mismatch { .. } => PCSSH_KH_MISMATCH,
LookupResult::Unknown => PCSSH_KH_UNKNOWN,
};
unsafe { *out_result = r };
PCSSH_OK
})
}
#[allow(clippy::too_many_arguments)]
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_add(
kh: *mut PcSshKnownHosts,
host: *const c_char,
port: u16,
algorithm: *const c_char,
key_blob: *const u8,
key_blob_len: usize,
hash_host: c_int,
) -> c_int {
catch(|| {
if kh.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
if key_blob.is_null() && key_blob_len != 0 {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let host_s = match unsafe { cstr_to_str(host) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let alg_s = match unsafe { cstr_to_str(algorithm) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let blob = if key_blob_len == 0 {
&[][..]
} else {
unsafe { slice::from_raw_parts(key_blob, key_blob_len) }
};
let h = unsafe { &*kh };
let mut g = match h.inner.lock() {
Ok(g) => g,
Err(_) => return PCSSH_ERR_GENERIC,
};
g.add(host_s, port, alg_s, blob, hash_host != 0);
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_remove(
kh: *mut PcSshKnownHosts,
host: *const c_char,
port: u16,
out_removed: *mut usize,
) -> c_int {
catch(|| {
if kh.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let host_s = match unsafe { cstr_to_str(host) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let h = unsafe { &*kh };
let mut g = match h.inner.lock() {
Ok(g) => g,
Err(_) => return PCSSH_ERR_GENERIC,
};
let n = g.remove(host_s, port);
if !out_removed.is_null() {
unsafe { *out_removed = n };
}
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_hash_in_place(kh: *mut PcSshKnownHosts) -> c_int {
catch(|| {
if kh.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let h = unsafe { &*kh };
let mut g = match h.inner.lock() {
Ok(g) => g,
Err(_) => return PCSSH_ERR_GENERIC,
};
g.hash_in_place();
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_known_hosts_free(kh: *mut PcSshKnownHosts) {
if kh.is_null() {
return;
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = unsafe { Box::from_raw(kh) };
}));
}
#[allow(clippy::too_many_arguments)]
#[no_mangle]
pub unsafe extern "C" fn pcssh_client_connect_known_hosts(
host: *const c_char,
port: u16,
timeout_ms: i32,
kh: *mut PcSshKnownHosts,
on_unknown: c_int,
prompt_cb: PcSshTofuPromptCb,
prompt_ctx: *mut c_void,
save_path: *const c_char,
hash_new: c_int,
out_client: *mut *mut PcSshClient,
) -> c_int {
catch(|| {
if out_client.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
unsafe { *out_client = ptr::null_mut() };
if kh.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let host_s = match unsafe { cstr_to_str(host) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let save_path_opt: Option<PathBuf> = if save_path.is_null() {
None
} else {
match unsafe { cstr_to_str(save_path) } {
Some(s) => Some(PathBuf::from(s)),
None => return PCSSH_ERR_INVALID_ARGUMENT,
}
};
let timeout = if timeout_ms > 0 {
Some(std::time::Duration::from_millis(timeout_ms as u64))
} else {
None
};
let on_unknown = match on_unknown {
PCSSH_TOFU_REJECT => TofuAction::Reject,
PCSSH_TOFU_ACCEPT => TofuAction::Accept,
PCSSH_TOFU_PROMPT => match prompt_cb {
Some(cb) => {
let ctx_addr = prompt_ctx as usize;
TofuAction::Prompt(Arc::new(
move |host: &str, port: u16, alg: &str, blob: &[u8]| -> bool {
let h_cs = match std::ffi::CString::new(host) {
Ok(c) => c,
Err(_) => return false,
};
let a_cs = match std::ffi::CString::new(alg) {
Ok(c) => c,
Err(_) => return false,
};
let rc = unsafe {
cb(
ctx_addr as *mut c_void,
h_cs.as_ptr(),
port,
a_cs.as_ptr(),
blob.as_ptr(),
blob.len(),
)
};
rc != 0
},
))
}
None => return PCSSH_ERR_INVALID_ARGUMENT,
},
_ => return PCSSH_ERR_INVALID_ARGUMENT,
};
let kh_handle = unsafe { &*kh };
let store = Arc::clone(&kh_handle.inner);
let policy = KnownHostsPolicy {
store,
save_path: save_path_opt,
hash_new: hash_new != 0,
on_unknown,
on_mismatch: TofuAction::Reject,
};
let cfg = Config {
host_key_policy: HostKeyPolicy::KnownHosts(policy),
timeout,
};
match Client::connect_to_host(host_s, port, cfg) {
Ok(c) => {
let boxed = Box::new(PcSshClient {
inner: SharedClient::from(c),
});
unsafe { *out_client = Box::into_raw(boxed) };
PCSSH_OK
}
Err(Error::Io(_)) => PCSSH_ERR_CONNECT,
Err(e) => map_error(&e),
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;
#[test]
fn free_null_is_safe() {
unsafe { pcssh_known_hosts_free(ptr::null_mut()) };
}
#[test]
fn new_then_free() {
let mut h: *mut PcSshKnownHosts = ptr::null_mut();
let rc = unsafe { pcssh_known_hosts_new(&mut h) };
assert_eq!(rc, PCSSH_OK);
assert!(!h.is_null());
unsafe { pcssh_known_hosts_free(h) };
}
#[test]
fn lookup_empty_is_unknown() {
let mut h: *mut PcSshKnownHosts = ptr::null_mut();
unsafe { pcssh_known_hosts_new(&mut h) };
let host = CString::new("example.com").unwrap();
let alg = CString::new("ssh-ed25519").unwrap();
let blob = [1u8, 2, 3];
let mut r: c_int = -1;
let rc = unsafe {
pcssh_known_hosts_lookup(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
&mut r,
)
};
assert_eq!(rc, PCSSH_OK);
assert_eq!(r, PCSSH_KH_UNKNOWN);
unsafe { pcssh_known_hosts_free(h) };
}
#[test]
fn add_then_lookup_matches() {
let mut h: *mut PcSshKnownHosts = ptr::null_mut();
unsafe { pcssh_known_hosts_new(&mut h) };
let host = CString::new("example.com").unwrap();
let alg = CString::new("ssh-ed25519").unwrap();
let blob = [9u8, 8, 7, 6];
let rc = unsafe {
pcssh_known_hosts_add(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
0,
)
};
assert_eq!(rc, PCSSH_OK);
let mut r: c_int = -1;
let rc = unsafe {
pcssh_known_hosts_lookup(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
&mut r,
)
};
assert_eq!(rc, PCSSH_OK);
assert_eq!(r, PCSSH_KH_MATCH);
unsafe { pcssh_known_hosts_free(h) };
}
#[test]
fn add_wrong_key_is_mismatch() {
let mut h: *mut PcSshKnownHosts = ptr::null_mut();
unsafe { pcssh_known_hosts_new(&mut h) };
let host = CString::new("example.com").unwrap();
let alg = CString::new("ssh-ed25519").unwrap();
let blob_good = [1u8, 2, 3];
let blob_bad = [4u8, 5, 6];
unsafe {
pcssh_known_hosts_add(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob_good.as_ptr(),
blob_good.len(),
0,
)
};
let mut r: c_int = -1;
unsafe {
pcssh_known_hosts_lookup(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob_bad.as_ptr(),
blob_bad.len(),
&mut r,
)
};
assert_eq!(r, PCSSH_KH_MISMATCH);
unsafe { pcssh_known_hosts_free(h) };
}
#[test]
fn to_bytes_then_from_bytes_roundtrip() {
let mut h: *mut PcSshKnownHosts = ptr::null_mut();
unsafe { pcssh_known_hosts_new(&mut h) };
let host = CString::new("a.example").unwrap();
let alg = CString::new("ssh-ed25519").unwrap();
let blob = [1u8, 2, 3, 4];
unsafe {
pcssh_known_hosts_add(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
0,
)
};
let mut need: usize = 0;
let rc = unsafe { pcssh_known_hosts_to_bytes(h, ptr::null_mut(), 0, &mut need) };
assert_eq!(rc, PCSSH_ERR_BUFFER_TOO_SMALL);
assert!(need > 0);
let mut buf = vec![0u8; need];
let mut got: usize = 0;
let rc = unsafe { pcssh_known_hosts_to_bytes(h, buf.as_mut_ptr(), buf.len(), &mut got) };
assert_eq!(rc, PCSSH_OK);
assert_eq!(got, need);
let mut h2: *mut PcSshKnownHosts = ptr::null_mut();
let rc = unsafe { pcssh_known_hosts_from_bytes(buf.as_ptr(), got, &mut h2) };
assert_eq!(rc, PCSSH_OK);
let mut r: c_int = -1;
unsafe {
pcssh_known_hosts_lookup(
h2,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
&mut r,
)
};
assert_eq!(r, PCSSH_KH_MATCH);
unsafe {
pcssh_known_hosts_free(h);
pcssh_known_hosts_free(h2);
};
}
#[test]
fn remove_then_lookup_unknown() {
let mut h: *mut PcSshKnownHosts = ptr::null_mut();
unsafe { pcssh_known_hosts_new(&mut h) };
let host = CString::new("rm.example").unwrap();
let alg = CString::new("ssh-ed25519").unwrap();
let blob = [7u8; 8];
unsafe {
pcssh_known_hosts_add(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
0,
)
};
let mut n: usize = 0;
let rc = unsafe { pcssh_known_hosts_remove(h, host.as_ptr(), 22, &mut n) };
assert_eq!(rc, PCSSH_OK);
assert_eq!(n, 1);
let mut r: c_int = -1;
unsafe {
pcssh_known_hosts_lookup(
h,
host.as_ptr(),
22,
alg.as_ptr(),
blob.as_ptr(),
blob.len(),
&mut r,
)
};
assert_eq!(r, PCSSH_KH_UNKNOWN);
unsafe { pcssh_known_hosts_free(h) };
}
}