1use std::sync::{Arc, RwLock, RwLockReadGuard};
25
26use derive_more::{From, IsVariant, Unwrap};
27
28use crate::span::Spanned;
29
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34#[derive(Default, Clone, Debug, PartialEq)]
35pub struct TranslationUnit {
36 #[cfg(feature = "imports")]
37 pub imports: Vec<Import>,
38 pub global_directives: Vec<GlobalDirective>,
39 pub global_declarations: Vec<GlobalDeclaration>,
40}
41
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[derive(Clone, Debug)]
50pub struct Ident(Arc<RwLock<String>>);
51
52impl Ident {
53 pub fn new(name: String) -> Ident {
55 Ident(Arc::new(RwLock::new(name)))
56 }
57 pub fn name(&self) -> RwLockReadGuard<'_, String> {
59 self.0.read().unwrap()
60 }
61 pub fn rename(&mut self, name: String) {
63 *self.0.write().unwrap() = name;
64 }
65 pub fn use_count(&self) -> usize {
67 Arc::<_>::strong_count(&self.0)
68 }
69}
70
71impl PartialEq for Ident {
73 fn eq(&self, other: &Self) -> bool {
74 Arc::ptr_eq(&self.0, &other.0)
75 }
76}
77
78impl Eq for Ident {}
80
81impl std::hash::Hash for Ident {
83 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
84 std::ptr::hash(&*self.0, state)
85 }
86}
87
88#[cfg(feature = "imports")]
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90#[derive(Clone, Debug, PartialEq)]
91pub struct Import {
92 #[cfg(feature = "attributes")]
93 pub attributes: Attributes,
94 pub path: std::path::PathBuf,
95 pub content: ImportContent,
96}
97
98#[cfg(feature = "imports")]
99#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
100#[derive(Clone, Debug, PartialEq, IsVariant, Unwrap)]
101pub enum ImportContent {
102 Item(ImportItem),
103 Collection(Vec<Import>),
104}
105
106#[cfg(feature = "imports")]
107#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
108#[derive(Clone, Debug, PartialEq)]
109pub struct ImportItem {
110 pub ident: Ident,
111 pub rename: Option<Ident>,
112}
113
114#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
115#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
116pub enum GlobalDirective {
117 Diagnostic(DiagnosticDirective),
118 Enable(EnableDirective),
119 Requires(RequiresDirective),
120}
121
122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
123#[derive(Clone, Debug, PartialEq)]
124pub struct DiagnosticDirective {
125 #[cfg(feature = "attributes")]
126 pub attributes: Attributes,
127 pub severity: DiagnosticSeverity,
128 pub rule_name: String,
129}
130
131#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
132#[derive(Clone, Debug, PartialEq, Eq, IsVariant)]
133pub enum DiagnosticSeverity {
134 Error,
135 Warning,
136 Info,
137 Off,
138}
139
140#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
141#[derive(Clone, Debug, PartialEq)]
142pub struct EnableDirective {
143 #[cfg(feature = "attributes")]
144 pub attributes: Attributes,
145 pub extensions: Vec<String>,
146}
147
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149#[derive(Clone, Debug, PartialEq)]
150pub struct RequiresDirective {
151 #[cfg(feature = "attributes")]
152 pub attributes: Attributes,
153 pub extensions: Vec<String>,
154}
155
156#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
157#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
158pub enum GlobalDeclaration {
159 Void,
160 Declaration(Declaration),
161 TypeAlias(TypeAlias),
162 Struct(Struct),
163 Function(Function),
164 ConstAssert(ConstAssert),
165}
166
167#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
168#[derive(Clone, Debug, PartialEq)]
169pub struct Declaration {
170 pub attributes: Attributes,
171 pub kind: DeclarationKind,
172 pub ident: Ident,
173 pub ty: Option<TypeExpression>,
174 pub initializer: Option<ExpressionNode>,
175}
176
177#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
178#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
179pub enum DeclarationKind {
180 Const,
181 Override,
182 Let,
183 Var(Option<AddressSpace>), }
185
186#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
187#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
188pub enum AddressSpace {
189 Function,
190 Private,
191 Workgroup,
192 Uniform,
193 Storage(Option<AccessMode>),
194 Handle, }
196
197#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
198#[derive(Clone, Copy, Debug, PartialEq, Eq)]
199pub enum AccessMode {
200 Read,
201 Write,
202 ReadWrite,
203}
204
205#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
206#[derive(Clone, Debug, PartialEq)]
207pub struct TypeAlias {
208 #[cfg(feature = "attributes")]
209 pub attributes: Attributes,
210 pub ident: Ident,
211 pub ty: TypeExpression,
212}
213
214#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
215#[derive(Clone, Debug, PartialEq)]
216pub struct Struct {
217 #[cfg(feature = "attributes")]
218 pub attributes: Attributes,
219 pub ident: Ident,
220 pub members: Vec<StructMember>,
221}
222
223#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
224#[derive(Clone, Debug, PartialEq)]
225pub struct StructMember {
226 pub attributes: Attributes,
227 pub ident: Ident,
228 pub ty: TypeExpression,
229}
230
231#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
232#[derive(Clone, Debug, PartialEq)]
233pub struct Function {
234 pub attributes: Attributes,
235 pub ident: Ident,
236 pub parameters: Vec<FormalParameter>,
237 pub return_attributes: Attributes,
238 pub return_type: Option<TypeExpression>,
239 pub body: CompoundStatement,
240}
241
242#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
243#[derive(Clone, Debug, PartialEq)]
244pub struct FormalParameter {
245 pub attributes: Attributes,
246 pub ident: Ident,
247 pub ty: TypeExpression,
248}
249
250#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
251#[derive(Clone, Debug, PartialEq)]
252pub struct ConstAssert {
253 #[cfg(feature = "attributes")]
254 pub attributes: Attributes,
255 pub expression: ExpressionNode,
256}
257
258#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
259#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
260pub enum BuiltinValue {
261 VertexIndex,
262 InstanceIndex,
263 Position,
264 FrontFacing,
265 FragDepth,
266 SampleIndex,
267 SampleMask,
268 LocalInvocationId,
269 LocalInvocationIndex,
270 GlobalInvocationId,
271 WorkgroupId,
272 NumWorkgroups,
273}
274
275#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
276#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
277pub enum InterpolationType {
278 Perspective,
279 Linear,
280 Flat,
281}
282
283#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
284#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
285pub enum InterpolationSampling {
286 Center,
287 Centroid,
288 Sample,
289 First,
290 Either,
291}
292
293#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
294#[derive(Clone, Debug, PartialEq)]
295pub struct DiagnosticAttribute {
296 pub severity: DiagnosticSeverity,
297 pub rule: String,
298}
299
300#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
301#[derive(Clone, Debug, PartialEq)]
302pub struct InterpolateAttribute {
303 pub ty: InterpolationType,
304 pub sampling: Option<InterpolationSampling>,
305}
306
307#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
308#[derive(Clone, Debug, PartialEq)]
309pub struct WorkgroupSizeAttribute {
310 pub x: ExpressionNode,
311 pub y: Option<ExpressionNode>,
312 pub z: Option<ExpressionNode>,
313}
314
315#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
316#[derive(Clone, Debug, PartialEq)]
317pub struct CustomAttribute {
318 pub name: String,
319 pub arguments: Option<Vec<ExpressionNode>>,
320}
321
322#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
323#[derive(Clone, Debug, PartialEq, IsVariant, Unwrap)]
324pub enum Attribute {
325 Align(ExpressionNode),
326 Binding(ExpressionNode),
327 BlendSrc(ExpressionNode),
328 Builtin(BuiltinValue),
329 Const,
330 Diagnostic(DiagnosticAttribute),
331 Group(ExpressionNode),
332 Id(ExpressionNode),
333 Interpolate(InterpolateAttribute),
334 Invariant,
335 Location(ExpressionNode),
336 MustUse,
337 Size(ExpressionNode),
338 WorkgroupSize(WorkgroupSizeAttribute),
339 Vertex,
340 Fragment,
341 Compute,
342 #[cfg(feature = "condcomp")]
343 If(ExpressionNode),
344 #[cfg(feature = "generics")]
345 Type(TypeConstraint),
346 Custom(CustomAttribute),
347}
348
349#[cfg(feature = "generics")]
350#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
351#[derive(Clone, Debug, PartialEq, From)]
352pub struct TypeConstraint {
353 pub ident: Ident,
354 pub variants: Vec<TypeExpression>,
355}
356
357pub type Attributes = Vec<Attribute>;
358
359#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
360#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
361pub enum Expression {
362 Literal(LiteralExpression),
363 Parenthesized(ParenthesizedExpression),
364 NamedComponent(NamedComponentExpression),
365 Indexing(IndexingExpression),
366 Unary(UnaryExpression),
367 Binary(BinaryExpression),
368 FunctionCall(FunctionCallExpression),
369 TypeOrIdentifier(TypeExpression),
370}
371
372pub type ExpressionNode = Spanned<Expression>;
373
374#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
375#[derive(Clone, Copy, Debug, PartialEq, From, IsVariant, Unwrap)]
376pub enum LiteralExpression {
377 Bool(bool),
378 AbstractInt(i64),
379 AbstractFloat(f64),
380 I32(i32),
381 U32(u32),
382 F32(f32),
383 #[from(skip)]
384 F16(f32),
385}
386
387#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
388#[derive(Clone, Debug, PartialEq)]
389pub struct ParenthesizedExpression {
390 pub expression: ExpressionNode,
391}
392
393#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
394#[derive(Clone, Debug, PartialEq)]
395pub struct NamedComponentExpression {
396 pub base: ExpressionNode,
397 pub component: Ident,
398}
399
400#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
401#[derive(Clone, Debug, PartialEq)]
402pub struct IndexingExpression {
403 pub base: ExpressionNode,
404 pub index: ExpressionNode,
405}
406
407#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
408#[derive(Clone, Debug, PartialEq)]
409pub struct UnaryExpression {
410 pub operator: UnaryOperator,
411 pub operand: ExpressionNode,
412}
413
414#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
415#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
416pub enum UnaryOperator {
417 LogicalNegation,
418 Negation,
419 BitwiseComplement,
420 AddressOf,
421 Indirection,
422}
423
424#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
425#[derive(Clone, Debug, PartialEq)]
426pub struct BinaryExpression {
427 pub operator: BinaryOperator,
428 pub left: ExpressionNode,
429 pub right: ExpressionNode,
430}
431
432#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
433#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
434pub enum BinaryOperator {
435 ShortCircuitOr,
436 ShortCircuitAnd,
437 Addition,
438 Subtraction,
439 Multiplication,
440 Division,
441 Remainder,
442 Equality,
443 Inequality,
444 LessThan,
445 LessThanEqual,
446 GreaterThan,
447 GreaterThanEqual,
448 BitwiseOr,
449 BitwiseAnd,
450 BitwiseXor,
451 ShiftLeft,
452 ShiftRight,
453}
454
455#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
456#[derive(Clone, Debug, PartialEq)]
457pub struct FunctionCall {
458 pub ty: TypeExpression,
459 pub arguments: Vec<ExpressionNode>,
460}
461
462pub type FunctionCallExpression = FunctionCall;
463
464#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
465#[derive(Clone, Debug, PartialEq)]
466pub struct TypeExpression {
467 #[cfg(feature = "imports")]
468 pub path: Option<std::path::PathBuf>,
469 pub ident: Ident,
470 pub template_args: TemplateArgs,
471}
472
473#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
474#[derive(Clone, Debug, PartialEq)]
475pub struct TemplateArg {
476 pub expression: ExpressionNode,
477}
478pub type TemplateArgs = Option<Vec<TemplateArg>>;
479
480#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
481#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
482pub enum Statement {
483 Void,
484 Compound(CompoundStatement),
485 Assignment(AssignmentStatement),
486 Increment(IncrementStatement),
487 Decrement(DecrementStatement),
488 If(IfStatement),
489 Switch(SwitchStatement),
490 Loop(LoopStatement),
491 For(ForStatement),
492 While(WhileStatement),
493 Break(BreakStatement),
494 Continue(ContinueStatement),
495 Return(ReturnStatement),
496 Discard(DiscardStatement),
497 FunctionCall(FunctionCallStatement),
498 ConstAssert(ConstAssertStatement),
499 Declaration(DeclarationStatement),
500}
501
502pub type StatementNode = Spanned<Statement>;
503
504#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
505#[derive(Clone, Debug, PartialEq)]
506pub struct CompoundStatement {
507 pub attributes: Attributes,
508 pub statements: Vec<StatementNode>,
509}
510
511#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
512#[derive(Clone, Debug, PartialEq)]
513pub struct AssignmentStatement {
514 #[cfg(feature = "attributes")]
515 pub attributes: Attributes,
516 pub operator: AssignmentOperator,
517 pub lhs: ExpressionNode,
518 pub rhs: ExpressionNode,
519}
520
521#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
522#[derive(Clone, Debug, PartialEq, Eq, IsVariant)]
523pub enum AssignmentOperator {
524 Equal,
525 PlusEqual,
526 MinusEqual,
527 TimesEqual,
528 DivisionEqual,
529 ModuloEqual,
530 AndEqual,
531 OrEqual,
532 XorEqual,
533 ShiftRightAssign,
534 ShiftLeftAssign,
535}
536
537#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
538#[derive(Clone, Debug, PartialEq)]
539pub struct IncrementStatement {
540 #[cfg(feature = "attributes")]
541 pub attributes: Attributes,
542 pub expression: ExpressionNode,
543}
544
545#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
546#[derive(Clone, Debug, PartialEq)]
547pub struct DecrementStatement {
548 #[cfg(feature = "attributes")]
549 pub attributes: Attributes,
550 pub expression: ExpressionNode,
551}
552
553#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
554#[derive(Clone, Debug, PartialEq)]
555pub struct IfStatement {
556 pub attributes: Attributes,
557 pub if_clause: IfClause,
558 pub else_if_clauses: Vec<ElseIfClause>,
559 pub else_clause: Option<ElseClause>,
560}
561
562#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
563#[derive(Clone, Debug, PartialEq)]
564pub struct IfClause {
565 pub expression: ExpressionNode,
566 pub body: CompoundStatement,
567}
568
569#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
570#[derive(Clone, Debug, PartialEq)]
571pub struct ElseIfClause {
572 #[cfg(feature = "attributes")]
573 pub attributes: Attributes,
574 pub expression: ExpressionNode,
575 pub body: CompoundStatement,
576}
577
578#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
579#[derive(Clone, Debug, PartialEq)]
580pub struct ElseClause {
581 #[cfg(feature = "attributes")]
582 pub attributes: Attributes,
583 pub body: CompoundStatement,
584}
585
586#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
587#[derive(Clone, Debug, PartialEq)]
588pub struct SwitchStatement {
589 pub attributes: Attributes,
590 pub expression: ExpressionNode,
591 pub body_attributes: Attributes,
592 pub clauses: Vec<SwitchClause>,
593}
594
595#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
596#[derive(Clone, Debug, PartialEq)]
597pub struct SwitchClause {
598 #[cfg(feature = "attributes")]
599 pub attributes: Attributes,
600 pub case_selectors: Vec<CaseSelector>,
601 pub body: CompoundStatement,
602}
603
604#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
605#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
606pub enum CaseSelector {
607 Default,
608 Expression(ExpressionNode),
609}
610
611#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
612#[derive(Clone, Debug, PartialEq)]
613pub struct LoopStatement {
614 pub attributes: Attributes,
615 pub body: CompoundStatement,
616 pub continuing: Option<ContinuingStatement>,
620}
621
622#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
623#[derive(Clone, Debug, PartialEq)]
624pub struct ContinuingStatement {
625 #[cfg(feature = "attributes")]
626 pub attributes: Attributes,
627 pub body: CompoundStatement,
628 pub break_if: Option<BreakIfStatement>,
632}
633
634#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
635#[derive(Clone, Debug, PartialEq)]
636pub struct BreakIfStatement {
637 #[cfg(feature = "attributes")]
638 pub attributes: Attributes,
639 pub expression: ExpressionNode,
640}
641
642#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
643#[derive(Clone, Debug, PartialEq)]
644pub struct ForStatement {
645 pub attributes: Attributes,
646 pub initializer: Option<StatementNode>,
647 pub condition: Option<ExpressionNode>,
648 pub update: Option<StatementNode>,
649 pub body: CompoundStatement,
650}
651
652#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
653#[derive(Clone, Debug, PartialEq)]
654pub struct WhileStatement {
655 pub attributes: Attributes,
656 pub condition: ExpressionNode,
657 pub body: CompoundStatement,
658}
659
660#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
661#[derive(Clone, Debug, PartialEq)]
662pub struct BreakStatement {
663 #[cfg(feature = "attributes")]
664 pub attributes: Attributes,
665}
666
667#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
668#[derive(Clone, Debug, PartialEq)]
669pub struct ContinueStatement {
670 #[cfg(feature = "attributes")]
671 pub attributes: Attributes,
672}
673
674#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
675#[derive(Clone, Debug, PartialEq)]
676pub struct ReturnStatement {
677 #[cfg(feature = "attributes")]
678 pub attributes: Attributes,
679 pub expression: Option<ExpressionNode>,
680}
681
682#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
683#[derive(Clone, Debug, PartialEq)]
684pub struct DiscardStatement {
685 #[cfg(feature = "attributes")]
686 pub attributes: Attributes,
687}
688
689#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
690#[derive(Clone, Debug, PartialEq)]
691pub struct FunctionCallStatement {
692 #[cfg(feature = "attributes")]
693 pub attributes: Attributes,
694 pub call: FunctionCall,
695}
696
697pub type ConstAssertStatement = ConstAssert;
698
699pub type DeclarationStatement = Declaration;