Skip to main content

dawn_codegen/
parser.rs

1use heck::{ToKebabCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
2use serde::Deserialize;
3use std::collections::{HashMap, HashSet};
4use std::fs;
5use std::hash::{Hash, Hasher};
6use std::path::Path;
7
8/// The root structure of dawn.json
9#[derive(Debug, Deserialize)]
10pub struct DawnApi {
11    #[serde(rename = "_comment")]
12    pub comment: Option<Vec<String>>,
13
14    #[serde(rename = "_doc")]
15    pub doc: Option<String>,
16
17    #[serde(rename = "_metadata")]
18    pub metadata: ApiMetadata,
19
20    /// All other entries are definitions keyed by their canonical names
21    #[serde(flatten)]
22    pub definitions: HashMap<String, Definition>,
23}
24
25impl DawnApi {
26    /// Filter definitions by tags
27    pub fn filter_by_tags(&self, enabled_tags: &[String]) -> DawnApi {
28        let mut filtered_definitions = HashMap::new();
29
30        for (name, def) in &self.definitions {
31            if Self::should_include_definition(def, enabled_tags) {
32                filtered_definitions.insert(name.clone(), (*def).clone());
33            }
34        }
35
36        DawnApi {
37            comment: self.comment.clone(),
38            doc: self.doc.clone(),
39            metadata: self.metadata.clone(),
40            definitions: filtered_definitions,
41        }
42    }
43
44    /// Check if a definition should be included based on tags
45    fn should_include_definition(def: &Definition, enabled_tags: &[String]) -> bool {
46        let def_tags = match def {
47            Definition::Native(d) => &d.tags,
48            Definition::Typedef(d) => &d.tags,
49            Definition::Enum(d) => &d.tags,
50            Definition::Bitmask(d) => &d.tags,
51            Definition::FunctionPointer(d) => &d.tags,
52            Definition::Structure(d) => &d.tags,
53            Definition::Object(d) => &d.tags,
54            Definition::Constant(d) => &d.tags,
55            Definition::Function(d) => &d.tags,
56            Definition::Callback(d) => &d.tags,
57            Definition::CallbackFunction(d) => &d.tags,
58            Definition::CallbackInfo(d) => &d.tags,
59        };
60
61        // If no tags specified, include by default
62        if def_tags.is_empty() {
63            return true;
64        }
65
66        // If any tag matches enabled tags, include
67        def_tags.iter().any(|tag| enabled_tags.contains(tag))
68    }
69
70    /// Get all definitions of a specific category
71    pub fn enums(&self) -> Vec<(&String, &EnumDef)> {
72        self.definitions
73            .iter()
74            .filter_map(|(name, def)| match def {
75                Definition::Enum(enum_def) => Some((name, enum_def)),
76                _ => None,
77            })
78            .collect()
79    }
80
81    pub fn bitmasks(&self) -> Vec<(&String, &BitmaskDef)> {
82        self.definitions
83            .iter()
84            .filter_map(|(name, def)| match def {
85                Definition::Bitmask(bitmask_def) => Some((name, bitmask_def)),
86                _ => None,
87            })
88            .collect()
89    }
90
91    pub fn structures(&self) -> Vec<(&String, &StructureDef)> {
92        self.definitions
93            .iter()
94            .filter_map(|(name, def)| match def {
95                Definition::Structure(struct_def) => Some((name, struct_def)),
96                _ => None,
97            })
98            .collect()
99    }
100
101    pub fn extensions(&self) -> HashMap<&String, HashSet<Extension<'_>>> {
102        let mut extensions = HashMap::new();
103        for (name, def) in self.structures() {
104            for chain_root in &def.chain_roots {
105                extensions
106                    .entry(chain_root)
107                    .or_insert(HashSet::new())
108                    .insert(Extension {
109                        ty: name,
110                        tags: &def.tags,
111                    });
112            }
113        }
114        extensions
115    }
116
117    pub fn objects(&self) -> Vec<(&String, &ObjectDef)> {
118        self.definitions
119            .iter()
120            .filter_map(|(name, def)| match def {
121                Definition::Object(object_def) => Some((name, object_def)),
122                _ => None,
123            })
124            .collect()
125    }
126
127    pub fn functions(&self) -> Vec<(&String, &FunctionDef)> {
128        self.definitions
129            .iter()
130            .filter_map(|(name, def)| match def {
131                Definition::Function(func_def) => Some((name, func_def)),
132                _ => None,
133            })
134            .collect()
135    }
136
137    pub fn callbacks(&self) -> Vec<(&String, &CallbackDef)> {
138        self.definitions
139            .iter()
140            .filter_map(|(name, def)| match def {
141                Definition::Callback(callback_def) => Some((name, callback_def)),
142                _ => None,
143            })
144            .collect()
145    }
146
147    pub fn callback_functions(&self) -> Vec<(&String, &CallbackFunctionDef)> {
148        self.definitions
149            .iter()
150            .filter_map(|(name, def)| match def {
151                Definition::CallbackFunction(callback_func_def) => Some((name, callback_func_def)),
152                _ => None,
153            })
154            .collect()
155    }
156
157    pub fn callback_infos(&self) -> Vec<(&String, &CallbackInfoDef)> {
158        self.definitions
159            .iter()
160            .filter_map(|(name, def)| match def {
161                Definition::CallbackInfo(callback_info_def) => Some((name, callback_info_def)),
162                _ => None,
163            })
164            .collect()
165    }
166
167    pub fn constants(&self) -> Vec<(&String, &ConstantDef)> {
168        self.definitions
169            .iter()
170            .filter_map(|(name, def)| match def {
171                Definition::Constant(const_def) => Some((name, const_def)),
172                _ => None,
173            })
174            .collect()
175    }
176}
177
178/// Metadata about the API
179#[derive(Debug, Clone, Deserialize)]
180pub struct ApiMetadata {
181    pub api: String,
182    pub namespace: String,
183    pub c_prefix: Option<String>,
184    pub proc_table_prefix: String,
185    pub native_namespace: String,
186    pub copyright_year: Option<String>,
187}
188
189/// A definition can be one of many types
190#[derive(Debug, Clone, Deserialize)]
191#[serde(tag = "category")]
192pub enum Definition {
193    #[serde(rename = "native")]
194    Native(NativeType),
195
196    #[serde(rename = "typedef")]
197    Typedef(TypedefDef),
198
199    #[serde(rename = "enum")]
200    Enum(EnumDef),
201
202    #[serde(rename = "bitmask")]
203    Bitmask(BitmaskDef),
204
205    #[serde(rename = "function pointer")]
206    FunctionPointer(FunctionPointerDef),
207
208    #[serde(rename = "structure")]
209    Structure(StructureDef),
210
211    #[serde(rename = "object")]
212    Object(ObjectDef),
213
214    #[serde(rename = "constant")]
215    Constant(ConstantDef),
216
217    #[serde(rename = "function")]
218    Function(FunctionDef),
219
220    #[serde(rename = "callback")]
221    Callback(CallbackDef),
222
223    #[serde(rename = "callback function")]
224    CallbackFunction(CallbackFunctionDef),
225
226    #[serde(rename = "callback info")]
227    CallbackInfo(CallbackInfoDef),
228}
229
230/// Native type definition
231#[derive(Debug, Clone, Deserialize)]
232pub struct NativeType {
233    #[serde(default)]
234    pub tags: Vec<String>,
235
236    #[serde(rename = "wire transparent", default = "default_true")]
237    pub wire_transparent: bool,
238
239    #[serde(rename = "wasm type")]
240    pub wasm_type: Option<String>,
241
242    #[serde(rename = "is nullable pointer")]
243    pub is_nullable_pointer: Option<bool>,
244}
245
246/// Typedef definition
247#[derive(Debug, Clone, Deserialize)]
248pub struct TypedefDef {
249    #[serde(default)]
250    pub tags: Vec<String>,
251
252    #[serde(rename = "type")]
253    pub target_type: String,
254}
255
256pub struct Extension<'a> {
257    pub ty: &'a String,
258    pub tags: &'a Vec<String>,
259}
260
261impl PartialEq for Extension<'_> {
262    fn eq(&self, other: &Self) -> bool {
263        self.ty == other.ty
264    }
265}
266impl Hash for Extension<'_> {
267    fn hash<H: Hasher>(&self, state: &mut H) {
268        self.ty.hash(state);
269    }
270}
271impl Eq for Extension<'_> {}
272
273/// Enum definition
274#[derive(Debug, Clone, Deserialize)]
275pub struct EnumDef {
276    #[serde(default)]
277    pub tags: Vec<String>,
278
279    pub values: Vec<EnumValueDef>,
280
281    #[serde(rename = "emscripten_no_enum_table")]
282    pub emscripten_no_enum_table: Option<bool>,
283
284    #[serde(rename = "emscripten_string_to_int")]
285    pub emscripten_string_to_int: Option<bool>,
286}
287
288/// Bitmask definition - similar to enum but for bitflags
289#[derive(Debug, Clone, Deserialize)]
290pub struct BitmaskDef {
291    #[serde(default)]
292    pub tags: Vec<String>,
293
294    pub values: Vec<EnumValueDef>,
295
296    #[serde(rename = "emscripten_no_enum_table")]
297    pub emscripten_no_enum_table: Option<bool>,
298}
299
300/// An enum or bitmask value
301#[derive(Debug, Clone, Deserialize)]
302pub struct EnumValueDef {
303    pub name: String,
304    pub value: serde_json::Value, // Can be number or string
305
306    #[serde(default)]
307    pub tags: Vec<String>,
308
309    pub jsrepr: Option<String>,
310
311    #[serde(default = "default_true")]
312    pub valid: bool,
313
314    #[serde(rename = "emscripten_string_to_int")]
315    pub emscripten_string_to_int: Option<bool>,
316}
317
318/// Function pointer definition
319#[derive(Debug, Clone, Deserialize)]
320pub struct FunctionPointerDef {
321    #[serde(default)]
322    pub tags: Vec<String>,
323
324    returns: Option<ReturnType>,
325    args: Vec<RecordMember>,
326}
327
328/// Structure definition
329#[derive(Debug, Clone, Deserialize)]
330pub struct StructureDef {
331    #[serde(default)]
332    pub tags: Vec<String>,
333
334    pub members: Vec<RecordMember>,
335
336    #[serde(default)]
337    pub extensible: ExtensibleType,
338
339    pub chained: Option<String>, // "in" or "out"
340
341    #[serde(rename = "chain roots", default)]
342    pub chain_roots: Vec<String>,
343
344    #[serde(rename = "_comment")]
345    pub comment: Option<String>,
346
347    pub out: Option<bool>,
348}
349
350/// Extensible type for structures - can be boolean or directional string
351/// Represents the extensible field which can be either a boolean or a direction string
352#[derive(Debug, Clone, PartialEq, Eq)]
353pub enum ExtensibleType {
354    Direction(String), // "in" or "out"
355    Bool(bool),
356}
357
358impl ExtensibleType {
359    pub fn extensible(&self) -> bool {
360        match self {
361            ExtensibleType::Direction(_) => true,
362            ExtensibleType::Bool(v) => *v,
363        }
364    }
365}
366
367impl Default for ExtensibleType {
368    fn default() -> Self {
369        ExtensibleType::Bool(false)
370    }
371}
372
373impl<'de> Deserialize<'de> for ExtensibleType {
374    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
375    where
376        D: serde::Deserializer<'de>,
377    {
378        use serde::de::Error;
379
380        let value = serde_json::Value::deserialize(deserializer)?;
381
382        match value {
383            serde_json::Value::String(s) => match s.as_str() {
384                "in" | "out" => Ok(ExtensibleType::Direction(s)),
385                _ => Err(D::Error::custom(format!("Invalid direction value: {}", s))),
386            },
387            serde_json::Value::Bool(b) => Ok(ExtensibleType::Bool(b)),
388            _ => Err(D::Error::custom(
389                "Expected string or boolean for extensible field",
390            )),
391        }
392    }
393}
394
395/// Represents the length field which can be either a string (field reference) or a number (literal size)
396#[derive(Debug, Clone, PartialEq, Deserialize)]
397#[serde(untagged)]
398pub enum LengthValue {
399    String(String),
400    Number(u32),
401}
402
403/// Represents the returns field which can be either a simple string or an object with type and optional fields
404#[derive(Debug, Clone, PartialEq, Deserialize)]
405#[serde(untagged)]
406pub enum ReturnType {
407    Simple(String),
408    Complex {
409        #[serde(rename = "type")]
410        return_type: String,
411        #[serde(default)]
412        optional: bool,
413    },
414}
415
416impl LengthValue {
417    /// Check if the length is a string reference
418    pub fn is_string(&self) -> bool {
419        matches!(self, LengthValue::String(_))
420    }
421
422    /// Check if the length is a numeric value
423    pub fn is_number(&self) -> bool {
424        matches!(self, LengthValue::Number(_))
425    }
426
427    /// Get the string value if it's a string reference
428    pub fn as_string(&self) -> Option<&str> {
429        match self {
430            LengthValue::String(s) => Some(s),
431            LengthValue::Number(_) => None,
432        }
433    }
434
435    /// Get the numeric value if it's a number
436    pub fn as_number(&self) -> Option<u32> {
437        match self {
438            LengthValue::String(_) => None,
439            LengthValue::Number(n) => Some(*n),
440        }
441    }
442}
443
444impl ExtensibleType {
445    pub fn is_extensible(&self) -> bool {
446        match self {
447            ExtensibleType::Direction(_) => true,
448            ExtensibleType::Bool(b) => *b,
449        }
450    }
451
452    pub fn direction(&self) -> Option<&str> {
453        match self {
454            ExtensibleType::Direction(dir) => Some(dir),
455            ExtensibleType::Bool(_) => None,
456        }
457    }
458
459    pub fn is_input(&self) -> bool {
460        matches!(self, ExtensibleType::Direction(dir) if dir == "in")
461    }
462
463    pub fn is_output(&self) -> bool {
464        matches!(self, ExtensibleType::Direction(dir) if dir == "out")
465    }
466
467    pub fn is_boolean(&self) -> bool {
468        matches!(self, ExtensibleType::Bool(_))
469    }
470}
471
472impl ReturnType {
473    pub fn get_type(&self) -> &str {
474        match self {
475            ReturnType::Simple(s) => s,
476            ReturnType::Complex { return_type, .. } => return_type,
477        }
478    }
479
480    pub fn is_optional(&self) -> bool {
481        match self {
482            ReturnType::Simple(_) => false,
483            ReturnType::Complex { optional, .. } => *optional,
484        }
485    }
486}
487
488impl FunctionPointerDef {
489    /// Get the return type
490    pub fn returns(&self) -> Option<&ReturnType> {
491        self.returns.as_ref()
492    }
493
494    /// Get the arguments
495    pub fn args(&self) -> &[RecordMember] {
496        &self.args
497    }
498}
499
500/// Object definition (like WebGPU handles)
501#[derive(Debug, Clone, Deserialize)]
502pub struct ObjectDef {
503    #[serde(default)]
504    pub tags: Vec<String>,
505
506    pub methods: Vec<MethodDef>,
507
508    #[serde(rename = "no autolock")]
509    pub no_autolock: Option<bool>,
510}
511
512/// Constant definition
513#[derive(Debug, Clone, Deserialize)]
514pub struct ConstantDef {
515    #[serde(default)]
516    pub tags: Vec<String>,
517
518    #[serde(rename = "type")]
519    pub const_type: String,
520
521    pub value: serde_json::Value,
522
523    pub cpp_value: Option<String>,
524}
525
526/// Function definition
527#[derive(Debug, Clone, Deserialize)]
528pub struct FunctionDef {
529    #[serde(default)]
530    pub tags: Vec<String>,
531
532    pub returns: Option<ReturnType>,
533    pub args: Vec<RecordMember>,
534
535    #[serde(rename = "_comment")]
536    pub comment: Option<String>,
537}
538
539impl FunctionDef {
540    /// Get the return type
541    pub fn returns(&self) -> Option<&ReturnType> {
542        self.returns.as_ref()
543    }
544
545    /// Get the arguments
546    pub fn args(&self) -> &[RecordMember] {
547        &self.args
548    }
549}
550
551/// Callback definition
552#[derive(Debug, Clone, Deserialize)]
553pub struct CallbackDef {
554    #[serde(default)]
555    pub tags: Vec<String>,
556
557    returns: Option<ReturnType>,
558    args: Vec<RecordMember>,
559}
560
561impl CallbackDef {
562    /// Get the return type
563    pub fn returns(&self) -> Option<&ReturnType> {
564        self.returns.as_ref()
565    }
566
567    /// Get the arguments
568    pub fn args(&self) -> &[RecordMember] {
569        &self.args
570    }
571}
572
573/// Callback function definition
574#[derive(Debug, Clone, Deserialize)]
575pub struct CallbackFunctionDef {
576    #[serde(default)]
577    pub tags: Vec<String>,
578
579    pub args: Vec<RecordMember>,
580}
581
582/// Callback info definition
583#[derive(Debug, Clone, Deserialize)]
584pub struct CallbackInfoDef {
585    #[serde(default)]
586    pub tags: Vec<String>,
587
588    pub members: Vec<RecordMember>,
589}
590
591/// A method on an object
592#[derive(Debug, Clone, Deserialize)]
593pub struct MethodDef {
594    pub name: String,
595
596    #[serde(default)]
597    pub tags: Vec<String>,
598
599    pub returns: Option<ReturnType>,
600
601    #[serde(default)]
602    pub args: Vec<RecordMember>,
603
604    #[serde(rename = "no autolock")]
605    pub no_autolock: Option<bool>,
606
607    pub extensible: Option<ExtensibleType>,
608}
609
610impl MethodDef {
611    /// Get the return type
612    pub fn returns(&self) -> Option<&ReturnType> {
613        self.returns.as_ref()
614    }
615
616    /// Get the arguments
617    pub fn args(&self) -> &[RecordMember] {
618        &self.args
619    }
620}
621
622/// A record member (used in function arguments, struct members, etc.)
623#[derive(Debug, Clone, Deserialize)]
624pub struct RecordMember {
625    pub name: String,
626
627    #[serde(rename = "type")]
628    pub member_type: String,
629
630    #[serde(default)]
631    pub annotation: Annotation,
632
633    pub length: Option<LengthValue>,
634
635    #[serde(default)]
636    pub optional: bool,
637
638    pub default: Option<serde_json::Value>,
639
640    #[serde(rename = "wire_is_data_only", default)]
641    pub wire_is_data_only: bool,
642
643    #[serde(rename = "skip_serialize", default)]
644    pub skip_serialize: bool,
645
646    #[serde(rename = "no_default")]
647    pub no_default: Option<bool>,
648
649    #[serde(rename = "array_element_optional")]
650    pub array_element_optional: Option<bool>,
651}
652
653#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
654pub enum Annotation {
655    /// *
656    MutPtr,
657    /// const*
658    ConstPtr,
659    /// const*const*
660    ConstConstPtr,
661    /// value*
662    #[default]
663    Value,
664}
665
666impl Annotation {
667    pub fn is_mut_ptr(&self) -> bool {
668        matches!(self, Annotation::MutPtr)
669    }
670
671    pub fn is_const_ptr(&self) -> bool {
672        matches!(self, Annotation::ConstPtr | Annotation::ConstConstPtr)
673    }
674
675    pub fn is_const_const_ptr(&self) -> bool {
676        matches!(self, Annotation::ConstConstPtr)
677    }
678
679    pub fn is_value(&self) -> bool {
680        matches!(self, Annotation::Value)
681    }
682}
683
684impl<'de> Deserialize<'de> for Annotation {
685    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
686    where
687        D: serde::Deserializer<'de>,
688    {
689        let annotation = String::deserialize(deserializer)?;
690
691        match annotation.as_str() {
692            "*" => Ok(Annotation::MutPtr),
693            "const*" => Ok(Annotation::ConstPtr),
694            "const*const*" => Ok(Annotation::ConstConstPtr),
695            _ => Ok(Annotation::Value),
696        }
697    }
698}
699
700pub struct DawnJsonParser;
701
702impl DawnJsonParser {
703    pub fn parse_file<P: AsRef<Path>>(path: P) -> Result<DawnApi, Box<dyn std::error::Error>> {
704        let content = fs::read_to_string(path)?;
705        Self::parse_string(&content)
706    }
707
708    pub fn parse_string(content: &str) -> Result<DawnApi, Box<dyn std::error::Error>> {
709        let api: DawnApi = serde_json::from_str(content)?;
710        Ok(api)
711    }
712}
713
714pub struct Name {
715    pub canonical_name: String,
716}
717
718impl Name {
719    pub fn new(canonical_name: &str) -> Self {
720        Self {
721            canonical_name: canonical_name.to_string(),
722        }
723    }
724
725    pub fn camel_case(&self) -> String {
726        self.canonical_name.to_upper_camel_case()
727    }
728
729    pub fn snake_case(&self) -> String {
730        let name = self.canonical_name.to_snake_case();
731        if name == "type" {
732            "r#type".into()
733        } else {
734            name
735        }
736    }
737
738    pub fn shouty_snake_case(&self) -> String {
739        self.canonical_name.to_shouty_snake_case()
740    }
741
742    pub fn kebab_case(&self) -> String {
743        self.canonical_name.to_kebab_case()
744    }
745}
746
747fn default_true() -> bool {
748    true
749}