1use crate::{
5 common::uppercase_first_letter,
6 indent::{IndentConfig, IndentedWriter},
7 CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use heck::SnakeCase;
11use include_dir::include_dir as include_directory;
12use phf::phf_set;
13use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
14use std::{
15 collections::BTreeMap,
16 io::{Result, Write},
17 path::PathBuf,
18};
19
20pub struct CodeGenerator<'a> {
21 config: &'a CodeGeneratorConfig,
22 libraries: Vec<String>,
23}
24
25struct OCamlEmitter<'a, T> {
26 out: IndentedWriter<T>,
27 generator: &'a CodeGenerator<'a>,
28 current_namespace: Vec<String>,
29}
30
31impl<'a> CodeGenerator<'a> {
32 pub fn new(config: &'a CodeGeneratorConfig) -> Self {
33 if config.c_style_enums {
34 panic!("OCaml does not support generating c-style enums");
35 }
36 Self {
37 config,
38 libraries: config
39 .external_definitions
40 .keys()
41 .map(|k| k.to_string())
42 .collect::<Vec<_>>(),
43 }
44 }
45
46 pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
47 let current_namespace = self
48 .config
49 .module_name
50 .split('.')
51 .map(String::from)
52 .collect();
53 let mut emitter = OCamlEmitter {
54 out: IndentedWriter::new(out, IndentConfig::Space(2)),
55 generator: self,
56 current_namespace,
57 };
58 emitter.output_preamble()?;
59 let n = registry.len();
60 for (i, (name, format)) in registry.iter().enumerate() {
61 let first = i == 0;
62 let last = i == n - 1;
63 emitter.output_container(name, format, first, last)?;
64 }
65 for (name, _) in registry.iter() {
66 emitter.output_custom_code(name)?;
67 }
68 Ok(())
69 }
70}
71
72static KEYWORDS: phf::Set<&str> = phf_set! {
73 "and", "as", "assert", "asr",
74 "begin", "class", "constraint",
75 "do", "done", "downto", "else",
76 "end", "exception", "external",
77 "false", "for", "fun", "function",
78 "functor", "if", "in", "include",
79 "inherit", "initializer", "land",
80 "lazy", "let", "lor", "lsl",
81 "lsr", "lxor", "match", "method",
82 "mod", "module", "mutable", "new",
83 "nonrec", "object", "of", "open",
84 "or", "private", "rec", "sig",
85 "struct", "then", "to", "true",
86 "try", "type", "val", "virtual",
87 "when", "while", "with", "bool",
88 "string", "bytes", "char", "unit",
89 "option", "float", "list",
90 "int32", "int64"
91};
92
93impl<'a, T> OCamlEmitter<'a, T>
94where
95 T: Write,
96{
97 fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
98 let mut path = self.current_namespace.clone();
99 path.push(name.to_string());
100 if let Some(doc) = self.generator.config.comments.get(&path) {
101 writeln!(self.out, "(*")?;
102 self.out.indent();
103 write!(self.out, "{}", doc)?;
104 self.out.unindent();
105 writeln!(self.out, "*)")?;
106 }
107 Ok(())
108 }
109
110 fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
111 let mut path = self.current_namespace.clone();
112 path.push(name.to_string());
113 if let Some(code) = self.generator.config.custom_code.get(&path) {
114 write!(self.out, "\n{}", code)?;
115 }
116 Ok(())
117 }
118
119 fn output_preamble(&mut self) -> Result<()> {
120 for namespace in self.generator.libraries.iter() {
121 if !namespace.is_empty() {
122 writeln!(self.out, "open {}", uppercase_first_letter(namespace))?
123 }
124 }
125 Ok(())
126 }
127
128 fn safe_snake_case(&self, s: &str) -> String {
129 let s = s.to_snake_case();
130 if KEYWORDS.contains(&*s) {
131 s + "_"
132 } else {
133 s
134 }
135 }
136
137 fn output_format(&mut self, format: &Format, is_struct: bool) -> Result<()> {
138 use Format::*;
139 if is_struct {
140 write!(self.out, "(")?
141 }
142 match format {
143 Variable(_) => panic!("incorrect value"),
144 TypeName(s) => write!(self.out, "{}", self.safe_snake_case(s))?,
145 Unit => write!(self.out, "unit")?,
146 Bool => write!(self.out, "bool")?,
147 I8 => write!(self.out, "Stdint.int8")?,
148 I16 => write!(self.out, "Stdint.int16")?,
149 I32 => write!(self.out, "int32")?,
150 I64 => write!(self.out, "int64")?,
151 I128 => write!(self.out, "Stdint.int128")?,
152 U8 => write!(self.out, "Stdint.uint8")?,
153 U16 => write!(self.out, "Stdint.uint16")?,
154 U32 => write!(self.out, "Stdint.uint32")?,
155 U64 => write!(self.out, "Stdint.uint64")?,
156 U128 => write!(self.out, "Stdint.uint128")?,
157 F32 => write!(self.out, "(float [@float32])")?,
158 F64 => write!(self.out, "float")?,
159 Char => write!(self.out, "char")?,
160 Str => write!(self.out, "string")?,
161 Bytes => write!(self.out, "bytes")?,
162 Option(f) => {
163 self.output_format(f, false)?;
164 write!(self.out, " option")?
165 }
166 Seq(f) => {
167 self.output_format(f, false)?;
168 write!(self.out, " list")?
169 }
170 Map { key, value } => self.output_map(key, value)?,
171 Tuple(fs) => self.output_tuple(fs, false)?,
172 TupleArray { content, size } => {
173 write!(self.out, "(")?;
174 self.output_format(content, false)?;
175 write!(self.out, " array [@length {}])", size)?
176 }
177 }
178 if is_struct {
179 write!(self.out, " [@struct])")?
180 }
181 Ok(())
182 }
183
184 fn output_map(&mut self, key: &Format, value: &Format) -> Result<()> {
185 write!(self.out, "(")?;
186 self.output_format(key, false)?;
187 write!(self.out, ", ")?;
188 self.output_format(value, false)?;
189 write!(self.out, ") Serde.map")
190 }
191
192 fn output_tuple(&mut self, formats: &[Format], is_struct: bool) -> Result<()> {
193 if is_struct {
194 write!(self.out, "(")?
195 }
196 write!(self.out, "(")?;
197 let n = formats.len();
198 formats
199 .iter()
200 .enumerate()
201 .map(|(i, f)| {
202 self.output_format(f, false)?;
203 if i != n - 1 {
204 write!(self.out, " * ")
205 } else {
206 Ok(())
207 }
208 })
209 .collect::<Result<Vec<_>>>()?;
210 write!(self.out, ")")?;
211 if is_struct {
212 write!(self.out, " [@struct])")?
213 }
214 Ok(())
215 }
216
217 fn output_record(&mut self, formats: &[Named<Format>]) -> Result<()> {
218 writeln!(self.out, "{{")?;
219 self.out.indent();
220 formats
221 .iter()
222 .map(|f| {
223 self.output_comment(&f.name)?;
224 write!(self.out, "{}: ", self.safe_snake_case(&f.name))?;
225 self.output_format(&f.value, false)?;
226 writeln!(self.out, ";")
227 })
228 .collect::<Result<Vec<_>>>()?;
229 self.out.unindent();
230 write!(self.out, "}}")
231 }
232
233 fn output_variant(&mut self, format: &VariantFormat) -> Result<()> {
234 use VariantFormat::*;
235 match format {
236 Variable(_) => panic!("incorrect value"),
237 Unit => Ok(()),
238 NewType(f) => {
239 write!(self.out, " of ")?;
240 self.output_format(f, false)
241 }
242 Tuple(fields) if fields.is_empty() => Ok(()),
243 Tuple(fields) => {
244 write!(self.out, " of ")?;
245 self.output_tuple(fields, false)
246 }
247 Struct(fields) if fields.is_empty() => Ok(()),
248 Struct(fields) => {
249 write!(self.out, " of ")?;
250 self.output_record(fields)
251 }
252 }
253 }
254
255 fn output_enum(
256 &mut self,
257 name: &str,
258 formats: &BTreeMap<u32, Named<VariantFormat>>,
259 cyclic: bool,
260 ) -> Result<()> {
261 writeln!(self.out)?;
262 self.out.indent();
263 let c = if cyclic { " [@cyclic]" } else { "" };
264 formats
265 .iter()
266 .map(|(_, f)| {
267 self.output_comment(&f.name)?;
268 write!(self.out, "| {}_{}", name, f.name)?;
269 self.output_variant(&f.value)?;
270 writeln!(self.out, "{}", c)
271 })
272 .collect::<Result<Vec<_>>>()?;
273 self.out.unindent();
274 Ok(())
275 }
276
277 fn is_cyclic(name: &str, format: &Format) -> bool {
278 use Format::*;
279 match format {
280 TypeName(s) => name == s,
281 Option(f) => Self::is_cyclic(name, f),
282 Seq(f) => Self::is_cyclic(name, f),
283 Map { key, value } => Self::is_cyclic(name, key) || Self::is_cyclic(name, value),
284 Tuple(fs) => fs.iter().any(|f| Self::is_cyclic(name, f)),
285 TupleArray { content, size: _ } => Self::is_cyclic(name, content),
286 _ => false,
287 }
288 }
289
290 fn output_container(
291 &mut self,
292 name: &str,
293 format: &ContainerFormat,
294 first: bool,
295 last: bool,
296 ) -> Result<()> {
297 use ContainerFormat::*;
298 self.output_comment(name)?;
299 write!(
300 self.out,
301 "{} {} =",
302 if first { "type" } else { "\nand" },
303 self.safe_snake_case(name)
304 )?;
305 match format {
306 UnitStruct => {
307 write!(self.out, " unit")?;
308 writeln!(self.out)?;
309 }
310 NewTypeStruct(format) if Self::is_cyclic(name, format.as_ref()) => {
311 let mut map = BTreeMap::new();
312 map.insert(
313 0,
314 Named {
315 name: String::new(),
316 value: VariantFormat::NewType(format.clone()),
317 },
318 );
319 self.output_enum(&name.to_camel_case(), &map, true)?;
320 }
321 NewTypeStruct(format) => {
322 write!(self.out, " ")?;
323 self.output_format(format.as_ref(), true)?;
324 writeln!(self.out)?;
325 }
326 TupleStruct(formats) => {
327 write!(self.out, " ")?;
328 self.output_tuple(formats, true)?;
329 writeln!(self.out)?;
330 }
331 Struct(fields) => {
332 write!(self.out, " ")?;
333 self.output_record(fields)?;
334 writeln!(self.out)?;
335 }
336 Enum(variants) => {
337 self.output_enum(&name.to_camel_case(), variants, false)?;
338 }
339 }
340
341 if last && self.generator.config.serialization {
342 writeln!(self.out, "[@@deriving serde]")?;
343 }
344 Ok(())
345 }
346}
347
348pub struct Installer {
349 install_dir: PathBuf,
350}
351
352impl Installer {
353 pub fn new(install_dir: PathBuf) -> Self {
354 Installer { install_dir }
355 }
356
357 fn install_runtime(
358 &self,
359 source_dir: include_dir::Dir,
360 path: &str,
361 ) -> std::result::Result<(), Box<dyn std::error::Error>> {
362 let dir_path = self.install_dir.join(path);
363 std::fs::create_dir_all(&dir_path)?;
364 for entry in source_dir.files() {
365 let mut file = std::fs::File::create(dir_path.join(entry.path()))?;
366 file.write_all(entry.contents())?;
367 }
368 Ok(())
369 }
370}
371
372impl crate::SourceInstaller for Installer {
373 type Error = Box<dyn std::error::Error>;
374
375 fn install_module(
376 &self,
377 config: &CodeGeneratorConfig,
378 registry: &Registry,
379 ) -> std::result::Result<(), Self::Error> {
380 let dir_path = self.install_dir.join(&config.module_name);
381 std::fs::create_dir_all(&dir_path)?;
382 let dune_project_source_path = self.install_dir.join("dune-project");
383 let mut dune_project_file = std::fs::File::create(dune_project_source_path)?;
384 writeln!(dune_project_file, "(lang dune 3.0)")?;
385 let name = config.module_name.to_snake_case();
386
387 if config.package_manifest {
388 let dune_source_path = dir_path.join("dune");
389 let mut dune_file = std::fs::File::create(dune_source_path)?;
390 let mut runtime_str = "";
391 if config.encodings.len() == 1 {
392 for enc in config.encodings.iter() {
393 match enc {
394 Encoding::Bcs => runtime_str = "\n(libraries bcs_runtime)",
395 Encoding::Bincode => runtime_str = "\n(libraries bincode_runtime)",
396 }
397 }
398 }
399 writeln!(
400 dune_file,
401 "(env (_ (flags (:standard -w -30-42 -warn-error -a))))\n\n\
402 (library\n (name {0})\n (modules {0})\n (preprocess (pps ppx)){1})",
403 name, runtime_str
404 )?;
405 }
406
407 let source_path = dir_path.join(format!("{}.ml", name));
408 let mut file = std::fs::File::create(source_path)?;
409 let generator = CodeGenerator::new(config);
410 generator.output(&mut file, registry)?;
411 Ok(())
412 }
413
414 fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
415 self.install_runtime(include_directory!("runtime/ocaml/common"), "common")?;
416 self.install_runtime(include_directory!("runtime/ocaml/virtual"), "virtual")?;
417 self.install_runtime(include_directory!("runtime/ocaml/ppx"), "ppx")?;
418 self.install_runtime(include_directory!("runtime/ocaml/serde"), "serde")
419 }
420
421 fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
422 self.install_runtime(include_directory!("runtime/ocaml/common"), "common")?;
423 self.install_runtime(include_directory!("runtime/ocaml/virtual"), "virtual")?;
424 self.install_runtime(include_directory!("runtime/ocaml/ppx"), "ppx")?;
425 self.install_runtime(include_directory!("runtime/ocaml/serde"), "serde")?;
426 self.install_runtime(include_directory!("runtime/ocaml/bincode"), "bincode")
427 }
428
429 fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
430 self.install_runtime(include_directory!("runtime/ocaml/common"), "common")?;
431 self.install_runtime(include_directory!("runtime/ocaml/virtual"), "virtual")?;
432 self.install_runtime(include_directory!("runtime/ocaml/ppx"), "ppx")?;
433 self.install_runtime(include_directory!("runtime/ocaml/serde"), "serde")?;
434 self.install_runtime(include_directory!("runtime/ocaml/bcs"), "bcs")
435 }
436}