#![cfg(unix)]
use core::ffi::{c_char, c_int};
use core::ptr;
use std::slice;
use super::common::{
catch, cstr_to_str, map_error, PCSSH_ERR_BUFFER_TOO_SMALL, PCSSH_ERR_GENERIC,
PCSSH_ERR_INVALID_ARGUMENT, PCSSH_OK,
};
use crate::agent::{Agent, AgentIdentity};
pub const PCSSH_AGENT_SIGN_DEFAULT: u32 = 0;
pub const PCSSH_AGENT_RSA_SHA2_256: u32 = 2;
pub const PCSSH_AGENT_RSA_SHA2_512: u32 = 4;
pub struct PcSshAgent {
inner: Agent,
identities: Option<Vec<AgentIdentity>>,
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_connect(
path: *const c_char,
out: *mut *mut PcSshAgent,
) -> c_int {
catch(|| {
if out.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
unsafe { *out = ptr::null_mut() };
let path_s = match unsafe { cstr_to_str(path) } {
Some(s) => s,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
match Agent::connect(path_s) {
Ok(agent) => {
let boxed = Box::new(PcSshAgent {
inner: agent,
identities: None,
});
unsafe { *out = Box::into_raw(boxed) };
PCSSH_OK
}
Err(e) => map_error(&e),
}
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_connect_env(out: *mut *mut PcSshAgent) -> c_int {
catch(|| {
if out.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
unsafe { *out = ptr::null_mut() };
match Agent::connect_env() {
Ok(None) => PCSSH_OK, Ok(Some(agent)) => {
let boxed = Box::new(PcSshAgent {
inner: agent,
identities: None,
});
unsafe { *out = Box::into_raw(boxed) };
PCSSH_OK
}
Err(e) => map_error(&e),
}
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_identity_count(
agent: *mut PcSshAgent,
out_count: *mut usize,
) -> c_int {
catch(|| {
if agent.is_null() || out_count.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let a = unsafe { &mut *agent };
if a.identities.is_none() {
match a.inner.identities() {
Ok(ids) => a.identities = Some(ids),
Err(e) => return map_error(&e),
}
}
unsafe { *out_count = a.identities.as_ref().map(|v| v.len()).unwrap_or(0) };
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_refresh_identities(agent: *mut PcSshAgent) -> c_int {
catch(|| {
if agent.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
let a = unsafe { &mut *agent };
a.identities = None;
PCSSH_OK
})
}
#[allow(clippy::too_many_arguments)]
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_identity(
agent: *mut PcSshAgent,
index: usize,
algorithm_buf: *mut u8,
algorithm_cap: usize,
algorithm_len: *mut usize,
comment_buf: *mut u8,
comment_cap: usize,
comment_len: *mut usize,
key_blob_buf: *mut u8,
key_blob_cap: usize,
key_blob_len: *mut usize,
) -> c_int {
catch(|| {
if agent.is_null()
|| algorithm_len.is_null()
|| comment_len.is_null()
|| key_blob_len.is_null()
{
return PCSSH_ERR_INVALID_ARGUMENT;
}
if (algorithm_buf.is_null() && algorithm_cap != 0)
|| (comment_buf.is_null() && comment_cap != 0)
|| (key_blob_buf.is_null() && key_blob_cap != 0)
{
return PCSSH_ERR_INVALID_ARGUMENT;
}
let a = unsafe { &*agent };
let ids = match a.identities.as_ref() {
Some(v) => v,
None => return PCSSH_ERR_GENERIC,
};
let id = match ids.get(index) {
Some(id) => id,
None => return PCSSH_ERR_INVALID_ARGUMENT,
};
let alg = id.algorithm();
let alg_bytes = alg.as_bytes();
let comment = id.comment().as_bytes();
let blob = id.key_blob();
let need_alg = alg_bytes.len();
let need_comment = comment.len();
let need_blob = blob.len();
unsafe {
*algorithm_len = need_alg;
*comment_len = need_comment;
*key_blob_len = need_blob;
}
if need_alg > algorithm_cap || need_comment > comment_cap || need_blob > key_blob_cap {
return PCSSH_ERR_BUFFER_TOO_SMALL;
}
if need_alg > 0 {
unsafe { ptr::copy_nonoverlapping(alg_bytes.as_ptr(), algorithm_buf, need_alg) };
}
if need_comment > 0 {
unsafe { ptr::copy_nonoverlapping(comment.as_ptr(), comment_buf, need_comment) };
}
if need_blob > 0 {
unsafe { ptr::copy_nonoverlapping(blob.as_ptr(), key_blob_buf, need_blob) };
}
PCSSH_OK
})
}
#[allow(clippy::too_many_arguments)]
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_sign(
agent: *mut PcSshAgent,
key_blob: *const u8,
key_blob_len: usize,
data: *const u8,
data_len: usize,
flags: u32,
sig_buf: *mut u8,
sig_cap: usize,
sig_len: *mut usize,
) -> c_int {
catch(|| {
if agent.is_null() || sig_len.is_null() {
return PCSSH_ERR_INVALID_ARGUMENT;
}
if (key_blob.is_null() && key_blob_len != 0)
|| (data.is_null() && data_len != 0)
|| (sig_buf.is_null() && sig_cap != 0)
{
return PCSSH_ERR_INVALID_ARGUMENT;
}
let key_slice = if key_blob_len == 0 {
&[][..]
} else {
unsafe { slice::from_raw_parts(key_blob, key_blob_len) }
};
let data_slice = if data_len == 0 {
&[][..]
} else {
unsafe { slice::from_raw_parts(data, data_len) }
};
let a = unsafe { &mut *agent };
let sig = match a.inner.sign(key_slice, data_slice, flags) {
Ok(s) => s,
Err(e) => return map_error(&e),
};
let need = sig.len();
unsafe { *sig_len = need };
if need > sig_cap {
return PCSSH_ERR_BUFFER_TOO_SMALL;
}
if need > 0 {
unsafe { ptr::copy_nonoverlapping(sig.as_ptr(), sig_buf, need) };
}
PCSSH_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn pcssh_agent_free(agent: *mut PcSshAgent) {
if agent.is_null() {
return;
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = unsafe { Box::from_raw(agent) };
}));
}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;
#[test]
fn free_null_is_safe() {
unsafe { pcssh_agent_free(ptr::null_mut()) };
}
#[test]
fn connect_rejects_null_out() {
let p = CString::new("/nonexistent/socket").unwrap();
let rc = unsafe { pcssh_agent_connect(p.as_ptr(), ptr::null_mut()) };
assert_eq!(rc, PCSSH_ERR_INVALID_ARGUMENT);
}
#[test]
fn connect_to_missing_socket_fails() {
let p = CString::new("/nonexistent/puressh-agent-test-socket").unwrap();
let mut out: *mut PcSshAgent = ptr::null_mut();
let rc = unsafe { pcssh_agent_connect(p.as_ptr(), &mut out) };
assert!(rc < 0, "expected failure, got {rc}");
assert!(out.is_null());
}
#[test]
fn connect_env_unset_returns_ok_null() {
let prev = std::env::var_os("SSH_AUTH_SOCK");
unsafe { std::env::remove_var("SSH_AUTH_SOCK") };
let mut out: *mut PcSshAgent = ptr::null_mut();
let rc = unsafe { pcssh_agent_connect_env(&mut out) };
assert_eq!(rc, PCSSH_OK);
assert!(out.is_null());
if let Some(v) = prev {
unsafe { std::env::set_var("SSH_AUTH_SOCK", v) };
}
}
}