open Ppxlib
open Ast_builder.Default
type expr = {
ser: expression;
de: expression;
}
type ret = {
is_recursive: bool;
exprs: expr;
}
let lid ~loc s = Located.lident ~loc s
let efun ~loc p e = pexp_fun ~loc Nolabel None p e
let elet ~loc p expr = pexp_let ~loc Nonrecursive [ value_binding ~loc ~pat:(pvar ~loc p) ~expr ]
let ser_name s = "Serde.Serialize." ^ s
let de_name s = "Serde.Deserialize." ^ s
let ret ~loc s =
{ is_recursive = false;
exprs = { ser = evar ~loc (ser_name s); de = evar ~loc (de_name s) } }
let check_depth ~loc r = eapply ~loc (evar ~loc "Serde.check_depth") [ r ]
let tuple ~loc l =
let is_recursive = List.exists (fun r -> r.is_recursive) l in
let ptuple = ppat_tuple ~loc (List.mapi (fun i _ -> pvar ~loc ("t" ^ string_of_int i)) l) in
let ser =
efun ~loc ptuple @@
eapply ~loc (evar ~loc (ser_name"concat")) [
elist ~loc @@ List.mapi (fun i r -> eapply ~loc r.exprs.ser [ evar ~loc ("t" ^ string_of_int i) ]) l
] in
let r_de =
pexp_let ~loc Nonrecursive (List.mapi (fun i r ->
value_binding ~loc ~pat:(pvar ~loc ("t" ^ string_of_int i)) ~expr:(
elet ~loc "x" (eapply ~loc r.exprs.de [ evar ~loc "b" ]) @@
pexp_sequence ~loc (eapply ~loc (evar ~loc ":=") [
evar ~loc "depth";
eapply ~loc (evar ~loc "max") [
eapply ~loc (evar ~loc "!") [ evar ~loc "depth"];
pexp_field ~loc (evar ~loc "x") (lid ~loc "Serde.depth") ] ])
(pexp_field ~loc (evar ~loc "x") (lid ~loc "Serde.r")))) l) @@
pexp_tuple ~loc (List.mapi (fun i _ -> evar ~loc ("t" ^ string_of_int i)) l) in
let de =
efun ~loc (pvar ~loc "b") @@
elet ~loc "depth" (eapply ~loc (evar ~loc "ref") [ eint ~loc 0 ]) @@
elet ~loc "r" r_de @@
pexp_record ~loc [
lid ~loc "Serde.depth", eapply ~loc (evar ~loc "!") [ evar ~loc "depth" ];
lid ~loc "Serde.r", evar ~loc "r" ] None in
{ is_recursive; exprs = { ser; de } }
let is_struct l =
List.exists (fun a -> a.attr_name.txt = "struct") l
let incr_depth ~loc e =
elet ~loc "r" e @@
check_depth ~loc @@
pexp_record ~loc [
lid ~loc "Serde.depth", eapply ~loc (evar ~loc "+") [
pexp_field ~loc (evar ~loc "r") (lid ~loc "Serde.depth");
eint ~loc 1
]
] (Some (evar ~loc "r"))
let rec core ~names c =
let loc = c.ptyp_loc in
let r = match c.ptyp_desc with
| Ptyp_var v ->
{ is_recursive = false;
exprs = { ser = evar ~loc ("_" ^ v ^ "_ser"); de = evar ~loc ("_" ^ v ^ "_de") } }
| Ptyp_constr ({txt; _}, args) ->
let id = Longident.name txt in
base ~attrs:c.ptyp_attributes ~loc ~names ~id args
| Ptyp_tuple l ->
let l = List.map (core ~names) l in
tuple ~loc l
| _ -> assert false in
if is_struct c.ptyp_attributes then
let ser = efun ~loc (pvar ~loc "x") @@
incr_depth ~loc (eapply ~loc r.exprs.ser [ evar ~loc "x" ]) in
let de = efun ~loc (pvar ~loc "b") @@
incr_depth ~loc (eapply ~loc r.exprs.de [ evar ~loc "b" ]) in
{ r with exprs = {ser; de} }
else r
and base ?(attrs=[]) ~loc ~names ~id args = match id, args with
| , [] | , [] -> ret ~loc "bool"
| , [] | , [] -> ret ~loc "string"
| , [] | , [] -> ret ~loc "bytes"
| , [] | , [] ->
if List.exists (fun a -> a.attr_name.txt = "float32") attrs then ret ~loc "float32"
else ret ~loc "float64"
| , [] | , [] -> ret ~loc "char"
| , [] -> ret ~loc "unit"
| , [] | , [] | , [] -> ret ~loc "uint8"
| , [] | , [] | , [] -> ret ~loc "uint16"
| , [] | , [] | , [] -> ret ~loc "uint32"
| , [] | , [] | , [] -> ret ~loc "uint64"
| , [] | , [] | , [] -> ret ~loc "uint128"
| , [] | , [] | , [] -> ret ~loc "int8"
| , [] | , [] | , [] -> ret ~loc "int16"
| , [] | , [] | , [] -> ret ~loc "int32"
| , [] | , [] | , [] -> ret ~loc "int64"
| , [] | , [] | , [] -> ret ~loc "int128"
| , [ c ] | , [ c ] ->
let r = core ~names c in
{ r with exprs = {
ser = eapply ~loc (evar ~loc (ser_name "option")) [ r.exprs.ser ];
de = eapply ~loc (evar ~loc (de_name "option")) [ r.exprs.de ] } }
| , [ c ] | , [ c ] ->
let r = core ~names c in
{ r with exprs = {
ser = eapply ~loc (evar ~loc (ser_name "variable")) [ r.exprs.ser ];
de = eapply ~loc (evar ~loc (de_name "variable")) [ r.exprs.de ] } }
| , [ c ] | , [ c ] ->
let r = core ~names c in
let length = List.find_map (fun a -> match a.attr_name.txt, a.attr_payload with
| , PStr [{pstr_desc=Pstr_eval ({pexp_desc=Pexp_constant (Pconst_integer (s, _)); _}, _); _}] ->
Some (int_of_string s)
| _ -> None) attrs in
begin match length with
| None ->
{ r with exprs = {
ser = eapply ~loc (evar ~loc (ser_name "variable")) [ r.exprs.ser ];
de = eapply ~loc (evar ~loc (de_name "variable")) [ r.exprs.de ] } }
| Some length ->
{ r with exprs = {
ser = eapply ~loc (evar ~loc (ser_name "fixed")) [ r.exprs.ser ];
de = eapply ~loc (evar ~loc (de_name "fixed")) [ r.exprs.de; eint ~loc length ] } }
end
| , [k; v] | , [ k; v ] | , [k; v] | , [k; v] ->
let rk = core ~names k in
let rv = core ~names v in
let is_recursive = rk.is_recursive || rv.is_recursive in
{ is_recursive; exprs = {
ser = eapply ~loc (evar ~loc (ser_name "map")) [ rk.exprs.ser; rv.exprs.ser ];
de = eapply ~loc (evar ~loc (de_name "map")) [ rk.exprs.ser; rk.exprs.de; rv.exprs.de ];
} }
| _ ->
let l = List.map (core ~names) args in
let is_recursive = List.mem id names || List.exists (fun r -> r.is_recursive) l in
let ser = eapply ~loc (evar ~loc (id ^ "_ser")) (List.map (fun r -> r.exprs.ser) l) in
let de = eapply ~loc (evar ~loc (id ^ "_de")) (List.map (fun r -> r.exprs.de) l) in
{ is_recursive; exprs = { ser; de } }
let record ~names ~loc ?constructor l =
let lr = List.map (fun pld -> core ~names pld.pld_type) l in
let is_recursive = List.exists (fun r -> r.is_recursive) lr in
let p = ppat_record ~loc (List.map (fun pld -> Located.lident ~loc pld.pld_name.txt, pvar ~loc pld.pld_name.txt) l) Closed in
let fields = List.map (fun pld -> evar ~loc pld.pld_name.txt) l in
let r_ser = eapply ~loc (evar ~loc (ser_name "concat")) [
elist ~loc @@ List.map2 (fun (pld, f) r ->
eapply ~loc:pld.pld_loc r.exprs.ser [ f ]) (List.combine l fields) lr ] in
let ser = match constructor with
| None -> efun ~loc p @@ incr_depth ~loc r_ser
| Some _ -> r_ser in
let r_de =
let re =
pexp_record ~loc (List.map (fun pld ->
lid ~loc pld.pld_name.txt, evar ~loc pld.pld_name.txt) l) None in
pexp_let ~loc Nonrecursive (List.map2 (fun pld r ->
value_binding ~loc ~pat:(pvar ~loc pld.pld_name.txt) ~expr:(
elet ~loc "r" (eapply ~loc r.exprs.de [ evar ~loc "b" ]) @@
pexp_sequence ~loc (eapply ~loc (evar ~loc ":=") [
evar ~loc "depth";
eapply ~loc (evar ~loc "max") [
eapply ~loc (evar ~loc "!") [ evar ~loc "depth"];
pexp_field ~loc (evar ~loc "r") (lid ~loc "Serde.depth") ] ])
(pexp_field ~loc (evar ~loc "r") (lid ~loc "Serde.r")))) l lr) @@
Option.fold ~none:re ~some:(fun id -> pexp_construct ~loc (Located.lident ~loc id) (Some re)) constructor in
let de e =
elet ~loc "depth" (eapply ~loc (evar ~loc "ref") [ eint ~loc 0 ]) @@
elet ~loc "r" r_de e in
let de_end =
pexp_record ~loc [
Located.lident ~loc "Serde.depth",
eapply ~loc (evar ~loc "+") [
eapply ~loc (evar ~loc "!") [ evar ~loc "depth" ]; eint ~loc 1 ];
Located.lident ~loc "Serde.r", evar ~loc "r" ] None in
let de = match constructor with
| None ->
efun ~loc (pvar ~loc "b") @@
de @@ check_depth ~loc @@ de_end
| Some _ ->
de de_end in
{ is_recursive; exprs = { ser; de } },
(if Option.is_some constructor then Some p else None)
let constructor ~loc ~names ~id = function
| Pcstr_tuple [] -> None
| Pcstr_tuple [ c ] -> Some (core ~names c, None)
| Pcstr_tuple l -> Some (core ~names (ptyp_tuple ~loc l), None)
| Pcstr_record l -> Some (record ~names ~loc ~constructor:id l)
let is_cyclic l =
List.find_map (fun pcd ->
if List.exists (fun a -> a.attr_name.txt = "cyclic") pcd.pcd_attributes then
match pcd.pcd_args with
| Pcstr_tuple [ x ] -> Some (pcd.pcd_name.txt, x)
| _ -> None
else None) l
let variant ~names ~loc l =
match is_cyclic l with
| Some (name, c) ->
let r = core ~names c in
let nlid = lid ~loc name in
let ser = efun ~loc (pvar ~loc "x") @@
pexp_let ~loc Nonrecursive [
value_binding ~loc ~pat:(ppat_construct ~loc nlid (Some (pvar ~loc "x"))) ~expr:(evar ~loc "x") ] @@
incr_depth ~loc (eapply ~loc r.exprs.ser [ evar ~loc "x" ]) in
let de = efun ~loc (pvar ~loc "b") @@
elet ~loc "r" (eapply ~loc r.exprs.de [ evar ~loc "b" ]) @@
check_depth ~loc @@
pexp_record ~loc [
lid ~loc "Serde.r",
pexp_construct ~loc nlid (Some (pexp_field ~loc (evar ~loc "r") (lid ~loc "Serde.r")));
lid ~loc "Serde.depth",
eapply ~loc (evar ~loc "+") [
pexp_field ~loc (evar ~loc "r") (lid ~loc "Serde.depth");
eint ~loc 1 ]
] None in
{r with exprs = {ser; de}}
| None ->
let lr = List.mapi (fun i pcd ->
i, constructor ~names ~loc:pcd.pcd_loc ~id:pcd.pcd_name.txt pcd.pcd_args ) l in
let is_recursive = List.exists (fun (_, r) -> Option.fold ~none:false ~some:(fun (r, _) -> r.is_recursive) r) lr in
let ser =
efun ~loc (pvar ~loc "x") @@
incr_depth ~loc (
pexp_match ~loc (evar ~loc "x") @@ List.map2 (fun pcd (i, r) ->
let loc = pcd.pcd_loc in
let c = lid ~loc pcd.pcd_name.txt in
let p = Option.map (fun (_, p) -> Option.value ~default:(pvar ~loc "x") p) r in
case ~guard:None ~lhs:(ppat_construct ~loc c p)
~rhs:(eapply ~loc (evar ~loc (ser_name "concat")) [
elist ~loc @@
eapply ~loc (evar ~loc @@ ser_name "variant_index") [ eint ~loc i] ::
Option.fold ~none:[] ~some:(fun (r, o) ->
Option.fold ~none:[ eapply ~loc r.exprs.ser [evar ~loc "x"] ]
~some:(fun _ -> [ r.exprs.ser ]) o
) r ])) l lr) in
let l_expr = elist ~loc @@ List.map2 (fun pcd (_, r) ->
let loc = pcd.pcd_loc in
let p = Option.fold ~none:(ppat_any ~loc) ~some:(fun _ -> pvar ~loc "b") r in
let e = match r with
| None ->
pexp_record ~loc [
lid ~loc "Serde.r",
pexp_construct ~loc (lid ~loc pcd.pcd_name.txt) None;
lid ~loc "Serde.depth", eint ~loc 1 ] None
| Some (r, None) ->
elet ~loc "r" (eapply ~loc r.exprs.de [ evar ~loc "b" ]) @@
pexp_record ~loc [
Located.lident ~loc "Serde.r", pexp_construct ~loc (Located.lident ~loc pcd.pcd_name.txt)
(Some (pexp_field ~loc (evar ~loc "r") (Located.lident ~loc "Serde.r")));
Located.lident ~loc "Serde.depth",
eapply ~loc (evar ~loc "+") [
pexp_field ~loc (evar ~loc "r") (Located.lident ~loc "Serde.depth");
eint ~loc 1 ]
] None
| Some (r, Some _) -> r.exprs.de in
efun ~loc p e) l lr in
let de = efun ~loc (pvar ~loc "b") @@
elet ~loc "tag" (eapply ~loc (evar ~loc @@ de_name "variant_index") [ evar ~loc "b" ]) @@
elet ~loc "l" l_expr @@
pexp_match ~loc (eapply ~loc (evar ~loc "List.nth_opt") [ evar ~loc "l"; evar ~loc "tag" ]) [
case ~guard:None ~lhs:(ppat_construct ~loc (lid ~loc "None") None)
~rhs:(eapply ~loc (evar ~loc "failwith") [estring ~loc "no case matched"]);
case ~guard:None ~lhs:(ppat_construct ~loc (lid ~loc "Some") (Some (pvar ~loc "de")))
~rhs:(check_depth ~loc (eapply ~loc (evar ~loc "de") [ evar ~loc "b" ]));
] in
{is_recursive; exprs = { ser; de } }
let ptype ~names t =
let loc = t.ptype_loc in
match t.ptype_kind, t.ptype_manifest with
| Ptype_abstract, Some c -> core ~names c
| Ptype_variant l, _ -> variant ~names ~loc l
| Ptype_record l, _ -> fst @@ record ~names ~loc l
| _ -> Location.raise_errorf ~loc "type not handled"
let rec add_params ~loc:_ r = function
| [] -> r.exprs.ser, r.exprs.de
| ({ptyp_desc=Ptyp_var v; ptyp_loc=loc; _}, _) :: t ->
let (e_ser, e_de) = add_params ~loc r t in
efun ~loc (pvar ~loc ("_" ^ v ^ "_ser")) e_ser,
efun ~loc (pvar ~loc ("_" ^ v ^ "_de")) e_de
| ({ptyp_loc=loc; _}, _) :: _ -> Location.raise_errorf ~loc "param not handled"
let str_gen ~loc ~path:_ (_rec_flag, l) debug =
let names = List.map (fun t -> t.ptype_name.txt) l in
let lr = List.map (ptype ~names) l in
let rec_flag =
if List.exists (fun r -> r.is_recursive) lr then Recursive
else Nonrecursive in
let l = List.map2 (fun t r ->
let loc = t.ptype_loc in
let name = t.ptype_name.txt in
let ser, de = add_params ~loc r t.ptype_params in [
value_binding ~loc ~pat:(pvar ~loc (name ^ "_ser")) ~expr:ser;
value_binding ~loc ~pat:(pvar ~loc (name ^ "_de")) ~expr:de;
]) l lr in
let str = [ pstr_value ~loc rec_flag @@ List.flatten l ] in
if debug then Format.printf "%s@." @@ Pprintast.string_of_structure str;
str
let () =
let args_str = Deriving.Args.(empty +> flag "debug") in
let str_type_decl = Deriving.Generator.make args_str str_gen in
Deriving.ignore @@ Deriving.add "serde" ~str_type_decl