use crate::{c_str_to_rust, memory, sys};
use crate::memory::Malloc;
use core::{cmp, mem, ptr};
const NAME_LEN: usize = 1024;
#[derive(Debug)]
pub enum EnvError {
NotPresent,
NotUnicode,
}
pub fn get_var(name: &str) -> Result<&'static str, EnvError> {
debug_assert!(name.len() > 0, "Empty variable name makes no sense");
debug_assert!(name.len() <= NAME_LEN);
let mut name_buff = mem::MaybeUninit::<[i8; NAME_LEN + 1]>::uninit();
let len = cmp::min(NAME_LEN, name.len());
unsafe {
let name_ptr = name_buff.as_mut_ptr() as *mut i8;
ptr::copy_nonoverlapping(name.as_ptr() as *const i8, name_ptr, len);
ptr::write(name_ptr.add(len), 0);
}
let result = unsafe {
sys::getenv(name_buff.as_ptr() as *const i8)
};
if result.is_null() {
return Err(EnvError::NotPresent);
}
unsafe { c_str_to_rust(result as *const u8).map_err(|_| EnvError::NotUnicode) }
}
#[cfg(windows)]
extern "system" {
pub fn MultiByteToWideChar(cp: libc::c_uint, flags: libc::c_ulong, in_str: *const i8, in_size: libc::c_int, out_str: *mut u16, out_size: libc::c_int) -> libc::c_int;
pub fn SetEnvironmentVariableW(name: *const u16, value: *const u16) -> libc::c_int;
}
#[cfg(not(windows))]
pub fn set_var(name: &str, value: &str) -> bool {
debug_assert!(name.len() > 0, "Empty variable name makes no sense");
debug_assert!(value.len() > 0, "Empty variable value makes no sense");
debug_assert!(name.len() <= NAME_LEN);
let mut name_buff = mem::MaybeUninit::<[i8; NAME_LEN + 1]>::uninit();
let len = cmp::min(NAME_LEN, name.len());
unsafe {
let name_ptr = name_buff.as_mut_ptr() as *mut i8;
ptr::copy_nonoverlapping(name.as_ptr() as *const i8, name_ptr, len);
ptr::write(name_ptr.add(len), 0);
}
let value_store = memory::Box::malloc(mem::size_of::<u8>() * value.len() + mem::size_of::<u8>());
unsafe {
ptr::copy_nonoverlapping(value.as_ptr() as *const i8, value_store.cast::<i8>(), value.len());
ptr::write(value_store.cast::<i8>().add(value.len()), 0);
libc::setenv(name_buff.as_ptr() as *const i8, value_store.const_cast::<i8>(), 1) == 0
}
}
#[cfg(windows)]
pub fn set_var(name: &str, value: &str) -> bool {
const U16_SIZE: usize = mem::size_of::<u16>();
debug_assert!(name.len() > 0, "Empty variable name makes no sense");
debug_assert!(value.len() > 0, "Empty variable value makes no sense");
let name_size = unsafe {
MultiByteToWideChar(65001, 0, name.as_ptr() as *const _, name.len() as libc::c_int, ptr::null_mut(), 0)
};
if name_size == 0 {
return false;
}
let value_size = unsafe {
MultiByteToWideChar(65001, 0, value.as_ptr() as *const _, value.len() as libc::c_int, ptr::null_mut(), 0)
};
if value_size == 0 {
return false;
}
let name_store = memory::Box::malloc(U16_SIZE * name_size as usize + U16_SIZE);
let value_store = memory::Box::malloc(U16_SIZE * value_size as usize + U16_SIZE);
unsafe {
let mut result = MultiByteToWideChar(65001, 0, name.as_ptr() as *const _, name.len() as libc::c_int, name_store.cast::<u16>(), name_size);
if result == 0 {
return false;
}
ptr::write(name_store.cast::<u16>().offset(result as isize), 0);
result = MultiByteToWideChar(65001, 0, value.as_ptr() as *const _, value.len() as libc::c_int, value_store.cast::<u16>(), value_size);
if result == 0 {
return false;
}
ptr::write(value_store.cast::<u16>().offset(result as isize), 0);
SetEnvironmentVariableW(name_store.const_cast::<u16>(), value_store.const_cast::<u16>()) != 0
}
}
#[cfg(not(windows))]
pub fn unset_var(name: &str) -> bool {
debug_assert!(name.len() > 0, "Empty variable name makes no sense");
debug_assert!(name.len() <= NAME_LEN);
let mut name_buff = mem::MaybeUninit::<[i8; NAME_LEN + 1]>::uninit();
let len = cmp::min(NAME_LEN, name.len());
unsafe {
let name_ptr = name_buff.as_mut_ptr() as *mut i8;
ptr::copy_nonoverlapping(name.as_ptr() as *const i8, name_ptr, len);
ptr::write(name_ptr.add(len), 0);
}
unsafe {
libc::unsetenv(name_buff.as_ptr() as *const i8) == 0
}
}
#[cfg(windows)]
pub fn unset_var(name: &str) -> bool {
const U16_SIZE: usize = mem::size_of::<u16>();
debug_assert!(name.len() > 0, "Empty variable name makes no sense");
let name_size = unsafe {
MultiByteToWideChar(65001, 0, name.as_ptr() as *const _, name.len() as libc::c_int, ptr::null_mut(), 0)
};
if name_size == 0 {
return false;
}
let name_store = memory::Box::malloc(U16_SIZE * name_size as usize + U16_SIZE);
unsafe {
let result = MultiByteToWideChar(65001, 0, name.as_ptr() as *const _, name.len() as libc::c_int, name_store.cast::<u16>(), name_size);
if result == 0 {
return false;
}
ptr::write(name_store.cast::<u16>().offset(result as isize), 0);
SetEnvironmentVariableW(name_store.const_cast::<u16>(), ptr::null()) != 0
}
}