use std::error;
use std::fmt;
use std::os::raw::{c_int, c_void};
use std::panic;
use std::{ptr, result, slice};
use bindings::{
secure_session_connect, secure_session_create, secure_session_destroy,
secure_session_generate_connect_request, secure_session_get_remote_id,
secure_session_is_established, secure_session_receive, secure_session_send, secure_session_t,
secure_session_unwrap, secure_session_user_callbacks_t, secure_session_wrap, STATE_ESTABLISHED,
STATE_IDLE, STATE_NEGOTIATING,
};
use crate::error::{themis_status_t, Error, ErrorKind, Result};
use crate::keys::{EcdsaPrivateKey, EcdsaPublicKey};
use crate::utils::into_raw_parts;
pub struct SecureSession {
session: *mut secure_session_t,
#[allow(dead_code)]
context: Box<SecureSessionContext>,
}
struct SecureSessionContext {
callbacks: secure_session_user_callbacks_t,
transport: Box<dyn SecureSessionTransport>,
last_error: Option<TransportError>,
}
unsafe impl Send for SecureSession {}
#[allow(unused_variables)]
pub trait SecureSessionTransport {
fn get_public_key_for_id(&mut self, id: &[u8]) -> Option<EcdsaPublicKey>;
fn send_data(&mut self, data: &[u8]) -> result::Result<usize, TransportError> {
Err(TransportError::unspecified())
}
fn receive_data(&mut self, data: &mut [u8]) -> result::Result<usize, TransportError> {
Err(TransportError::unspecified())
}
fn state_changed(&mut self, state: SecureSessionState) {}
}
pub struct TransportError {
inner: TransportErrorInner,
}
enum TransportErrorInner {
Unspecified,
Simple(String),
Custom(Box<dyn error::Error + Send + Sync>),
}
impl fmt::Display for TransportError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.inner {
TransportErrorInner::Unspecified => write!(f, "Secure Session transport failed"),
TransportErrorInner::Simple(s) => write!(f, "Secure Session transport failed: {}", s),
TransportErrorInner::Custom(e) => write!(f, "Secure Session transport failed: {}", e),
}
}
}
impl fmt::Debug for TransportError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.inner {
TransportErrorInner::Unspecified => write!(f, "TransportError::Unspecified"),
TransportErrorInner::Simple(s) => write!(f, "TransportError::Simple({:?})", s),
TransportErrorInner::Custom(e) => write!(f, "TransportError::Custom({:?})", e),
}
}
}
impl<T> From<T> for TransportError
where
T: error::Error + Send + Sync + 'static,
{
fn from(error: T) -> Self {
TransportError {
inner: TransportErrorInner::Custom(Box::new(error)),
}
}
}
impl TransportError {
pub fn new(description: impl Into<String>) -> TransportError {
TransportError {
inner: TransportErrorInner::Simple(description.into()),
}
}
pub fn unspecified() -> TransportError {
TransportError {
inner: TransportErrorInner::Unspecified,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SecureSessionState {
Idle,
Negotiating,
Established,
}
impl SecureSessionState {
fn from_int(state: c_int) -> Option<Self> {
match state as u32 {
STATE_IDLE => Some(SecureSessionState::Idle),
STATE_NEGOTIATING => Some(SecureSessionState::Negotiating),
STATE_ESTABLISHED => Some(SecureSessionState::Established),
_ => None,
}
}
}
impl SecureSession {
pub fn new(
id: impl AsRef<[u8]>,
key: &EcdsaPrivateKey,
transport: impl SecureSessionTransport + 'static,
) -> Result<Self> {
if id.as_ref().is_empty() {
return Err(Error::with_kind(ErrorKind::InvalidParameter));
}
let (id_ptr, id_len) = into_raw_parts(id.as_ref());
let (key_ptr, key_len) = into_raw_parts(key.as_ref());
let mut context = Box::new(SecureSessionContext {
callbacks: secure_session_user_callbacks_t {
send_data: Some(send_data),
receive_data: Some(receive_data),
state_changed: Some(state_changed),
get_public_key_for_id: Some(get_public_key_for_id),
user_data: std::ptr::null_mut(),
},
transport: Box::new(transport),
last_error: None,
});
context.callbacks.user_data = context_as_user_data(&context);
let session = unsafe {
secure_session_create(
id_ptr as *const c_void,
id_len,
key_ptr as *const c_void,
key_len,
&context.callbacks,
)
};
if session.is_null() {
return Err(Error::with_kind(ErrorKind::NoMemory));
}
Ok(Self { session, context })
}
pub fn is_established(&self) -> bool {
unsafe { secure_session_is_established(self.session) }
}
pub fn remote_peer_id(&self) -> Result<Option<Vec<u8>>> {
let mut id = Vec::new();
let mut id_len = 0;
unsafe {
let status = secure_session_get_remote_id(self.session, ptr::null_mut(), &mut id_len);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::BufferTooSmall {
return Err(error);
}
}
id.reserve(id_len);
unsafe {
let status = secure_session_get_remote_id(self.session, id.as_mut_ptr(), &mut id_len);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::Success {
return Err(error);
}
debug_assert!(id_len <= id.capacity());
id.set_len(id_len);
}
Ok(if id.is_empty() { None } else { Some(id) })
}
pub fn connect(&mut self) -> Result<()> {
unsafe {
let status = secure_session_connect(self.session);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::Success {
return Err(error);
}
}
Ok(())
}
pub fn negotiate(&mut self) -> Result<()> {
unsafe {
let result = secure_session_receive(self.session, ptr::null_mut(), 0);
if result == TRANSPORT_FAILURE {
let error = self.context.last_error.take().expect("missing error");
return Err(Error::from_transport_error(error));
}
if result == TRANSPORT_OVERFLOW {
return Err(Error::with_kind(ErrorKind::BufferTooSmall));
}
if result == TRANSPORT_PANIC {
return Err(Error::from_transport_error(TransportError::unspecified()));
}
let error = Error::from_session_status(result as themis_status_t);
if error.kind() != ErrorKind::Success {
return Err(error);
}
}
Ok(())
}
const THEMIX_MAX_ERROR: isize = 21;
pub fn send(&mut self, message: impl AsRef<[u8]>) -> Result<()> {
let (message_ptr, message_len) = into_raw_parts(message.as_ref());
unsafe {
let length =
secure_session_send(self.session, message_ptr as *const c_void, message_len);
if length == TRANSPORT_FAILURE {
let error = self.context.last_error.take().expect("missing error");
return Err(Error::from_transport_error(error));
}
if length == TRANSPORT_OVERFLOW {
return Err(Error::with_kind(ErrorKind::BufferTooSmall));
}
if length == TRANSPORT_PANIC {
return Err(Error::from_transport_error(TransportError::unspecified()));
}
if length <= Self::THEMIX_MAX_ERROR {
return Err(Error::from_session_status(length as themis_status_t));
}
}
Ok(())
}
pub fn receive(&mut self, max_len: usize) -> Result<Vec<u8>> {
let mut message = Vec::with_capacity(max_len);
unsafe {
let length = secure_session_receive(
self.session,
message.as_mut_ptr() as *mut c_void,
message.capacity(),
);
if length == TRANSPORT_FAILURE {
let error = self.context.last_error.take().expect("missing error");
return Err(Error::from_transport_error(error));
}
if length == TRANSPORT_OVERFLOW {
return Err(Error::with_kind(ErrorKind::BufferTooSmall));
}
if length == TRANSPORT_PANIC {
return Err(Error::from_transport_error(TransportError::unspecified()));
}
if length <= Self::THEMIX_MAX_ERROR {
return Err(Error::from_session_status(length as themis_status_t));
}
debug_assert!(length as usize <= message.capacity());
message.set_len(length as usize);
}
Ok(message)
}
pub fn connect_request(&mut self) -> Result<Vec<u8>> {
let mut output = Vec::new();
let mut output_len = 0;
unsafe {
let status = secure_session_generate_connect_request(
self.session,
ptr::null_mut(),
&mut output_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::BufferTooSmall {
return Err(error);
}
}
output.reserve(output_len);
unsafe {
let status = secure_session_generate_connect_request(
self.session,
output.as_mut_ptr() as *mut c_void,
&mut output_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::Success {
return Err(error);
}
debug_assert!(output_len <= output.capacity());
output.set_len(output_len);
}
Ok(output)
}
pub fn negotiate_reply(&mut self, wrapped: impl AsRef<[u8]>) -> Result<Vec<u8>> {
let (wrapped_ptr, wrapped_len) = into_raw_parts(wrapped.as_ref());
let mut message = Vec::new();
let mut message_len = 0;
unsafe {
let status = secure_session_unwrap(
self.session,
wrapped_ptr as *const c_void,
wrapped_len,
ptr::null_mut(),
&mut message_len,
);
let error = Error::from_session_status(status);
if error.kind() == ErrorKind::Success {
return Ok(message);
}
if error.kind() != ErrorKind::BufferTooSmall {
return Err(error);
}
}
message.reserve(message_len);
unsafe {
let status = secure_session_unwrap(
self.session,
wrapped_ptr as *const c_void,
wrapped_len,
message.as_mut_ptr() as *mut c_void,
&mut message_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::SessionSendOutputToPeer {
assert_ne!(error.kind(), ErrorKind::Success);
return Err(error);
}
debug_assert!(message_len <= message.capacity());
message.set_len(message_len);
}
Ok(message)
}
pub fn wrap(&mut self, message: impl AsRef<[u8]>) -> Result<Vec<u8>> {
let (message_ptr, message_len) = into_raw_parts(message.as_ref());
let mut wrapped = Vec::new();
let mut wrapped_len = 0;
unsafe {
let status = secure_session_wrap(
self.session,
message_ptr as *const c_void,
message_len,
ptr::null_mut(),
&mut wrapped_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::BufferTooSmall {
return Err(error);
}
}
wrapped.reserve(wrapped_len);
unsafe {
let status = secure_session_wrap(
self.session,
message_ptr as *const c_void,
message_len,
wrapped.as_mut_ptr() as *mut c_void,
&mut wrapped_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::Success {
return Err(error);
}
debug_assert!(wrapped_len <= wrapped.capacity());
wrapped.set_len(wrapped_len);
}
Ok(wrapped)
}
pub fn unwrap(&mut self, wrapped: impl AsRef<[u8]>) -> Result<Vec<u8>> {
let (wrapped_ptr, wrapped_len) = into_raw_parts(wrapped.as_ref());
let mut message = Vec::new();
let mut message_len = 0;
unsafe {
let status = secure_session_unwrap(
self.session,
wrapped_ptr as *const c_void,
wrapped_len,
ptr::null_mut(),
&mut message_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::BufferTooSmall {
return Err(error);
}
}
message.reserve(message_len);
unsafe {
let status = secure_session_unwrap(
self.session,
wrapped_ptr as *const c_void,
wrapped_len,
message.as_mut_ptr() as *mut c_void,
&mut message_len,
);
let error = Error::from_session_status(status);
if error.kind() != ErrorKind::Success {
return Err(error);
}
debug_assert!(message_len <= message.capacity());
message.set_len(message_len);
}
Ok(message)
}
}
impl fmt::Debug for SecureSession {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SecureSession")
.field("session", &self.session)
.finish()
}
}
#[allow(clippy::borrowed_box)]
fn context_as_user_data(context: &Box<SecureSessionContext>) -> *mut c_void {
(context.as_ref() as *const SecureSessionContext) as *mut c_void
}
unsafe fn user_data_as_context<'a>(ptr: *mut c_void) -> &'a mut SecureSessionContext {
&mut *(ptr as *mut SecureSessionContext)
}
const TRANSPORT_FAILURE: isize = -1;
const TRANSPORT_OVERFLOW: isize = -2;
const TRANSPORT_PANIC: isize = -3;
unsafe extern "C" fn send_data(
data_ptr: *const u8,
data_len: usize,
user_data: *mut c_void,
) -> isize {
let result = panic::catch_unwind(|| {
let data = byte_slice_from_ptr(data_ptr, data_len);
let context = user_data_as_context(user_data);
match context.transport.send_data(data) {
Ok(sent_bytes) => as_isize(sent_bytes).unwrap_or(TRANSPORT_OVERFLOW),
Err(error) => {
context.last_error = Some(error);
TRANSPORT_FAILURE
}
}
});
result.unwrap_or(TRANSPORT_PANIC)
}
unsafe extern "C" fn receive_data(
data_ptr: *mut u8,
data_len: usize,
user_data: *mut c_void,
) -> isize {
let result = panic::catch_unwind(|| {
let data = byte_slice_from_ptr_mut(data_ptr, data_len);
let context = user_data_as_context(user_data);
match context.transport.receive_data(data) {
Ok(received_bytes) => as_isize(received_bytes).unwrap_or(TRANSPORT_OVERFLOW),
Err(error) => {
context.last_error = Some(error);
TRANSPORT_FAILURE
}
}
});
result.unwrap_or(TRANSPORT_PANIC)
}
unsafe extern "C" fn state_changed(event: c_int, user_data: *mut c_void) {
let _ = panic::catch_unwind(|| {
let transport = &mut user_data_as_context(user_data).transport;
if let Some(state) = SecureSessionState::from_int(event) {
transport.state_changed(state);
}
});
}
const GET_PUBLIC_KEY_SUCCESS: c_int = 0;
const GET_PUBLIC_KEY_FAILURE: c_int = -1;
unsafe extern "C" fn get_public_key_for_id(
id_ptr: *const c_void,
id_len: usize,
key_ptr: *mut c_void,
key_len: usize,
user_data: *mut c_void,
) -> c_int {
let result = panic::catch_unwind(|| {
let id = byte_slice_from_ptr(id_ptr as *const u8, id_len);
let key_out = byte_slice_from_ptr_mut(key_ptr as *mut u8, key_len);
let transport = &mut user_data_as_context(user_data).transport;
if let Some(key) = transport.get_public_key_for_id(id) {
let key = key.as_ref();
if key_out.len() >= key.len() {
key_out[0..key.len()].copy_from_slice(key);
return GET_PUBLIC_KEY_SUCCESS;
}
}
GET_PUBLIC_KEY_FAILURE
});
result.unwrap_or(GET_PUBLIC_KEY_FAILURE)
}
#[doc(hidden)]
impl Drop for SecureSession {
fn drop(&mut self) {
unsafe {
let status = secure_session_destroy(self.session);
let error = Error::from_session_status(status);
if (cfg!(debug) || cfg!(test)) && error.kind() != ErrorKind::Success {
panic!("secure_session_destroy() failed: {}", error);
}
}
}
}
fn as_isize(n: usize) -> Option<isize> {
if n <= isize::max_value() as usize {
Some(n as isize)
} else {
None
}
}
unsafe fn byte_slice_from_ptr<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
slice::from_raw_parts(escape_null_ptr(ptr as *mut u8), len)
}
unsafe fn byte_slice_from_ptr_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
slice::from_raw_parts_mut(escape_null_ptr(ptr), len)
}
fn escape_null_ptr<T>(ptr: *mut T) -> *mut T {
if ptr.is_null() {
ptr::NonNull::dangling().as_ptr()
} else {
ptr
}
}