#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FusedEncoding {
Base64,
Hex,
}
impl FusedEncoding {
#[must_use]
pub fn label(self) -> &'static str {
match self {
Self::Base64 => "base64",
Self::Hex => "hex",
}
}
}
pub struct FusedDecodeScanPrograms {
pub base64_program: Option<vyre::Program>,
pub hex_program: Option<vyre::Program>,
pub state_count: u32,
}
pub fn build_fused_programs(state_count: u32, input_len: u32) -> FusedDecodeScanPrograms {
let base64_program = std::panic::catch_unwind(|| {
vyre_libs::decode::base64_decode_then_aho_corasick(
"haystack",
"decoded",
"transitions",
"accept",
"matches",
input_len,
state_count,
)
})
.ok();
let hex_program = std::panic::catch_unwind(|| {
vyre_libs::decode::hex_decode_then_aho_corasick(
"haystack",
"decoded",
"transitions",
"accept",
"matches",
input_len,
state_count,
)
})
.ok();
if base64_program.is_none() {
tracing::debug!(
target: "keyhog::gpu",
"fused base64 decode+scan program build failed - will use CPU decode path"
);
}
if hex_program.is_none() {
tracing::debug!(
target: "keyhog::gpu",
"fused hex decode+scan program build failed - will use CPU decode path"
);
}
FusedDecodeScanPrograms {
base64_program,
hex_program,
state_count,
}
}
impl FusedDecodeScanPrograms {
#[must_use]
pub fn program_for(&self, encoding: FusedEncoding) -> Option<&vyre::Program> {
match encoding {
FusedEncoding::Base64 => self.base64_program.as_ref(),
FusedEncoding::Hex => self.hex_program.as_ref(),
}
}
#[must_use]
pub fn any_available(&self) -> bool {
self.base64_program.is_some() || self.hex_program.is_some()
}
}
#[must_use]
pub fn detect_encoding(data: &[u8]) -> Option<FusedEncoding> {
if data.is_empty() {
return None;
}
let len = data.len();
let mut hex_chars = 0usize;
let mut b64_chars = 0usize;
let mut other = 0usize;
let sample = &data[..len.min(256)];
for &b in sample {
match b {
b'0'..=b'9' => {
hex_chars += 1;
b64_chars += 1;
}
b'a'..=b'f' | b'A'..=b'F' => {
hex_chars += 1;
b64_chars += 1;
}
b'g'..=b'z' | b'G'..=b'Z' => {
b64_chars += 1;
}
b'+' | b'/' | b'=' => {
b64_chars += 1;
}
b'\n' | b'\r' | b' ' | b'\t' => {
}
_ => {
other += 1;
}
}
}
if other * 5 > sample.len() {
return None;
}
if hex_chars == b64_chars && hex_chars > 0 && len % 2 == 0 {
return Some(FusedEncoding::Hex);
}
if b64_chars > hex_chars && (len % 4 == 0 || data.ends_with(b"=")) {
return Some(FusedEncoding::Base64);
}
if b64_chars > hex_chars {
return Some(FusedEncoding::Base64);
}
None
}