Skip to main content

tch/wrappers/
jit.rs

1//! JIT interface to run model trained/saved using PyTorch Python API.
2use super::utils::{path_to_cstring, ptr_to_string};
3use super::{device::Device, kind::Kind};
4use crate::{nn::Path, TchError, Tensor};
5use libc::{c_int, c_void};
6use std::borrow::Borrow;
7use std::convert::TryFrom;
8use torch_sys::*;
9
10/// Argument and output values for JIT models. These represent arbitrary values,
11/// e.g. tensors, atomic values, pairs of values, etc.
12#[derive(Debug, PartialEq)]
13#[non_exhaustive]
14pub enum IValue {
15    None,
16    Tensor(crate::Tensor),
17    Double(f64),
18    Int(i64),
19    Bool(bool),
20    Tuple(Vec<IValue>),
21    IntList(Vec<i64>),
22    DoubleList(Vec<f64>),
23    BoolList(Vec<bool>),
24    String(String),
25    StringList(Vec<String>),
26    TensorList(Vec<crate::Tensor>),
27    GenericList(Vec<IValue>),
28    // We use a vec to represent dictionaries as f64 does not implement
29    // Eq or Hash out of the box in rust. TODO: improve this?
30    GenericDict(Vec<(IValue, IValue)>),
31    Object(Object),
32    Device(Device),
33}
34
35impl IValue {
36    fn type_str(self) -> &'static str {
37        match self {
38            IValue::None => "None",
39            IValue::Tensor(_) => "Tensor",
40            IValue::Double(_) => "Double",
41            IValue::Int(_) => "Int",
42            IValue::Bool(_) => "Bool",
43            IValue::Tuple(_) => "Tuple",
44            IValue::IntList(_) => "IntList",
45            IValue::DoubleList(_) => "DoubleList",
46            IValue::BoolList(_) => "BoolList",
47            IValue::String(_) => "String",
48            IValue::StringList(_) => "StringList",
49            IValue::TensorList(_) => "TensorList",
50            IValue::GenericList(_) => "GenericList",
51            IValue::GenericDict(_) => "GenericDict",
52            IValue::Object(_) => "Object",
53            IValue::Device(_) => "Device",
54        }
55    }
56}
57
58impl From<()> for IValue {
59    fn from((): ()) -> Self {
60        IValue::None
61    }
62}
63
64impl<T1: Into<IValue>, T2: Into<IValue>> From<(T1, T2)> for IValue {
65    fn from((p1, p2): (T1, T2)) -> Self {
66        IValue::Tuple(vec![p1.into(), p2.into()])
67    }
68}
69
70impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>> From<(T1, T2, T3)> for IValue {
71    fn from((p1, p2, p3): (T1, T2, T3)) -> Self {
72        IValue::Tuple(vec![p1.into(), p2.into(), p3.into()])
73    }
74}
75
76impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>, T4: Into<IValue>> From<(T1, T2, T3, T4)>
77    for IValue
78{
79    fn from((p1, p2, p3, p4): (T1, T2, T3, T4)) -> Self {
80        IValue::Tuple(vec![p1.into(), p2.into(), p3.into(), p4.into()])
81    }
82}
83
84impl<T1, T2, T1E, T2E> TryFrom<IValue> for (T1, T2)
85where
86    T1: TryFrom<IValue, Error = T1E>,
87    TchError: From<T1E>,
88    T2: TryFrom<IValue, Error = T2E>,
89    TchError: From<T2E>,
90{
91    type Error = TchError;
92    fn try_from(value: IValue) -> Result<Self, TchError> {
93        match value {
94            IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
95                if vec.len() == 2 {
96                    let t2 = T2::try_from(vec.pop().unwrap())?;
97                    let t1 = T1::try_from(vec.pop().unwrap())?;
98                    Ok((t1, t2))
99                } else {
100                    Err(TchError::Kind(format!(
101                        "unable to unpack ivalue, expected a tuple of len 2 got {}",
102                        vec.len()
103                    )))
104                }
105            }
106            _ => Err(TchError::Kind(format!(
107                "unable to unpack ivalue, expected a tuple got {}",
108                value.type_str()
109            ))),
110        }
111    }
112}
113
114impl<T1, T2, T3, T1E, T2E, T3E> TryFrom<IValue> for (T1, T2, T3)
115where
116    T1: TryFrom<IValue, Error = T1E>,
117    TchError: From<T1E>,
118    T2: TryFrom<IValue, Error = T2E>,
119    TchError: From<T2E>,
120    T3: TryFrom<IValue, Error = T3E>,
121    TchError: From<T3E>,
122{
123    type Error = TchError;
124    fn try_from(value: IValue) -> Result<Self, TchError> {
125        match value {
126            IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
127                if vec.len() == 3 {
128                    let t3 = T3::try_from(vec.pop().unwrap())?;
129                    let t2 = T2::try_from(vec.pop().unwrap())?;
130                    let t1 = T1::try_from(vec.pop().unwrap())?;
131                    Ok((t1, t2, t3))
132                } else {
133                    Err(TchError::Kind(format!(
134                        "unable to unpack ivalue, expected a tuple of len 3 got {}",
135                        vec.len()
136                    )))
137                }
138            }
139            _ => Err(TchError::Kind(format!(
140                "unable to unpack ivalue, expected a tuple got {}",
141                value.type_str()
142            ))),
143        }
144    }
145}
146
147impl<T1, T2, T3, T4, T1E, T2E, T3E, T4E> TryFrom<IValue> for (T1, T2, T3, T4)
148where
149    T1: TryFrom<IValue, Error = T1E>,
150    TchError: From<T1E>,
151    T2: TryFrom<IValue, Error = T2E>,
152    TchError: From<T2E>,
153    T3: TryFrom<IValue, Error = T3E>,
154    TchError: From<T3E>,
155    T4: TryFrom<IValue, Error = T4E>,
156    TchError: From<T4E>,
157{
158    type Error = TchError;
159    fn try_from(value: IValue) -> Result<Self, TchError> {
160        match value {
161            IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
162                if vec.len() == 4 {
163                    let t4 = T4::try_from(vec.pop().unwrap())?;
164                    let t3 = T3::try_from(vec.pop().unwrap())?;
165                    let t2 = T2::try_from(vec.pop().unwrap())?;
166                    let t1 = T1::try_from(vec.pop().unwrap())?;
167                    Ok((t1, t2, t3, t4))
168                } else {
169                    Err(TchError::Kind(format!(
170                        "unable to unpack ivalue, expected a tuple of len 4 got {}",
171                        vec.len()
172                    )))
173                }
174            }
175            _ => Err(TchError::Kind(format!(
176                "unable to unpack ivalue, expected a tuple got {}",
177                value.type_str()
178            ))),
179        }
180    }
181}
182
183macro_rules! impl_from {
184    ($type_:ty, $cons:ident) => {
185        impl From<$type_> for IValue {
186            fn from(v: $type_) -> Self {
187                IValue::$cons(v)
188            }
189        }
190
191        impl TryFrom<IValue> for $type_ {
192            type Error = TchError;
193            fn try_from(value: IValue) -> Result<$type_, TchError> {
194                match value {
195                    IValue::$cons(t) => Ok(t),
196                    _ => Err(TchError::Kind(format!(
197                        "unable to unpack ivalue, expected {} got {}",
198                        std::stringify!($cons),
199                        value.type_str()
200                    ))),
201                }
202            }
203        }
204
205        // A generic trait for Option<T> would seem nicer but because
206        // of E0119, this is currently hard to do.
207        // See https://github.com/rust-lang/rust/issues/50133
208        impl TryFrom<IValue> for Option<$type_> {
209            type Error = TchError;
210            fn try_from(value: IValue) -> Result<Self, TchError> {
211                match value {
212                    IValue::None => Ok(None),
213                    IValue::$cons(t) => Ok(Some(t)),
214                    _ => Err(TchError::Kind(format!(
215                        "unable to unpack ivalue, expected {} or None got {}",
216                        std::stringify!($cons),
217                        value.type_str()
218                    ))),
219                }
220            }
221        }
222    };
223}
224
225impl_from!(i64, Int);
226impl_from!(f64, Double);
227impl_from!(bool, Bool);
228impl_from!(String, String);
229impl_from!(Tensor, Tensor);
230impl_from!(Vec<i64>, IntList);
231impl_from!(Vec<f64>, DoubleList);
232impl_from!(Vec<bool>, BoolList);
233impl_from!(Vec<String>, StringList);
234impl_from!(Vec<crate::Tensor>, TensorList);
235impl_from!(Vec<IValue>, GenericList);
236impl_from!(Vec<(IValue, IValue)>, GenericDict);
237impl_from!(Object, Object);
238impl_from!(Device, Device);
239
240impl From<&str> for IValue {
241    fn from(s: &str) -> Self {
242        IValue::String(s.to_string())
243    }
244}
245
246impl IValue {
247    #![allow(unused_unsafe)]
248    pub(super) fn to_c(&self) -> Result<*mut CIValue, TchError> {
249        let c = unsafe_torch_err!(match self {
250            IValue::Tensor(tensor) => ati_tensor(tensor.c_tensor),
251            IValue::Int(i) => ati_int(*i),
252            IValue::None => ati_none(),
253            IValue::Double(f) => ati_double(*f),
254            IValue::Bool(b) => ati_bool(i32::from(*b)),
255            IValue::Tuple(v) => {
256                let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
257                let tuple = ati_tuple(v.as_ptr(), v.len() as c_int);
258                for x in v {
259                    ati_free(x);
260                }
261
262                tuple
263            }
264            IValue::GenericList(v) => {
265                let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
266                let list = ati_generic_list(v.as_ptr(), v.len() as c_int);
267                for x in v {
268                    ati_free(x);
269                }
270                list
271            }
272            IValue::IntList(v) => ati_int_list(v.as_ptr(), v.len() as c_int),
273            IValue::DoubleList(v) => ati_double_list(v.as_ptr(), v.len() as c_int),
274            IValue::BoolList(v) => {
275                let v: Vec<libc::c_char> = v.iter().map(|&b| libc::c_char::from(b)).collect();
276                ati_bool_list(v.as_ptr(), v.len() as c_int)
277            }
278            IValue::TensorList(v) => {
279                let v = v.iter().map(|t| t.c_tensor).collect::<Vec<_>>();
280                ati_tensor_list(v.as_ptr(), v.len() as c_int)
281            }
282            IValue::String(string) => {
283                let c_str = std::ffi::CString::new(string.as_str())?;
284                ati_string(c_str.as_ptr())
285            }
286            IValue::StringList(strings) => {
287                let mut v = vec![];
288                for s in strings {
289                    v.push(std::ffi::CString::new(s.as_str())?);
290                }
291                let v_ptr: Vec<_> = v.iter().map(|s| s.as_ptr()).collect();
292                ati_string_list(v_ptr.as_ptr(), v.len() as c_int)
293            }
294            IValue::GenericDict(dict) => {
295                let v = dict
296                    .iter()
297                    .flat_map(|(k, v)| vec![Self::to_c(k), Self::to_c(v)])
298                    .collect::<Result<Vec<_>, TchError>>()?;
299                let dict = ati_generic_dict(v.as_ptr(), dict.len() as c_int);
300                for x in v {
301                    ati_free(x);
302                }
303                dict
304            }
305            IValue::Object(Object { c_ivalue }) => {
306                // Clone the object if necessary before passing the pointer to the C++ side.
307                unsafe_torch_err!(ati_clone(*c_ivalue))
308            }
309            IValue::Device(device) => {
310                ati_device(device.c_int())
311            }
312        });
313        Ok(c)
314    }
315
316    // This consumes the pointer and frees the associated memory (unless it is an Object).
317    pub(super) fn from_c(c_ivalue: *mut CIValue) -> Result<Self, TchError> {
318        let mut free = true;
319        let tag = unsafe_torch_err!(ati_tag(c_ivalue));
320        let v = match tag {
321            0 => IValue::None,
322            1 => {
323                let c_tensor = unsafe_torch_err!(ati_to_tensor(c_ivalue));
324                IValue::Tensor(crate::Tensor { c_tensor })
325            }
326            2 => IValue::Double(unsafe_torch_err!(ati_to_double(c_ivalue))),
327            3 => IValue::Int(unsafe_torch_err!(ati_to_int(c_ivalue))),
328            4 => {
329                let b = unsafe_torch_err!(ati_to_bool(c_ivalue));
330                if b < 0 {
331                    return Err(TchError::Kind(format!("unexpected bool value {b}")));
332                }
333                IValue::Bool(b != 0)
334            }
335            5 => {
336                let len = unsafe_torch_err!(ati_tuple_length(c_ivalue));
337                let mut c_ivalues: Vec<_> =
338                    (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
339                unsafe_torch_err!(ati_to_tuple(c_ivalue, c_ivalues.as_mut_ptr(), len));
340                let vec: Result<Vec<_>, _> =
341                    c_ivalues.iter().map(|&c_ivalue| Self::from_c(c_ivalue)).collect();
342                IValue::Tuple(vec?)
343            }
344            6 => {
345                let len = unsafe_torch_err!(ati_length(c_ivalue));
346                let mut c_array = vec![0i64; len as usize];
347                unsafe_torch_err!(ati_to_int_list(c_ivalue, c_array.as_mut_ptr(), len));
348                IValue::IntList(c_array)
349            }
350            7 => {
351                let len = unsafe_torch_err!(ati_length(c_ivalue));
352                let mut c_array = vec![0f64; len as usize];
353                unsafe_torch_err!(ati_to_double_list(c_ivalue, c_array.as_mut_ptr(), len));
354                IValue::DoubleList(c_array)
355            }
356            8 => {
357                let len = unsafe_torch_err!(ati_length(c_ivalue));
358                let mut c_array = vec![0_i8; len as usize];
359                let c_array_ptr = c_array.as_mut_ptr() as *mut libc::c_char;
360                unsafe_torch_err!(ati_to_bool_list(c_ivalue, c_array_ptr, len));
361                IValue::BoolList(c_array.iter().map(|&x| x != 0).collect())
362            }
363            9 => {
364                let ptr = unsafe_torch_err!(ati_to_string(c_ivalue));
365                let string = match unsafe { ptr_to_string(ptr) } {
366                    None => return Err(TchError::Kind("nullptr representation".to_string())),
367                    Some(s) => s,
368                };
369                IValue::String(string)
370            }
371            10 => {
372                let len = unsafe_torch_err!(ati_length(c_ivalue));
373                let mut c_tensors: Vec<_> =
374                    (0..len).map(|_| std::ptr::null_mut::<C_tensor>()).collect();
375                unsafe_torch_err!(ati_to_tensor_list(c_ivalue, c_tensors.as_mut_ptr(), len));
376                let vec: Vec<_> = c_tensors.iter().map(|&c_tensor| Tensor { c_tensor }).collect();
377                IValue::TensorList(vec)
378            }
379            12 => {
380                let len = unsafe_torch_err!(ati_length(c_ivalue));
381                let mut c_ivalues: Vec<_> =
382                    (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
383                unsafe_torch_err!(ati_to_generic_list(c_ivalue, c_ivalues.as_mut_ptr(), len));
384                let vec: Result<Vec<_>, _> =
385                    c_ivalues.iter().map(|&c_ivalue| Self::from_c(c_ivalue)).collect();
386                IValue::GenericList(vec?)
387            }
388            13 => {
389                let len = unsafe_torch_err!(ati_length(c_ivalue));
390                let mut c_ivalues: Vec<_> =
391                    (0..2 * len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
392                unsafe_torch_err!(ati_to_generic_dict(c_ivalue, c_ivalues.as_mut_ptr(), len));
393                let mut res: Vec<(IValue, IValue)> = vec![];
394                for i in 0..(len as usize) {
395                    let key = Self::from_c(c_ivalues[2 * i])?;
396                    let value = Self::from_c(c_ivalues[2 * i + 1])?;
397                    res.push((key, value))
398                }
399                IValue::GenericDict(res)
400            }
401            14 => {
402                free = false;
403                IValue::Object(Object { c_ivalue })
404            }
405            _ => return Err(TchError::Kind(format!("unhandled tag {tag}"))),
406        };
407        if free {
408            unsafe_torch_err!(ati_free(c_ivalue));
409        }
410        Ok(v)
411    }
412}
413
414/// A jit PyTorch module.
415///
416/// These modules can be created via the
417/// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
418#[derive(Debug)]
419pub struct CModule {
420    pub(super) c_module: *mut CModule_,
421}
422
423unsafe impl Send for CModule {}
424
425unsafe impl Sync for CModule {}
426
427impl Drop for CModule {
428    fn drop(&mut self) {
429        unsafe_torch!(atm_free(self.c_module))
430    }
431}
432
433impl CModule {
434    /// Loads a PyTorch saved JIT model from a file.
435    pub fn load<T: AsRef<std::path::Path>>(path: T) -> Result<CModule, TchError> {
436        let path = path_to_cstring(path)?;
437        let c_module = unsafe_torch_err!(atm_load(path.as_ptr()));
438        Ok(CModule { c_module })
439    }
440
441    /// Loads a PyTorch saved JIT model from a file onto the given device.
442    ///
443    /// This function loads the model directly on the specified device,
444    /// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
445    pub fn load_on_device<T: AsRef<std::path::Path>>(
446        path: T,
447        device: Device,
448    ) -> Result<CModule, TchError> {
449        let path = path_to_cstring(path)?;
450        let c_module = unsafe_torch_err!(atm_load_on_device(path.as_ptr(), device.c_int()));
451        Ok(CModule { c_module })
452    }
453
454    /// Loads a PyTorch saved JIT model from a read instance.
455    pub fn load_data<T: std::io::Read>(f: &mut T) -> Result<CModule, TchError> {
456        let mut buffer = Vec::new();
457        f.read_to_end(&mut buffer)?;
458        let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
459        let c_module = unsafe_torch_err!(atm_load_str(buffer_ptr, buffer.len()));
460        Ok(CModule { c_module })
461    }
462
463    /// Loads a PyTorch saved JIT model from a read instance.
464    ///
465    /// This function loads the model directly on the specified device,
466    /// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
467    pub fn load_data_on_device<T: std::io::Read>(
468        f: &mut T,
469        device: Device,
470    ) -> Result<CModule, TchError> {
471        let mut buffer = Vec::new();
472        f.read_to_end(&mut buffer)?;
473        let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
474        let c_module =
475            unsafe_torch_err!(atm_load_str_on_device(buffer_ptr, buffer.len(), device.c_int()));
476        Ok(CModule { c_module })
477    }
478
479    /// Performs the forward pass for a model on some specified tensor inputs. This is equivalent
480    /// to calling method_ts with the 'forward' method name, and returns a single tensor.
481    pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
482        let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
483        let c_tensor =
484            unsafe_torch_err!(atm_forward(self.c_module, ts.as_ptr(), ts.len() as c_int));
485        Ok(Tensor { c_tensor })
486    }
487
488    /// Performs the forward pass for a model on some specified ivalue inputs. This is equivalent
489    /// to calling method_is with the 'forward' method name, and returns an arbitrary ivalue.
490    pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
491        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
492        let c_ivalue =
493            unsafe_torch_err!(atm_forward_(self.c_module, ts.as_ptr(), ts.len() as c_int));
494        for x in ts {
495            unsafe { ati_free(x) }
496        }
497        IValue::from_c(c_ivalue)
498    }
499
500    /// Runs a specified entry point for a model on some given tensor inputs.
501    pub fn method_ts<T: Borrow<Tensor>>(
502        &self,
503        method_name: &str,
504        ts: &[T],
505    ) -> Result<Tensor, TchError> {
506        let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
507        let method_name = std::ffi::CString::new(method_name)?;
508        let c_tensor = unsafe_torch_err!(atm_method(
509            self.c_module,
510            method_name.as_ptr(),
511            ts.as_ptr(),
512            ts.len() as c_int
513        ));
514        Ok(Tensor { c_tensor })
515    }
516
517    /// Runs a specified entry point for a model on some given ivalue inputs.
518    pub fn method_is<T: Borrow<IValue>>(
519        &self,
520        method_name: &str,
521        ts: &[T],
522    ) -> Result<IValue, TchError> {
523        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
524        let method_name = std::ffi::CString::new(method_name)?;
525        let c_ivalue = unsafe_torch_err!(atm_method_(
526            self.c_module,
527            method_name.as_ptr(),
528            ts.as_ptr(),
529            ts.len() as c_int
530        ));
531        for x in ts {
532            unsafe { ati_free(x) }
533        }
534        IValue::from_c(c_ivalue)
535    }
536
537    /// Create a specified custom JIT class object with the given class name, eg: `__torch__.foo.Bar`
538    pub fn create_class_is<T: Borrow<IValue>>(
539        &self,
540        clz_name: &str,
541        ts: &[T],
542    ) -> Result<IValue, TchError> {
543        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
544        let clz_name = std::ffi::CString::new(clz_name)?;
545        let c_ivalue = unsafe_torch_err!(atm_create_class_(
546            self.c_module,
547            clz_name.as_ptr(),
548            ts.as_ptr(),
549            ts.len() as c_int
550        ));
551        for x in ts {
552            unsafe { ati_free(x) }
553        }
554        IValue::from_c(c_ivalue)
555    }
556
557    /// Switches the module to evaluation mode.
558    pub fn f_set_eval(&mut self) -> Result<(), TchError> {
559        unsafe_torch_err!(atm_eval(self.c_module));
560        Ok(())
561    }
562
563    /// Switches the module to evaluation mode.
564    pub fn set_eval(&mut self) {
565        self.f_set_eval().unwrap();
566    }
567
568    /// Switches the module to training mode.
569    pub fn f_set_train(&mut self) -> Result<(), TchError> {
570        unsafe_torch_err!(atm_train(self.c_module));
571        Ok(())
572    }
573
574    /// Switches the module to training mode.
575    pub fn set_train(&mut self) {
576        self.f_set_train().unwrap();
577    }
578
579    /// Moves the module to a different device and converts the kind.
580    pub fn to(&mut self, device: Device, kind: Kind, non_blocking: bool) {
581        unsafe_torch!(atm_to(self.c_module, device.c_int(), kind.c_int(), non_blocking));
582    }
583
584    /// Saves a module to a given path.
585    pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
586        let path = path_to_cstring(path)?;
587        unsafe_torch_err!(atm_save(self.c_module, path.as_ptr()));
588        Ok(())
589    }
590
591    /// Loads some named tensors from a module
592    pub fn named_parameters(&self) -> Result<Vec<(String, Tensor)>, TchError> {
593        let mut v: Vec<(String, Tensor)> = vec![];
594        unsafe_torch_err!(atm_named_parameters(
595            self.c_module,
596            &mut v as *mut _ as *mut c_void,
597            super::tensor::add_callback
598        ));
599        Ok(v)
600    }
601
602    /// Create a new module by tracing the application of the specified function on
603    /// the given inputs.
604    pub fn create_by_tracing<F>(
605        modl_name: &str,
606        fn_name: &str,
607        inputs: &[Tensor],
608        closure: &mut F,
609    ) -> Result<CModule, TchError>
610    where
611        F: FnMut(&[Tensor]) -> Vec<Tensor>,
612    {
613        let modl_name = std::ffi::CString::new(modl_name)?;
614        let fn_name = std::ffi::CString::new(fn_name)?;
615        let c_inputs = inputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
616        let c_module = unsafe_torch_err!(atm_create_for_tracing(
617            modl_name.as_ptr(),
618            c_inputs.as_ptr(),
619            c_inputs.len() as c_int
620        ));
621        let outputs = closure(inputs);
622        let c_outputs = outputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
623        unsafe_torch_err!(atm_end_tracing(
624            c_module,
625            fn_name.as_ptr(),
626            c_outputs.as_ptr(),
627            c_outputs.len() as c_int,
628        ));
629        Ok(CModule { c_module })
630    }
631}
632
633/// The trainable version of a jit PyTorch module.
634///
635/// These modules can be created via the
636/// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
637#[derive(Debug)]
638pub struct TrainableCModule {
639    pub(crate) inner: CModule,
640}
641
642impl TrainableCModule {
643    /// Loads a PyTorch saved JIT module from a file.
644    ///
645    /// This function also adds the tensors from the JIT module to the VarStore path
646    /// passed as argument so that the module can be trained.
647    pub fn load<T: AsRef<std::path::Path>>(module_path: T, path: Path) -> Result<Self, TchError> {
648        let inner = CModule::load_on_device(module_path, path.device())?;
649        for (name, tensor) in inner.named_parameters()? {
650            let requires_grad = tensor.requires_grad();
651            let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
652        }
653        Ok(TrainableCModule { inner })
654    }
655
656    /// Loads a PyTorch saved JIT model from a read instance.
657    ///
658    /// This function also adds the tensors from the JIT module to the VarStore path
659    /// passed as argument so that the module can be trained.
660    pub fn load_data<T: std::io::Read>(data: &mut T, path: Path) -> Result<Self, TchError> {
661        let inner = CModule::load_data_on_device(data, path.device())?;
662        for (name, tensor) in inner.named_parameters()? {
663            let requires_grad = tensor.requires_grad();
664            let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
665        }
666        Ok(TrainableCModule { inner })
667    }
668
669    pub fn save<T: AsRef<std::path::Path>>(&self, module_path: T) -> Result<(), TchError> {
670        self.inner.save(module_path)
671    }
672
673    /// Switches the module to training mode.
674    pub fn f_set_train(&mut self) -> Result<(), TchError> {
675        self.inner.f_set_train()
676    }
677
678    /// Switches the module to training mode.
679    pub fn set_train(&mut self) {
680        self.inner.set_train()
681    }
682
683    /// Switches the module to evaluation mode.
684    pub fn f_set_eval(&mut self) -> Result<(), TchError> {
685        self.inner.f_set_eval()
686    }
687
688    /// Switches the module to evaluation mode.
689    pub fn set_eval(&mut self) {
690        self.inner.set_eval()
691    }
692
693    /// Performs the forward pass for a model on some specified tensor inputs.
694    pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
695        self.inner.forward_ts(ts)
696    }
697
698    /// Performs the forward pass for a model on some specified ivalue inputs.
699    pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
700        self.inner.forward_is(ts)
701    }
702
703    /// Runs a specified entry point for a model on some given tensor inputs.
704    pub fn method_ts<T: Borrow<Tensor>>(
705        &self,
706        method_name: &str,
707        ts: &[T],
708    ) -> Result<Tensor, TchError> {
709        self.inner.method_ts(method_name, ts)
710    }
711
712    /// Runs a specified entry point for a model on some given ivalue inputs.
713    pub fn method_is<T: Borrow<IValue>>(
714        &self,
715        method_name: &str,
716        ts: &[T],
717    ) -> Result<IValue, TchError> {
718        self.inner.method_is(method_name, ts)
719    }
720}
721
722/// Returns whether profiling mode is set or not.
723pub fn f_get_profiling_mode() -> Result<bool, TchError> {
724    Ok(unsafe_torch_err!(atm_get_profiling_mode()) != 0)
725}
726
727/// Returns whether profiling mode is set or not.
728pub fn get_profiling_mode() -> bool {
729    f_get_profiling_mode().unwrap()
730}
731
732/// Activates or deactivates the profiling mode.
733pub fn f_set_profiling_mode(b: bool) -> Result<(), TchError> {
734    unsafe_torch_err!(atm_set_profiling_mode(b as c_int));
735    Ok(())
736}
737
738/// Activates or deactivates the profiling mode.
739pub fn set_profiling_mode(b: bool) {
740    f_set_profiling_mode(b).unwrap()
741}
742
743pub fn f_fuser_cuda_set_enabled(enabled: bool) -> Result<(), TchError> {
744    unsafe_torch_err!(atm_fuser_cuda_set_enabled(enabled));
745    Ok(())
746}
747
748pub fn fuser_cuda_set_enabled(enabled: bool) {
749    f_fuser_cuda_set_enabled(enabled).unwrap()
750}
751
752pub fn f_fuser_cuda_is_enabled() -> Result<bool, TchError> {
753    let b = unsafe_torch_err!(atm_fuser_cuda_is_enabled());
754    Ok(b)
755}
756
757pub fn fuser_cuda_is_enabled() -> bool {
758    f_fuser_cuda_is_enabled().unwrap()
759}
760
761pub fn f_set_tensor_expr_fuser_enabled(b: bool) -> Result<(), TchError> {
762    unsafe_torch_err!(atm_set_tensor_expr_fuser_enabled(b as c_int));
763    Ok(())
764}
765
766pub fn set_tensor_expr_fuser_enabled(b: bool) {
767    f_set_tensor_expr_fuser_enabled(b).unwrap()
768}
769
770pub fn f_get_tensor_expr_fuser_enabled() -> Result<bool, TchError> {
771    Ok(unsafe_torch_err!(atm_get_tensor_expr_fuser_enabled()))
772}
773
774pub fn get_tensor_expr_fuser_enabled() -> bool {
775    f_get_tensor_expr_fuser_enabled().unwrap()
776}
777
778/// Enables or disables the graph executor optimizer for the current thread.
779///
780/// # Arguments
781///
782/// * `b` - A boolean that if true enables the graph executor optimizer for the current thread.
783///
784/// This function returns an error if it is not possible to enable or disable the graph executor optimizer.
785pub fn f_set_graph_executor_optimize(b: bool) -> Result<(), TchError> {
786    unsafe_torch_err!(at_set_graph_executor_optimize(b));
787    Ok(())
788}
789
790/// Enables or disables the graph executor optimizer for the current thread.
791///
792/// # Arguments
793///
794/// * `b` - A boolean that if true enables the graph executor optimizer for the current thread.
795///
796/// This panics if it is not possible to enable or disable the graph executor optimizer.
797pub fn set_graph_executor_optimize(b: bool) {
798    f_set_graph_executor_optimize(b).unwrap();
799}
800
801#[allow(clippy::derive_partial_eq_without_eq)]
802#[derive(Debug, PartialEq)]
803pub struct Object {
804    c_ivalue: *mut CIValue,
805}
806
807impl Object {
808    /// Applies the specified method to the object. The method takes as argument an arbitrary
809    /// number of ivalues and returns an ivalue.
810    pub fn method_is<T: Borrow<IValue>>(
811        &self,
812        method_name: &str,
813        ts: &[T],
814    ) -> Result<IValue, TchError> {
815        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
816        let method_name = std::ffi::CString::new(method_name)?;
817        let c_ivalue = unsafe_torch_err!(ati_object_method_(
818            self.c_ivalue,
819            method_name.as_ptr(),
820            ts.as_ptr(),
821            ts.len() as c_int
822        ));
823        for x in ts {
824            unsafe { ati_free(x) }
825        }
826        IValue::from_c(c_ivalue)
827    }
828
829    /// Retrieves the specified attribute from an object as an ivalue.
830    pub fn getattr(&self, attr_name: &str) -> Result<IValue, TchError> {
831        let property_name = std::ffi::CString::new(attr_name)?;
832        let c_ivalue =
833            unsafe_torch_err!(ati_object_getattr_(self.c_ivalue, property_name.as_ptr()));
834        if c_ivalue.is_null() {
835            return Err(TchError::Torch(format!(
836                "Object.getattr(\"{attr_name}\") returned CIValue nullptr"
837            )));
838        }
839        IValue::from_c(c_ivalue)
840    }
841}
842
843impl Drop for Object {
844    fn drop(&mut self) {
845        unsafe_torch!(ati_free(self.c_ivalue))
846    }
847}
848
849#[cfg(test)]
850mod tests {
851    use super::IValue;
852    use std::f64::consts;
853
854    fn round_trip<T: Into<IValue>>(t: T) {
855        let ivalue: IValue = t.into();
856        let ivalue2 = IValue::from_c(ivalue.to_c().unwrap()).unwrap();
857        assert_eq!(ivalue, ivalue2);
858    }
859    #[test]
860    fn ivalue_round_trip() {
861        round_trip(());
862        round_trip(true);
863        round_trip(false);
864        round_trip(-1);
865        round_trip(42);
866        round_trip(15);
867        round_trip("".to_string());
868        round_trip("foobar".to_string());
869        round_trip((42, consts::PI));
870        round_trip(vec![42, 1337]);
871        round_trip(vec![consts::E, consts::PI, 299792458.00001]);
872        round_trip((vec![true, false, true, true], vec![consts::E, consts::PI, 299792458.00001]));
873        round_trip(vec![IValue::from(42), IValue::from("foobar")]);
874        round_trip(vec![
875            (IValue::from(42), IValue::from("foobar")),
876            (IValue::from("foo"), IValue::from("bar")),
877        ]);
878    }
879}