1use super::utils::{path_to_cstring, ptr_to_string};
3use super::{device::Device, kind::Kind};
4use crate::{nn::Path, TchError, Tensor};
5use libc::{c_int, c_void};
6use std::borrow::Borrow;
7use std::convert::TryFrom;
8use torch_sys::*;
9
10#[derive(Debug, PartialEq)]
13#[non_exhaustive]
14pub enum IValue {
15 None,
16 Tensor(crate::Tensor),
17 Double(f64),
18 Int(i64),
19 Bool(bool),
20 Tuple(Vec<IValue>),
21 IntList(Vec<i64>),
22 DoubleList(Vec<f64>),
23 BoolList(Vec<bool>),
24 String(String),
25 StringList(Vec<String>),
26 TensorList(Vec<crate::Tensor>),
27 GenericList(Vec<IValue>),
28 GenericDict(Vec<(IValue, IValue)>),
31 Object(Object),
32 Device(Device),
33}
34
35impl IValue {
36 fn type_str(self) -> &'static str {
37 match self {
38 IValue::None => "None",
39 IValue::Tensor(_) => "Tensor",
40 IValue::Double(_) => "Double",
41 IValue::Int(_) => "Int",
42 IValue::Bool(_) => "Bool",
43 IValue::Tuple(_) => "Tuple",
44 IValue::IntList(_) => "IntList",
45 IValue::DoubleList(_) => "DoubleList",
46 IValue::BoolList(_) => "BoolList",
47 IValue::String(_) => "String",
48 IValue::StringList(_) => "StringList",
49 IValue::TensorList(_) => "TensorList",
50 IValue::GenericList(_) => "GenericList",
51 IValue::GenericDict(_) => "GenericDict",
52 IValue::Object(_) => "Object",
53 IValue::Device(_) => "Device",
54 }
55 }
56}
57
58impl From<()> for IValue {
59 fn from((): ()) -> Self {
60 IValue::None
61 }
62}
63
64impl<T1: Into<IValue>, T2: Into<IValue>> From<(T1, T2)> for IValue {
65 fn from((p1, p2): (T1, T2)) -> Self {
66 IValue::Tuple(vec![p1.into(), p2.into()])
67 }
68}
69
70impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>> From<(T1, T2, T3)> for IValue {
71 fn from((p1, p2, p3): (T1, T2, T3)) -> Self {
72 IValue::Tuple(vec![p1.into(), p2.into(), p3.into()])
73 }
74}
75
76impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>, T4: Into<IValue>> From<(T1, T2, T3, T4)>
77 for IValue
78{
79 fn from((p1, p2, p3, p4): (T1, T2, T3, T4)) -> Self {
80 IValue::Tuple(vec![p1.into(), p2.into(), p3.into(), p4.into()])
81 }
82}
83
84impl<T1, T2, T1E, T2E> TryFrom<IValue> for (T1, T2)
85where
86 T1: TryFrom<IValue, Error = T1E>,
87 TchError: From<T1E>,
88 T2: TryFrom<IValue, Error = T2E>,
89 TchError: From<T2E>,
90{
91 type Error = TchError;
92 fn try_from(value: IValue) -> Result<Self, TchError> {
93 match value {
94 IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
95 if vec.len() == 2 {
96 let t2 = T2::try_from(vec.pop().unwrap())?;
97 let t1 = T1::try_from(vec.pop().unwrap())?;
98 Ok((t1, t2))
99 } else {
100 Err(TchError::Kind(format!(
101 "unable to unpack ivalue, expected a tuple of len 2 got {}",
102 vec.len()
103 )))
104 }
105 }
106 _ => Err(TchError::Kind(format!(
107 "unable to unpack ivalue, expected a tuple got {}",
108 value.type_str()
109 ))),
110 }
111 }
112}
113
114impl<T1, T2, T3, T1E, T2E, T3E> TryFrom<IValue> for (T1, T2, T3)
115where
116 T1: TryFrom<IValue, Error = T1E>,
117 TchError: From<T1E>,
118 T2: TryFrom<IValue, Error = T2E>,
119 TchError: From<T2E>,
120 T3: TryFrom<IValue, Error = T3E>,
121 TchError: From<T3E>,
122{
123 type Error = TchError;
124 fn try_from(value: IValue) -> Result<Self, TchError> {
125 match value {
126 IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
127 if vec.len() == 3 {
128 let t3 = T3::try_from(vec.pop().unwrap())?;
129 let t2 = T2::try_from(vec.pop().unwrap())?;
130 let t1 = T1::try_from(vec.pop().unwrap())?;
131 Ok((t1, t2, t3))
132 } else {
133 Err(TchError::Kind(format!(
134 "unable to unpack ivalue, expected a tuple of len 3 got {}",
135 vec.len()
136 )))
137 }
138 }
139 _ => Err(TchError::Kind(format!(
140 "unable to unpack ivalue, expected a tuple got {}",
141 value.type_str()
142 ))),
143 }
144 }
145}
146
147impl<T1, T2, T3, T4, T1E, T2E, T3E, T4E> TryFrom<IValue> for (T1, T2, T3, T4)
148where
149 T1: TryFrom<IValue, Error = T1E>,
150 TchError: From<T1E>,
151 T2: TryFrom<IValue, Error = T2E>,
152 TchError: From<T2E>,
153 T3: TryFrom<IValue, Error = T3E>,
154 TchError: From<T3E>,
155 T4: TryFrom<IValue, Error = T4E>,
156 TchError: From<T4E>,
157{
158 type Error = TchError;
159 fn try_from(value: IValue) -> Result<Self, TchError> {
160 match value {
161 IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
162 if vec.len() == 4 {
163 let t4 = T4::try_from(vec.pop().unwrap())?;
164 let t3 = T3::try_from(vec.pop().unwrap())?;
165 let t2 = T2::try_from(vec.pop().unwrap())?;
166 let t1 = T1::try_from(vec.pop().unwrap())?;
167 Ok((t1, t2, t3, t4))
168 } else {
169 Err(TchError::Kind(format!(
170 "unable to unpack ivalue, expected a tuple of len 4 got {}",
171 vec.len()
172 )))
173 }
174 }
175 _ => Err(TchError::Kind(format!(
176 "unable to unpack ivalue, expected a tuple got {}",
177 value.type_str()
178 ))),
179 }
180 }
181}
182
183macro_rules! impl_from {
184 ($type_:ty, $cons:ident) => {
185 impl From<$type_> for IValue {
186 fn from(v: $type_) -> Self {
187 IValue::$cons(v)
188 }
189 }
190
191 impl TryFrom<IValue> for $type_ {
192 type Error = TchError;
193 fn try_from(value: IValue) -> Result<$type_, TchError> {
194 match value {
195 IValue::$cons(t) => Ok(t),
196 _ => Err(TchError::Kind(format!(
197 "unable to unpack ivalue, expected {} got {}",
198 std::stringify!($cons),
199 value.type_str()
200 ))),
201 }
202 }
203 }
204
205 impl TryFrom<IValue> for Option<$type_> {
209 type Error = TchError;
210 fn try_from(value: IValue) -> Result<Self, TchError> {
211 match value {
212 IValue::None => Ok(None),
213 IValue::$cons(t) => Ok(Some(t)),
214 _ => Err(TchError::Kind(format!(
215 "unable to unpack ivalue, expected {} or None got {}",
216 std::stringify!($cons),
217 value.type_str()
218 ))),
219 }
220 }
221 }
222 };
223}
224
225impl_from!(i64, Int);
226impl_from!(f64, Double);
227impl_from!(bool, Bool);
228impl_from!(String, String);
229impl_from!(Tensor, Tensor);
230impl_from!(Vec<i64>, IntList);
231impl_from!(Vec<f64>, DoubleList);
232impl_from!(Vec<bool>, BoolList);
233impl_from!(Vec<String>, StringList);
234impl_from!(Vec<crate::Tensor>, TensorList);
235impl_from!(Vec<IValue>, GenericList);
236impl_from!(Vec<(IValue, IValue)>, GenericDict);
237impl_from!(Object, Object);
238impl_from!(Device, Device);
239
240impl From<&str> for IValue {
241 fn from(s: &str) -> Self {
242 IValue::String(s.to_string())
243 }
244}
245
246impl IValue {
247 #![allow(unused_unsafe)]
248 pub(super) fn to_c(&self) -> Result<*mut CIValue, TchError> {
249 let c = unsafe_torch_err!(match self {
250 IValue::Tensor(tensor) => ati_tensor(tensor.c_tensor),
251 IValue::Int(i) => ati_int(*i),
252 IValue::None => ati_none(),
253 IValue::Double(f) => ati_double(*f),
254 IValue::Bool(b) => ati_bool(i32::from(*b)),
255 IValue::Tuple(v) => {
256 let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
257 let tuple = ati_tuple(v.as_ptr(), v.len() as c_int);
258 for x in v {
259 ati_free(x);
260 }
261
262 tuple
263 }
264 IValue::GenericList(v) => {
265 let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
266 let list = ati_generic_list(v.as_ptr(), v.len() as c_int);
267 for x in v {
268 ati_free(x);
269 }
270 list
271 }
272 IValue::IntList(v) => ati_int_list(v.as_ptr(), v.len() as c_int),
273 IValue::DoubleList(v) => ati_double_list(v.as_ptr(), v.len() as c_int),
274 IValue::BoolList(v) => {
275 let v: Vec<libc::c_char> = v.iter().map(|&b| libc::c_char::from(b)).collect();
276 ati_bool_list(v.as_ptr(), v.len() as c_int)
277 }
278 IValue::TensorList(v) => {
279 let v = v.iter().map(|t| t.c_tensor).collect::<Vec<_>>();
280 ati_tensor_list(v.as_ptr(), v.len() as c_int)
281 }
282 IValue::String(string) => {
283 let c_str = std::ffi::CString::new(string.as_str())?;
284 ati_string(c_str.as_ptr())
285 }
286 IValue::StringList(strings) => {
287 let mut v = vec![];
288 for s in strings {
289 v.push(std::ffi::CString::new(s.as_str())?);
290 }
291 let v_ptr: Vec<_> = v.iter().map(|s| s.as_ptr()).collect();
292 ati_string_list(v_ptr.as_ptr(), v.len() as c_int)
293 }
294 IValue::GenericDict(dict) => {
295 let v = dict
296 .iter()
297 .flat_map(|(k, v)| vec![Self::to_c(k), Self::to_c(v)])
298 .collect::<Result<Vec<_>, TchError>>()?;
299 let dict = ati_generic_dict(v.as_ptr(), dict.len() as c_int);
300 for x in v {
301 ati_free(x);
302 }
303 dict
304 }
305 IValue::Object(Object { c_ivalue }) => {
306 unsafe_torch_err!(ati_clone(*c_ivalue))
308 }
309 IValue::Device(device) => {
310 ati_device(device.c_int())
311 }
312 });
313 Ok(c)
314 }
315
316 pub(super) fn from_c(c_ivalue: *mut CIValue) -> Result<Self, TchError> {
318 let mut free = true;
319 let tag = unsafe_torch_err!(ati_tag(c_ivalue));
320 let v = match tag {
321 0 => IValue::None,
322 1 => {
323 let c_tensor = unsafe_torch_err!(ati_to_tensor(c_ivalue));
324 IValue::Tensor(crate::Tensor { c_tensor })
325 }
326 2 => IValue::Double(unsafe_torch_err!(ati_to_double(c_ivalue))),
327 3 => IValue::Int(unsafe_torch_err!(ati_to_int(c_ivalue))),
328 4 => {
329 let b = unsafe_torch_err!(ati_to_bool(c_ivalue));
330 if b < 0 {
331 return Err(TchError::Kind(format!("unexpected bool value {b}")));
332 }
333 IValue::Bool(b != 0)
334 }
335 5 => {
336 let len = unsafe_torch_err!(ati_tuple_length(c_ivalue));
337 let mut c_ivalues: Vec<_> =
338 (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
339 unsafe_torch_err!(ati_to_tuple(c_ivalue, c_ivalues.as_mut_ptr(), len));
340 let vec: Result<Vec<_>, _> =
341 c_ivalues.iter().map(|&c_ivalue| Self::from_c(c_ivalue)).collect();
342 IValue::Tuple(vec?)
343 }
344 6 => {
345 let len = unsafe_torch_err!(ati_length(c_ivalue));
346 let mut c_array = vec![0i64; len as usize];
347 unsafe_torch_err!(ati_to_int_list(c_ivalue, c_array.as_mut_ptr(), len));
348 IValue::IntList(c_array)
349 }
350 7 => {
351 let len = unsafe_torch_err!(ati_length(c_ivalue));
352 let mut c_array = vec![0f64; len as usize];
353 unsafe_torch_err!(ati_to_double_list(c_ivalue, c_array.as_mut_ptr(), len));
354 IValue::DoubleList(c_array)
355 }
356 8 => {
357 let len = unsafe_torch_err!(ati_length(c_ivalue));
358 let mut c_array = vec![0_i8; len as usize];
359 let c_array_ptr = c_array.as_mut_ptr() as *mut libc::c_char;
360 unsafe_torch_err!(ati_to_bool_list(c_ivalue, c_array_ptr, len));
361 IValue::BoolList(c_array.iter().map(|&x| x != 0).collect())
362 }
363 9 => {
364 let ptr = unsafe_torch_err!(ati_to_string(c_ivalue));
365 let string = match unsafe { ptr_to_string(ptr) } {
366 None => return Err(TchError::Kind("nullptr representation".to_string())),
367 Some(s) => s,
368 };
369 IValue::String(string)
370 }
371 10 => {
372 let len = unsafe_torch_err!(ati_length(c_ivalue));
373 let mut c_tensors: Vec<_> =
374 (0..len).map(|_| std::ptr::null_mut::<C_tensor>()).collect();
375 unsafe_torch_err!(ati_to_tensor_list(c_ivalue, c_tensors.as_mut_ptr(), len));
376 let vec: Vec<_> = c_tensors.iter().map(|&c_tensor| Tensor { c_tensor }).collect();
377 IValue::TensorList(vec)
378 }
379 12 => {
380 let len = unsafe_torch_err!(ati_length(c_ivalue));
381 let mut c_ivalues: Vec<_> =
382 (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
383 unsafe_torch_err!(ati_to_generic_list(c_ivalue, c_ivalues.as_mut_ptr(), len));
384 let vec: Result<Vec<_>, _> =
385 c_ivalues.iter().map(|&c_ivalue| Self::from_c(c_ivalue)).collect();
386 IValue::GenericList(vec?)
387 }
388 13 => {
389 let len = unsafe_torch_err!(ati_length(c_ivalue));
390 let mut c_ivalues: Vec<_> =
391 (0..2 * len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
392 unsafe_torch_err!(ati_to_generic_dict(c_ivalue, c_ivalues.as_mut_ptr(), len));
393 let mut res: Vec<(IValue, IValue)> = vec![];
394 for i in 0..(len as usize) {
395 let key = Self::from_c(c_ivalues[2 * i])?;
396 let value = Self::from_c(c_ivalues[2 * i + 1])?;
397 res.push((key, value))
398 }
399 IValue::GenericDict(res)
400 }
401 14 => {
402 free = false;
403 IValue::Object(Object { c_ivalue })
404 }
405 _ => return Err(TchError::Kind(format!("unhandled tag {tag}"))),
406 };
407 if free {
408 unsafe_torch_err!(ati_free(c_ivalue));
409 }
410 Ok(v)
411 }
412}
413
414#[derive(Debug)]
419pub struct CModule {
420 pub(super) c_module: *mut CModule_,
421}
422
423unsafe impl Send for CModule {}
424
425unsafe impl Sync for CModule {}
426
427impl Drop for CModule {
428 fn drop(&mut self) {
429 unsafe_torch!(atm_free(self.c_module))
430 }
431}
432
433impl CModule {
434 pub fn load<T: AsRef<std::path::Path>>(path: T) -> Result<CModule, TchError> {
436 let path = path_to_cstring(path)?;
437 let c_module = unsafe_torch_err!(atm_load(path.as_ptr()));
438 Ok(CModule { c_module })
439 }
440
441 pub fn load_on_device<T: AsRef<std::path::Path>>(
446 path: T,
447 device: Device,
448 ) -> Result<CModule, TchError> {
449 let path = path_to_cstring(path)?;
450 let c_module = unsafe_torch_err!(atm_load_on_device(path.as_ptr(), device.c_int()));
451 Ok(CModule { c_module })
452 }
453
454 pub fn load_data<T: std::io::Read>(f: &mut T) -> Result<CModule, TchError> {
456 let mut buffer = Vec::new();
457 f.read_to_end(&mut buffer)?;
458 let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
459 let c_module = unsafe_torch_err!(atm_load_str(buffer_ptr, buffer.len()));
460 Ok(CModule { c_module })
461 }
462
463 pub fn load_data_on_device<T: std::io::Read>(
468 f: &mut T,
469 device: Device,
470 ) -> Result<CModule, TchError> {
471 let mut buffer = Vec::new();
472 f.read_to_end(&mut buffer)?;
473 let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
474 let c_module =
475 unsafe_torch_err!(atm_load_str_on_device(buffer_ptr, buffer.len(), device.c_int()));
476 Ok(CModule { c_module })
477 }
478
479 pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
482 let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
483 let c_tensor =
484 unsafe_torch_err!(atm_forward(self.c_module, ts.as_ptr(), ts.len() as c_int));
485 Ok(Tensor { c_tensor })
486 }
487
488 pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
491 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
492 let c_ivalue =
493 unsafe_torch_err!(atm_forward_(self.c_module, ts.as_ptr(), ts.len() as c_int));
494 for x in ts {
495 unsafe { ati_free(x) }
496 }
497 IValue::from_c(c_ivalue)
498 }
499
500 pub fn method_ts<T: Borrow<Tensor>>(
502 &self,
503 method_name: &str,
504 ts: &[T],
505 ) -> Result<Tensor, TchError> {
506 let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
507 let method_name = std::ffi::CString::new(method_name)?;
508 let c_tensor = unsafe_torch_err!(atm_method(
509 self.c_module,
510 method_name.as_ptr(),
511 ts.as_ptr(),
512 ts.len() as c_int
513 ));
514 Ok(Tensor { c_tensor })
515 }
516
517 pub fn method_is<T: Borrow<IValue>>(
519 &self,
520 method_name: &str,
521 ts: &[T],
522 ) -> Result<IValue, TchError> {
523 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
524 let method_name = std::ffi::CString::new(method_name)?;
525 let c_ivalue = unsafe_torch_err!(atm_method_(
526 self.c_module,
527 method_name.as_ptr(),
528 ts.as_ptr(),
529 ts.len() as c_int
530 ));
531 for x in ts {
532 unsafe { ati_free(x) }
533 }
534 IValue::from_c(c_ivalue)
535 }
536
537 pub fn create_class_is<T: Borrow<IValue>>(
539 &self,
540 clz_name: &str,
541 ts: &[T],
542 ) -> Result<IValue, TchError> {
543 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
544 let clz_name = std::ffi::CString::new(clz_name)?;
545 let c_ivalue = unsafe_torch_err!(atm_create_class_(
546 self.c_module,
547 clz_name.as_ptr(),
548 ts.as_ptr(),
549 ts.len() as c_int
550 ));
551 for x in ts {
552 unsafe { ati_free(x) }
553 }
554 IValue::from_c(c_ivalue)
555 }
556
557 pub fn f_set_eval(&mut self) -> Result<(), TchError> {
559 unsafe_torch_err!(atm_eval(self.c_module));
560 Ok(())
561 }
562
563 pub fn set_eval(&mut self) {
565 self.f_set_eval().unwrap();
566 }
567
568 pub fn f_set_train(&mut self) -> Result<(), TchError> {
570 unsafe_torch_err!(atm_train(self.c_module));
571 Ok(())
572 }
573
574 pub fn set_train(&mut self) {
576 self.f_set_train().unwrap();
577 }
578
579 pub fn to(&mut self, device: Device, kind: Kind, non_blocking: bool) {
581 unsafe_torch!(atm_to(self.c_module, device.c_int(), kind.c_int(), non_blocking));
582 }
583
584 pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
586 let path = path_to_cstring(path)?;
587 unsafe_torch_err!(atm_save(self.c_module, path.as_ptr()));
588 Ok(())
589 }
590
591 pub fn named_parameters(&self) -> Result<Vec<(String, Tensor)>, TchError> {
593 let mut v: Vec<(String, Tensor)> = vec![];
594 unsafe_torch_err!(atm_named_parameters(
595 self.c_module,
596 &mut v as *mut _ as *mut c_void,
597 super::tensor::add_callback
598 ));
599 Ok(v)
600 }
601
602 pub fn create_by_tracing<F>(
605 modl_name: &str,
606 fn_name: &str,
607 inputs: &[Tensor],
608 closure: &mut F,
609 ) -> Result<CModule, TchError>
610 where
611 F: FnMut(&[Tensor]) -> Vec<Tensor>,
612 {
613 let modl_name = std::ffi::CString::new(modl_name)?;
614 let fn_name = std::ffi::CString::new(fn_name)?;
615 let c_inputs = inputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
616 let c_module = unsafe_torch_err!(atm_create_for_tracing(
617 modl_name.as_ptr(),
618 c_inputs.as_ptr(),
619 c_inputs.len() as c_int
620 ));
621 let outputs = closure(inputs);
622 let c_outputs = outputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
623 unsafe_torch_err!(atm_end_tracing(
624 c_module,
625 fn_name.as_ptr(),
626 c_outputs.as_ptr(),
627 c_outputs.len() as c_int,
628 ));
629 Ok(CModule { c_module })
630 }
631}
632
633#[derive(Debug)]
638pub struct TrainableCModule {
639 pub(crate) inner: CModule,
640}
641
642impl TrainableCModule {
643 pub fn load<T: AsRef<std::path::Path>>(module_path: T, path: Path) -> Result<Self, TchError> {
648 let inner = CModule::load_on_device(module_path, path.device())?;
649 for (name, tensor) in inner.named_parameters()? {
650 let requires_grad = tensor.requires_grad();
651 let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
652 }
653 Ok(TrainableCModule { inner })
654 }
655
656 pub fn load_data<T: std::io::Read>(data: &mut T, path: Path) -> Result<Self, TchError> {
661 let inner = CModule::load_data_on_device(data, path.device())?;
662 for (name, tensor) in inner.named_parameters()? {
663 let requires_grad = tensor.requires_grad();
664 let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
665 }
666 Ok(TrainableCModule { inner })
667 }
668
669 pub fn save<T: AsRef<std::path::Path>>(&self, module_path: T) -> Result<(), TchError> {
670 self.inner.save(module_path)
671 }
672
673 pub fn f_set_train(&mut self) -> Result<(), TchError> {
675 self.inner.f_set_train()
676 }
677
678 pub fn set_train(&mut self) {
680 self.inner.set_train()
681 }
682
683 pub fn f_set_eval(&mut self) -> Result<(), TchError> {
685 self.inner.f_set_eval()
686 }
687
688 pub fn set_eval(&mut self) {
690 self.inner.set_eval()
691 }
692
693 pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
695 self.inner.forward_ts(ts)
696 }
697
698 pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
700 self.inner.forward_is(ts)
701 }
702
703 pub fn method_ts<T: Borrow<Tensor>>(
705 &self,
706 method_name: &str,
707 ts: &[T],
708 ) -> Result<Tensor, TchError> {
709 self.inner.method_ts(method_name, ts)
710 }
711
712 pub fn method_is<T: Borrow<IValue>>(
714 &self,
715 method_name: &str,
716 ts: &[T],
717 ) -> Result<IValue, TchError> {
718 self.inner.method_is(method_name, ts)
719 }
720}
721
722pub fn f_get_profiling_mode() -> Result<bool, TchError> {
724 Ok(unsafe_torch_err!(atm_get_profiling_mode()) != 0)
725}
726
727pub fn get_profiling_mode() -> bool {
729 f_get_profiling_mode().unwrap()
730}
731
732pub fn f_set_profiling_mode(b: bool) -> Result<(), TchError> {
734 unsafe_torch_err!(atm_set_profiling_mode(b as c_int));
735 Ok(())
736}
737
738pub fn set_profiling_mode(b: bool) {
740 f_set_profiling_mode(b).unwrap()
741}
742
743pub fn f_fuser_cuda_set_enabled(enabled: bool) -> Result<(), TchError> {
744 unsafe_torch_err!(atm_fuser_cuda_set_enabled(enabled));
745 Ok(())
746}
747
748pub fn fuser_cuda_set_enabled(enabled: bool) {
749 f_fuser_cuda_set_enabled(enabled).unwrap()
750}
751
752pub fn f_fuser_cuda_is_enabled() -> Result<bool, TchError> {
753 let b = unsafe_torch_err!(atm_fuser_cuda_is_enabled());
754 Ok(b)
755}
756
757pub fn fuser_cuda_is_enabled() -> bool {
758 f_fuser_cuda_is_enabled().unwrap()
759}
760
761pub fn f_set_tensor_expr_fuser_enabled(b: bool) -> Result<(), TchError> {
762 unsafe_torch_err!(atm_set_tensor_expr_fuser_enabled(b as c_int));
763 Ok(())
764}
765
766pub fn set_tensor_expr_fuser_enabled(b: bool) {
767 f_set_tensor_expr_fuser_enabled(b).unwrap()
768}
769
770pub fn f_get_tensor_expr_fuser_enabled() -> Result<bool, TchError> {
771 Ok(unsafe_torch_err!(atm_get_tensor_expr_fuser_enabled()))
772}
773
774pub fn get_tensor_expr_fuser_enabled() -> bool {
775 f_get_tensor_expr_fuser_enabled().unwrap()
776}
777
778pub fn f_set_graph_executor_optimize(b: bool) -> Result<(), TchError> {
786 unsafe_torch_err!(at_set_graph_executor_optimize(b));
787 Ok(())
788}
789
790pub fn set_graph_executor_optimize(b: bool) {
798 f_set_graph_executor_optimize(b).unwrap();
799}
800
801#[allow(clippy::derive_partial_eq_without_eq)]
802#[derive(Debug, PartialEq)]
803pub struct Object {
804 c_ivalue: *mut CIValue,
805}
806
807impl Object {
808 pub fn method_is<T: Borrow<IValue>>(
811 &self,
812 method_name: &str,
813 ts: &[T],
814 ) -> Result<IValue, TchError> {
815 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
816 let method_name = std::ffi::CString::new(method_name)?;
817 let c_ivalue = unsafe_torch_err!(ati_object_method_(
818 self.c_ivalue,
819 method_name.as_ptr(),
820 ts.as_ptr(),
821 ts.len() as c_int
822 ));
823 for x in ts {
824 unsafe { ati_free(x) }
825 }
826 IValue::from_c(c_ivalue)
827 }
828
829 pub fn getattr(&self, attr_name: &str) -> Result<IValue, TchError> {
831 let property_name = std::ffi::CString::new(attr_name)?;
832 let c_ivalue =
833 unsafe_torch_err!(ati_object_getattr_(self.c_ivalue, property_name.as_ptr()));
834 if c_ivalue.is_null() {
835 return Err(TchError::Torch(format!(
836 "Object.getattr(\"{attr_name}\") returned CIValue nullptr"
837 )));
838 }
839 IValue::from_c(c_ivalue)
840 }
841}
842
843impl Drop for Object {
844 fn drop(&mut self) {
845 unsafe_torch!(ati_free(self.c_ivalue))
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::IValue;
852 use std::f64::consts;
853
854 fn round_trip<T: Into<IValue>>(t: T) {
855 let ivalue: IValue = t.into();
856 let ivalue2 = IValue::from_c(ivalue.to_c().unwrap()).unwrap();
857 assert_eq!(ivalue, ivalue2);
858 }
859 #[test]
860 fn ivalue_round_trip() {
861 round_trip(());
862 round_trip(true);
863 round_trip(false);
864 round_trip(-1);
865 round_trip(42);
866 round_trip(15);
867 round_trip("".to_string());
868 round_trip("foobar".to_string());
869 round_trip((42, consts::PI));
870 round_trip(vec![42, 1337]);
871 round_trip(vec![consts::E, consts::PI, 299792458.00001]);
872 round_trip((vec![true, false, true, true], vec![consts::E, consts::PI, 299792458.00001]));
873 round_trip(vec![IValue::from(42), IValue::from("foobar")]);
874 round_trip(vec![
875 (IValue::from(42), IValue::from("foobar")),
876 (IValue::from("foo"), IValue::from("bar")),
877 ]);
878 }
879}