ps-parser 1.0.1

The Powershell Parser
Documentation
use std::{collections::HashMap, sync::LazyLock};

use thiserror_no_std::Error;

use super::Val;
use crate::parser::value::ValError;
#[derive(Error, Debug, PartialEq, Clone)]
pub enum BitwiseError {
    #[error("{0} not defined for {1}")]
    NotDefined(String, String),
    #[error("Failed casting to int: {0}")]
    CastToInt(ValError),
}

impl From<ValError> for BitwiseError {
    fn from(value: ValError) -> Self {
        Self::CastToInt(value)
    }
}

type BitwiseResult<T> = core::result::Result<T, BitwiseError>;

pub(crate) type BitwisePredType = fn(Val, Val) -> BitwiseResult<Val>;

pub(crate) struct BitwisePred;

impl BitwisePred {
    const BITWISE_PRED_MAP: LazyLock<HashMap<&'static str, BitwisePredType>> =
        LazyLock::new(|| {
            HashMap::from([
                ("-band", band as _),
                ("-bor", bor as _),
                ("-bxor", bxor as _),
                ("-shl", shl as _),
                ("-shr", shr as _),
            ])
        });

    pub(crate) fn get(name: &str) -> Option<BitwisePredType> {
        Self::BITWISE_PRED_MAP
            .get(name.to_ascii_lowercase().as_str())
            .copied()
    }
}

pub fn prepare(a: Val, b: Val, op_name: &str) -> BitwiseResult<(i64, i64)> {
    if let Val::RuntimeObject(ro) = a {
        return Err(BitwiseError::NotDefined(op_name.into(), ro.name()));
    }
    if let Val::RuntimeType(rt) = a {
        return Err(BitwiseError::NotDefined(op_name.into(), rt.name()));
    }
    Ok((a.cast_to_int()?, b.cast_to_int()?))
}

fn band_imp(a: i64, b: i64) -> i64 {
    a & b
}

pub fn band(a: Val, b: Val) -> BitwiseResult<Val> {
    let (a, b) = prepare(a, b, "-band")?;
    let res = band_imp(a, b);
    Ok(Val::Int(res))
}

fn bor_imp(a: i64, b: i64) -> i64 {
    a | b
}

pub fn bor(a: Val, b: Val) -> BitwiseResult<Val> {
    let (a, b) = prepare(a, b, "-bor")?;
    let res = bor_imp(a, b);
    Ok(Val::Int(res))
}

fn bxor_imp(a: i64, b: i64) -> i64 {
    a ^ b
}

pub fn bxor(a: Val, b: Val) -> BitwiseResult<Val> {
    let (a, b) = prepare(a, b, "-bxor")?;
    let res = bxor_imp(a, b);
    Ok(Val::Int(res))
}

fn shl_imp(a: i64, b: i64) -> i64 {
    a << b
}

pub fn shl(a: Val, b: Val) -> BitwiseResult<Val> {
    let (a, b) = prepare(a, b, "-shl")?;
    let res = shl_imp(a, b);
    Ok(Val::Int(res))
}

fn shr_imp(a: i64, b: i64) -> i64 {
    a >> b
}

pub fn shr(a: Val, b: Val) -> BitwiseResult<Val> {
    let (a, b) = prepare(a, b, "-shr")?;
    let res = shr_imp(a, b);
    Ok(Val::Int(res))
}

#[cfg(test)]
mod tests {
    use crate::PowerShellSession;

    #[test]
    fn test_band() {
        let mut p = PowerShellSession::new();
        assert_eq!(p.safe_eval(r#" 5 -band 4 "#).unwrap(), "4".to_string());
        assert_eq!(p.safe_eval(r#" 5 -band 2 "#).unwrap(), "0".to_string());
        assert_eq!(p.safe_eval(r#" 5 -Band 9 "#).unwrap(), "1".to_string());
    }

    #[test]
    fn test_bor() {
        let mut p = PowerShellSession::new();
        assert_eq!(p.safe_eval(r#" 5 -bOr 4 "#).unwrap(), "5".to_string());
        assert_eq!(p.safe_eval(r#" 5 -bor 2 "#).unwrap(), "7".to_string());
        assert_eq!(p.safe_eval(r#" 5 -bor 9 "#).unwrap(), "13".to_string());
        assert_eq!(
            p.safe_eval(r#" 5 -bor 5 -band 4 "#).unwrap(),
            "4".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" 5 -band 5 -bor 4 "#).unwrap(),
            "5".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" 6 -bor 5 -bor 4 "#).unwrap(),
            "7".to_string()
        );
    }

    #[test]
    fn test_bxor() {
        let mut p = PowerShellSession::new();
        assert_eq!(p.safe_eval(r#" 5 -bxor 4 "#).unwrap(), "1".to_string());
        assert_eq!(p.safe_eval(r#" 5 -bxor 2 "#).unwrap(), "7".to_string());
        assert_eq!(p.safe_eval(r#" 5 -bxor 9 "#).unwrap(), "12".to_string());
    }

    #[test]
    fn test_shl() {
        let mut p = PowerShellSession::new();
        assert_eq!(p.safe_eval(r#" 5 -shl 4 "#).unwrap(), "80".to_string());
        assert_eq!(p.safe_eval(r#" -5 -shl 2 "#).unwrap(), "-20".to_string());
        assert_eq!(
            p.safe_eval(r#" "5.5" -shl 3.5 "#).unwrap(),
            "96".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" "+5.5" -sHl 3.5 "#).unwrap(),
            "96".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" "5.5as" -shl 3.5 "#).unwrap(),
            "".to_string()
        );
    }

    #[test]
    fn test_shr() {
        let mut p = PowerShellSession::new();
        assert_eq!(p.safe_eval(r#" 96 -shr 4 "#).unwrap(), "6".to_string());
        assert_eq!(p.safe_eval(r#" -96 -shr 2 "#).unwrap(), "-24".to_string());
        assert_eq!(
            p.safe_eval(r#" "96.5" -shr 3.5 "#).unwrap(),
            "6".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" "+96.5" -shr 3.5 "#).unwrap(),
            "6".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" "96.5as" -shr 3.5 "#).unwrap(),
            "".to_string()
        );
    }

    #[test]
    fn test_bnot() {
        let mut p = PowerShellSession::new();
        assert_eq!(p.safe_eval(r#" -bnot 4 "#).unwrap(), "-5".to_string());
        assert_eq!(p.safe_eval(r#" -bnot -95 "#).unwrap(), "94".to_string());
        assert_eq!(p.safe_eval(r#" [int] "96.5" "#).unwrap(), "96".to_string());
        assert_eq!(p.safe_eval(r#" [int] "97.5" "#).unwrap(), "98".to_string());
        assert_eq!(p.safe_eval(r#" -bnot "96.5" "#).unwrap(), "-97".to_string());
        assert_eq!(p.safe_eval(r#" -bnot "97" "#).unwrap(), "-98".to_string());
        assert_eq!(
            p.safe_eval(r#" -bnot "+96.5" "#).unwrap(),
            "-97".to_string()
        );
        assert_eq!(p.safe_eval(r#" -bnot "96.5as" "#).unwrap(), "".to_string());
        assert_eq!(
            p.safe_eval(r#" [float]"+96.51e1" "#).unwrap(),
            "965.1".to_string()
        );
        assert_eq!(
            p.safe_eval(r#" [int]"+96.51e1" "#).unwrap(),
            "965".to_string()
        );
        assert_eq!(p.safe_eval(r#" [int]"+96e1" "#).unwrap(), "960".to_string());
        assert_eq!(p.safe_eval(r#" [int]"0x96" "#).unwrap(), "150".to_string());
        assert_eq!(
            p.safe_eval(r#" [int]"0x96e1" "#).unwrap(),
            "38625".to_string()
        );
        assert_eq!(p.safe_eval(r#" [int]"+0x96e1" "#).unwrap(), "".to_string());
        assert_eq!(
            p.safe_eval(r#" -bnot "+96.51e1" "#).unwrap(),
            "-966".to_string()
        );
    }
}