ocaml_build/
lib.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4#[cfg(feature = "dune")]
5mod dune;
6
7#[cfg(feature = "dune")]
8pub use dune::Dune;
9use syn::MetaList;
10use syn::__private::ToTokens;
11
12struct Source {
13    path: PathBuf,
14    functions: Vec<String>,
15    types: Vec<String>,
16}
17
18pub struct Sigs {
19    base_dir: PathBuf,
20    output: PathBuf,
21    source: Vec<Source>,
22}
23
24fn strip_quotes(s: &str) -> &str {
25    s.trim_start_matches('"').trim_end_matches('"')
26}
27
28fn snake_case(s: &str) -> String {
29    let mut dest = String::new();
30    for c in s.chars() {
31        if !dest.is_empty() && c.is_uppercase() {
32            dest.push('_');
33        }
34        dest.push(c.to_ascii_lowercase());
35    }
36    dest
37}
38
39fn handle(attrs: Vec<syn::Attribute>, mut f: impl FnMut(&str)) {
40    for attr in attrs {
41        let attr_name = attr
42            .path()
43            .segments
44            .iter()
45            .map(|x| x.ident.to_string())
46            .collect::<Vec<_>>()
47            .join("::");
48        if attr_name == "sig" || attr_name == "ocaml::sig" {
49            match &attr.meta {
50                // #[sig] or #[ocaml::sig]
51                syn::Meta::Path(_) => f(""),
52                // #[ocaml::sig("...")]
53                syn::Meta::List(MetaList {
54                    path: _,
55                    delimiter: _,
56                    tokens,
57                }) => match &tokens.clone().into_iter().collect::<Vec<_>>()[..] {
58                    [proc_macro2::TokenTree::Literal(ref sig)] => {
59                        let s = sig.to_string();
60                        let ty = strip_quotes(&s);
61                        f(ty)
62                    }
63                    [] => f(""),
64                    x => {
65                        panic!("Invalid signature: {x:?}");
66                    }
67                },
68                syn::Meta::NameValue(x) => panic!("Invalid signature: {}", x.into_token_stream()),
69            }
70        }
71    }
72}
73
74impl Sigs {
75    pub fn new(p: impl AsRef<Path>) -> Sigs {
76        let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
77        let base_dir = root.join("src");
78        Sigs {
79            base_dir,
80            output: p.as_ref().to_path_buf(),
81            source: Vec::new(),
82        }
83    }
84
85    pub fn with_source_dir(mut self, p: impl AsRef<Path>) -> Sigs {
86        self.base_dir = p.as_ref().to_path_buf();
87        self
88    }
89
90    fn parse(&mut self, path: &Path) -> Result<(), std::io::Error> {
91        let files = std::fs::read_dir(path)?;
92
93        for file in files {
94            let file = file?;
95            if file.metadata()?.is_dir() {
96                self.parse(&file.path())?;
97                continue;
98            }
99
100            if Some(Some("rs")) != file.path().extension().map(|x| x.to_str()) {
101                continue;
102            }
103
104            let path = file.path();
105            let mut src = Source {
106                path: path.clone(),
107                functions: Vec::new(),
108                types: Vec::new(),
109            };
110            let s = std::fs::read_to_string(&path)?;
111            let t: syn::File = syn::parse_str(&s)
112                .unwrap_or_else(|_| panic!("Unable to parse input file: {}", path.display()));
113
114            for item in t.items {
115                match item {
116                    syn::Item::Fn(item_fn) => {
117                        let name = &item_fn.sig.ident;
118                        handle(item_fn.attrs, |ty| {
119                            let def = if item_fn.sig.inputs.len() > 5 {
120                                format!("external {name}: {ty} = \"{name}_bytecode\" \"{name}\"")
121                            } else {
122                                format!("external {name}: {ty} = \"{name}\"")
123                            };
124                            src.functions.push(def);
125                        });
126                    }
127                    syn::Item::Struct(item) => {
128                        let name = snake_case(&item.ident.to_string());
129                        handle(item.attrs, |ty| {
130                            let def = if ty.is_empty() {
131                                format!("type {name}")
132                            } else if !ty.trim_start().starts_with('{') {
133                                format!("type {}{name}{} = {ty}", '{', '}')
134                            } else {
135                                format!("type {name} = {ty}")
136                            };
137                            src.types.push(def);
138                        });
139                    }
140                    syn::Item::Enum(item) => {
141                        let name = snake_case(&item.ident.to_string());
142                        handle(item.attrs, |ty| {
143                            let def = if ty.is_empty() {
144                                format!("type {name}")
145                            } else {
146                                format!("type {name} = {ty}")
147                            };
148                            src.types.push(def);
149                        });
150                    }
151                    syn::Item::Type(item) => {
152                        let name = snake_case(&item.ident.to_string());
153                        handle(item.attrs, |_ty| src.types.push(format!("type {name}")));
154                    }
155                    _ => (),
156                }
157            }
158
159            if !src.functions.is_empty() || !src.types.is_empty() {
160                self.source.push(src);
161            }
162        }
163
164        Ok(())
165    }
166
167    fn generate_ml(&mut self) -> Result<(), std::io::Error> {
168        let mut f = std::fs::File::create(&self.output).unwrap();
169
170        writeln!(f, "(* Generated by ocaml-rs *)\n")?;
171        writeln!(f, "open! Bigarray")?;
172
173        for src in &self.source {
174            writeln!(
175                f,
176                "\n(* file: {} *)\n",
177                src.path.strip_prefix(&self.base_dir).unwrap().display()
178            )?;
179
180            for t in &src.types {
181                writeln!(f, "{t}")?;
182            }
183
184            for func in &src.functions {
185                writeln!(f, "{func}")?;
186            }
187        }
188
189        Ok(())
190    }
191
192    fn generate_mli(&mut self) -> Result<(), std::io::Error> {
193        let filename = self.output.with_extension("mli");
194        let mut f = std::fs::File::create(filename).unwrap();
195
196        writeln!(f, "(* Generated by ocaml-rs *)\n")?;
197        writeln!(f, "open! Bigarray")?;
198
199        for src in &self.source {
200            writeln!(
201                f,
202                "\n(* file: {} *)\n",
203                src.path.strip_prefix(&self.base_dir).unwrap().display()
204            )?;
205
206            for t in &src.types {
207                writeln!(f, "{t}")?;
208            }
209
210            for func in &src.functions {
211                writeln!(f, "{func}")?;
212            }
213        }
214
215        Ok(())
216    }
217
218    pub fn generate(mut self) -> Result<(), std::io::Error> {
219        let dir = self.base_dir.clone();
220        self.parse(&dir)?;
221
222        self.source.sort_by(|a, b| a.path.cmp(&b.path));
223        self.generate_ml()?;
224        self.generate_mli()
225    }
226}