open Base
open Stdio
let excluded_functions =
Set.of_list
(module String)
[ "multi_margin_loss"
; "multi_margin_loss_out"
; "log_softmax_backward_data"
; "softmax_backward_data"
; "clone"
; "copy_"
; "conv_transpose2d_backward_out"
; "conv_transpose3d_backward_out"
; "slow_conv_transpose2d_backward_out"
; "slow_conv_transpose3d_backward_out"
; "slow_conv3d_backward_out"
; "normal"
; "_cufft_set_plan_cache_max_size"
; "_cufft_clear_plan_cache"
; "backward"
; "_amp_non_finite_check_and_unscale_"
; "_cummin_helper"
; "_cummax_helper"
; "retain_grad"
; "_validate_sparse_coo_tensor_args"
; "_backward"
; "size"
; "stride"
; "_assert_async"
; "gradient"
; "linalg_vector_norm"
; "linalg_vector_norm_out"
; "linalg_matrix_norm"
; "linalg_matrix_norm_out"
]
let no_tensor_options =
Set.of_list
(module String)
[ "zeros_like"
; "empty_like"
; "full_like"
; "ones_like"
; "rand_like"
; "randint_like"
; "randn_like"
]
let prefixed_functions =
Set.of_list
(module String)
[ "add"
; "add_"
; "div"
; "div_"
; "mul"
; "mul_"
; "sub"
; "sub_"
; "nll_loss"
; "to_mkldnn"
]
let excluded_prefixes = [ "_thnn_"; "_th_"; "thnn_"; "th_"; "_foreach"; "_amp_foreach" ]
let excluded_suffixes = [ "_forward"; "_forward_out" ]
let yaml_error yaml ~msg = Printf.failwithf "%s, %s" msg (Yaml.to_string_exn yaml) ()
let extract_bool = function
| `Bool b -> b
| `String -> true
| `String -> false
| yaml -> yaml_error yaml ~msg:"expected bool"
let extract_list = function
| `A l -> l
| yaml -> yaml_error yaml ~msg:"expected list"
let extract_map = function
| `O map -> Map.of_alist_exn (module String) map
| yaml -> yaml_error yaml ~msg:"expected map"
let extract_string = function
| `String s -> s
| `Bool b -> if b then "y" else "n"
| `Float f -> Float.to_string f
| yaml -> yaml_error yaml ~msg:"expected string"
module Func = struct
type arg_type =
| Bool
| Int64
| Int64Option
| Double
| DoubleOption
| Tensor
| TensorOption
| IntList
| IntListOption
| DoubleList
| TensorOptList
| TensorList
| TensorOptions
| Scalar
| ScalarType
| Device
| String
type arg =
{ arg_name : string
; arg_type : arg_type
; default_value : string option
}
type t =
{ name : string
; operator_name : string
; overload_name : string
; args : arg list
; returns : [ `fixed of int | `dynamic | `bool | `int64_t | `double | `nothing ]
;
kind : [ `function_ | `method_ ]
}
let arg_type_of_string str ~is_nullable =
match String.lowercase str with
| -> Some Bool
| -> Some (if is_nullable then Int64Option else Int64)
| -> Some (if is_nullable then DoubleOption else Double)
| -> Some (if is_nullable then TensorOption else Tensor)
| -> Some TensorOptions
| -> Some (if is_nullable then IntListOption else IntList)
| -> Some DoubleList
| -> Some TensorOptList
| -> Some TensorList
| -> Some Device
| | -> Some Scalar
| -> Some ScalarType
| -> Some String
| _ -> None
let c_typed_args_list t =
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with
| IntList | IntListOption ->
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
| TensorOptList | TensorList ->
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
| otherwise ->
let simple_type_cstring =
match otherwise with
| Bool -> "int"
| Int64 -> "int64_t"
| Double -> "double"
| Tensor -> "tensor"
| TensorOption -> "tensor"
| ScalarType -> "int"
| Device -> "int"
| Scalar -> "scalar"
| Int64Option
| DoubleOption
| String
| IntList
| IntListOption
| DoubleList
| TensorOptList
| TensorList
| TensorOptions -> assert false
in
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|> String.concat ~sep:", "
let c_args_list args =
List.map args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with
| Scalar | Tensor -> "*" ^ arg_name
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
| Bool -> "(bool)" ^ arg_name
| IntList ->
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
| IntListOption ->
Printf.sprintf
"%s_data == nullptr ? c10::nullopt : \
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
arg_name
arg_name
arg_name
| DoubleList ->
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
| TensorOptList ->
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
| TensorList ->
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
| TensorOptions ->
Printf.sprintf
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
arg_name
arg_name
| Int64Option ->
Printf.sprintf
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
arg_name
arg_name
| DoubleOption ->
Printf.sprintf
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
arg_name
arg_name
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
| _ -> arg_name)
|> String.concat ~sep:", "
let c_call t =
match t.kind with
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
| `method_ ->
(match t.args with
| head :: tail ->
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
| [] ->
Printf.failwithf "Method calls should have at least one argument %s" t.name ())
let replace_map =
Map.of_alist_exn
(module String)
[ "t", "tr"
; "where", "where_"
; "view", "view_"
; "unsafe", "unsafe_"
; "to_device", "to_device_"
]
let rust_name name =
let name =
Map.find replace_map name
|> Option.value ~default:name
|> String.lowercase
|> String.substr_replace_all ~pattern:"__" ~with_:"_"
in
if String.is_prefix name ~prefix:"_" then "internal" ^ name else name
let c_rust_args_list t =
List.map t.args ~f:(fun arg ->
let an = arg.arg_name in
let single_param = Printf.sprintf "%s_: %s" an in
match arg.arg_type with
| Bool -> single_param "c_int"
| Int64 -> single_param "i64"
| Double -> single_param "f64"
| Tensor -> single_param "*mut C_tensor"
| TensorOption -> single_param "*mut C_tensor"
| Scalar -> single_param "*mut C_scalar"
| ScalarType -> single_param "c_int"
| Device -> single_param "c_int"
| String -> Printf.sprintf "%s_ptr: *const u8, %s_len: c_int" an an
| IntList | IntListOption ->
Printf.sprintf "%s_data: *const i64, %s_len: c_int" an an
| DoubleList -> Printf.sprintf "%s_data: *const f64, %s_len: c_int" an an
| TensorOptList ->
Printf.sprintf "%s_data: *const *mut C_tensor, %s_len: c_int" an an
| TensorList ->
Printf.sprintf "%s_data: *const *mut C_tensor, %s_len: c_int" an an
| Int64Option -> Printf.sprintf "%s_v: i64, %s_null: i8" an an
| DoubleOption -> Printf.sprintf "%s_v: f64, %s_null: i8" an an
| TensorOptions -> Printf.sprintf "%s_kind: c_int, %s_device: c_int" an an)
|> String.concat ~sep:", "
let self_name = "self"
let input_name = "input"
let self_tensor arg =
match arg.arg_type with
| Tensor -> String.( = ) arg.arg_name self_name
| _ -> false
let input_tensor arg =
match arg.arg_type with
| Tensor -> String.( = ) arg.arg_name input_name
| _ -> false
let type_parameters t =
let needs_scalar_parameter =
List.exists t.args ~f:(fun arg ->
match arg.arg_type with
| Scalar -> true
| _ -> false)
in
let needs_type_parameter =
List.exists t.args ~f:(fun arg ->
match arg.arg_type with
| TensorOptList | TensorList | TensorOption -> true
| _ -> false)
in
let needs_lifetime_parameter =
List.exists t.args ~f:(fun arg ->
match arg.arg_type with
| IntListOption -> true
| _ -> false)
in
let type_parameter = if needs_type_parameter then [ "T: Borrow<Tensor>" ] else [] in
let scalar_parameter = if needs_scalar_parameter then [ "S: Into<Scalar>" ] else [] in
let lifetime_parameter = if needs_lifetime_parameter then [ "'a" ] else [] in
let parameters = lifetime_parameter @ type_parameter @ scalar_parameter in
match parameters with
| [] -> ""
| p -> "<" ^ String.concat p ~sep:", " ^ ">"
let rust_args_list t =
match List.partition_tf t.args ~f:self_tensor with
| [ self ], args_list -> Some self, args_list
| _, _ ->
(match List.partition_tf t.args ~f:input_tensor with
| [ self ], args_list -> Some self, args_list
| _, _ -> None, t.args)
let rust_typed_args_list t =
let to_string args =
List.map args ~f:(fun arg ->
let rust_arg_type =
match arg.arg_type with
| Bool -> "bool"
| Int64 ->
if String.( = ) arg.arg_name "reduction" then "crate::Reduction" else "i64"
| Double -> "f64"
| Tensor -> "&Tensor"
| TensorOption -> "Option<T>"
| IntList -> "&[i64]"
| IntListOption -> "impl Into<Option<&'a [i64]>>"
| DoubleList -> "&[f64]"
| TensorOptList -> "&[Option<T>]"
| TensorList -> "&[T]"
| String -> "&str"
| TensorOptions -> "(Kind, Device)"
| Int64Option -> "impl Into<Option<i64>>"
| DoubleOption -> "impl Into<Option<f64>>"
| Scalar -> "S"
| ScalarType -> "Kind"
| Device -> "Device"
in
Printf.sprintf "%s: %s" (rust_name arg.arg_name) rust_arg_type)
|> String.concat ~sep:", "
in
let self_arg =
if String.is_suffix t.name ~suffix:"_" || String.( = ) t.name "set_data"
then "&mut self"
else "&self"
in
match List.partition_tf t.args ~f:self_tensor with
| [ self ], args_list ->
Some self.arg_name, Printf.sprintf "%s, %s" self_arg (to_string args_list)
| _, _ ->
(match List.partition_tf t.args ~f:input_tensor with
| [ self ], args_list ->
Some self.arg_name, Printf.sprintf "%s, %s" self_arg (to_string args_list)
| _, _ -> None, to_string t.args)
let rust_return_type t ~fallible =
let returns =
match t.returns with
| `nothing -> None
| `fixed 1 -> Some "Tensor"
| `fixed v ->
List.init v ~f:(fun _ -> "Tensor")
|> String.concat ~sep:", "
|> Printf.sprintf "(%s)"
|> Option.some
| `dynamic -> Some "Vec<Tensor>"
| `bool -> Some "bool"
| `int64_t -> Some "i64"
| `double -> Some "f64"
in
match returns with
| None -> if fallible then Printf.sprintf " -> Result<(), TchError>" else ""
| Some returns ->
if fallible
then Printf.sprintf " -> Result<%s, TchError>" returns
else Printf.sprintf " -> %s" returns
let rust_binding_args t ~self =
List.map t.args ~f:(fun arg ->
let name =
if Option.value_map self ~default:false ~f:(String.( = ) arg.arg_name)
then "self"
else rust_name arg.arg_name
in
match arg.arg_type with
| Tensor -> Printf.sprintf "%s.c_tensor" name
| Scalar -> Printf.sprintf "%s.into().c_scalar" name
| Bool -> Printf.sprintf "if %s { 1 } else { 0 }" name
| ScalarType -> Printf.sprintf "%s.c_int()" name
| Device -> Printf.sprintf "%s.c_int()" name
| TensorOptions -> Printf.sprintf "%s.0.c_int(), %s.1.c_int()" name name
| Int64Option -> Printf.sprintf "%s.unwrap_or(0i64), %s.is_none() as i8" name name
| DoubleOption ->
Printf.sprintf "%s.unwrap_or(std::f64::NAN), %s.is_none() as i8" name name
| String -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
| IntList -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
| IntListOption ->
Printf.sprintf
"%s.as_ref().map_or(std::ptr::null_mut(), |t| t.as_ptr()), \
%s.as_ref().map_or(-1, |t| t.len() as i32)"
name
name
| DoubleList -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
| TensorOptList ->
Printf.sprintf "ptr_list_opt(%s).as_ptr(), %s.len() as i32" name name
| TensorList -> Printf.sprintf "ptr_list(%s).as_ptr(), %s.len() as i32" name name
| TensorOption ->
Printf.sprintf
"%s.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor)"
name
| Int64 when String.( = ) name "reduction" -> "reduction.to_int()"
| _ -> name)
|> String.concat ~sep:",\n "
end
exception Not_a_simple_arg
let read_yaml filename =
let funcs =
In_channel.with_file filename ~f:In_channel.input_lines
|> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-')
|> List.concat_map ~f:(fun lines ->
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
in
printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
List.filter_map funcs ~f:(fun yaml ->
let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in
let operator_name = Map.find_exn map "operator_name" |> extract_string in
let overload_name = Map.find_exn map "overload_name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of =
Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string
in
let arguments = Map.find_exn map "arguments" |> extract_list in
let returns =
let is_tensor returns =
let returns = extract_map returns in
let return_type = Map.find_exn returns "dynamic_type" |> extract_string in
String.( = ) return_type "at::Tensor"
in
let returns = Map.find_exn map "returns" |> extract_list in
if List.is_empty returns
then Some `nothing
else if List.for_all returns ~f:is_tensor
then Some (`fixed (List.length returns))
else (
match returns with
| [ returns ] ->
let return_type =
Map.find_exn (extract_map returns) "dynamic_type" |> extract_string
in
(match return_type with
| -> Some `bool
| -> Some `int64_t
| -> Some `double
| "at::TensorList" | "dynamic_type: const c10::List<c10::optional<Tensor>> &"
-> Some `dynamic
| _ -> None)
| [] | _ :: _ :: _ -> None)
in
let kind =
if List.exists method_of ~f:(String.( = ) "namespace")
then Some `function_
else if List.exists method_of ~f:(String.( = ) "Tensor")
then Some `method_
else None
in
if (not deprecated)
&& (not
(List.exists excluded_prefixes ~f:(fun prefix ->
String.is_prefix name ~prefix)))
&& (not
(List.exists excluded_suffixes ~f:(fun suffix ->
String.is_suffix name ~suffix)))
&& not (Set.mem excluded_functions name)
then
Option.both returns kind
|> Option.bind ~f:(fun (returns, kind) ->
try
let args =
List.filter_map arguments ~f:(fun arg ->
let arg = extract_map arg in
let arg_name = Map.find_exn arg "name" |> extract_string in
let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in
let is_nullable =
Map.find arg "is_nullable"
|> Option.value_map ~default:false ~f:extract_bool
in
let default_value =
Map.find arg "default" |> Option.map ~f:extract_string
in
match Func.arg_type_of_string arg_type ~is_nullable with
| Some Scalar when Option.is_some default_value && not is_nullable
-> None
| Some TensorOptions
when Option.is_some default_value
&& Set.mem no_tensor_options name -> None
| Some arg_type ->
let arg_name =
match arg_name, arg_type with
| , Scalar -> "self_scalar"
| _, _ -> arg_name
in
Some { Func.arg_name; arg_type; default_value }
| None ->
if Option.is_some default_value
then None
else raise Not_a_simple_arg)
in
Some { Func.name; operator_name; overload_name; args; returns; kind }
with
| Not_a_simple_arg -> None)
else None)
let p out_channel s =
Printf.ksprintf
(fun line ->
Out_channel.output_string out_channel line;
Out_channel.output_char out_channel '\n')
s
let write_cpp funcs filename =
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
let pc s = p out_cpp s in
let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph "";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
let c_typed_args_list = Func.c_typed_args_list func in
match func.returns with
| `nothing ->
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " %s;" (Func.c_call func);
pc " )";
pc "}";
pc "";
ph "void atg_%s(%s);" exported_name c_typed_args_list
| `fixed ntensors ->
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
if ntensors = 1
then pc " out__[0] = new torch::Tensor(outputs__);"
else
for i = 0 to ntensors - 1 do
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
done;
pc " )";
pc "}";
pc "";
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
| `dynamic ->
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
pc " int sz = outputs__.size();";
pc
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
sizeof(torch::Tensor*));";
pc " for (int i = 0; i < sz; ++i)";
pc " out__[i] = new torch::Tensor(outputs__[i]);";
pc " out__[sz] = nullptr;";
pc " return out__;";
pc " )";
pc " return nullptr;";
pc "}";
pc "";
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
| (`bool | `int64_t | `double) as returns ->
let c_type =
match returns with
| `bool -> "int"
| `int64_t -> "int64_t"
| `double -> "double"
in
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
pc " PROTECT(";
pc " return %s;" (Func.c_call func);
pc " )";
pc " return 0;";
pc "}";
pc "";
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
let write_fallible_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
pm "#[allow(clippy::all)]";
pm "use torch_sys::*;";
pm "use torch_sys::c_generated::*;";
pm "use crate::{Device, Kind, Scalar, TchError, Tensor};";
pm "use std::convert::Into;";
pm "use std::borrow::Borrow;";
pm "";
pm "fn ptr_list_opt<T: Borrow<Tensor>>(l: &[Option<T>]) -> Vec<*mut C_tensor> {";
pm
" l.iter().map(|x| x.as_ref().map_or(std::ptr::null_mut(), |x| \
x.borrow().c_tensor)).collect()";
pm "}";
pm "";
pm "fn ptr_list<T: Borrow<Tensor>>(l: &[T]) -> Vec<*mut C_tensor> {";
pm " l.iter().map(|x| x.borrow().c_tensor).collect()";
pm "}";
pm "";
pm "impl Tensor {";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
let rust_name = Func.rust_name exported_name in
let self, rust_args_list = Func.rust_typed_args_list func in
pm "";
pm " pub fn f_%s%s(" rust_name (Func.type_parameters func);
pm " %s" rust_args_list;
pm " )%s {" (Func.rust_return_type func ~fallible:true);
List.iter func.args ~f:(fun arg ->
match arg.arg_type with
| DoubleOption | Int64Option | IntListOption ->
pm " let %s = %s.into();" arg.arg_name arg.arg_name
| _ -> ());
match func.returns with
| `dynamic ->
pm " let c_tensors = unsafe_torch_err!(";
pm " atg_%s(" exported_name;
pm " %s));" (Func.rust_binding_args func ~self);
pm " let mut r__ = vec![];";
pm " let mut i = 0;";
pm " loop {";
pm " let c__ = unsafe{*c_tensors.add(i)};";
pm " if c__.is_null() { break }";
pm " r__.push(Tensor {c_tensor: c__});";
pm " i += 1;";
pm " }";
pm " unsafe{libc::free(c_tensors as *mut libc::c_void)}";
pm " Ok(r__)";
pm " }"
| `nothing ->
pm " unsafe_torch_err!(";
pm " atg_%s(" exported_name;
pm " %s" (Func.rust_binding_args func ~self);
pm " ));";
pm " Ok(())";
pm " }"
| `fixed ntensors ->
pm " let mut c_tensors = [std::ptr::null_mut(); %d];" ntensors;
pm " unsafe_torch_err!(";
pm " atg_%s(c_tensors.as_mut_ptr()," exported_name;
pm " %s" (Func.rust_binding_args func ~self);
pm " ));";
let returns =
if ntensors = 1
then "Tensor { c_tensor: c_tensors[0] }"
else
List.init
ntensors
~f:(Printf.sprintf "Tensor { c_tensor: c_tensors[%d] }")
|> String.concat ~sep:", "
|> Printf.sprintf "(%s)"
in
pm " Ok(%s)" returns;
pm " }"
| (`bool | `int64_t | `double) as returns ->
let is_bool =
match returns with
| `bool -> true
| `int64_t | `double -> false
in
pm " let return_;";
pm " unsafe_torch_err!(";
pm " return_ = atg_%s(" exported_name;
pm " %s" (Func.rust_binding_args func ~self);
pm " ));";
let return_ = if is_bool then "return_ != 0" else "return_" in
pm " Ok(%s)" return_;
pm " }");
pm "}")
let write_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
pm "#[allow(clippy::all)]";
pm "use crate::{Device, Kind, Scalar, Tensor};";
pm "use std::convert::Into;";
pm "use std::borrow::Borrow;";
pm "";
pm "impl Tensor {";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
let rust_name = Func.rust_name exported_name in
let rust_name, fallible_rust_name =
if Set.mem prefixed_functions func.name
then "g_" ^ rust_name, "f_" ^ rust_name
else rust_name, "f_" ^ rust_name
in
pm "";
pm " pub fn %s%s(" rust_name (Func.type_parameters func);
let _self, rust_args_list = Func.rust_typed_args_list func in
pm " %s" rust_args_list;
pm " )%s {" (Func.rust_return_type func ~fallible:false);
let self, rust_args_list = Func.rust_args_list func in
let self = if Option.is_some self then "self." else "Tensor::" in
let rust_args_list =
List.map rust_args_list ~f:(fun arg -> Func.rust_name arg.Func.arg_name)
|> String.concat ~sep:", "
in
pm " %s%s(%s).unwrap()" self fallible_rust_name rust_args_list;
pm " }");
pm "}")
let write_ffi funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
pm "#[allow(clippy::all)]";
pm "use crate::{C_scalar, C_tensor};";
pm "use libc::c_int;";
pm "";
pm "extern \"C\" {";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
match func.Func.returns with
| `nothing ->
pm " pub fn atg_%s(%s);" exported_name (Func.c_rust_args_list func)
| `fixed _ ->
pm
" pub fn atg_%s(out__: *mut *mut C_tensor, %s);"
exported_name
(Func.c_rust_args_list func)
| `dynamic ->
pm
" pub fn atg_%s(%s) -> *mut *mut C_tensor;"
exported_name
(Func.c_rust_args_list func)
| (`bool | `int64_t | `double) as returns ->
let rust_type =
match returns with
| `bool -> "c_int"
| `int64_t -> "i64"
| `double -> "f64"
in
pm
" pub fn atg_%s(%s) -> %s;"
exported_name
(Func.c_rust_args_list func)
rust_type);
pm "}")
let methods =
let c name args =
{ Func.name
; operator_name = name
; overload_name = ""
; args
; returns = `fixed 1
; kind = `method_
}
in
let ca arg_name arg_type = { Func.arg_name; arg_type; default_value = None } in
[ c "grad" [ ca "self" Tensor ]
; c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ]
; c "toType" [ ca "self" Tensor; ca "scalar_type" ScalarType ]
; c "to" [ ca "self" Tensor; ca "device" Device ]
]
let run
~yaml_filename
~cpp_filename
~ffi_filename
~wrapper_filename
~fallible_wrapper_filename
=
let funcs = read_yaml yaml_filename in
let funcs = methods @ funcs in
printf "Generating code for %d functions.\n%!" (List.length funcs);
let funcs =
List.map funcs ~f:(fun func -> String.lowercase func.operator_name, func)
|> Map.of_alist_multi (module String)
|> Map.to_alist
|> List.concat_map ~f:(fun (name, funcs) ->
match funcs with
| [] -> assert false
| [ func ] -> [ name, func ]
| funcs ->
let has_empty_overload =
List.exists funcs ~f:(fun (func : Func.t) ->
String.is_empty func.overload_name)
in
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
match Int.compare (String.length f1.name) (String.length f2.name) with
| 0 -> Int.compare (List.length f1.args) (List.length f2.args)
| cmp -> cmp)
|> List.mapi ~f:(fun index (func : Func.t) ->
let operator_name = String.lowercase func.operator_name in
let overload_name = String.lowercase func.overload_name in
let name =
if String.is_empty overload_name
|| (index = 0 && not has_empty_overload)
then operator_name
else if String.is_suffix operator_name ~suffix:"_"
then operator_name ^ overload_name ^ "_"
else operator_name ^ "_" ^ overload_name
in
name, func))
|> Map.of_alist_exn (module String)
in
write_cpp funcs cpp_filename;
write_ffi funcs ffi_filename;
write_wrapper funcs wrapper_filename;
write_fallible_wrapper funcs fallible_wrapper_filename
let () =
run
~yaml_filename:"third_party/pytorch/Declarations-v1.10.0.yaml"
~cpp_filename:"torch-sys/libtch/torch_api_generated"
~ffi_filename:"torch-sys/src/c_generated.rs"
~wrapper_filename:"src/wrappers/tensor_generated.rs"
~fallible_wrapper_filename:"src/wrappers/tensor_fallible_generated.rs"