tch 0.0.1

PyTorch wrappers for rust
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
(* Automatically generate the C++ -> C -> rust bindings.
   This takes as input the Descriptions.yaml file that gets generated when
   building PyTorch from source.

   Run with: dune exec gen/gen.exe
 *)
open Base
open Stdio

let excluded_functions =
  Set.of_list
    (module String)
    [ "multi_margin_loss"
    ; "multi_margin_loss_out"
    ; "log_softmax_backward_data"
    ; "softmax_backward_data" ]

let prefixed_functions =
  Set.of_list
    (module String)
    ["add"; "add_"; "div"; "div_"; "mul"; "mul_"; "sub"; "sub_"; "nll_loss"]

let excluded_prefixes = ["_"; "thnn_"; "th_"]

let excluded_suffixes = ["_forward"; "_forward_out"]

let yaml_error yaml ~msg =
  Printf.failwithf "%s, %s" msg (Yaml.to_string_exn yaml) ()

let extract_bool = function
  | `Bool b -> b
  | `String "true" -> true
  | `String "false" -> false
  | yaml -> yaml_error yaml ~msg:"expected bool"

let extract_list = function
  | `A l -> l
  | yaml -> yaml_error yaml ~msg:"expected list"

let extract_map = function
  | `O map -> Map.of_alist_exn (module String) map
  | yaml -> yaml_error yaml ~msg:"expected map"

let extract_string = function
  | `String s -> s
  | yaml -> yaml_error yaml ~msg:"expected string"

module Func = struct
  type arg_type =
    | Bool
    | Int64
    | Double
    | Tensor
    | TensorOption
    | IntList
    | TensorList
    | TensorOptions
    | Scalar
    | ScalarType
    | Device

  type arg =
    {arg_name: string; arg_type: arg_type; default_value: string option}

  type t =
    { name: string
    ; args: arg list
    ; returns: int (* number of tensors that are returned *)
    ; kind: [`function_ | `method_] }

  let arg_type_of_string str ~is_nullable =
    match String.lowercase str with
    | "bool" -> Some Bool
    | "int64_t" -> Some Int64
    | "double" -> Some Double
    | "booltensor" | "indextensor" | "tensor" ->
        Some (if is_nullable then TensorOption else Tensor)
    | "tensoroptions" -> Some TensorOptions
    | "intlist" -> Some IntList
    | "tensorlist" -> Some TensorList
    | "device" -> Some Device
    | "scalar" -> Some Scalar
    | "scalartype" -> Some ScalarType
    | _ -> None

  let c_typed_args_list t =
    List.map t.args ~f:(fun {arg_name; arg_type; _} ->
        match arg_type with
        | IntList ->
            Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
        | TensorList ->
            Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
        | TensorOptions ->
            Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
        | otherwise ->
            let simple_type_cstring =
              match otherwise with
              | Bool -> "int"
              | Int64 -> "int64_t"
              | Double -> "double"
              | Tensor -> "tensor"
              | TensorOption -> "tensor"
              | ScalarType -> "int"
              | Device -> "int"
              | Scalar -> "scalar"
              | IntList | TensorList | TensorOptions -> assert false
            in
            Printf.sprintf "%s %s" simple_type_cstring arg_name )
    |> String.concat ~sep:", "

  let c_args_list args =
    List.map args ~f:(fun {arg_name; arg_type; _} ->
        match arg_type with
        | Scalar | Tensor -> "*" ^ arg_name
        | TensorOption ->
            Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
        | Bool -> "(bool)" ^ arg_name
        | IntList ->
            Printf.sprintf "torch::IntList(%s_data, %s_len)" arg_name arg_name
        | TensorList ->
            Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name
              arg_name
        | TensorOptions ->
            Printf.sprintf
              "at::device(at::DeviceType(%s_device)).dtype(at::ScalarType(%s_kind))"
              arg_name arg_name
        | ScalarType -> Printf.sprintf "torch::ScalarType(%s)" arg_name
        | Device ->
            Printf.sprintf "torch::Device(torch::DeviceType(%s))" arg_name
        | _ -> arg_name )
    |> String.concat ~sep:", "

  let c_call t =
    match t.kind with
    | `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
    | `method_ -> (
      match t.args with
      | head :: tail ->
          Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
      | [] ->
          Printf.failwithf "Method calls should have at least one argument %s"
            t.name () )

  let replace_map =
    Map.of_alist_exn
      (module String)
      [("end", "end_"); ("to", "to_"); ("t", "tr"); ("where", "where_")]

  let rust_name name =
    Map.find replace_map name |> Option.value ~default:name |> String.lowercase

  let c_rust_args_list t =
    List.map t.args ~f:(fun arg ->
        let an = arg.arg_name in
        let single_param = Printf.sprintf "%s_: %s" an in
        match arg.arg_type with
        | Bool -> single_param "c_int"
        | Int64 -> single_param "i64"
        | Double -> single_param "f64"
        | Tensor -> single_param "*mut C_tensor"
        | TensorOption -> single_param "*mut C_tensor"
        | Scalar -> single_param "*mut C_scalar"
        | ScalarType -> single_param "c_int"
        | Device -> single_param "c_int"
        | IntList -> Printf.sprintf "%s_data: *const i64, %s_len: c_int" an an
        | TensorList ->
            Printf.sprintf "%s_data: *const *mut C_tensor, %s_len: c_int" an an
        | TensorOptions ->
            Printf.sprintf "%s_kind: c_int, %s_device: c_int" an an )
    |> String.concat ~sep:", "

  let self_name = "self"

  let input_name = "input"

  let self_tensor arg =
    match arg.arg_type with
    | Tensor -> String.( = ) arg.arg_name self_name
    | _ -> false

  let input_tensor arg =
    match arg.arg_type with
    | Tensor -> String.( = ) arg.arg_name input_name
    | _ -> false

  let rust_args_list t =
    let to_string args =
      List.map args ~f:(fun arg ->
          let rust_arg_type =
            match arg.arg_type with
            | Bool -> "bool"
            | Int64 -> "i64"
            | Double -> "f64"
            | Tensor -> "&Tensor"
            | TensorOption -> "Option<&Tensor>"
            | IntList -> "&[i64]"
            | TensorList -> "&[&Tensor]"
            | TensorOptions -> "(Kind, Device)"
            | Scalar -> "&Scalar"
            | ScalarType -> "Kind"
            | Device -> "Device"
          in
          Printf.sprintf "%s: %s" (rust_name arg.arg_name) rust_arg_type )
      |> String.concat ~sep:", "
    in
    match List.partition_tf t.args ~f:self_tensor with
    | [self], args_list ->
        (Some self.arg_name, Printf.sprintf "&self, %s" (to_string args_list))
    | _, _ -> (
      match List.partition_tf t.args ~f:input_tensor with
      | [self], args_list ->
          (Some self.arg_name, Printf.sprintf "&self, %s" (to_string args_list))
      | _, _ -> (None, to_string t.args) )

  let rust_return_type t =
    match t.returns with
    | 0 -> ""
    | 1 -> " -> Tensor"
    | v ->
        List.init v ~f:(fun _ -> "Tensor")
        |> String.concat ~sep:", " |> Printf.sprintf " -> (%s)"

  let rust_binding_args t ~self =
    List.map t.args ~f:(fun arg ->
        let name =
          if
            Option.value_map self ~default:false ~f:(String.( = ) arg.arg_name)
          then "self"
          else rust_name arg.arg_name
        in
        match arg.arg_type with
        | Tensor -> Printf.sprintf "%s.c_tensor" name
        | Scalar -> Printf.sprintf "%s.c_scalar" name
        | Bool -> Printf.sprintf "if %s { 1 } else { 0 }" name
        | ScalarType -> Printf.sprintf "%s.c_int()" name
        | Device -> Printf.sprintf "%s.c_int()" name
        | TensorOptions ->
            Printf.sprintf "%s.0.c_int(), %s.1.c_int()" name name
        | IntList -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
        | TensorList ->
            Printf.sprintf "ptr_list(%s).as_ptr(), %s.len() as i32" name name
        | TensorOption ->
            Printf.sprintf "%s.map_or(std::ptr::null_mut(), |t| t.c_tensor)"
              name
        | _ -> name )
    |> String.concat ~sep:",\n                "
end

exception Not_a_simple_arg

let read_yaml filename =
  let funcs =
    (* Split the file to avoid Yaml.of_string_exn segfaulting. *)
    In_channel.with_file filename ~f:In_channel.input_lines
    |> List.group ~break:(fun _ l ->
           String.length l > 0 && Char.( = ) l.[0] '-' )
    |> List.concat_map ~f:(fun lines ->
           Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list
       )
  in
  printf "Read %s, got %d functions.\n%!" filename (List.length funcs) ;
  List.filter_map funcs ~f:(fun yaml ->
      let map = extract_map yaml in
      let name = Map.find_exn map "name" |> extract_string in
      let deprecated = Map.find_exn map "deprecated" |> extract_bool in
      let method_of =
        Map.find_exn map "method_of"
        |> extract_list |> List.map ~f:extract_string
      in
      let arguments = Map.find_exn map "arguments" |> extract_list in
      let returns =
        let is_tensor returns =
          let returns = extract_map returns in
          let return_type =
            Map.find_exn returns "dynamic_type" |> extract_string
          in
          String.( = ) return_type "Tensor"
          || String.( = ) return_type "BoolTensor"
          || String.( = ) return_type "IndexTensor"
        in
        let returns = Map.find_exn map "returns" |> extract_list in
        if List.for_all returns ~f:is_tensor then Some (List.length returns)
        else None
      in
      let kind =
        if List.exists method_of ~f:(String.( = ) "namespace") then
          Some `function_
        else if List.exists method_of ~f:(String.( = ) "Tensor") then
          Some `method_
        else None
      in
      if
        (not deprecated)
        && (not
              (List.exists excluded_prefixes ~f:(fun prefix ->
                   String.is_prefix name ~prefix )))
        && (not
              (List.exists excluded_suffixes ~f:(fun suffix ->
                   String.is_suffix name ~suffix )))
        && not (Set.mem excluded_functions name)
      then
        Option.both returns kind
        |> Option.bind ~f:(fun (returns, kind) ->
               try
                 let args =
                   List.filter_map arguments ~f:(fun arg ->
                       let arg = extract_map arg in
                       let arg_name =
                         Map.find_exn arg "name" |> extract_string
                       in
                       let arg_type =
                         Map.find_exn arg "dynamic_type" |> extract_string
                       in
                       let is_nullable =
                         Map.find arg "is_nullable"
                         |> Option.value_map ~default:false ~f:extract_bool
                       in
                       let default_value =
                         Map.find arg "default" |> Option.map ~f:extract_string
                       in
                       match Func.arg_type_of_string arg_type ~is_nullable with
                       | Some Scalar
                         when Option.is_some default_value && not is_nullable
                         ->
                           None
                       | Some arg_type ->
                           let arg_name =
                             match (arg_name, arg_type) with
                             | "self", Scalar -> "self_scalar"
                             | _, _ -> arg_name
                           in
                           Some {Func.arg_name; arg_type; default_value}
                       | None ->
                           if Option.is_some default_value then None
                           else raise Not_a_simple_arg )
                 in
                 Some {Func.name; args; returns; kind}
               with Not_a_simple_arg -> None )
      else None )

let p out_channel s =
  Printf.ksprintf
    (fun line ->
      Out_channel.output_string out_channel line ;
      Out_channel.output_char out_channel '\n' )
    s

let write_cpp funcs filename =
  Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
      Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
          let pc s = p out_cpp s in
          let ph s = p out_h s in
          pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
          pc "" ;
          ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
          ph "" ;
          Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
              let c_typed_args_list = Func.c_typed_args_list func in
              pc "void atg_%s(tensor *out__, %s) {" exported_name
                c_typed_args_list ;
              pc "  PROTECT(" ;
              pc "    auto outputs__ = %s;" (Func.c_call func) ;
              if func.returns = 1 then
                pc "    out__[0] = new torch::Tensor(outputs__);"
              else
                for i = 0 to func.returns - 1 do
                  pc
                    "    out__[%d] = new \
                     torch::Tensor(std::get<%d>(outputs__));"
                    i i
                done ;
              pc "  )" ;
              pc "}" ;
              pc "" ;
              ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
          ) ) )

let write_wrapper funcs filename =
  Out_channel.with_file filename ~f:(fun out_ml ->
      let pm s = p out_ml s in
      pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
      pm "#[allow(clippy::all)]" ;
      pm "use crate::device::Device;" ;
      pm "use crate::kind::Kind;" ;
      pm "use crate::scalar::{C_scalar, Scalar};" ;
      pm "use libc::c_int;" ;
      pm "use super::c_wrapper::{C_tensor, Tensor};" ;
      pm "" ;
      pm "extern \"C\" {" ;
      Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
          pm "    fn atg_%s(out__: *mut *mut C_tensor, %s);" exported_name
            (Func.c_rust_args_list func) ) ;
      pm "}" ;
      pm "" ;
      pm "fn ptr_list(l: &[&Tensor]) -> Vec<*mut C_tensor> {" ;
      pm "    l.iter().map(|x| x.c_tensor).collect()" ;
      pm "}" ;
      pm "" ;
      pm "impl Tensor {" ;
      Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
          let rust_name = Func.rust_name exported_name in
          let rust_name =
            if Set.mem prefixed_functions func.name then "g_" ^ rust_name
            else rust_name
          in
          let returns =
            match func.returns with
            | 0 -> ""
            | 1 -> "Tensor { c_tensor: c_tensors[0] }"
            | n ->
                List.init n
                  ~f:(Printf.sprintf "Tensor { c_tensor: c_tensors[%d] }")
                |> String.concat ~sep:", " |> Printf.sprintf "(%s)"
          in
          pm "" ;
          pm "    pub fn %s(" rust_name ;
          let self, rust_args_list = Func.rust_args_list func in
          pm "        %s" rust_args_list ;
          pm "    )%s {" (Func.rust_return_type func) ;
          pm "        let mut c_tensors = [std::ptr::null_mut(); %d];"
            func.returns ;
          pm "        unsafe_torch!({" ;
          pm "            atg_%s(c_tensors.as_mut_ptr()," exported_name ;
          pm "                %s" (Func.rust_binding_args func ~self) ;
          pm "            ) });" ;
          pm "        %s" returns ;
          pm "    }" ) ;
      pm "}" )

let methods =
  let c name args = {Func.name; args; returns= 1; kind= `method_} in
  let ca arg_name arg_type = {Func.arg_name; arg_type; default_value= None} in
  [ c "grad" [ca "self" Tensor]
  ; c "set_requires_grad" [ca "self" Tensor; ca "r" Bool]
  ; c "toType" [ca "self" Tensor; ca "scalar_type" ScalarType]
  ; c "to" [ca "self" Tensor; ca "device" Device] ]

let run ~yaml_filename ~cpp_filename ~wrapper_filename =
  let funcs = read_yaml yaml_filename in
  let funcs = methods @ funcs in
  printf "Generating code for %d functions.\n%!" (List.length funcs) ;
  (* Generate some unique names for overloaded functions. *)
  let funcs =
    List.map funcs ~f:(fun func -> (String.lowercase func.name, func))
    |> Map.of_alist_multi (module String)
    |> Map.to_alist
    |> List.concat_map ~f:(fun (name, funcs) ->
           match funcs with
           | [] -> assert false
           | [func] -> [(name, func)]
           | funcs ->
               List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
                   Int.compare (List.length f1.args) (List.length f2.args) )
               |> List.mapi ~f:(fun i func ->
                      ( (if i = 0 then name else Printf.sprintf "%s%d" name i)
                      , func ) ) )
    |> Map.of_alist_exn (module String)
  in
  write_cpp funcs cpp_filename ;
  write_wrapper funcs wrapper_filename

let () =
  run ~yaml_filename:"data/Declarations.yaml"
    ~cpp_filename:"libtch/torch_api_generated"
    ~wrapper_filename:"src/tensor/c_wrapper_generated.rs"