#![allow(clippy::missing_safety_doc)]
use super::noise::{Tunn, TunnResult};
use base64::{decode, encode};
use hex::encode as encode_hex;
use libc::{raise, SIGSEGV};
use rand_core::OsRng;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::serialization::KeyBytes;
use std::ffi::{CStr, CString};
use std::os::raw::c_char;
use std::panic;
use std::ptr;
use std::ptr::null_mut;
use std::slice;
use std::sync::Once;
static PANIC_HOOK: Once = Once::new();
#[allow(non_camel_case_types)]
#[repr(C)]
pub enum result_type {
WIREGUARD_DONE = 0,
WRITE_TO_NETWORK = 1,
WIREGUARD_ERROR = 2,
WRITE_TO_TUNNEL_IPV4 = 4,
WRITE_TO_TUNNEL_IPV6 = 6,
}
#[repr(C)]
pub struct wireguard_result {
pub op: result_type,
pub size: usize,
}
#[repr(C)]
pub struct stats {
pub time_since_last_handshake: i64,
pub tx_bytes: usize,
pub rx_bytes: usize,
pub estimated_loss: f32,
pub estimated_rtt: i32,
reserved: [u8; 56], }
impl<'a> From<TunnResult<'a>> for wireguard_result {
fn from(res: TunnResult<'a>) -> wireguard_result {
match res {
TunnResult::Done => wireguard_result {
op: result_type::WIREGUARD_DONE,
size: 0,
},
TunnResult::Err(e) => wireguard_result {
op: result_type::WIREGUARD_ERROR,
size: e as _,
},
TunnResult::WriteToNetwork(b) => wireguard_result {
op: result_type::WRITE_TO_NETWORK,
size: b.len(),
},
TunnResult::WriteToTunnelV4(b, _) => wireguard_result {
op: result_type::WRITE_TO_TUNNEL_IPV4,
size: b.len(),
},
TunnResult::WriteToTunnelV6(b, _) => wireguard_result {
op: result_type::WRITE_TO_TUNNEL_IPV6,
size: b.len(),
},
}
}
}
#[repr(C)]
pub struct x25519_key {
pub key: [u8; 32],
}
#[no_mangle]
pub extern "C" fn x25519_secret_key() -> x25519_key {
x25519_key {
key: StaticSecret::new(OsRng).to_bytes(),
}
}
#[no_mangle]
pub extern "C" fn x25519_public_key(private_key: x25519_key) -> x25519_key {
let private = StaticSecret::from(private_key.key);
let public = PublicKey::from(&private);
x25519_key {
key: public.to_bytes(),
}
}
#[no_mangle]
pub extern "C" fn x25519_key_to_base64(key: x25519_key) -> *const c_char {
let encoded_key = encode(&key.key);
CString::into_raw(CString::new(encoded_key).unwrap())
}
#[no_mangle]
pub extern "C" fn x25519_key_to_hex(key: x25519_key) -> *const c_char {
let encoded_key = encode_hex(&key.key);
CString::into_raw(CString::new(encoded_key).unwrap())
}
#[no_mangle]
pub unsafe extern "C" fn x25519_key_to_str_free(stringified_key: *mut c_char) {
let _ = CString::from_raw(stringified_key);
}
#[no_mangle]
pub unsafe extern "C" fn check_base64_encoded_x25519_key(key: *const c_char) -> i32 {
let c_str = CStr::from_ptr(key);
let utf8_key = match c_str.to_str() {
Err(_) => return 0,
Ok(string) => string,
};
if let Ok(key) = decode(&utf8_key) {
let len = key.len();
let mut zero = 0u8;
for b in key {
zero |= b
}
if len == 32 && zero != 0 {
1
} else {
0
}
} else {
0
}
}
#[no_mangle]
pub unsafe extern "C" fn new_tunnel(
static_private: *const c_char,
server_static_public: *const c_char,
preshared_key: *const c_char,
keep_alive: u16,
index: u32,
) -> *mut Tunn {
let c_str = CStr::from_ptr(static_private);
let static_private = match c_str.to_str() {
Err(_) => return ptr::null_mut(),
Ok(string) => string,
};
let c_str = CStr::from_ptr(server_static_public);
let server_static_public = match c_str.to_str() {
Err(_) => return ptr::null_mut(),
Ok(string) => string,
};
let preshared_key = if preshared_key.is_null() {
None
} else {
let c_str = CStr::from_ptr(preshared_key);
if let Ok(string) = c_str.to_str() {
if let Ok(key) = string.parse::<KeyBytes>() {
Some(key.0)
} else {
return null_mut();
}
} else {
return null_mut();
}
};
let private_key = match static_private.parse::<KeyBytes>() {
Err(_) => return ptr::null_mut(),
Ok(key) => StaticSecret::from(key.0),
};
let public_key = match server_static_public.parse::<KeyBytes>() {
Err(_) => return ptr::null_mut(),
Ok(key) => PublicKey::from(key.0),
};
let keep_alive = if keep_alive == 0 {
None
} else {
Some(keep_alive)
};
let tunnel = match Tunn::new(
private_key,
public_key,
preshared_key,
keep_alive,
index,
None,
) {
Ok(t) => t,
Err(_) => return ptr::null_mut(),
};
PANIC_HOOK.call_once(|| {
panic::set_hook(Box::new(move |_| {
raise(SIGSEGV);
}));
});
Box::into_raw(tunnel)
}
#[no_mangle]
pub unsafe extern "C" fn tunnel_free(tunnel: *mut Tunn) {
Box::from_raw(tunnel);
}
#[no_mangle]
pub unsafe extern "C" fn wireguard_write(
tunnel: *mut Tunn,
src: *const u8,
src_size: u32,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let src = slice::from_raw_parts(src, src_size as usize);
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.encapsulate(src, dst))
}
#[no_mangle]
pub unsafe extern "C" fn wireguard_read(
tunnel: *mut Tunn,
src: *const u8,
src_size: u32,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let src = slice::from_raw_parts(src, src_size as usize);
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.decapsulate(None, src, dst))
}
#[no_mangle]
pub unsafe extern "C" fn wireguard_tick(
tunnel: *mut Tunn,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.update_timers(dst))
}
#[no_mangle]
pub unsafe extern "C" fn wireguard_force_handshake(
tunnel: *mut Tunn,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.format_handshake_initiation(dst, true))
}
#[no_mangle]
pub unsafe extern "C" fn wireguard_stats(tunnel: *mut Tunn) -> stats {
let tunnel = tunnel.as_ref().unwrap();
let (time, tx_bytes, rx_bytes, estimated_loss, estimated_rtt) = tunnel.stats();
stats {
time_since_last_handshake: time.map(|t| t.as_secs() as i64).unwrap_or(-1),
tx_bytes,
rx_bytes,
estimated_loss,
estimated_rtt: estimated_rtt.map(|r| r as i32).unwrap_or(-1),
reserved: [0u8; 56],
}
}