use core::ffi::{c_char, c_int};
use core::panic::AssertUnwindSafe;
use core::ptr;
use std::ffi::CStr;
use std::panic::catch_unwind;
use crate::error::Error;
pub const PCSSH_OK: c_int = 0;
pub const PCSSH_ERR_GENERIC: c_int = -1;
pub const PCSSH_ERR_BUFFER_TOO_SMALL: c_int = -2;
pub const PCSSH_ERR_INVALID_ARGUMENT: c_int = -3;
pub const PCSSH_ERR_IO: c_int = -4;
pub const PCSSH_ERR_CONNECT: c_int = -5;
pub const PCSSH_ERR_KEX: c_int = -6;
pub const PCSSH_ERR_AUTH_FAILED: c_int = -7;
pub const PCSSH_ERR_HOSTKEY_REJECTED: c_int = -8;
pub const PCSSH_ERR_PROTOCOL: c_int = -9;
pub const PCSSH_ERR_PARSE: c_int = -10;
pub const PCSSH_ERR_PANIC: c_int = -99;
pub(crate) fn map_error(err: &Error) -> c_int {
match err {
Error::Io(_) => PCSSH_ERR_IO,
Error::Format(_) => PCSSH_ERR_PARSE,
Error::NoCommonAlgorithm(_) => PCSSH_ERR_KEX,
Error::Protocol(_) => PCSSH_ERR_PROTOCOL,
Error::BadMac
| Error::BadTag
| Error::BadPadding
| Error::BadSignature
| Error::Crypto(_) => PCSSH_ERR_KEX,
Error::HostKeyRejected => PCSSH_ERR_HOSTKEY_REJECTED,
Error::AuthFailed => PCSSH_ERR_AUTH_FAILED,
Error::BadChannelState => PCSSH_ERR_PROTOCOL,
Error::Unsupported(_) => PCSSH_ERR_GENERIC,
}
}
pub(crate) unsafe fn cstr_to_str<'a>(ptr: *const c_char) -> Option<&'a str> {
if ptr.is_null() {
return None;
}
let cs = unsafe { CStr::from_ptr(ptr) };
cs.to_str().ok()
}
pub(crate) fn catch<F: FnOnce() -> c_int>(f: F) -> c_int {
match catch_unwind(AssertUnwindSafe(f)) {
Ok(code) => code,
Err(_) => PCSSH_ERR_PANIC,
}
}
#[no_mangle]
pub extern "C" fn pcssh_error_message(code: c_int) -> *const c_char {
let s: &'static [u8] = match code {
PCSSH_OK => b"ok\0",
PCSSH_ERR_GENERIC => b"generic error\0",
PCSSH_ERR_BUFFER_TOO_SMALL => b"buffer too small\0",
PCSSH_ERR_INVALID_ARGUMENT => b"invalid argument\0",
PCSSH_ERR_IO => b"I/O error\0",
PCSSH_ERR_CONNECT => b"connect failed\0",
PCSSH_ERR_KEX => b"key exchange failed\0",
PCSSH_ERR_AUTH_FAILED => b"authentication failed\0",
PCSSH_ERR_HOSTKEY_REJECTED => b"host key rejected\0",
PCSSH_ERR_PROTOCOL => b"protocol error\0",
PCSSH_ERR_PARSE => b"parse error\0",
PCSSH_ERR_PANIC => b"caught panic at FFI boundary\0",
_ => return ptr::null(),
};
s.as_ptr() as *const c_char
}
#[no_mangle]
pub extern "C" fn pcssh_version() -> *const c_char {
static VERSION: &[u8] = concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes();
VERSION.as_ptr() as *const c_char
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn version_is_nul_terminated() {
let p = pcssh_version();
assert!(!p.is_null());
let s = unsafe { CStr::from_ptr(p) };
let v = s.to_str().unwrap();
assert_eq!(v, env!("CARGO_PKG_VERSION"));
}
#[test]
fn error_message_known_codes() {
for code in [
PCSSH_OK,
PCSSH_ERR_GENERIC,
PCSSH_ERR_BUFFER_TOO_SMALL,
PCSSH_ERR_INVALID_ARGUMENT,
PCSSH_ERR_IO,
PCSSH_ERR_CONNECT,
PCSSH_ERR_KEX,
PCSSH_ERR_AUTH_FAILED,
PCSSH_ERR_HOSTKEY_REJECTED,
PCSSH_ERR_PROTOCOL,
PCSSH_ERR_PARSE,
PCSSH_ERR_PANIC,
] {
let p = pcssh_error_message(code);
assert!(!p.is_null(), "missing message for {code}");
let _ = unsafe { CStr::from_ptr(p) }.to_str().unwrap();
}
}
#[test]
fn error_message_unknown_returns_null() {
assert!(pcssh_error_message(12345).is_null());
assert!(pcssh_error_message(-12345).is_null());
}
#[test]
fn catch_returns_panic_code_on_unwind() {
let rc = catch(|| panic!("boom"));
assert_eq!(rc, PCSSH_ERR_PANIC);
}
#[test]
fn catch_passes_through_normal_return() {
let rc = catch(|| PCSSH_OK);
assert_eq!(rc, PCSSH_OK);
let rc = catch(|| PCSSH_ERR_IO);
assert_eq!(rc, PCSSH_ERR_IO);
}
}