use std::env;
use std::error::Error;
use std::ffi::OsStr;
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use err_context::prelude::*;
use libc::c_int;
use log::{debug, error, warn};
use serde::de::{Deserializer, Error as DeError, Unexpected};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
use crate::AnyError;
pub fn absolute_from_os_str(path: &OsStr) -> PathBuf {
let mut current = env::current_dir().unwrap_or_else(|e| {
warn!(
"Some paths may not be turned to absolute. Couldn't read current dir: {}",
e,
);
PathBuf::new()
});
current.push(path);
if let Ok(canonicized) = current.canonicalize() {
canonicized
} else {
current
}
}
#[derive(Copy, Clone, Debug)]
pub struct MissingEquals;
impl Display for MissingEquals {
fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
write!(fmt, "Missing = in map option")
}
}
impl Error for MissingEquals {}
pub fn key_val<K, V>(opt: &str) -> Result<(K, V), AnyError>
where
K: FromStr,
K::Err: Error + Send + Sync + 'static,
V: FromStr,
V::Err: Error + Send + Sync + 'static,
{
let pos = opt.find('=').ok_or(MissingEquals)?;
Ok((opt[..pos].parse()?, opt[pos + 1..].parse()?))
}
#[derive(Clone, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "cfg-help", derive(structdoc::StructDoc))]
#[repr(transparent)]
#[serde(transparent)]
pub struct Hidden<T>(pub T);
impl<T> From<T> for Hidden<T> {
fn from(val: T) -> Self {
Hidden(val)
}
}
impl<T> Deref for Hidden<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Hidden<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> Debug for Hidden<T> {
fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
write!(fmt, "\"******\"")
}
}
impl<T> Serialize for Hidden<T> {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str("******")
}
}
pub fn serialize_duration<S: Serializer>(dur: &Duration, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(&humantime::format_duration(*dur).to_string())
}
pub fn deserialize_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
let s = String::deserialize(d)?;
humantime::parse_duration(&s)
.map_err(|_| DeError::invalid_value(Unexpected::Str(&s), &"Human readable duration"))
}
pub fn deserialize_opt_duration<'de, D: Deserializer<'de>>(
d: D,
) -> Result<Option<Duration>, D::Error> {
if let Some(dur) = Option::<String>::deserialize(d)? {
humantime::parse_duration(&dur)
.map_err(|_| DeError::invalid_value(Unexpected::Str(&dur), &"Human readable duration"))
.map(Some)
} else {
Ok(None)
}
}
pub fn serialize_opt_duration<S: Serializer>(
dur: &Option<Duration>,
s: S,
) -> Result<S::Ok, S::Error> {
match dur {
Some(d) => serialize_duration(d, s),
None => s.serialize_none(),
}
}
#[deprecated(note = "Abstraction at the wrong place. Use support_emergency_shutdown instead.")]
#[doc(hidden)]
pub fn cleanup_signals() {
debug!("Resetting termination signal handlers to defaults");
for sig in signal_hook::consts::TERM_SIGNALS {
let registered =
signal_hook::flag::register_conditional_default(*sig, Arc::new(AtomicBool::new(true)));
if let Err(e) = registered {
let name = signal_hook::low_level::signal_name(*sig).unwrap_or_default();
error!(
"Failed to register forced shutdown signal {}/{}: {}",
name, sig, e
);
}
}
}
pub fn support_emergency_shutdown() -> Result<Arc<AtomicBool>, AnyError> {
let flag = Arc::new(AtomicBool::new(false));
let install = |sig: c_int| -> Result<(), AnyError> {
signal_hook::flag::register_conditional_shutdown(sig, 2, Arc::clone(&flag))?;
signal_hook::flag::register(sig, Arc::clone(&flag))?;
Ok(())
};
for sig in signal_hook::consts::TERM_SIGNALS {
let name = signal_hook::low_level::signal_name(*sig).unwrap_or_default();
debug!("Installing emergency shutdown support for {}/{}", name, sig);
install(*sig).with_context(|_| {
format!(
"Failed to install staged shutdown handler for {}/{}",
name, sig
)
})?
}
Ok(flag)
}
pub fn is_default<T: Default + PartialEq>(v: &T) -> bool {
v == &T::default()
}
pub fn is_true(v: &bool) -> bool {
*v
}
pub(crate) struct FlushGuard;
impl Drop for FlushGuard {
fn drop(&mut self) {
log::logger().flush();
}
}
#[cfg(test)]
mod tests {
use std::ffi::OsString;
use std::net::{AddrParseError, IpAddr};
use std::num::ParseIntError;
use super::*;
#[test]
fn abs() {
let current = env::current_dir().unwrap();
let parent = absolute_from_os_str(&OsString::from(".."));
assert!(parent.is_absolute());
assert!(current.starts_with(parent));
let child = absolute_from_os_str(&OsString::from("this-likely-doesn't-exist"));
assert!(child.is_absolute());
assert!(child.starts_with(current));
}
#[test]
fn key_val_success() {
assert_eq!(
("hello".to_owned(), "world".to_owned()),
key_val("hello=world").unwrap()
);
let ip: IpAddr = "192.0.2.1".parse().unwrap();
assert_eq!(("ip".to_owned(), ip), key_val("ip=192.0.2.1").unwrap());
assert_eq!(("count".to_owned(), 4), key_val("count=4").unwrap());
}
#[test]
fn key_val_extra_equals() {
assert_eq!(
("greeting".to_owned(), "hello=world".to_owned()),
key_val("greeting=hello=world").unwrap(),
);
}
#[test]
fn key_val_parse_fail() {
key_val::<String, IpAddr>("hello=192.0.2.1.0")
.unwrap_err()
.downcast_ref::<AddrParseError>()
.expect("Different error returned");
key_val::<usize, String>("hello=world")
.unwrap_err()
.downcast_ref::<ParseIntError>()
.expect("Different error returned");
}
#[test]
fn key_val_missing_eq() {
key_val::<String, String>("no equal sign")
.unwrap_err()
.downcast_ref::<MissingEquals>()
.expect("Different error returned");
}
}