1use crate::{
5    common::uppercase_first_letter,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use heck::SnakeCase;
11use include_dir::include_dir as include_directory;
12use phf::phf_set;
13use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
14use std::{
15    collections::BTreeMap,
16    io::{Result, Write},
17    path::PathBuf,
18};
19
20pub struct CodeGenerator<'a> {
21    config: &'a CodeGeneratorConfig,
22    libraries: Vec<String>,
23}
24
25struct OCamlEmitter<'a, T> {
26    out: IndentedWriter<T>,
27    generator: &'a CodeGenerator<'a>,
28    current_namespace: Vec<String>,
29}
30
31impl<'a> CodeGenerator<'a> {
32    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
33        if config.c_style_enums {
34            panic!("OCaml does not support generating c-style enums");
35        }
36        Self {
37            config,
38            libraries: config
39                .external_definitions
40                .keys()
41                .map(|k| k.to_string())
42                .collect::<Vec<_>>(),
43        }
44    }
45
46    pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
47        let current_namespace = self
48            .config
49            .module_name
50            .split('.')
51            .map(String::from)
52            .collect();
53        let mut emitter = OCamlEmitter {
54            out: IndentedWriter::new(out, IndentConfig::Space(2)),
55            generator: self,
56            current_namespace,
57        };
58        emitter.output_preamble()?;
59        let n = registry.len();
60        for (i, (name, format)) in registry.iter().enumerate() {
61            let first = i == 0;
62            let last = i == n - 1;
63            emitter.output_container(name, format, first, last)?;
64        }
65        for (name, _) in registry.iter() {
66            emitter.output_custom_code(name)?;
67        }
68        Ok(())
69    }
70}
71
72static KEYWORDS: phf::Set<&str> = phf_set! {
73    "and", "as", "assert", "asr",
74    "begin", "class", "constraint",
75    "do", "done", "downto", "else",
76    "end", "exception", "external",
77    "false", "for", "fun", "function",
78    "functor", "if", "in", "include",
79    "inherit", "initializer", "land",
80    "lazy", "let", "lor", "lsl",
81    "lsr", "lxor", "match", "method",
82    "mod", "module", "mutable", "new",
83    "nonrec", "object", "of", "open",
84    "or", "private", "rec", "sig",
85    "struct", "then", "to", "true",
86    "try", "type", "val", "virtual",
87    "when", "while", "with", "bool",
88    "string", "bytes", "char", "unit",
89    "option", "float", "list",
90    "int32", "int64"
91};
92
93impl<'a, T> OCamlEmitter<'a, T>
94where
95    T: Write,
96{
97    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
98        let mut path = self.current_namespace.clone();
99        path.push(name.to_string());
100        if let Some(doc) = self.generator.config.comments.get(&path) {
101            writeln!(self.out, "(*")?;
102            self.out.indent();
103            write!(self.out, "{}", doc)?;
104            self.out.unindent();
105            writeln!(self.out, "*)")?;
106        }
107        Ok(())
108    }
109
110    fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
111        let mut path = self.current_namespace.clone();
112        path.push(name.to_string());
113        if let Some(code) = self.generator.config.custom_code.get(&path) {
114            write!(self.out, "\n{}", code)?;
115        }
116        Ok(())
117    }
118
119    fn output_preamble(&mut self) -> Result<()> {
120        for namespace in self.generator.libraries.iter() {
121            if !namespace.is_empty() {
122                writeln!(self.out, "open {}", uppercase_first_letter(namespace))?
123            }
124        }
125        Ok(())
126    }
127
128    fn safe_snake_case(&self, s: &str) -> String {
129        let s = s.to_snake_case();
130        if KEYWORDS.contains(&*s) {
131            s + "_"
132        } else {
133            s
134        }
135    }
136
137    fn output_format(&mut self, format: &Format, is_struct: bool) -> Result<()> {
138        use Format::*;
139        if is_struct {
140            write!(self.out, "(")?
141        }
142        match format {
143            Variable(_) => panic!("incorrect value"),
144            TypeName(s) => write!(self.out, "{}", self.safe_snake_case(s))?,
145            Unit => write!(self.out, "unit")?,
146            Bool => write!(self.out, "bool")?,
147            I8 => write!(self.out, "Stdint.int8")?,
148            I16 => write!(self.out, "Stdint.int16")?,
149            I32 => write!(self.out, "int32")?,
150            I64 => write!(self.out, "int64")?,
151            I128 => write!(self.out, "Stdint.int128")?,
152            U8 => write!(self.out, "Stdint.uint8")?,
153            U16 => write!(self.out, "Stdint.uint16")?,
154            U32 => write!(self.out, "Stdint.uint32")?,
155            U64 => write!(self.out, "Stdint.uint64")?,
156            U128 => write!(self.out, "Stdint.uint128")?,
157            F32 => write!(self.out, "(float [@float32])")?,
158            F64 => write!(self.out, "float")?,
159            Char => write!(self.out, "char")?,
160            Str => write!(self.out, "string")?,
161            Bytes => write!(self.out, "bytes")?,
162            Option(f) => {
163                self.output_format(f, false)?;
164                write!(self.out, " option")?
165            }
166            Seq(f) => {
167                self.output_format(f, false)?;
168                write!(self.out, " list")?
169            }
170            Map { key, value } => self.output_map(key, value)?,
171            Tuple(fs) => self.output_tuple(fs, false)?,
172            TupleArray { content, size } => {
173                write!(self.out, "(")?;
174                self.output_format(content, false)?;
175                write!(self.out, " array [@length {}])", size)?
176            }
177        }
178        if is_struct {
179            write!(self.out, " [@struct])")?
180        }
181        Ok(())
182    }
183
184    fn output_map(&mut self, key: &Format, value: &Format) -> Result<()> {
185        write!(self.out, "(")?;
186        self.output_format(key, false)?;
187        write!(self.out, ", ")?;
188        self.output_format(value, false)?;
189        write!(self.out, ") Serde.map")
190    }
191
192    fn output_tuple(&mut self, formats: &[Format], is_struct: bool) -> Result<()> {
193        if is_struct {
194            write!(self.out, "(")?
195        }
196        write!(self.out, "(")?;
197        let n = formats.len();
198        formats
199            .iter()
200            .enumerate()
201            .map(|(i, f)| {
202                self.output_format(f, false)?;
203                if i != n - 1 {
204                    write!(self.out, " * ")
205                } else {
206                    Ok(())
207                }
208            })
209            .collect::<Result<Vec<_>>>()?;
210        write!(self.out, ")")?;
211        if is_struct {
212            write!(self.out, " [@struct])")?
213        }
214        Ok(())
215    }
216
217    fn output_record(&mut self, formats: &[Named<Format>]) -> Result<()> {
218        writeln!(self.out, "{{")?;
219        self.out.indent();
220        formats
221            .iter()
222            .map(|f| {
223                self.output_comment(&f.name)?;
224                write!(self.out, "{}: ", self.safe_snake_case(&f.name))?;
225                self.output_format(&f.value, false)?;
226                writeln!(self.out, ";")
227            })
228            .collect::<Result<Vec<_>>>()?;
229        self.out.unindent();
230        write!(self.out, "}}")
231    }
232
233    fn output_variant(&mut self, format: &VariantFormat) -> Result<()> {
234        use VariantFormat::*;
235        match format {
236            Variable(_) => panic!("incorrect value"),
237            Unit => Ok(()),
238            NewType(f) => {
239                write!(self.out, " of ")?;
240                self.output_format(f, false)
241            }
242            Tuple(fields) if fields.is_empty() => Ok(()),
243            Tuple(fields) => {
244                write!(self.out, " of ")?;
245                self.output_tuple(fields, false)
246            }
247            Struct(fields) if fields.is_empty() => Ok(()),
248            Struct(fields) => {
249                write!(self.out, " of ")?;
250                self.output_record(fields)
251            }
252        }
253    }
254
255    fn output_enum(
256        &mut self,
257        name: &str,
258        formats: &BTreeMap<u32, Named<VariantFormat>>,
259        cyclic: bool,
260    ) -> Result<()> {
261        writeln!(self.out)?;
262        self.out.indent();
263        let c = if cyclic { " [@cyclic]" } else { "" };
264        formats
265            .iter()
266            .map(|(_, f)| {
267                self.output_comment(&f.name)?;
268                write!(self.out, "| {}_{}", name, f.name)?;
269                self.output_variant(&f.value)?;
270                writeln!(self.out, "{}", c)
271            })
272            .collect::<Result<Vec<_>>>()?;
273        self.out.unindent();
274        Ok(())
275    }
276
277    fn is_cyclic(name: &str, format: &Format) -> bool {
278        use Format::*;
279        match format {
280            TypeName(s) => name == s,
281            Option(f) => Self::is_cyclic(name, f),
282            Seq(f) => Self::is_cyclic(name, f),
283            Map { key, value } => Self::is_cyclic(name, key) || Self::is_cyclic(name, value),
284            Tuple(fs) => fs.iter().any(|f| Self::is_cyclic(name, f)),
285            TupleArray { content, size: _ } => Self::is_cyclic(name, content),
286            _ => false,
287        }
288    }
289
290    fn output_container(
291        &mut self,
292        name: &str,
293        format: &ContainerFormat,
294        first: bool,
295        last: bool,
296    ) -> Result<()> {
297        use ContainerFormat::*;
298        self.output_comment(name)?;
299        write!(
300            self.out,
301            "{} {} =",
302            if first { "type" } else { "\nand" },
303            self.safe_snake_case(name)
304        )?;
305        match format {
306            UnitStruct => {
307                write!(self.out, " unit")?;
308                writeln!(self.out)?;
309            }
310            NewTypeStruct(format) if Self::is_cyclic(name, format.as_ref()) => {
311                let mut map = BTreeMap::new();
312                map.insert(
313                    0,
314                    Named {
315                        name: String::new(),
316                        value: VariantFormat::NewType(format.clone()),
317                    },
318                );
319                self.output_enum(&name.to_camel_case(), &map, true)?;
320            }
321            NewTypeStruct(format) => {
322                write!(self.out, " ")?;
323                self.output_format(format.as_ref(), true)?;
324                writeln!(self.out)?;
325            }
326            TupleStruct(formats) => {
327                write!(self.out, " ")?;
328                self.output_tuple(formats, true)?;
329                writeln!(self.out)?;
330            }
331            Struct(fields) => {
332                write!(self.out, " ")?;
333                self.output_record(fields)?;
334                writeln!(self.out)?;
335            }
336            Enum(variants) => {
337                self.output_enum(&name.to_camel_case(), variants, false)?;
338            }
339        }
340
341        if last && self.generator.config.serialization {
342            writeln!(self.out, "[@@deriving serde]")?;
343        }
344        Ok(())
345    }
346}
347
348pub struct Installer {
349    install_dir: PathBuf,
350}
351
352impl Installer {
353    pub fn new(install_dir: PathBuf) -> Self {
354        Installer { install_dir }
355    }
356
357    fn install_runtime(
358        &self,
359        source_dir: include_dir::Dir,
360        path: &str,
361    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
362        let dir_path = self.install_dir.join(path);
363        std::fs::create_dir_all(&dir_path)?;
364        for entry in source_dir.files() {
365            let mut file = std::fs::File::create(dir_path.join(entry.path()))?;
366            file.write_all(entry.contents())?;
367        }
368        Ok(())
369    }
370}
371
372impl crate::SourceInstaller for Installer {
373    type Error = Box<dyn std::error::Error>;
374
375    fn install_module(
376        &self,
377        config: &CodeGeneratorConfig,
378        registry: &Registry,
379    ) -> std::result::Result<(), Self::Error> {
380        let dir_path = self.install_dir.join(&config.module_name);
381        std::fs::create_dir_all(&dir_path)?;
382        let dune_project_source_path = self.install_dir.join("dune-project");
383        let mut dune_project_file = std::fs::File::create(dune_project_source_path)?;
384        writeln!(dune_project_file, "(lang dune 3.0)")?;
385        let name = config.module_name.to_snake_case();
386
387        if config.package_manifest {
388            let dune_source_path = dir_path.join("dune");
389            let mut dune_file = std::fs::File::create(dune_source_path)?;
390            let mut runtime_str = "";
391            if config.encodings.len() == 1 {
392                for enc in config.encodings.iter() {
393                    match enc {
394                        Encoding::Bcs => runtime_str = "\n(libraries bcs_runtime)",
395                        Encoding::Bincode => runtime_str = "\n(libraries bincode_runtime)",
396                    }
397                }
398            }
399            writeln!(
400                dune_file,
401                "(env (_ (flags (:standard -w -30-42 -warn-error -a))))\n\n\
402                (library\n (name {0})\n (modules {0})\n (preprocess (pps ppx)){1})",
403                name, runtime_str
404            )?;
405        }
406
407        let source_path = dir_path.join(format!("{}.ml", name));
408        let mut file = std::fs::File::create(source_path)?;
409        let generator = CodeGenerator::new(config);
410        generator.output(&mut file, registry)?;
411        Ok(())
412    }
413
414    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
415        self.install_runtime(include_directory!("runtime/ocaml/common"), "common")?;
416        self.install_runtime(include_directory!("runtime/ocaml/virtual"), "virtual")?;
417        self.install_runtime(include_directory!("runtime/ocaml/ppx"), "ppx")?;
418        self.install_runtime(include_directory!("runtime/ocaml/serde"), "serde")
419    }
420
421    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
422        self.install_runtime(include_directory!("runtime/ocaml/common"), "common")?;
423        self.install_runtime(include_directory!("runtime/ocaml/virtual"), "virtual")?;
424        self.install_runtime(include_directory!("runtime/ocaml/ppx"), "ppx")?;
425        self.install_runtime(include_directory!("runtime/ocaml/serde"), "serde")?;
426        self.install_runtime(include_directory!("runtime/ocaml/bincode"), "bincode")
427    }
428
429    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
430        self.install_runtime(include_directory!("runtime/ocaml/common"), "common")?;
431        self.install_runtime(include_directory!("runtime/ocaml/virtual"), "virtual")?;
432        self.install_runtime(include_directory!("runtime/ocaml/ppx"), "ppx")?;
433        self.install_runtime(include_directory!("runtime/ocaml/serde"), "serde")?;
434        self.install_runtime(include_directory!("runtime/ocaml/bcs"), "bcs")
435    }
436}