1use super::LimitTracker;
2use crate::cst;
3use crate::cst::CstNode;
4use crate::Error;
5use crate::SyntaxElement;
6use crate::SyntaxKind;
7use crate::SyntaxNode;
8use rowan::GreenNode;
9use rowan::GreenNodeBuilder;
10use std::fmt;
11use std::marker::PhantomData;
12use std::slice::Iter;
13
14pub(crate) enum SyntaxTreeWrapper {
52 Document(SyntaxTree<cst::Document>),
53 FieldSet(SyntaxTree<cst::SelectionSet>),
54 Type(SyntaxTree<cst::Type>),
55}
56
57#[derive(PartialEq, Eq, Clone)]
58pub struct SyntaxTree<T: CstNode = cst::Document> {
59 pub(crate) green: GreenNode,
60 pub(crate) errors: Vec<crate::Error>,
61 pub(crate) recursion_limit: LimitTracker,
62 pub(crate) token_limit: LimitTracker,
63 _phantom: PhantomData<fn() -> T>,
64}
65
66const _: () = {
67 fn assert_send<T: Send>() {}
68 fn assert_sync<T: Sync>() {}
69 let _ = assert_send::<SyntaxTree>;
70 let _ = assert_sync::<SyntaxTree>;
71};
72
73impl<T: CstNode> SyntaxTree<T> {
74 pub fn errors(&self) -> Iter<'_, crate::Error> {
76 self.errors.iter()
77 }
78
79 pub fn recursion_limit(&self) -> LimitTracker {
81 self.recursion_limit
82 }
83
84 pub fn token_limit(&self) -> LimitTracker {
86 self.token_limit
87 }
88
89 pub fn green(&self) -> GreenNode {
90 self.green.clone()
91 }
92
93 pub(crate) fn syntax_node(&self) -> SyntaxNode {
94 rowan::SyntaxNode::new_root(self.green.clone())
95 }
96}
97
98impl SyntaxTree<cst::Document> {
99 pub fn document(&self) -> cst::Document {
101 cst::Document {
102 syntax: self.syntax_node(),
103 }
104 }
105}
106
107impl SyntaxTree<cst::SelectionSet> {
108 pub fn field_set(&self) -> cst::SelectionSet {
111 cst::SelectionSet {
112 syntax: self.syntax_node(),
113 }
114 }
115}
116
117impl SyntaxTree<cst::Type> {
118 pub fn ty(&self) -> cst::Type {
121 match self.syntax_node().kind() {
122 SyntaxKind::NAMED_TYPE => cst::Type::NamedType(cst::NamedType {
123 syntax: self.syntax_node(),
124 }),
125 SyntaxKind::LIST_TYPE => cst::Type::ListType(cst::ListType {
126 syntax: self.syntax_node(),
127 }),
128 SyntaxKind::NON_NULL_TYPE => cst::Type::NonNullType(cst::NonNullType {
129 syntax: self.syntax_node(),
130 }),
131 _ => unreachable!("this should only return Type node"),
132 }
133 }
134}
135
136impl<T: CstNode> fmt::Debug for SyntaxTree<T> {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 fn print(f: &mut fmt::Formatter<'_>, indent: usize, element: SyntaxElement) -> fmt::Result {
139 let kind: SyntaxKind = element.kind();
140 write!(f, "{:indent$}", "", indent = indent)?;
141 match element {
142 rowan::NodeOrToken::Node(node) => {
143 writeln!(f, "- {:?}@{:?}", kind, node.text_range())?;
144 for child in node.children_with_tokens() {
145 print(f, indent + 4, child)?;
146 }
147 Ok(())
148 }
149
150 rowan::NodeOrToken::Token(token) => {
151 writeln!(
152 f,
153 "- {:?}@{:?} {:?}",
154 kind,
155 token.text_range(),
156 token.text()
157 )
158 }
159 }
160 }
161
162 fn print_err(f: &mut fmt::Formatter<'_>, errors: Vec<Error>) -> fmt::Result {
163 for err in errors {
164 writeln!(f, "- {err:?}")?;
165 }
166
167 write!(f, "")
168 }
169
170 fn print_recursion_limit(
171 f: &mut fmt::Formatter<'_>,
172 recursion_limit: LimitTracker,
173 ) -> fmt::Result {
174 write!(f, "{recursion_limit:?}")
175 }
176
177 print(f, 0, self.syntax_node().into())?;
178 print_err(f, self.errors.clone())?;
179 print_recursion_limit(f, self.recursion_limit)
180 }
181}
182
183#[derive(Debug)]
184pub(crate) struct SyntaxTreeBuilder {
185 builder: GreenNodeBuilder<'static>,
186}
187
188impl SyntaxTreeBuilder {
189 pub(crate) fn new() -> Self {
191 Self {
192 builder: GreenNodeBuilder::new(),
193 }
194 }
195
196 pub(crate) fn checkpoint(&self) -> rowan::Checkpoint {
197 self.builder.checkpoint()
198 }
199
200 pub(crate) fn start_node(&mut self, kind: SyntaxKind) {
202 self.builder.start_node(rowan::SyntaxKind(kind as u16));
203 }
204
205 pub(crate) fn finish_node(&mut self) {
207 self.builder.finish_node();
208 }
209
210 pub(crate) fn wrap_node(&mut self, checkpoint: rowan::Checkpoint, kind: SyntaxKind) {
211 self.builder
212 .start_node_at(checkpoint, rowan::SyntaxKind(kind as u16));
213 }
214
215 pub(crate) fn token(&mut self, kind: SyntaxKind, text: &str) {
217 self.builder.token(rowan::SyntaxKind(kind as u16), text);
218 }
219
220 pub(crate) fn finish_document(
221 self,
222 errors: Vec<Error>,
223 recursion_limit: LimitTracker,
224 token_limit: LimitTracker,
225 ) -> SyntaxTreeWrapper {
226 SyntaxTreeWrapper::Document(SyntaxTree {
227 green: self.builder.finish(),
228 errors,
230 recursion_limit,
232 token_limit,
233 _phantom: PhantomData,
234 })
235 }
236
237 pub(crate) fn finish_selection_set(
238 self,
239 errors: Vec<Error>,
240 recursion_limit: LimitTracker,
241 token_limit: LimitTracker,
242 ) -> SyntaxTreeWrapper {
243 SyntaxTreeWrapper::FieldSet(SyntaxTree {
244 green: self.builder.finish(),
245 errors,
247 recursion_limit,
249 token_limit,
250 _phantom: PhantomData,
251 })
252 }
253
254 pub(crate) fn finish_type(
255 self,
256 errors: Vec<Error>,
257 recursion_limit: LimitTracker,
258 token_limit: LimitTracker,
259 ) -> SyntaxTreeWrapper {
260 SyntaxTreeWrapper::Type(SyntaxTree {
261 green: self.builder.finish(),
262 errors,
264 recursion_limit,
266 token_limit,
267 _phantom: PhantomData,
268 })
269 }
270}
271
272#[cfg(test)]
273mod test {
274 use crate::cst::Definition;
275 use crate::Parser;
276
277 #[test]
278 fn directive_name() {
279 let input = "directive @example(isTreat: Boolean, treatKind: String) on FIELD | MUTATION";
280 let parser = Parser::new(input);
281 let cst = parser.parse();
282 let doc = cst.document();
283
284 for def in doc.definitions() {
285 if let Definition::DirectiveDefinition(directive) = def {
286 assert_eq!(directive.name().unwrap().text(), "example");
287 }
288 }
289 }
290
291 #[test]
292 fn object_type_definition() {
293 let input = "
294 type ProductDimension {
295 size: String
296 weight: Float @tag(name: \"hi from inventory value type field\")
297 }
298 ";
299 let parser = Parser::new(input);
300 let cst = parser.parse();
301 assert_eq!(0, cst.errors().len());
302
303 let doc = cst.document();
304
305 for def in doc.definitions() {
306 if let Definition::ObjectTypeDefinition(object_type) = def {
307 assert_eq!(object_type.name().unwrap().text(), "ProductDimension");
308 for field_def in object_type.fields_definition().unwrap().field_definitions() {
309 println!("{}", field_def.name().unwrap().text()); }
311 }
312 }
313 }
314}