use crate::rust_macro_graphql_token_source::RustMacroGraphQLTokenSource;
use libgraphql_parser::token::GraphQLToken;
use libgraphql_parser::token::GraphQLTokenKind;
use libgraphql_parser::token::GraphQLTriviaToken;
use proc_macro2::TokenStream;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
pub(crate) fn token_kinds_match(
a: &GraphQLTokenKind<'static>,
b: &GraphQLTokenKind<'_>,
) -> bool {
use GraphQLTokenKind::*;
match (a, b) {
(Ampersand, Ampersand)
| (At, At)
| (Bang, Bang)
| (Colon, Colon)
| (CurlyBraceClose, CurlyBraceClose)
| (CurlyBraceOpen, CurlyBraceOpen)
| (Dollar, Dollar)
| (Ellipsis, Ellipsis)
| (Equals, Equals)
| (ParenClose, ParenClose)
| (ParenOpen, ParenOpen)
| (Pipe, Pipe)
| (SquareBracketClose, SquareBracketClose)
| (SquareBracketOpen, SquareBracketOpen) => true,
(True, True) | (False, False) | (Null, Null) => true,
(Eof, Eof) => true,
(Name(a), Name(b)) => a == b,
(IntValue(a), IntValue(b)) => a == b,
(FloatValue(a), FloatValue(b)) => a == b,
(StringValue(a), StringValue(b)) => a == b,
(Error(e1), Error(e2)) => {
e1.message == e2.message
&& e1.error_notes.len() == e2.error_notes.len()
&& e1.error_notes.iter().zip(e2.error_notes.iter()).all(|(a, b)| {
a.kind == b.kind && a.message == b.message
})
},
_ => false,
}
}
pub(crate) fn trivia_kinds_match(
a: &GraphQLTriviaToken<'static>,
b: &GraphQLTriviaToken<'_>,
) -> bool {
match (a, b) {
(
GraphQLTriviaToken::Comma { .. },
GraphQLTriviaToken::Comma { .. },
) => true,
(
GraphQLTriviaToken::Comment { value: va, .. },
GraphQLTriviaToken::Comment { value: vb, .. },
) => va == vb,
_ => false,
}
}
pub(crate) fn tokenize_via_rust(
input: &str,
) -> Vec<GraphQLToken<'static>> {
let stream = TokenStream::from_str(input)
.expect("Failed to parse as Rust tokens");
let span_map = Rc::new(RefCell::new(HashMap::new()));
let source =
RustMacroGraphQLTokenSource::new(stream, span_map);
source.collect()
}
pub(crate) fn tokenize_via_str(
input: &str,
) -> Vec<GraphQLToken<'_>> {
let config =
libgraphql_parser::token::StrGraphQLTokenSourceConfig {
retain_whitespace: false,
..Default::default()
};
let source =
libgraphql_parser::token::StrGraphQLTokenSource::with_config(
input,
config,
);
source.collect()
}
pub(crate) fn assert_parity(input: &str) {
let rust_tokens = tokenize_via_rust(input);
let str_tokens = tokenize_via_str(input);
assert_eq!(
rust_tokens.len(),
str_tokens.len(),
"Token count mismatch for input: {input:?}\n\
Rust tokens: {rust_kinds:?}\n\
Str tokens: {str_kinds:?}",
rust_kinds = rust_tokens
.iter()
.map(|t| format!("{:?}", t.kind))
.collect::<Vec<_>>(),
str_kinds = str_tokens
.iter()
.map(|t| format!("{:?}", t.kind))
.collect::<Vec<_>>(),
);
for (i, (rt, st)) in rust_tokens.iter().zip(str_tokens.iter()).enumerate() {
assert!(
token_kinds_match(&rt.kind, &st.kind),
"Token kind mismatch at position {i} for input: \
{input:?}\n Rust: {:?}\n Str: {:?}",
rt.kind,
st.kind,
);
assert_eq!(
rt.preceding_trivia.len(),
st.preceding_trivia.len(),
"Trivia count mismatch at position {i} for input: \
{input:?}\n Rust trivia: {:?}\n Str trivia: {:?}",
rt.preceding_trivia,
st.preceding_trivia,
);
for (j, (rtv, stv)) in rt
.preceding_trivia
.iter()
.zip(st.preceding_trivia.iter())
.enumerate()
{
assert!(
trivia_kinds_match(rtv, stv),
"Trivia mismatch at position {i}, trivia {j} \
for input: {input:?}\n Rust: {rtv:?}\n \
Str: {stv:?}",
);
}
}
}