use std::sync::Arc;
use vyre::ir::model::expr::Ident;
use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use vyre_primitives::graph::csr_forward_traverse::bitset_words;
use vyre_primitives::graph::program_graph::{
ProgramGraphShape, NAME_EDGE_KIND_MASK, NAME_EDGE_OFFSETS, NAME_EDGE_TARGETS, NAME_NODES,
NAME_NODE_TAGS,
};
use vyre_primitives::predicate::edge_kind;
const OP_ID: &str = "vyre-libs::security::sanitized_by";
#[must_use]
pub fn sanitized_by(
shape: ProgramGraphShape,
frontier_in: &str,
sanitizers_in: &str,
frontier_out: &str,
) -> Program {
crate::security::assert_security_inputs(
OP_ID,
shape.node_count,
&[
("frontier_in", frontier_in),
("sanitizers_in", sanitizers_in),
("frontier_out", frontier_out),
],
);
let words = bitset_words(shape.node_count);
let clean_buf = format!("__sanitized_by_clean__{}", frontier_in);
let t = Expr::InvocationId { axis: 0 };
let clean_word = Expr::bitand(
Expr::load(frontier_in, Expr::var("word_idx")),
Expr::bitnot(Expr::load(sanitizers_in, Expr::var("word_idx"))),
);
let mut body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(words)),
vec![
Node::let_bind("word_idx", t.clone()),
Node::store(&clean_buf, t.clone(), clean_word.clone()),
],
)];
body.push(Node::if_then(
Expr::lt(t.clone(), Expr::u32(shape.node_count)),
vec![
Node::let_bind("src", t.clone()),
Node::let_bind("word_idx", Expr::shr(Expr::var("src"), Expr::u32(5))),
Node::let_bind(
"bit_mask",
Expr::shl(Expr::u32(1), Expr::bitand(Expr::var("src"), Expr::u32(31))),
),
Node::let_bind("clean_word", clean_word),
Node::if_then(
Expr::ne(
Expr::bitand(Expr::var("clean_word"), Expr::var("bit_mask")),
Expr::u32(0),
),
vec![
Node::let_bind(
"edge_start",
Expr::load(NAME_EDGE_OFFSETS, Expr::var("src")),
),
Node::let_bind(
"edge_end",
Expr::load(NAME_EDGE_OFFSETS, Expr::add(Expr::var("src"), Expr::u32(1))),
),
Node::loop_for(
"e",
Expr::var("edge_start"),
Expr::var("edge_end"),
vec![
Node::let_bind(
"kind_mask",
Expr::load(NAME_EDGE_KIND_MASK, Expr::var("e")),
),
Node::if_then(
Expr::ne(
Expr::bitand(
Expr::var("kind_mask"),
Expr::u32(crate::security::flows_to::FLOWS_TO_MASK),
),
Expr::u32(0),
),
vec![
Node::let_bind(
"dst",
Expr::load(NAME_EDGE_TARGETS, Expr::var("e")),
),
Node::if_then(
Expr::lt(Expr::var("dst"), Expr::u32(shape.node_count)),
vec![
Node::let_bind(
"dst_word_idx",
Expr::shr(Expr::var("dst"), Expr::u32(5)),
),
Node::let_bind(
"dst_bit",
Expr::shl(
Expr::u32(1),
Expr::bitand(Expr::var("dst"), Expr::u32(31)),
),
),
Node::let_bind(
"_prev",
Expr::atomic_or(
frontier_out,
Expr::var("dst_word_idx"),
Expr::var("dst_bit"),
),
),
],
),
],
),
],
),
],
),
],
));
Program::wrapped(
vec![
BufferDecl::storage(frontier_in, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
BufferDecl::storage(sanitizers_in, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
BufferDecl::storage(&clean_buf, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(words),
BufferDecl::storage(NAME_NODES, 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(shape.node_count),
BufferDecl::storage(NAME_EDGE_OFFSETS, 4, BufferAccess::ReadOnly, DataType::U32)
.with_count(shape.node_count.saturating_add(1)),
BufferDecl::storage(NAME_EDGE_TARGETS, 5, BufferAccess::ReadOnly, DataType::U32)
.with_count(shape.edge_count.max(1)),
BufferDecl::storage(
NAME_EDGE_KIND_MASK,
6,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(shape.edge_count.max(1)),
BufferDecl::storage(NAME_NODE_TAGS, 7, BufferAccess::ReadOnly, DataType::U32)
.with_count(shape.node_count),
BufferDecl::storage(frontier_out, 8, BufferAccess::ReadWrite, DataType::U32)
.with_count(words),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || sanitized_by(ProgramGraphShape::new(4, 3), "fin", "san", "fout"),
test_inputs: Some(|| {
let to_bytes = |w: &[u32]| vyre_primitives::wire::pack_u32_slice(w);
vec![vec![
to_bytes(&[0b0001]), to_bytes(&[0b0010]), to_bytes(&[0b0000]), to_bytes(&[0, 0, 0, 0]), to_bytes(&[0, 1, 2, 3, 3]), to_bytes(&[1, 2, 3]), to_bytes(&[
edge_kind::ASSIGNMENT,
edge_kind::ASSIGNMENT,
edge_kind::ASSIGNMENT,
]), to_bytes(&[0, 1, 0, 0]), to_bytes(&[0b0001]), ]]
}),
expected_output: Some(|| {
let to_bytes = |w: &[u32]| vyre_primitives::wire::pack_u32_slice(w);
vec![vec![
to_bytes(&[0b0001]), to_bytes(&[0b0011]), ]]
}),
category: Some("security"),
}
}
inventory::submit! {
crate::harness::ConvergenceContract {
op_id: OP_ID,
max_iterations: 4096,
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_primitives::predicate::edge_kind;
#[test]
fn sanitized_by_declares_sanitizer_buffer() {
let p = sanitized_by(ProgramGraphShape::new(4, 3), "fin", "san", "fout");
let names: Vec<&str> = p.buffers().iter().map(|b| b.name()).collect();
assert!(names.contains(&"fin"), "frontier_in must be declared");
assert!(names.contains(&"san"), "sanitizers_in must be declared");
assert!(names.contains(&"fout"), "frontier_out must be declared");
}
#[test]
fn sanitized_by_uses_dataflow_mask_not_universal() {
use crate::security::flows_to::FLOWS_TO_MASK;
assert_eq!(FLOWS_TO_MASK & edge_kind::CONTROL, 0);
assert_eq!(FLOWS_TO_MASK & edge_kind::DOMINANCE, 0);
}
#[test]
fn sanitized_by_program_uses_non_degenerate_shape() {
let shape = ProgramGraphShape::new(64, 128);
let p = sanitized_by(shape, "fin", "san", "fout");
let fin_buf = p
.buffers()
.iter()
.find(|b| b.name() == "fin")
.expect("Fix: fin buffer");
assert!(
fin_buf.count >= 2,
"bitset_words(64) = 2; count {} suggests degenerate shape",
fin_buf.count
);
}
#[test]
fn sanitized_by_marks_sanitizer_when_taint_arrives_at_it() {
let p = sanitized_by(ProgramGraphShape::new(4, 3), "fin", "san", "fout");
let to_bytes = |w: &[u32]| vyre_primitives::wire::pack_u32_slice(w);
let inputs = vec![
to_bytes(&[0b0001]),
to_bytes(&[0b0010]),
to_bytes(&[0b0000]),
to_bytes(&[0, 0, 0, 0]),
to_bytes(&[0, 1, 2, 3, 3]),
to_bytes(&[1, 2, 3]),
to_bytes(&[
edge_kind::ASSIGNMENT,
edge_kind::ASSIGNMENT,
edge_kind::ASSIGNMENT,
]),
to_bytes(&[0, 1, 0, 0]),
to_bytes(&[0b0001]),
];
let values: Vec<vyre_reference::value::Value> = inputs
.into_iter()
.map(vyre_reference::value::Value::from)
.collect();
let outputs = vyre_reference::reference_eval(&p, &values).unwrap();
let fout_word = u32::from_le_bytes(outputs[1].to_bytes()[0..4].try_into().unwrap());
assert_eq!(
fout_word, 0b0011,
"sanitized_by must mark the sanitizer when taint arrives at it; \
observability of 'taint hit this sanitizer' is the entire point - \
without it, downstream SARIF/audit consumers cannot distinguish \
'sanitized at node 1' from 'never reached node 1'."
);
}
#[test]
fn sanitized_by_blocks_propagation_from_sanitizer_node() {
let p = sanitized_by(ProgramGraphShape::new(3, 2), "fin", "san", "fout");
let to_bytes = |w: &[u32]| vyre_primitives::wire::pack_u32_slice(w);
let inputs = vec![
to_bytes(&[0b0010]), to_bytes(&[0b0010]), to_bytes(&[0b0000]), to_bytes(&[0, 0, 0]), to_bytes(&[0, 1, 2, 2]), to_bytes(&[1, 2]), to_bytes(&[edge_kind::ASSIGNMENT, edge_kind::ASSIGNMENT]),
to_bytes(&[0, 1, 0]), to_bytes(&[0b0010]), ];
let values: Vec<vyre_reference::value::Value> = inputs
.into_iter()
.map(vyre_reference::value::Value::from)
.collect();
let outputs = vyre_reference::reference_eval(&p, &values).unwrap();
let fout_word = u32::from_le_bytes(outputs[1].to_bytes()[0..4].try_into().unwrap());
assert_eq!(
fout_word, 0b0010,
"sanitized_by must NOT propagate from sanitizer node 1; fout should remain {{1}}"
);
}
#[test]
#[should_panic(expected = "node_count must be positive")]
fn sanitized_by_zero_node_count_should_panic() {
let _ = sanitized_by(ProgramGraphShape::new(0, 0), "fin", "san", "fout");
}
#[test]
#[should_panic(expected = "empty buffer name")]
fn sanitized_by_empty_buffer_name_should_panic() {
let _ = sanitized_by(ProgramGraphShape::new(4, 3), "", "san", "fout");
}
}