use std::collections::HashMap;
use std::marker::PhantomData;
pub mod what_to_infer {
pub struct ActiveSetInference;
pub struct ProtocolInference;
}
pub mod full_inference {
pub struct DecidableRestrictions;
pub fn is_inference_decidable(
has_data_dependent_shuffle: bool,
has_unbounded_loops: bool,
has_higher_order: bool,
) -> bool {
!has_data_dependent_shuffle && !has_unbounded_loops && !has_higher_order
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decidability_conditions() {
assert!(is_inference_decidable(false, false, false));
assert!(!is_inference_decidable(true, false, false));
assert!(!is_inference_decidable(false, true, false));
assert!(!is_inference_decidable(false, false, true));
}
}
}
pub mod local_inference {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InferredPoint {
pub active_mask: u32,
pub protocol: Vec<ProtocolOp>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ProtocolOp {
Shuffle { mask: u32 },
Diverge { predicate: String },
Merge,
Sync,
}
pub struct LocalInferrer {
current_mask: u32,
protocol: Vec<ProtocolOp>,
diverge_stack: Vec<u32>,
}
impl LocalInferrer {
pub fn new() -> Self {
LocalInferrer {
current_mask: 0xFFFFFFFF, protocol: Vec::new(),
diverge_stack: Vec::new(),
}
}
pub fn shuffle(&mut self, mask: u32) -> Result<(), String> {
if self.current_mask != 0xFFFFFFFF {
return Err(format!(
"Shuffle requires All lanes, but active set is 0x{:08X}",
self.current_mask
));
}
self.protocol.push(ProtocolOp::Shuffle { mask });
Ok(())
}
pub fn diverge(&mut self, predicate: &str, true_mask: u32) {
self.diverge_stack.push(self.current_mask);
self.protocol.push(ProtocolOp::Diverge {
predicate: predicate.to_string(),
});
self.current_mask &= true_mask;
}
pub fn else_branch(&mut self) {
if let Some(&parent_mask) = self.diverge_stack.last() {
let true_mask = self.current_mask;
self.current_mask = parent_mask & !true_mask;
}
}
pub fn merge(&mut self) -> Result<(), String> {
if let Some(parent_mask) = self.diverge_stack.pop() {
self.current_mask = parent_mask;
self.protocol.push(ProtocolOp::Merge);
Ok(())
} else {
Err("Merge without matching diverge".to_string())
}
}
pub fn finish(self) -> Result<Vec<ProtocolOp>, String> {
if !self.diverge_stack.is_empty() {
return Err(format!(
"Unclosed diverge: {} pending",
self.diverge_stack.len()
));
}
Ok(self.protocol)
}
pub fn current_state(&self) -> InferredPoint {
InferredPoint {
active_mask: self.current_mask,
protocol: self.protocol.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_shuffle_inference() {
let mut inf = LocalInferrer::new();
inf.shuffle(1).unwrap();
inf.shuffle(2).unwrap();
let protocol = inf.finish().unwrap();
assert_eq!(protocol.len(), 2);
assert_eq!(protocol[0], ProtocolOp::Shuffle { mask: 1 });
assert_eq!(protocol[1], ProtocolOp::Shuffle { mask: 2 });
}
#[test]
fn test_diverge_merge_inference() {
let mut inf = LocalInferrer::new();
inf.diverge("even", 0x55555555); assert_eq!(inf.current_mask, 0x55555555);
inf.merge().unwrap();
assert_eq!(inf.current_mask, 0xFFFFFFFF);
let protocol = inf.finish().unwrap();
assert_eq!(protocol.len(), 2);
}
#[test]
fn test_shuffle_after_diverge_fails() {
let mut inf = LocalInferrer::new();
inf.diverge("even", 0x55555555);
let result = inf.shuffle(1);
assert!(result.is_err());
}
}
}
pub mod bidirectional {
use super::*;
pub trait Infer {
fn infer(&self, ctx: &InferContext) -> Result<InferredProtocol, String>;
}
pub trait Check {
fn check(&self, ctx: &CheckContext, expected: &Protocol) -> Result<(), String>;
}
pub struct InferContext {
pub active_mask: u32,
pub variables: HashMap<String, Protocol>,
}
pub struct CheckContext {
pub active_mask: u32,
pub variables: HashMap<String, Protocol>,
}
#[derive(Clone, Debug)]
pub struct InferredProtocol {
pub ops: Vec<local_inference::ProtocolOp>,
pub final_mask: u32,
}
#[derive(Clone, Debug)]
pub struct Protocol {
pub ops: Vec<local_inference::ProtocolOp>,
pub input_mask: u32,
pub output_mask: u32,
}
pub struct BiChecker {
mode: Mode,
}
enum Mode {
Infer,
Check(Protocol),
}
impl BiChecker {
pub fn infer() -> Self {
BiChecker { mode: Mode::Infer }
}
pub fn check(expected: Protocol) -> Self {
BiChecker {
mode: Mode::Check(expected),
}
}
pub fn switch_to_check(&mut self, annotation: Protocol) {
self.mode = Mode::Check(annotation);
}
}
pub struct ModeSwitching;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bidirectional_creation() {
let _inf = BiChecker::infer();
let _chk = BiChecker::check(Protocol {
ops: vec![],
input_mask: 0xFFFFFFFF,
output_mask: 0xFFFFFFFF,
});
}
}
}
pub mod protocol_first {
#[derive(Clone, Debug)]
pub enum ProtocolSpec {
End,
Shuffle { mask: u32, then: Box<ProtocolSpec> },
Diverge {
predicate: String,
true_branch: Box<ProtocolSpec>,
false_branch: Box<ProtocolSpec>,
},
Seq(Vec<ProtocolSpec>),
}
impl ProtocolSpec {
pub fn generate_skeleton(&self, indent: usize) -> String {
let pad = " ".repeat(indent);
match self {
ProtocolSpec::End => format!("{}// protocol complete\n", pad),
ProtocolSpec::Shuffle { mask, then } => {
format!(
"{}let data = warp.shuffle_xor(data, {});\n{}",
pad,
mask,
then.generate_skeleton(indent)
)
}
ProtocolSpec::Diverge {
predicate,
true_branch,
false_branch,
} => {
format!(
"{}let (true_warp, false_warp) = warp.diverge(|lane| {});\n\
{}// true branch:\n{}\
{}// false branch:\n{}\
{}let warp = merge(true_warp, false_warp);\n",
pad,
predicate,
pad,
true_branch.generate_skeleton(indent + 1),
pad,
false_branch.generate_skeleton(indent + 1),
pad
)
}
ProtocolSpec::Seq(specs) => specs
.iter()
.map(|s| s.generate_skeleton(indent))
.collect::<Vec<_>>()
.join(""),
}
}
}
pub fn butterfly_protocol() -> ProtocolSpec {
ProtocolSpec::Seq(vec![
ProtocolSpec::Shuffle {
mask: 1,
then: Box::new(ProtocolSpec::End),
},
ProtocolSpec::Shuffle {
mask: 2,
then: Box::new(ProtocolSpec::End),
},
ProtocolSpec::Shuffle {
mask: 4,
then: Box::new(ProtocolSpec::End),
},
ProtocolSpec::Shuffle {
mask: 8,
then: Box::new(ProtocolSpec::End),
},
ProtocolSpec::Shuffle {
mask: 16,
then: Box::new(ProtocolSpec::End),
},
])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_skeleton() {
let protocol = ProtocolSpec::Shuffle {
mask: 1,
then: Box::new(ProtocolSpec::Shuffle {
mask: 2,
then: Box::new(ProtocolSpec::End),
}),
};
let code = protocol.generate_skeleton(0);
assert!(code.contains("shuffle_xor(data, 1)"));
assert!(code.contains("shuffle_xor(data, 2)"));
}
#[test]
fn test_butterfly_skeleton() {
let protocol = butterfly_protocol();
let code = protocol.generate_skeleton(0);
assert!(code.contains("shuffle_xor(data, 1)"));
assert!(code.contains("shuffle_xor(data, 16)"));
}
}
}
pub mod gradual {
use super::*;
#[derive(Clone)]
pub struct DynWarp {
active_mask: u32,
}
impl DynWarp {
pub fn all() -> Self {
DynWarp {
active_mask: 0xFFFFFFFF,
}
}
pub fn shuffle(&self, _data: i32, _mask: u32) -> Result<i32, String> {
if self.active_mask != 0xFFFFFFFF {
return Err(format!(
"Runtime error: shuffle requires All lanes, got 0x{:08X}",
self.active_mask
));
}
Ok(0) }
pub fn diverge(&self, predicate_mask: u32) -> (DynWarp, DynWarp) {
let true_warp = DynWarp {
active_mask: self.active_mask & predicate_mask,
};
let false_warp = DynWarp {
active_mask: self.active_mask & !predicate_mask,
};
(true_warp, false_warp)
}
pub fn merge(self, other: DynWarp) -> Result<DynWarp, String> {
if self.active_mask & other.active_mask != 0 {
return Err("Runtime error: merge of overlapping warps".to_string());
}
Ok(DynWarp {
active_mask: self.active_mask | other.active_mask,
})
}
pub fn get_mask(&self) -> u32 {
self.active_mask
}
}
pub struct GradualMigration;
pub enum GradualWarp<S: ActiveSet> {
Static(StaticWarp<S>),
Dynamic(DynWarp),
}
pub trait ActiveSet {
const MASK: u32;
}
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
}
pub struct StaticWarp<S: ActiveSet> {
_marker: PhantomData<S>,
}
impl<S: ActiveSet> GradualWarp<S> {
pub fn ascribe(dyn_warp: DynWarp) -> Result<Self, String> {
if dyn_warp.active_mask == S::MASK {
Ok(GradualWarp::Dynamic(dyn_warp))
} else {
Err(format!(
"Type ascription failed: expected 0x{:08X}, got 0x{:08X}",
S::MASK,
dyn_warp.active_mask
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dyn_warp_shuffle_all() {
let warp = DynWarp::all();
assert!(warp.shuffle(42, 1).is_ok());
}
#[test]
fn test_dyn_warp_shuffle_partial_fails() {
let warp = DynWarp {
active_mask: 0x55555555,
};
assert!(warp.shuffle(42, 1).is_err());
}
#[test]
fn test_dyn_warp_diverge_merge() {
let warp = DynWarp::all();
let (even, odd) = warp.diverge(0x55555555);
assert_eq!(even.active_mask, 0x55555555);
assert_eq!(odd.active_mask, 0xAAAAAAAA);
let merged = even.merge(odd).unwrap();
assert_eq!(merged.active_mask, 0xFFFFFFFF);
}
#[test]
fn test_gradual_ascription() {
let dyn_warp = DynWarp::all();
let result: Result<GradualWarp<All>, _> = GradualWarp::ascribe(dyn_warp);
assert!(result.is_ok());
}
}
}
pub mod recommendation {
pub struct CombinedApproach;
}
pub mod literature {
pub struct References;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_full_inference_reduction() {
let mut inf = local_inference::LocalInferrer::new();
for mask in [1, 2, 4, 8, 16] {
inf.shuffle(mask).unwrap();
}
let protocol = inf.finish().unwrap();
assert_eq!(protocol.len(), 5);
}
#[test]
fn test_diverge_with_different_work() {
let mut inf = local_inference::LocalInferrer::new();
assert_eq!(inf.current_state().active_mask, 0xFFFFFFFF);
inf.diverge("lane % 2 == 0", 0x55555555);
assert_eq!(inf.current_state().active_mask, 0x55555555);
assert!(inf.shuffle(1).is_err());
inf.merge().unwrap();
assert_eq!(inf.current_state().active_mask, 0xFFFFFFFF);
assert!(inf.shuffle(1).is_ok());
}
#[test]
fn test_protocol_first_matches_inference() {
let protocol = protocol_first::butterfly_protocol();
let skeleton = protocol.generate_skeleton(0);
assert!(skeleton.contains("shuffle_xor"));
}
}