use std::fmt;
use crate::emit::{Emit, PtxWriter};
use crate::ir::Register;
use crate::types::PtxType;
#[derive(Debug, Clone)]
pub enum MemoryOp {
LdParam {
dst: Register,
param_name: String,
ty: PtxType,
},
LdGlobal {
dst: Register,
addr: Register,
ty: PtxType,
},
StGlobal {
addr: Register,
src: Register,
ty: PtxType,
},
LdShared {
dst: Register,
addr: Register,
ty: PtxType,
},
StShared {
addr: Register,
src: Register,
ty: PtxType,
},
CvtaToGlobal {
dst: Register,
src: Register,
},
}
impl Emit for MemoryOp {
fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
match self {
MemoryOp::LdParam {
dst,
param_name,
ty,
} => {
let mnemonic = format!("ld.param{}", ty.ptx_suffix());
let addr = format!("[{param_name}]");
w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr])
}
MemoryOp::LdGlobal { dst, addr, ty } => {
let mnemonic = format!("ld.global{}", ty.ptx_suffix());
let addr_str = format!("[{addr}]");
w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr_str])
}
MemoryOp::StGlobal { addr, src, ty } => {
let mnemonic = format!("st.global{}", ty.ptx_suffix());
let addr_str = format!("[{addr}]");
w.instruction(&mnemonic, &[&addr_str as &dyn fmt::Display, src])
}
MemoryOp::LdShared { dst, addr, ty } => {
let mnemonic = format!("ld.shared{}", ty.ptx_suffix());
let addr_str = format!("[{addr}]");
w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr_str])
}
MemoryOp::StShared { addr, src, ty } => {
let mnemonic = format!("st.shared{}", ty.ptx_suffix());
let addr_str = format!("[{addr}]");
w.instruction(&mnemonic, &[&addr_str as &dyn fmt::Display, src])
}
MemoryOp::CvtaToGlobal { dst, src } => {
w.instruction("cvta.to.global.u64", &[dst as &dyn fmt::Display, src])
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::RegKind;
fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
Register {
kind,
index,
ptx_type,
}
}
#[test]
fn emit_ld_param_u64() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::LdParam {
dst: reg(RegKind::Rd, 1, PtxType::U64),
param_name: "vector_add_param_0".to_string(),
ty: PtxType::U64,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ld.param.u64 %rd1, [vector_add_param_0];\n");
}
#[test]
fn emit_ld_param_u32() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::LdParam {
dst: reg(RegKind::R, 2, PtxType::U32),
param_name: "vector_add_param_3".to_string(),
ty: PtxType::U32,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ld.param.u32 %r2, [vector_add_param_3];\n");
}
#[test]
fn emit_cvta_to_global() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::CvtaToGlobal {
dst: reg(RegKind::Rd, 4, PtxType::U64),
src: reg(RegKind::Rd, 1, PtxType::U64),
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " cvta.to.global.u64 %rd4, %rd1;\n");
}
#[test]
fn emit_ld_global_f32() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::LdGlobal {
dst: reg(RegKind::F, 1, PtxType::F32),
addr: reg(RegKind::Rd, 8, PtxType::U64),
ty: PtxType::F32,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ld.global.f32 %f1, [%rd8];\n");
}
#[test]
fn emit_st_global_f32() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::StGlobal {
addr: reg(RegKind::Rd, 10, PtxType::U64),
src: reg(RegKind::F, 3, PtxType::F32),
ty: PtxType::F32,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " st.global.f32 [%rd10], %f3;\n");
}
#[test]
fn memory_via_ptx_instruction() {
use crate::ir::PtxInstruction;
let mut w = PtxWriter::new();
w.indent();
let instr = PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: reg(RegKind::F, 0, PtxType::F32),
addr: reg(RegKind::Rd, 0, PtxType::U64),
ty: PtxType::F32,
});
instr.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ld.global.f32 %f0, [%rd0];\n");
}
#[test]
fn emit_ld_shared_f32() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::LdShared {
dst: reg(RegKind::F, 0, PtxType::F32),
addr: reg(RegKind::R, 0, PtxType::U32),
ty: PtxType::F32,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ld.shared.f32 %f0, [%r0];\n");
}
#[test]
fn emit_st_shared_f32() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::StShared {
addr: reg(RegKind::R, 0, PtxType::U32),
src: reg(RegKind::F, 1, PtxType::F32),
ty: PtxType::F32,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " st.shared.f32 [%r0], %f1;\n");
}
#[test]
fn st_global_operand_order() {
let mut w = PtxWriter::new();
w.indent();
let op = MemoryOp::StGlobal {
addr: reg(RegKind::Rd, 0, PtxType::U64),
src: reg(RegKind::F, 0, PtxType::F32),
ty: PtxType::F32,
};
op.emit(&mut w).unwrap();
let output = w.finish();
let addr_pos = output.find("[%rd0]").expect("address not found");
let src_pos = output.find("%f0").expect("source not found");
assert!(
addr_pos < src_pos,
"store operand order wrong: address must come before source in PTX"
);
}
}