1use anyhow::Result;
2use std::collections::{btree_map::Entry, BTreeMap, HashMap, HashSet};
3use std::ops::Deref;
4use std::path::Path;
5use cosmian_wit_parser::abi::Abi;
6use cosmian_wit_parser::*;
7
8pub use cosmian_wit_parser;
9mod ns;
10
11pub use ns::Ns;
12
13#[derive(Copy, Clone, Eq, PartialEq)]
37pub enum Direction {
38 Import,
39 Export,
40}
41
42pub trait Generator {
43 fn preprocess_all(&mut self, imports: &[Interface], exports: &[Interface]) {
44 drop((imports, exports));
45 }
46
47 fn preprocess_one(&mut self, iface: &Interface, dir: Direction) {
48 drop((iface, dir));
49 }
50
51 fn type_record(
52 &mut self,
53 iface: &Interface,
54 id: TypeId,
55 name: &str,
56 record: &Record,
57 docs: &Docs,
58 );
59 fn type_variant(
60 &mut self,
61 iface: &Interface,
62 id: TypeId,
63 name: &str,
64 variant: &Variant,
65 docs: &Docs,
66 );
67 fn type_resource(&mut self, iface: &Interface, ty: ResourceId);
68 fn type_alias(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs);
69 fn type_list(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs);
70 fn type_pointer(
71 &mut self,
72 iface: &Interface,
73 id: TypeId,
74 name: &str,
75 const_: bool,
76 ty: &Type,
77 docs: &Docs,
78 );
79 fn type_builtin(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs);
80 fn type_push_buffer(
81 &mut self,
82 iface: &Interface,
83 id: TypeId,
84 name: &str,
85 ty: &Type,
86 docs: &Docs,
87 );
88 fn type_pull_buffer(
89 &mut self,
90 iface: &Interface,
91 id: TypeId,
92 name: &str,
93 ty: &Type,
94 docs: &Docs,
95 );
96 fn import(&mut self, iface: &Interface, func: &Function);
98 fn export(&mut self, iface: &Interface, func: &Function);
99
100 fn finish_one(&mut self, iface: &Interface, files: &mut Files);
101
102 fn finish_all(&mut self, files: &mut Files) {
103 drop(files);
104 }
105
106 fn generate_one(&mut self, iface: &Interface, dir: Direction, files: &mut Files) {
107 self.preprocess_one(iface, dir);
108
109 for (id, ty) in iface.types.iter() {
110 let name = match &ty.name {
112 Some(name) => name,
113 None => continue,
114 };
115 match &ty.kind {
116 TypeDefKind::Record(record) => self.type_record(iface, id, name, record, &ty.docs),
117 TypeDefKind::Variant(variant) => {
118 self.type_variant(iface, id, name, variant, &ty.docs)
119 }
120 TypeDefKind::List(t) => self.type_list(iface, id, name, t, &ty.docs),
121 TypeDefKind::PushBuffer(t) => self.type_push_buffer(iface, id, name, t, &ty.docs),
122 TypeDefKind::PullBuffer(t) => self.type_pull_buffer(iface, id, name, t, &ty.docs),
123 TypeDefKind::Type(t) => self.type_alias(iface, id, name, t, &ty.docs),
124 TypeDefKind::Pointer(t) => self.type_pointer(iface, id, name, false, t, &ty.docs),
125 TypeDefKind::ConstPointer(t) => {
126 self.type_pointer(iface, id, name, true, t, &ty.docs)
127 }
128 }
129 }
130
131 for (id, _resource) in iface.resources.iter() {
132 self.type_resource(iface, id);
133 }
134
135 for f in iface.functions.iter() {
140 match dir {
141 Direction::Import => self.import(iface, &f),
142 Direction::Export => self.export(iface, &f),
143 }
144 }
145
146 self.finish_one(iface, files)
147 }
148
149 fn generate_all(&mut self, imports: &[Interface], exports: &[Interface], files: &mut Files) {
150 self.preprocess_all(imports, exports);
151
152 for imp in imports {
153 self.generate_one(imp, Direction::Import, files);
154 }
155
156 for exp in exports {
157 self.generate_one(exp, Direction::Export, files);
158 }
159
160 self.finish_all(files);
161 }
162}
163
164#[derive(Default)]
165pub struct Types {
166 type_info: HashMap<TypeId, TypeInfo>,
167 handle_dtors: HashSet<ResourceId>,
168 dtor_funcs: HashSet<String>,
169}
170
171#[derive(Default, Clone, Copy)]
172pub struct TypeInfo {
173 pub param: bool,
176
177 pub result: bool,
180
181 pub has_list: bool,
183
184 pub has_handle: bool,
186
187 pub has_push_buffer: bool,
189
190 pub has_pull_buffer: bool,
192}
193
194impl std::ops::BitOrAssign for TypeInfo {
195 fn bitor_assign(&mut self, rhs: Self) {
196 self.param |= rhs.param;
197 self.result |= rhs.result;
198 self.has_list |= rhs.has_list;
199 self.has_handle |= rhs.has_handle;
200 self.has_push_buffer |= rhs.has_push_buffer;
201 self.has_pull_buffer |= rhs.has_pull_buffer;
202 }
203}
204
205impl Types {
206 pub fn analyze(&mut self, iface: &Interface) {
207 for (t, _) in iface.types.iter() {
208 self.type_id_info(iface, t);
209 }
210 for f in iface.functions.iter() {
211 for (_, ty) in f.params.iter() {
212 self.set_param_result_ty(iface, ty, true, false);
213 }
214 for (_, ty) in f.results.iter() {
215 self.set_param_result_ty(iface, ty, false, true);
216 }
217 self.maybe_set_preview1_dtor(iface, f);
218 }
219 }
220
221 fn maybe_set_preview1_dtor(&mut self, iface: &Interface, f: &Function) {
222 match f.abi {
223 Abi::Preview1 => {}
224 _ => return,
225 }
226
227 if f.params.len() != 1 {
229 return;
230 }
231
232 let name = f.name.as_str();
234 let prefix = match name.strip_suffix("_close") {
235 Some(prefix) => prefix,
236 None => return,
237 };
238
239 let resource = match find_handle(iface, &f.params[0].1) {
242 Some(id) => id,
243 None => return,
244 };
245 if iface.resources[resource].name != prefix {
246 return;
247 }
248
249 self.handle_dtors.insert(resource);
250 self.dtor_funcs.insert(f.name.to_string());
251
252 fn find_handle(iface: &Interface, ty: &Type) -> Option<ResourceId> {
253 match ty {
254 Type::Handle(r) => Some(*r),
255 Type::Id(id) => match &iface.types[*id].kind {
256 TypeDefKind::Type(t) => find_handle(iface, t),
257 _ => None,
258 },
259 _ => None,
260 }
261 }
262 }
263
264 pub fn get(&self, id: TypeId) -> TypeInfo {
265 self.type_info[&id]
266 }
267
268 pub fn has_preview1_dtor(&self, resource: ResourceId) -> bool {
269 self.handle_dtors.contains(&resource)
270 }
271
272 pub fn is_preview1_dtor_func(&self, func: &Function) -> bool {
273 self.dtor_funcs.contains(&func.name)
274 }
275
276 pub fn type_id_info(&mut self, iface: &Interface, ty: TypeId) -> TypeInfo {
277 if let Some(info) = self.type_info.get(&ty) {
278 return *info;
279 }
280 let mut info = TypeInfo::default();
281 match &iface.types[ty].kind {
282 TypeDefKind::Record(r) => {
283 for field in r.fields.iter() {
284 info |= self.type_info(iface, &field.ty);
285 }
286 }
287 TypeDefKind::Variant(v) => {
288 for case in v.cases.iter() {
289 if let Some(ty) = &case.ty {
290 info |= self.type_info(iface, ty);
291 }
292 }
293 }
294 TypeDefKind::List(ty) => {
295 info = self.type_info(iface, ty);
296 info.has_list = true;
297 }
298 TypeDefKind::PushBuffer(ty) => {
299 info = self.type_info(iface, ty);
300 info.has_push_buffer = true;
301 }
302 TypeDefKind::PullBuffer(ty) => {
303 info = self.type_info(iface, ty);
304 info.has_pull_buffer = true;
305 }
306 TypeDefKind::ConstPointer(ty) | TypeDefKind::Pointer(ty) | TypeDefKind::Type(ty) => {
307 info = self.type_info(iface, ty)
308 }
309 }
310 self.type_info.insert(ty, info);
311 return info;
312 }
313
314 pub fn type_info(&mut self, iface: &Interface, ty: &Type) -> TypeInfo {
315 let mut info = TypeInfo::default();
316 match ty {
317 Type::Handle(_) => info.has_handle = true,
318 Type::Id(id) => return self.type_id_info(iface, *id),
319 _ => {}
320 }
321 info
322 }
323
324 fn set_param_result_id(&mut self, iface: &Interface, ty: TypeId, param: bool, result: bool) {
325 match &iface.types[ty].kind {
326 TypeDefKind::Record(r) => {
327 for field in r.fields.iter() {
328 self.set_param_result_ty(iface, &field.ty, param, result)
329 }
330 }
331 TypeDefKind::Variant(v) => {
332 for case in v.cases.iter() {
333 if let Some(ty) = &case.ty {
334 self.set_param_result_ty(iface, ty, param, result)
335 }
336 }
337 }
338 TypeDefKind::List(ty)
339 | TypeDefKind::PushBuffer(ty)
340 | TypeDefKind::PullBuffer(ty)
341 | TypeDefKind::Pointer(ty)
342 | TypeDefKind::ConstPointer(ty) => self.set_param_result_ty(iface, ty, param, result),
343 TypeDefKind::Type(ty) => self.set_param_result_ty(iface, ty, param, result),
344 }
345 }
346
347 fn set_param_result_ty(&mut self, iface: &Interface, ty: &Type, param: bool, result: bool) {
348 match ty {
349 Type::Id(id) => {
350 self.type_id_info(iface, *id);
351 let info = self.type_info.get_mut(id).unwrap();
352 if (param && !info.param) || (result && !info.result) {
353 info.param = info.param || param;
354 info.result = info.result || result;
355 self.set_param_result_id(iface, *id, param, result);
356 }
357 }
358 _ => {}
359 }
360 }
361}
362
363#[derive(Default)]
364pub struct Files {
365 files: BTreeMap<String, Vec<u8>>,
366}
367
368impl Files {
369 pub fn push(&mut self, name: &str, contents: &[u8]) {
370 match self.files.entry(name.to_owned()) {
371 Entry::Vacant(entry) => {
372 entry.insert(contents.to_owned());
373 }
374 Entry::Occupied(ref mut entry) => {
375 entry.get_mut().extend_from_slice(contents);
376 }
377 }
378 }
379
380 pub fn iter(&self) -> impl Iterator<Item = (&'_ str, &'_ [u8])> {
381 self.files.iter().map(|p| (p.0.as_str(), p.1.as_slice()))
382 }
383}
384
385pub fn load(path: impl AsRef<Path>) -> Result<Interface> {
386 Interface::parse_file(path)
387}
388
389#[derive(Default)]
390pub struct Source {
391 s: String,
392 indent: usize,
393}
394
395impl Source {
396 pub fn push_str(&mut self, src: &str) {
397 let lines = src.lines().collect::<Vec<_>>();
398 for (i, line) in lines.iter().enumerate() {
399 let trimmed = line.trim();
400 if trimmed.starts_with("}") && self.s.ends_with(" ") {
401 self.s.pop();
402 self.s.pop();
403 }
404 self.s.push_str(if lines.len() == 1 {
405 line
406 } else {
407 line.trim_start()
408 });
409 if trimmed.ends_with('{') {
410 self.indent += 1;
411 }
412 if trimmed.starts_with('}') {
413 self.indent -= 1;
414 }
415 if i != lines.len() - 1 || src.ends_with("\n") {
416 self.newline();
417 }
418 }
419 }
420
421 pub fn indent(&mut self, amt: usize) {
422 self.indent += amt;
423 }
424
425 pub fn deindent(&mut self, amt: usize) {
426 self.indent -= amt;
427 }
428
429 fn newline(&mut self) {
430 self.s.push_str("\n");
431 for _ in 0..self.indent {
432 self.s.push_str(" ");
433 }
434 }
435
436 pub fn as_mut_string(&mut self) -> &mut String {
437 &mut self.s
438 }
439}
440
441impl Deref for Source {
442 type Target = str;
443 fn deref(&self) -> &str {
444 &self.s
445 }
446}
447
448impl From<Source> for String {
449 fn from(s: Source) -> String {
450 s.s
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::{Generator, Source};
457
458 #[test]
459 fn simple_append() {
460 let mut s = Source::default();
461 s.push_str("x");
462 assert_eq!(s.s, "x");
463 s.push_str("y");
464 assert_eq!(s.s, "xy");
465 s.push_str("z ");
466 assert_eq!(s.s, "xyz ");
467 s.push_str(" a ");
468 assert_eq!(s.s, "xyz a ");
469 s.push_str("\na");
470 assert_eq!(s.s, "xyz a \na");
471 }
472
473 #[test]
474 fn newline_remap() {
475 let mut s = Source::default();
476 s.push_str("function() {\n");
477 s.push_str("y\n");
478 s.push_str("}\n");
479 assert_eq!(s.s, "function() {\n y\n}\n");
480 }
481
482 #[test]
483 fn if_else() {
484 let mut s = Source::default();
485 s.push_str("if() {\n");
486 s.push_str("y\n");
487 s.push_str("} else if () {\n");
488 s.push_str("z\n");
489 s.push_str("}\n");
490 assert_eq!(s.s, "if() {\n y\n} else if () {\n z\n}\n");
491 }
492
493 #[test]
494 fn trim_ws() {
495 let mut s = Source::default();
496 s.push_str(
497 "function() {
498 x
499 }",
500 );
501 assert_eq!(s.s, "function() {\n x\n}");
502 }
503
504 #[test]
505 fn generator_is_object_safe() {
506 fn _assert(_: &dyn Generator) {}
507 }
508}