use rspirv::dr::{Instruction, Module, Operand};
use rspirv::spirv::{Decoration, Op, Word};
use rustc_span::{source_map::SourceMap, FileName, Pos, Span};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use std::path::PathBuf;
use std::{iter, slice};
pub trait CustomDecoration: for<'de> Deserialize<'de> + Serialize {
const ENCODING_PREFIX: &'static str;
fn encode(self, id: Word) -> Instruction {
let json = serde_json::to_string(&self).unwrap();
let encoded = [Self::ENCODING_PREFIX, &json].concat();
Instruction::new(
Op::DecorateString,
None,
None,
vec![
Operand::IdRef(id),
Operand::Decoration(Decoration::UserTypeGOOGLE),
Operand::LiteralString(encoded),
],
)
}
fn try_decode(inst: &Instruction) -> Option<(Word, LazilyDeserialized<'_, Self>)> {
if inst.class.opcode == Op::DecorateString
&& inst.operands[1].unwrap_decoration() == Decoration::UserTypeGOOGLE
{
let id = inst.operands[0].unwrap_id_ref();
let encoded = inst.operands[2].unwrap_literal_string();
let json = encoded.strip_prefix(Self::ENCODING_PREFIX)?;
Some((
id,
LazilyDeserialized {
json,
_marker: PhantomData,
},
))
} else {
None
}
}
fn decode_all(module: &Module) -> DecodeAllIter<'_, Self> {
module
.annotations
.iter()
.filter_map(Self::try_decode as fn(_) -> _)
}
fn remove_all(module: &mut Module) {
module
.annotations
.retain(|inst| Self::try_decode(inst).is_none());
}
}
type DecodeAllIter<'a, D> = iter::FilterMap<
slice::Iter<'a, Instruction>,
fn(&'a Instruction) -> Option<(Word, LazilyDeserialized<'a, D>)>,
>;
#[derive(Copy, Clone)]
pub struct LazilyDeserialized<'a, D> {
json: &'a str,
_marker: PhantomData<D>,
}
impl<'a, D: Deserialize<'a>> LazilyDeserialized<'a, D> {
pub fn deserialize(self) -> D {
serde_json::from_str(self.json).unwrap()
}
}
#[derive(Deserialize, Serialize)]
pub struct ZombieDecoration {
pub reason: String,
#[serde(flatten)]
pub span: Option<SerializedSpan>,
}
impl CustomDecoration for ZombieDecoration {
const ENCODING_PREFIX: &'static str = "Z";
}
#[derive(Deserialize, Serialize)]
pub struct SerializedSpan {
file: PathBuf,
hash: serde_adapters::SourceFileHash,
lo: u32,
hi: u32,
}
mod serde_adapters {
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub enum SourceFileHashAlgorithm {
Md5,
Sha1,
Sha256,
}
impl From<rustc_span::SourceFileHashAlgorithm> for SourceFileHashAlgorithm {
fn from(kind: rustc_span::SourceFileHashAlgorithm) -> Self {
match kind {
rustc_span::SourceFileHashAlgorithm::Md5 => Self::Md5,
rustc_span::SourceFileHashAlgorithm::Sha1 => Self::Sha1,
rustc_span::SourceFileHashAlgorithm::Sha256 => Self::Sha256,
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct SourceFileHash {
kind: SourceFileHashAlgorithm,
value: [u8; 32],
}
impl From<rustc_span::SourceFileHash> for SourceFileHash {
fn from(hash: rustc_span::SourceFileHash) -> Self {
let bytes = hash.hash_bytes();
let mut hash = Self {
kind: hash.kind.into(),
value: Default::default(),
};
hash.value[..bytes.len()].copy_from_slice(bytes);
hash
}
}
}
impl SerializedSpan {
pub fn from_rustc(span: Span, source_map: &SourceMap) -> Option<Self> {
if span.is_dummy() {
return None;
}
let (lo, hi) = (span.lo(), span.hi());
if lo > hi {
return None;
}
let file = source_map.lookup_source_file(lo);
if !(file.start_pos <= lo && hi <= file.end_pos) {
return None;
}
Some(Self {
file: match &file.name {
FileName::Real(real_name) => real_name.local_path()?.to_path_buf(),
_ => return None,
},
hash: file.src_hash.into(),
lo: (lo - file.start_pos).to_u32(),
hi: (hi - file.start_pos).to_u32(),
})
}
pub fn to_rustc(&self, source_map: &SourceMap) -> Option<Span> {
let file = source_map.load_file(&self.file).ok()?;
if self.hash != file.src_hash.into() {
return None;
}
assert!(self.lo <= self.hi && self.hi <= (file.end_pos.0 - file.start_pos.0));
Some(Span::with_root_ctxt(
file.start_pos + Pos::from_u32(self.lo),
file.start_pos + Pos::from_u32(self.hi),
))
}
}