use std::cell::RefCell;
use proc_macro2::{TokenStream, TokenTree};
use serde::{de::Error, de::Visitor, Deserialize};
#[derive(Debug)]
pub struct TokenStreamWrapper(TokenStream);
impl TokenStreamWrapper {
pub fn into_inner(self) -> TokenStream {
self.0
}
}
impl<'de> Deserialize<'de> for TokenStreamWrapper {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Self(deserializer.deserialize_bytes(WrapperVisitor)?))
}
}
impl std::ops::Deref for TokenStreamWrapper {
type Target = TokenStream;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct ParseWrapper<P: syn::parse::Parse>(P);
impl<P: syn::parse::Parse> ParseWrapper<P> {
pub fn into_inner(self) -> P {
self.0
}
}
impl<'de, P: syn::parse::Parse> Deserialize<'de> for ParseWrapper<P> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let token_stream = deserializer.deserialize_bytes(WrapperVisitor)?;
match syn::parse2::<P>(token_stream) {
Ok(parsed) => Ok(Self(parsed)),
Err(err) => {
let msg = err.to_string();
set_parse_error(err);
Err(D::Error::custom(msg))
}
}
}
}
impl<P: syn::parse::Parse> std::ops::Deref for ParseWrapper<P> {
type Target = P;
fn deref(&self) -> &Self::Target {
&self.0
}
}
struct WrapperVisitor;
impl<'de> Visitor<'de> for WrapperVisitor {
type Value = TokenStream;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
formatter.write_str("TokenStream")
}
fn visit_bytes<E>(self, bytes: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
assert!(
bytes.is_empty(),
"visit_bytes should always be called with an empty slice \
(a side channel is used to pass the actual TokenStream;
was TokenStreamWrapper or ParseWrapper used outside of a
serde_tokenstream context?)"
);
Ok(take_wrapper_tokens())
}
}
thread_local! {
static WRAPPER_TOKENS: RefCell<Option<TokenStream>> = Default::default();
static PARSE_ERROR: RefCell<Option<syn::Error>> = Default::default();
}
pub(crate) fn set_wrapper_tokens(tokens: Vec<TokenTree>) {
WRAPPER_TOKENS.with(|cell| {
let mut cell = cell.borrow_mut();
assert!(cell.is_none(), "set_wrapper_tokens requires TLS to be unset");
*cell = Some(tokens.into_iter().collect());
});
}
fn take_wrapper_tokens() -> TokenStream {
WRAPPER_TOKENS.with(|cell| {
cell.borrow_mut().take().expect(
"take_wrapper_tokens requires TLS to be set \
(was TokenStreamWrapper or ParseWrapper used
outside of a serde_tokenstream context?)",
)
})
}
fn set_parse_error(err: syn::Error) {
PARSE_ERROR.with(|cell| {
*cell.borrow_mut() = Some(err);
});
}
pub(crate) fn take_parse_error() -> Option<syn::Error> {
PARSE_ERROR.with(|cell| cell.borrow_mut().take())
}