use core::ffi::{c_char, c_int};
use core::ptr;
use std::ffi::CStr;
use std::net::ToSocketAddrs;
use std::slice;
use std::time::Duration;
use super::common::{
catch, cstr_to_str, map_error, PCSSH_ERR_BUFFER_TOO_SMALL, PCSSH_ERR_CONNECT,
PCSSH_ERR_INVALID_ARGUMENT, PCSSH_OK,
};
use crate::auth::ClientCredential;
use crate::client::{Client, Config, HostKeyPolicy};
use crate::error::Error;
use crate::key::PrivateKey;
use crate::shared::SharedClient;
pub struct PcSshClient {
pub(crate) inner: SharedClient,
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_client_connect(
host: *const c_char,
port: u16,
timeout_ms: i32,
out: *mut *mut PcSshClient,
) -> c_int {
catch(|| {
if out.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
unsafe { *out = ptr::null_mut() };
let host_str = match unsafe { cstr_to_str(host) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let addr = format!("{host_str}:{port}");
let addrs = match addr.to_socket_addrs() {
Ok(a) => a,
Err(_) => return PCSSH_ERR_CONNECT,
};
let timeout = if timeout_ms > 0 {
Some(Duration::from_millis(timeout_ms as u64))
} else {
None
};
let mut last_err: Option<Error> = None;
for sa in addrs {
let cfg = Config {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout,
};
match Client::connect(sa, cfg) {
Ok(c) => {
let boxed = Box::new(PcSshClient {
inner: SharedClient::from(c),
});
unsafe { *out = Box::into_raw(boxed) };
return PCSSH_OK;
}
Err(e) => last_err = Some(e),
}
}
match last_err {
Some(Error::Io(_)) => PCSSH_ERR_CONNECT,
Some(e) => map_error(&e),
None => PCSSH_ERR_CONNECT,
}
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_client_auth_password(
client: *mut PcSshClient,
user: *const c_char,
password: *const c_char,
) -> c_int {
catch(|| {
if client.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let user_s = match unsafe { cstr_to_str(user) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let pass_s = match unsafe { cstr_to_str(password) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let c = unsafe { &*client };
match c
.inner
.with_client(|cl| cl.authenticate_password(user_s, pass_s))
{
Ok(()) => PCSSH_OK,
Err(e) => map_error(&e),
}
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_client_auth_publickey(
client: *mut PcSshClient,
user: *const c_char,
private_key_pem: *const c_char,
private_key_pem_len: usize,
passphrase: *const c_char,
) -> c_int {
catch(|| {
if client.is_null() || private_key_pem.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let user_s = match unsafe { cstr_to_str(user) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let pem_bytes =
unsafe { slice::from_raw_parts(private_key_pem as *const u8, private_key_pem_len) };
let pem_str = match core::str::from_utf8(pem_bytes) {
Ok(s) => s,
Err(_) => return PCSSH_ERR_INVALID_ARGUMENT,
};
let passphrase_opt: Option<Vec<u8>> = if passphrase.is_null() {
None
} else {
let cs = unsafe { CStr::from_ptr(passphrase) };
let bytes = cs.to_bytes();
if bytes.is_empty() {
None
} else {
Some(bytes.to_vec())
}
};
let priv_key = match PrivateKey::parse_openssh_pem(pem_str, passphrase_opt.as_deref()) {
Ok(k) => k,
Err(e) => return map_error(&e),
};
let hk = match priv_key.into_host_key() {
Ok(h) => h,
Err(e) => return map_error(&e),
};
let c = unsafe { &*client };
match c
.inner
.with_client(|cl| cl.authenticate(user_s, vec![ClientCredential::PublicKey(hk)]))
{
Ok(()) => PCSSH_OK,
Err(e) => map_error(&e),
}
})
}
#[allow(clippy::too_many_arguments)]
#[no_mangle]
pub unsafe extern "C" fn pcssh_client_exec(
client: *mut PcSshClient,
command: *const c_char,
stdout_buf: *mut u8,
stdout_cap: usize,
stdout_out_len: *mut usize,
stderr_buf: *mut u8,
stderr_cap: usize,
stderr_out_len: *mut usize,
exit_status_out: *mut i32,
) -> c_int {
catch(|| {
if client.is_null()
|| stdout_out_len.is_null()
|| stderr_out_len.is_null()
|| exit_status_out.is_null()
{
return PCSSH_ERR_INVALID_ARGUMENT;
}
if (stdout_buf.is_null() && stdout_cap != 0) || (stderr_buf.is_null() && stderr_cap != 0) {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let cmd_s = match unsafe { cstr_to_str(command) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let c = unsafe { &*client };
let out = match c.inner.with_client(|cl| cl.exec(cmd_s)) {
Ok(o) => o,
Err(e) => return map_error(&e),
};
let need_out = out.stdout.len();
let need_err = out.stderr.len();
unsafe {
*stdout_out_len = need_out;
*stderr_out_len = need_err;
*exit_status_out = out.exit_status.map(|v| v as i32).unwrap_or(-1);
}
if need_out > stdout_cap || need_err > stderr_cap {
return PCSSH_ERR_BUFFER_TOO_SMALL;
}
if need_out > 0 {
unsafe {
ptr::copy_nonoverlapping(out.stdout.as_ptr(), stdout_buf, need_out);
}
}
if need_err > 0 {
unsafe {
ptr::copy_nonoverlapping(out.stderr.as_ptr(), stderr_buf, need_err);
}
}
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_client_free(client: *mut PcSshClient) {
if client.is_null() {
return;
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let boxed = unsafe { Box::from_raw(client) };
drop(boxed);
}));
}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;
#[test]
fn free_null_is_safe() {
unsafe { pcssh_client_free(ptr::null_mut()) };
}
#[test]
fn connect_rejects_null_out() {
let host = CString::new("127.0.0.1").unwrap();
let rc = unsafe { pcssh_client_connect(host.as_ptr(), 22, 100, ptr::null_mut()) };
assert_eq!(rc, PCSSH_ERR_INVALID_ARGUMENT);
}
#[test]
fn connect_rejects_null_host() {
let mut out: *mut PcSshClient = ptr::null_mut();
let rc = unsafe { pcssh_client_connect(ptr::null(), 22, 100, &mut out) };
assert_eq!(rc, PCSSH_ERR_INVALID_ARGUMENT);
assert!(out.is_null());
}
#[test]
fn connect_to_unbound_port_fails() {
let host = CString::new("127.0.0.1").unwrap();
let mut out: *mut PcSshClient = ptr::null_mut();
let rc = unsafe { pcssh_client_connect(host.as_ptr(), 1, 500, &mut out) };
assert!(rc < 0, "expected failure, got {rc}");
assert!(out.is_null());
}
}