use crate::ty::{Attribute, TokenIter};
use crate::utils::spanned_error;
use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Span, TokenStream, TokenTree};
pub trait TokenStreamExt {
fn into_token_iter(self) -> TokenIter;
}
pub trait TokenIterExt: Iterator<Item = TokenTree> {
fn parse_attributes(&mut self) -> Result<Vec<Attribute>, TokenStream>;
fn parse_visibility(&mut self) -> Result<(), TokenStream>;
fn parse_path(&mut self) -> Result<(String, Span), TokenStream>;
fn expect_group(&mut self, expect: Delimiter) -> Result<TokenIter, TokenStream>;
fn expect_ident(&mut self, expect: &str) -> Result<(), TokenStream>;
fn expect_punct(&mut self, expect: char) -> Result<(), TokenStream>;
fn try_group(&mut self) -> Result<Group, TokenStream>;
fn try_ident(&mut self) -> Result<Ident, TokenStream>;
fn try_lit(&mut self) -> Result<Literal, TokenStream>;
fn try_punct(&mut self) -> Result<Punct, TokenStream>;
}
pub trait TokenTreeExt {
fn as_span(&self) -> Span;
}
pub trait LiteralExt {
fn as_char(&self) -> Result<char, TokenStream>;
fn as_string(&self) -> Result<String, TokenStream>;
}
impl TokenStreamExt for TokenStream {
fn into_token_iter(self) -> TokenIter {
self.into_iter().peekable()
}
}
impl TokenTreeExt for Option<TokenTree> {
fn as_span(&self) -> Span {
match self {
Some(TokenTree::Group(group)) => group.span(),
Some(TokenTree::Ident(ident)) => ident.span(),
Some(TokenTree::Punct(punct)) => punct.span(),
Some(TokenTree::Literal(lit)) => lit.span(),
None => Span::call_site(),
}
}
}
impl LiteralExt for Literal {
fn as_char(&self) -> Result<char, TokenStream> {
let string = self.to_string();
if !string.starts_with('\'') || !string.ends_with('\'') {
return Err(spanned_error("Expected char literal", self.span()));
}
string
.chars()
.nth(1)
.ok_or_else(|| spanned_error("Expected char literal", self.span()))
}
fn as_string(&self) -> Result<String, TokenStream> {
let string = self.to_string();
if !string.starts_with('"') || !string.ends_with('"') {
return Err(spanned_error("Expected string literal", self.span()));
}
Ok(string[1..string.len() - 1]
.replace(r#"\""#, r#"""#)
.replace(r"\n", "\n")
.replace(r"\r", "\r")
.replace(r"\t", "\t")
.replace(r"\'", "'")
.replace(r"\\", r"\"))
}
}