use proc_macro2::{
Delimiter, Ident, Span, TokenStream,
TokenTree::{self, Punct},
};
use quote::{format_ident, quote, quote_spanned};
use std::default::Default;
use std::result::Result::{self, Err, Ok};
use std::stringify;
use std::vec::Vec;
#[derive(Debug)]
struct BitStruct {
name: Ident,
slices: Vec<BitSlice>,
ty: Ident,
}
#[derive(Debug)]
struct BitSlice {
base: Ident,
start: usize,
end: usize,
ty: Option<Ident>,
}
#[derive(Debug)]
struct Options {
volatile: bool,
unaligned: bool,
}
impl Default for Options {
fn default() -> Self {
Self {
volatile: false,
unaligned: false,
}
}
}
macro_rules! check_token {
($enum:expr, $expected_variant:path) => {
match $enum {
Some($expected_variant(_)) => {}
e @ _ => {
panic!(
"Encountered an illegal token! Expected {}, found {:?}",
stringify!($expected_variant),
e
);
}
}
};
}
macro_rules! parse_token {
($enum:expr, $expected_variant:path, $expected_name:literal) => {
match $enum {
Some($expected_variant(item)) => item,
Some(tt) => {
let msg = format!("Expected {}, found {:?}", $expected_name, tt);
return Err(quote::quote_spanned! {
tt.span()=> {
compile_error!(#msg);
}
});
}
_ => {
let msg = format!("Expected {} but found delimiter", $expected_name);
return Err(quote::quote! {
compile_error!(#msg);
});
}
}
};
}
pub(crate) fn parse_attr(attr: TokenStream) -> Result<(Vec<BitStruct>, Options), TokenStream> {
let mut bitstructs = Vec::<BitStruct>::new();
let mut iter = attr.into_iter();
let mut next_tt = iter.next();
let mut options = Options::default();
next_tt = match next_tt {
Some(TokenTree::Group(group)) => {
let mut group = group.stream().into_iter();
while let Some(TokenTree::Ident(option)) = group.next() {
match option.to_string().as_str() {
"volatile" => {
options.volatile = true;
}
"unaligned" => {
options.unaligned = true;
}
s @ _ => {
let fmt = format!("Expected valid option, found '{}'!", s);
return Err(quote! {
compile_error!(#fmt);
});
}
}
if let Some(TokenTree::Punct(_)) = group.next() {
continue;
}
}
iter.next()
}
None => {
return Err(quote! {
compile_error!("bitstruct can't parse empty atttribute!");
});
}
tt @ _ => tt,
};
while let Some(tt) = next_tt {
if let TokenTree::Ident(name) = tt {
let group = parse_token!(iter.next(), TokenTree::Group, "bitstruct definition");
let mut iter = group.stream().into_iter();
let ty = parse_token!(iter.next(), TokenTree::Ident, "bitstruct type");
let mut bitstruct = BitStruct {
name,
slices: Vec::new(),
ty,
};
while let Some(tt) = iter.next() {
if let TokenTree::Punct(_) = tt {
let slice_name = parse_token!(iter.next(), TokenTree::Ident, "slice name");
let mut bitslice = BitSlice {
base: slice_name,
start: 0,
end: 0,
ty: None,
};
let slice_limits =
parse_token!(iter.next(), TokenTree::Group, "range expression");
let mut iter = slice_limits.stream().into_iter();
bitslice.start =
parse_token!(iter.next(), TokenTree::Literal, "start of slice")
.to_string()
.parse()
.unwrap();
check_token!(iter.next(), TokenTree::Punct);
check_token!(iter.next(), TokenTree::Punct);
bitslice.end = parse_token!(iter.next(), TokenTree::Literal, "end of slice")
.to_string()
.parse()
.unwrap();
bitstruct.slices.push(bitslice);
}
}
bitstructs.push(bitstruct);
} else {
return Err(quote_spanned! {
tt.span()=> {
compile_error!("Expected begin of a new binding!");
}
});
}
if let Some(tt) = iter.next() {
if let TokenTree::Punct(_) = tt {
next_tt = iter.next();
continue;
}
return Err(quote_spanned! {
tt.span()=> {
compile_error!("Expected either start of a new bitslice or \")\"");
}
});
} else {
break;
}
}
Ok((bitstructs, options))
}
pub(crate) fn parse_item(
item: &TokenStream,
bitstructs: &mut Vec<BitStruct>,
) -> Result<Ident, TokenStream> {
let mut iter = item.clone().into_iter();
while let Some(tt) = iter.next() {
if let TokenTree::Ident(keyword) = tt {
if keyword.to_string() == "struct" {
break;
}
}
}
let struct_name = if let Some(tt) = iter.next() {
if let TokenTree::Ident(struct_name) = tt {
struct_name
} else {
panic!("Expected valid struct name in struct declaration!");
}
} else {
panic!("Expected struct declaration but didn't find struct keyword!");
};
while let Some(tt) = iter.next() {
let member_group = if let TokenTree::Group(group) = tt {
if group.delimiter() == Delimiter::Brace {
group
} else {
continue;
}
} else {
continue;
};
let mut tokens = member_group.stream().into_iter().peekable();
let mut last_token = None;
while let Some(ref tt) = tokens.peek() {
if let Punct(punct) = tt {
if punct.as_char() == ':' {
if let Some(TokenTree::Ident(ref member_name)) = last_token {
if let Some(TokenTree::Ident(type_name)) = tokens.nth(1) {
for bitstruct in &mut *bitstructs {
for slice in &mut bitstruct.slices {
if slice.base.to_string() == member_name.to_string() {
slice.ty = Some(type_name.clone());
}
}
}
}
}
}
}
last_token = tokens.next();
}
}
Ok(struct_name)
}
#[proc_macro_attribute]
pub fn slicefields(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let attr = TokenStream::from(attr);
let item = TokenStream::from(item);
#[cfg(feature = "debug")]
println!("attr: {:#?}", attr);
#[cfg(feature = "debug")]
println!("item: {:#?}", item);
let (mut bitstructs, options) = match parse_attr(attr) {
Ok(ret) => ret,
Err(ts) => {
return proc_macro::TokenStream::from(ts);
}
};
let struct_name = match parse_item(&item, &mut bitstructs) {
Ok(struct_name) => struct_name,
Err(ts) => {
return proc_macro::TokenStream::from(ts);
}
};
let mut function_implementations = Vec::new();
for (i1, bitstruct) in bitstructs.into_iter().enumerate() {
let bitstruct_name = bitstruct.name;
let output_ty = bitstruct.ty;
let mut bitstruct_size = 0;
let mut bitstruct_slice_get = Vec::new();
let mut bitstruct_slice_set = Vec::new();
for slice in &bitstruct.slices {
bitstruct_size += slice.end - slice.start;
}
let mut slice_current_offset = 0;
for (i2, slice) in bitstruct.slices.into_iter().enumerate() {
let output_ty = output_ty.clone();
let slice_base = slice.base;
let slice_ty = if let Some(ty) = slice.ty {
ty
} else {
panic!("Couldn't find struct member {}", slice_base);
};
let slice_start = slice.start;
let slice_end = slice.end;
let size_check_name = format_ident!("__size_check_{}_{}_{}", slice_base, i1, i2);
let slice_bitops_gen = if output_ty.to_string() == "bool" {
slice_bitops_bool
} else {
slice_bitops_naive
};
let slice_bitops = slice_bitops_gen(
slice_current_offset,
bitstruct_name.clone(),
bitstruct_size,
output_ty,
slice_base,
slice_ty.clone(),
slice_start,
slice_end,
&options,
);
let slice_bitops_get = slice_bitops.0;
let slice_bitops_set = slice_bitops.1;
let size_check_msg = format!(
"checking whether the range is {}..{} is valid and fits into a type of {}",
slice_start,
slice_end,
slice_ty.to_string()
);
bitstruct_slice_get.push(quote! {
const #size_check_name: () = {
if #slice_end < #slice_start ||
#slice_end > core::mem::size_of::<#slice_ty>() * 8 + 1
{
panic!(#size_check_msg);
}
()
};
#slice_bitops_get
});
bitstruct_slice_set.push(quote! {
#slice_bitops_set
});
slice_current_offset += slice_end - slice_start;
}
let bitstruct_set_ident = format_ident!("set_{}", bitstruct_name);
let size_check_msg = format!(
"checking whether {} can fit into a type of {}",
bitstruct_name.to_string(),
output_ty.to_string()
);
let (slice_get, slice_set) = gen_slice_ops(&options);
function_implementations.push(quote! {
pub fn #bitstruct_name(&self) -> #output_ty {
#slice_get
const __SIZE_CHECK: () = {
if core::mem::size_of::<#output_ty>() * 8 < #bitstruct_size {
panic!(#size_check_msg);
}
()
};
let mut ret = <#output_ty as core::default::Default>::default();
#(#bitstruct_slice_get)*
ret
}
pub fn #bitstruct_set_ident(&mut self, val: #output_ty) {
#slice_get
#slice_set
#(#bitstruct_slice_set)*
}
});
}
let struct_implementation = quote! {
#[allow(non_upper_case_globals)]
#[allow(dead_code)]
#[allow(unused_braces)]
#[allow(redundant_semicolon)]
#[allow(non_snake_case)]
impl #struct_name {
#(#function_implementations)*
}
};
#[cfg(not(feature = "debug"))]
return proc_macro::TokenStream::from(quote! {
#item
#struct_implementation
});
#[cfg(feature = "debug")]
{
let result = proc_macro::TokenStream::from(quote! {
#item
#struct_implementation
});
println!("\n[BITSTRUCT CODEGEN DEBUG]:\n{}\n", result.to_string());
result
}
}
const SLICE_GET_IDENT: &'static str = "__BITSTRUCT_SLICE_GET";
const SLICE_SET_IDENT: &'static str = "__BITSTRUCT_SLICE_SET";
fn gen_slice_ops(options: &Options) -> (TokenStream, TokenStream) {
let slice_get_ident = Ident::new(SLICE_GET_IDENT, Span::call_site());
let slice_set_ident = Ident::new(SLICE_SET_IDENT, Span::call_site());
let (slice_get, slice_set) = if options.volatile {
if options.unaligned {
#[cfg(not(feature = "unstable"))]
return (
quote! {
compile_error!("Unaligned and volatile slices only work\
on nightly and when using the 'unstable' crate feature");
},
quote! {},
);
#[allow(unreachable_code)]
(
quote! {
#[inline(always)]
unsafe fn #slice_get_ident<T: Copy>(ptr: *const T) -> T {
core::intrinsics::unaligned_volatile_load(ptr)
}
},
quote! {
#[inline(always)]
unsafe fn #slice_set_ident<T: Copy>(ptr: *mut T, slice: T) {
core::intrinsics::unaligned_volatile_store(ptr, slice)
}
},
)
} else {
(
quote! {
#[inline(always)]
unsafe fn #slice_get_ident<T: Copy>(ptr: *const T) -> T {
core::ptr::read_volatile(ptr)
}
},
quote! {
#[inline(always)]
unsafe fn #slice_set_ident<T: Copy>(ptr: *mut T, slice: T) {
core::ptr::write_volatile(ptr, slice)
}
},
)
}
} else if options.unaligned {
(
quote! {
#[inline(always)]
unsafe fn #slice_get_ident<T: Copy>(ptr: *const T) -> T {
core::ptr::read_unaligned(ptr)
}
},
quote! {
#[inline(always)]
unsafe fn #slice_set_ident<T: Copy>(ptr: *mut T, slice: T) {
core::ptr::write_unaligned(ptr, slice)
}
},
)
} else {
(
quote! {
#[inline(always)]
unsafe fn #slice_get_ident<T: Copy>(ptr: *const T) -> T {
*ptr
}
},
quote! {
#[inline(always)]
unsafe fn #slice_set_ident<T: Copy>(ptr: *mut T, slice: T) {
(*ptr) = slice;
}
},
)
};
(slice_get, slice_set)
}
fn slice_bitops_naive(
slice_current_offset: usize,
bitstruct: Ident,
_bitstruct_size: usize,
bitstruct_ty: Ident,
slice_base: Ident,
slice_ty: Ident,
slice_start: usize,
slice_end: usize,
_options: &Options,
) -> (TokenStream, TokenStream) {
let mut mask = String::new();
let mut bitstruct_mask = String::new();
for _ in slice_start..slice_end {
mask += "1";
bitstruct_mask += "1";
}
for _ in 0..slice_start {
mask += "0";
}
for _ in 0..slice_current_offset {
bitstruct_mask += "0"
}
let bitstruct_mask = u128::from_str_radix(&bitstruct_mask, 2).unwrap();
let mask = u128::from_str_radix(&mask, 2).unwrap();
let mask_ident = format_ident!(
"__BITSTRUCT_MASK_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let mask_set_ident = format_ident!(
"__BITSTRUCT_MASK_SET_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let bitstruct_mask_ident = format_ident!(
"__BITSTRUCT_MASK_BITSTRUCT_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let bitstruct_shift_precast_ident = format_ident!(
"__BITSTRUCT_SHIFT_PRECAST_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let bitstruct_shift_postcast_ident = format_ident!(
"__BITSTRUCT_SHIFT_POSTCAST_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let bitstruct_shift_cast_ident = format_ident!(
"__BITSTRUCT_SHIFT_CAST_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let slice_get_ident = Ident::new(SLICE_GET_IDENT, Span::call_site());
let slice_set_ident = Ident::new(SLICE_SET_IDENT, Span::call_site());
let (shift_get_op, shift_set_op) = if slice_current_offset > slice_start {
(
quote! { << (#slice_current_offset - #slice_start) },
quote! { >> (#slice_current_offset - #slice_start) },
)
} else {
(
quote! { >> (#slice_start - #slice_current_offset) },
quote! { << (#slice_start - #slice_current_offset) },
)
};
(
quote! {
const #mask_ident: #slice_ty = #mask as #slice_ty;
const #bitstruct_shift_cast_ident: fn(#slice_ty) -> #bitstruct_ty = {
if core::mem::size_of::<#bitstruct_ty>() < core::mem::size_of::<#slice_ty>() {
#bitstruct_shift_precast_ident
} else {
#bitstruct_shift_postcast_ident
}
};
let mut compl = unsafe {
#slice_get_ident(core::ptr::addr_of!(self.#slice_base)) & #mask_ident
};
#[inline(always)]
fn #bitstruct_shift_precast_ident(val: #slice_ty) -> #bitstruct_ty {
(val #shift_get_op) as #bitstruct_ty
}
#[inline(always)]
fn #bitstruct_shift_postcast_ident(val: #slice_ty) -> #bitstruct_ty {
(val as #bitstruct_ty) #shift_get_op
}
ret = ret | #bitstruct_shift_cast_ident(compl);
},
quote! {
const #mask_set_ident: #slice_ty = !(#mask as #slice_ty);
const #bitstruct_mask_ident: #bitstruct_ty = #bitstruct_mask as #bitstruct_ty;
const #bitstruct_shift_cast_ident: fn(#bitstruct_ty) -> #slice_ty = {
if core::mem::size_of::<#slice_ty>() < core::mem::size_of::<#bitstruct_ty>() {
#bitstruct_shift_precast_ident
} else {
#bitstruct_shift_postcast_ident
}
};
#[inline(always)]
fn #bitstruct_shift_precast_ident(val: #bitstruct_ty) -> #slice_ty {
(val #shift_set_op) as #slice_ty
}
#[inline(always)]
fn #bitstruct_shift_postcast_ident(val: #bitstruct_ty) -> #slice_ty {
(val as #slice_ty) #shift_set_op
}
unsafe {
#slice_set_ident(core::ptr::addr_of_mut!(self.#slice_base),
(#slice_get_ident(core::ptr::addr_of!(self.#slice_base)) & #mask_set_ident) |
#bitstruct_shift_cast_ident(val & #bitstruct_mask_ident)
);
}
},
)
}
fn slice_bitops_bool(
slice_current_offset: usize,
bitstruct: Ident,
_bitstruct_size: usize,
_bitstruct_ty: Ident,
slice_base: Ident,
slice_ty: Ident,
slice_start: usize,
_slice_end: usize,
_options: &Options,
) -> (TokenStream, TokenStream) {
let mut mask = String::from("1");
for _ in 0..slice_start {
mask += "0";
}
let mask = u128::from_str_radix(&mask, 2).unwrap();
let mask_ident = format_ident!(
"__BITSTRUCT_MASK_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let mask_set_ident = format_ident!(
"__BITSTRUCT_MASK_SET_{}_{}_{}",
bitstruct,
slice_base,
slice_start
);
let slice_get_ident = Ident::new(SLICE_GET_IDENT, Span::call_site());
let slice_set_ident = Ident::new(SLICE_SET_IDENT, Span::call_site());
let bitops_get = quote! {
{
const #mask_ident: #slice_ty = #mask as #slice_ty;
(
(unsafe { #slice_get_ident(core::ptr::addr_of!(self.#slice_base)) } & #mask_ident )
>> #slice_start
) != 0
}
};
let setter = quote! {
const #mask_set_ident: #slice_ty = !(#mask as #slice_ty);
unsafe {
#slice_set_ident(core::ptr::addr_of_mut!(self.#slice_base),
(#slice_get_ident(core::ptr::addr_of!(self.#slice_base))
& #mask_set_ident) |
(<#slice_ty>::from(val) << #slice_start)
);
}
};
if slice_current_offset == 0 {
return (
quote! {
ret = #bitops_get;
},
setter,
);
}
(
quote! {
ret = ret || #bitops_get;
},
setter,
)
}