candle/
lib.rs

1#![allow(clippy::redundant_closure_call)]
2#![allow(clippy::useless_conversion)]
3use pyo3::exceptions::{PyTypeError, PyValueError};
4use pyo3::prelude::*;
5use pyo3::pyclass::CompareOp;
6use pyo3::types::{IntoPyDict, PyDict, PyTuple};
7use pyo3::ToPyObject;
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use half::{bf16, f16};
13
14#[cfg(feature = "mkl")]
15extern crate intel_mkl_src;
16
17#[cfg(feature = "accelerate")]
18extern crate accelerate_src;
19
20use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
21
22mod utils;
23use utils::wrap_err;
24
25mod shape;
26use shape::{PyShape, PyShapeWithHole};
27
28#[cfg(feature = "onnx")]
29mod onnx;
30
31#[derive(Clone, Debug)]
32#[pyclass(name = "Tensor")]
33/// A `candle` tensor.
34struct PyTensor(Tensor);
35
36impl std::ops::Deref for PyTensor {
37    type Target = Tensor;
38
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44#[derive(Clone, Copy, Debug, PartialEq, Eq)]
45#[pyclass(name = "DType")]
46/// A `candle` dtype.
47struct PyDType(DType);
48
49#[pymethods]
50impl PyDType {
51    fn __repr__(&self) -> String {
52        format!("{:?}", self.0)
53    }
54
55    fn __str__(&self) -> String {
56        self.__repr__()
57    }
58}
59
60impl PyDType {
61    fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
62        use std::str::FromStr;
63        if let Ok(dtype) = ob.extract::<String>(py) {
64            let dtype = DType::from_str(&dtype)
65                .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
66            Ok(Self(dtype))
67        } else {
68            ob.extract(py)
69        }
70    }
71}
72
73static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
74static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq)]
77enum PyDevice {
78    Cpu,
79    Cuda,
80    Metal,
81}
82
83impl PyDevice {
84    fn from_device(device: &Device) -> Self {
85        match device {
86            Device::Cpu => Self::Cpu,
87            Device::Cuda(_) => Self::Cuda,
88            Device::Metal(_) => Self::Metal,
89        }
90    }
91
92    fn as_device(&self) -> PyResult<Device> {
93        match self {
94            Self::Cpu => Ok(Device::Cpu),
95            Self::Cuda => {
96                let mut device = CUDA_DEVICE.lock().unwrap();
97                if let Some(device) = device.as_ref() {
98                    return Ok(device.clone());
99                };
100                let d = Device::new_cuda(0).map_err(wrap_err)?;
101                *device = Some(d.clone());
102                Ok(d)
103            }
104            Self::Metal => {
105                let mut device = METAL_DEVICE.lock().unwrap();
106                if let Some(device) = device.as_ref() {
107                    return Ok(device.clone());
108                };
109                let d = Device::new_metal(0).map_err(wrap_err)?;
110                *device = Some(d.clone());
111                Ok(d)
112            }
113        }
114    }
115}
116
117impl<'source> FromPyObject<'source> for PyDevice {
118    fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
119        let device: String = ob.extract()?;
120        let device = match device.as_str() {
121            "cpu" => PyDevice::Cpu,
122            "cuda" => PyDevice::Cuda,
123            _ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
124        };
125        Ok(device)
126    }
127}
128
129impl ToPyObject for PyDevice {
130    fn to_object(&self, py: Python<'_>) -> PyObject {
131        let str = match self {
132            PyDevice::Cpu => "cpu",
133            PyDevice::Cuda => "cuda",
134            PyDevice::Metal => "metal",
135        };
136        str.to_object(py)
137    }
138}
139
140trait PyWithDType: WithDType {
141    fn to_py(&self, py: Python<'_>) -> PyObject;
142}
143
144macro_rules! pydtype {
145    ($ty:ty, $conv:expr) => {
146        impl PyWithDType for $ty {
147            fn to_py(&self, py: Python<'_>) -> PyObject {
148                $conv(*self).to_object(py)
149            }
150        }
151    };
152}
153
154pydtype!(i64, |v| v);
155pydtype!(u8, |v| v);
156pydtype!(u32, |v| v);
157pydtype!(f16, f32::from);
158pydtype!(bf16, f32::from);
159pydtype!(f32, |v| v);
160pydtype!(f64, |v| v);
161
162fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result<usize> {
163    let dim = t.dim(dim)?;
164    if 0 <= index {
165        let index = index as usize;
166        if dim <= index {
167            ::candle::bail!("index {index} is too large for tensor dimension {dim}")
168        }
169        Ok(index)
170    } else {
171        if (dim as i64) < -index {
172            ::candle::bail!("index {index} is too low for tensor dimension {dim}")
173        }
174        Ok((dim as i64 + index) as usize)
175    }
176}
177
178fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
179    let rank = t.rank();
180    if 0 <= dim {
181        let dim = dim as usize;
182        if rank <= dim {
183            ::candle::bail!("dimension index {dim} is too large for tensor rank {rank}")
184        }
185        Ok(dim)
186    } else {
187        if (rank as i64) < -dim {
188            ::candle::bail!("dimension index {dim} is too low for tensor rank {rank}")
189        }
190        Ok((rank as i64 + dim) as usize)
191    }
192}
193
194// TODO: Something similar to this should probably be a part of candle core.
195trait MapDType {
196    type Output;
197    fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
198
199    fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
200        match t.dtype() {
201            DType::U8 => self.f::<u8>(t),
202            DType::U32 => self.f::<u32>(t),
203            DType::I64 => self.f::<i64>(t),
204            DType::BF16 => self.f::<bf16>(t),
205            DType::F16 => self.f::<f16>(t),
206            DType::F32 => self.f::<f32>(t),
207            DType::F64 => self.f::<f64>(t),
208        }
209    }
210}
211
212enum Indexer {
213    Index(usize),
214    Slice(usize, usize),
215    Ellipsis,
216    Expand,
217    IndexSelect(Tensor),
218}
219
220#[derive(Debug)]
221struct TorchTensor(PyObject);
222
223impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
224    fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
225        let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
226        Ok(TorchTensor(numpy_value))
227    }
228}
229
230#[pymethods]
231impl PyTensor {
232    #[new]
233    #[pyo3(text_signature = "(self, data:_ArrayLike)")]
234    // TODO: Handle arbitrary input dtype and shape.
235    /// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
236    fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
237        use Device::Cpu;
238        let tensor = if let Ok(vs) = data.extract::<u32>(py) {
239            Tensor::new(vs, &Cpu).map_err(wrap_err)?
240        } else if let Ok(vs) = data.extract::<i64>(py) {
241            Tensor::new(vs, &Cpu).map_err(wrap_err)?
242        } else if let Ok(vs) = data.extract::<f32>(py) {
243            Tensor::new(vs, &Cpu).map_err(wrap_err)?
244        } else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
245            let len = vs.len();
246            Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
247        } else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
248            let len = vs.len();
249            Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
250        } else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
251            let len = vs.len();
252            Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
253        } else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
254            Tensor::new(vs, &Cpu).map_err(wrap_err)?
255        } else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
256            Tensor::new(vs, &Cpu).map_err(wrap_err)?
257        } else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
258            Tensor::new(vs, &Cpu).map_err(wrap_err)?
259        } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
260            Tensor::new(vs, &Cpu).map_err(wrap_err)?
261        } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
262            Tensor::new(vs, &Cpu).map_err(wrap_err)?
263        } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
264            Tensor::new(vs, &Cpu).map_err(wrap_err)?
265        } else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
266            return PyTensor::new(py, numpy);
267        } else {
268            let ty = data.bind(py).get_type();
269            Err(PyTypeError::new_err(format!(
270                "incorrect type {ty} for tensor"
271            )))?
272        };
273        Ok(Self(tensor))
274    }
275
276    /// Gets the tensor's data as a Python scalar or array-like object.
277    /// &RETURNS&: _ArrayLike
278    fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
279        struct M<'a>(Python<'a>);
280        impl MapDType for M<'_> {
281            type Output = PyObject;
282            fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
283                match t.rank() {
284                    0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
285                    1 => {
286                        let v = t.to_vec1::<T>().map_err(wrap_err)?;
287                        let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
288                        Ok(v.to_object(self.0))
289                    }
290                    2 => {
291                        let v = t.to_vec2::<T>().map_err(wrap_err)?;
292                        let v = v
293                            .iter()
294                            .map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
295                            .collect::<Vec<Vec<_>>>();
296                        Ok(v.to_object(self.0))
297                    }
298                    3 => {
299                        let v = t.to_vec3::<T>().map_err(wrap_err)?;
300                        let v = v
301                            .iter()
302                            .map(|v| {
303                                v.iter()
304                                    .map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
305                                    .collect()
306                            })
307                            .collect::<Vec<Vec<Vec<_>>>>();
308                        Ok(v.to_object(self.0))
309                    }
310                    n => Err(PyTypeError::new_err(format!(
311                        "TODO: conversion to PyObject is not handled for rank {n}"
312                    )))?,
313                }
314            }
315        }
316        // TODO: Handle arbitrary shapes.
317        M(py).map(self)
318    }
319
320    /// Converts candle's tensor to pytorch's tensor
321    /// &RETURNS&: torch.Tensor
322    fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
323        let candle_values = self.values(py)?;
324        let torch_tensor: PyObject = py
325            .import_bound("torch")?
326            .getattr("tensor")?
327            .call1((candle_values,))?
328            .extract()?;
329        Ok(torch_tensor)
330    }
331
332    #[getter]
333    /// Gets the tensor's shape.
334    /// &RETURNS&: Tuple[int]
335    fn shape(&self, py: Python<'_>) -> PyObject {
336        PyTuple::new_bound(py, self.0.dims()).to_object(py)
337    }
338
339    #[getter]
340    /// Gets the tensor's element count.
341    /// &RETURNS&: int
342    fn nelement(&self) -> usize {
343        self.0.elem_count()
344    }
345
346    #[getter]
347    /// Gets the tensor's strides.
348    /// &RETURNS&: Tuple[int]
349    fn stride(&self, py: Python<'_>) -> PyObject {
350        PyTuple::new_bound(py, self.0.stride()).to_object(py)
351    }
352
353    #[getter]
354    /// Gets the tensor's dtype.
355    /// &RETURNS&: DType
356    fn dtype(&self) -> PyDType {
357        PyDType(self.0.dtype())
358    }
359
360    #[getter]
361    /// Gets the tensor's device.
362    /// &RETURNS&: Device
363    fn device(&self, py: Python<'_>) -> PyObject {
364        PyDevice::from_device(self.0.device()).to_object(py)
365    }
366
367    #[getter]
368    /// Gets the tensor's rank.
369    /// &RETURNS&: int
370    fn rank(&self) -> usize {
371        self.0.rank()
372    }
373
374    fn __repr__(&self) -> String {
375        format!("{}", self.0)
376    }
377
378    fn __str__(&self) -> String {
379        self.__repr__()
380    }
381
382    /// Performs the `abs` operation on the tensor.
383    /// &RETURNS&: Tensor
384    fn abs(&self) -> PyResult<Self> {
385        Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
386    }
387
388    /// Performs the `sin` operation on the tensor.
389    /// &RETURNS&: Tensor
390    fn sin(&self) -> PyResult<Self> {
391        Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
392    }
393
394    /// Performs the `cos` operation on the tensor.
395    /// &RETURNS&: Tensor
396    fn cos(&self) -> PyResult<Self> {
397        Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
398    }
399
400    /// Performs the `log` operation on the tensor.
401    /// &RETURNS&: Tensor
402    fn log(&self) -> PyResult<Self> {
403        Ok(PyTensor(self.0.log().map_err(wrap_err)?))
404    }
405
406    /// Squares the tensor.
407    /// &RETURNS&: Tensor
408    fn sqr(&self) -> PyResult<Self> {
409        Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
410    }
411
412    /// Calculates the square root of the tensor.
413    /// &RETURNS&: Tensor
414    fn sqrt(&self) -> PyResult<Self> {
415        Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
416    }
417
418    /// Get the `recip` of the tensor.
419    /// &RETURNS&: Tensor
420    fn recip(&self) -> PyResult<Self> {
421        Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
422    }
423
424    /// Performs the `exp` operation on the tensor.
425    /// &RETURNS&: Tensor
426    fn exp(&self) -> PyResult<Self> {
427        Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
428    }
429
430    #[pyo3(text_signature = "(self, p:float)")]
431    /// Performs the `pow` operation on the tensor with the given exponent.
432    /// &RETURNS&: Tensor
433    fn powf(&self, p: f64) -> PyResult<Self> {
434        Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
435    }
436
437    #[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")]
438    /// Select values for the input tensor at the target indexes across the specified dimension.
439    ///
440    /// The `indexes` is argument is an int tensor with a single dimension.
441    /// The output has the same number of dimension as the `self` input. The target dimension of
442    /// the output has length the length of `indexes` and the values are taken from `self` using
443    /// the index from `indexes`. Other dimensions have the same number of elements as the input
444    /// tensor.
445    /// &RETURNS&: Tensor
446    fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
447        let dim = actual_dim(self, dim).map_err(wrap_err)?;
448        Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
449    }
450
451    /// Gathers values along an axis specified by dim.
452    fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> {
453        let dim = actual_dim(self, dim).map_err(wrap_err)?;
454        Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?))
455    }
456
457    #[pyo3(text_signature = "(self, rhs:Tensor)")]
458    /// Performs a matrix multiplication between the two tensors.
459    /// &RETURNS&: Tensor
460    fn matmul(&self, rhs: &Self) -> PyResult<Self> {
461        Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
462    }
463
464    #[pyo3(text_signature = "(self, rhs:Tensor)")]
465    /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
466    /// &RETURNS&: Tensor
467    fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
468        Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
469    }
470
471    #[pyo3(text_signature = "(self, rhs:Tensor)")]
472    /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
473    /// &RETURNS&: Tensor
474    fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
475        Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
476    }
477
478    #[pyo3(text_signature = "(self, rhs:Tensor)")]
479    /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
480    /// &RETURNS&: Tensor
481    fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
482        Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
483    }
484
485    #[pyo3(text_signature = "(self, rhs:Tensor)")]
486    /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
487    /// &RETURNS&: Tensor
488    fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
489        Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
490    }
491
492    #[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")]
493    /// Returns a tensor with the same shape as the input tensor, the values are taken from
494    /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
495    /// input tensor is equal to zero.
496    /// &RETURNS&: Tensor
497    fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
498        Ok(PyTensor(
499            self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
500        ))
501    }
502
503    #[getter]
504    /// Index a tensor.
505    /// &RETURNS&: Tensor
506    fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
507        let mut indexers: Vec<Indexer> = vec![];
508        let dims = self.0.shape().dims();
509
510        fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult<usize> {
511            // Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
512            let actual_index = if index < 0 {
513                dims[current_dim] as isize + index
514            } else {
515                index
516            };
517
518            // Check that the index is in range
519            if actual_index < 0 || actual_index >= dims[current_dim] as isize {
520                return Err(PyValueError::new_err(format!(
521                    "index out of range for dimension '{i}' with indexer '{value}'",
522                    i = current_dim,
523                    value = index
524                )));
525            }
526            Ok(actual_index as usize)
527        }
528
529        fn extract_indexer(
530            py_indexer: &Bound<PyAny>,
531            current_dim: usize,
532            dims: &[usize],
533            index_argument_count: usize,
534        ) -> PyResult<(Indexer, usize)> {
535            if let Ok(index) = py_indexer.extract() {
536                // Handle a single index e.g. tensor[0] or tensor[-1]
537                Ok((
538                    Indexer::Index(to_absolute_index(index, current_dim, dims)?),
539                    current_dim + 1,
540                ))
541            } else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
542                // Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
543                let index = slice.indices(dims[current_dim] as isize)?;
544                Ok((
545                    Indexer::Slice(index.start as usize, index.stop as usize),
546                    current_dim + 1,
547                ))
548            } else if let Ok(tensor) = py_indexer.extract::<PyTensor>() {
549                // Handle a tensor as indices e.g. tensor[tensor([0,1])]
550                let t = tensor.0;
551                if t.rank() != 1 {
552                    return Err(PyTypeError::new_err(
553                        "multi-dimensional tensor indexing is not supported",
554                    ));
555                }
556                Ok((Indexer::IndexSelect(t), current_dim + 1))
557            } else if let Ok(list) = py_indexer.downcast::<pyo3::types::PyList>() {
558                // Handle a list of indices e.g. tensor[[0,1]]
559                let mut indexes = vec![];
560                for item in list.iter() {
561                    let index = item.extract::<i64>()?;
562                    indexes.push(index);
563                }
564                Ok((
565                    Indexer::IndexSelect(
566                        Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?,
567                    ),
568                    current_dim + 1,
569                ))
570            } else if py_indexer.is(&py_indexer.py().Ellipsis()) {
571                // Handle '...' e.g. tensor[..., 0]
572                if current_dim > 0 {
573                    return Err(PyTypeError::new_err(
574                        "Ellipsis ('...') can only be used at the start of an indexing operation",
575                    ));
576                }
577                Ok((Indexer::Ellipsis, dims.len() - (index_argument_count - 1)))
578            } else if py_indexer.is_none() {
579                // Handle None e.g. tensor[None, 0]
580                Ok((Indexer::Expand, current_dim))
581            } else {
582                Err(PyTypeError::new_err(format!(
583                    "unsupported indexer {}",
584                    py_indexer
585                )))
586            }
587        }
588
589        if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
590            let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
591
592            if not_none_count > dims.len() {
593                return Err(PyValueError::new_err("provided too many indices"));
594            }
595
596            let mut current_dim = 0;
597            for item in tuple.iter() {
598                let (indexer, new_current_dim) =
599                    extract_indexer(&item, current_dim, dims, not_none_count)?;
600                current_dim = new_current_dim;
601                indexers.push(indexer);
602            }
603        } else {
604            let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
605            indexers.push(indexer);
606        }
607
608        let mut x = self.0.clone();
609        let mut current_dim = 0;
610        // Apply the indexers
611        for indexer in indexers.iter() {
612            x = match indexer {
613                Indexer::Index(n) => x
614                    .narrow(current_dim, *n, 1)
615                    .map_err(wrap_err)?
616                    .squeeze(current_dim)
617                    .map_err(wrap_err)?,
618                Indexer::Slice(start, stop) => {
619                    let out = x
620                        .narrow(current_dim, *start, stop.saturating_sub(*start))
621                        .map_err(wrap_err)?;
622                    current_dim += 1;
623                    out
624                }
625                Indexer::Ellipsis => {
626                    // Ellipsis is a special case, it means that all remaining dimensions should be
627                    // selected => advance the current_dim to the last dimension we have indexers for
628                    current_dim += dims.len() - (indexers.len() - 1);
629                    x
630                }
631                Indexer::Expand => {
632                    // Expand is a special case, it means that a new dimension should be added => unsqueeze and advance the current_dim
633                    let out = x.unsqueeze(current_dim).map_err(wrap_err)?;
634                    current_dim += 1;
635                    out
636                }
637                Indexer::IndexSelect(indexes) => {
638                    let out = x
639                        .index_select(
640                            &indexes.to_device(x.device()).map_err(wrap_err)?,
641                            current_dim,
642                        )
643                        .map_err(wrap_err)?;
644                    current_dim += 1;
645                    out
646                }
647            }
648        }
649
650        Ok(Self(x))
651    }
652
653    /// Add two tensors.
654    /// &RETURNS&: Tensor
655    fn __add__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
656        let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
657            self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
658        } else if let Ok(rhs) = rhs.extract::<f64>() {
659            (&self.0 + rhs).map_err(wrap_err)?
660        } else {
661            Err(PyTypeError::new_err("unsupported rhs for add"))?
662        };
663        Ok(Self(tensor))
664    }
665
666    fn __radd__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
667        self.__add__(rhs)
668    }
669
670    /// Multiply two tensors.
671    /// &RETURNS&: Tensor
672    fn __mul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
673        let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
674            self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
675        } else if let Ok(rhs) = rhs.extract::<f64>() {
676            (&self.0 * rhs).map_err(wrap_err)?
677        } else {
678            Err(PyTypeError::new_err("unsupported rhs for mul"))?
679        };
680        Ok(Self(tensor))
681    }
682
683    fn __rmul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
684        self.__mul__(rhs)
685    }
686
687    /// Subtract two tensors.
688    /// &RETURNS&: Tensor
689    fn __sub__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
690        let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
691            self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
692        } else if let Ok(rhs) = rhs.extract::<f64>() {
693            (&self.0 - rhs).map_err(wrap_err)?
694        } else {
695            Err(PyTypeError::new_err("unsupported rhs for sub"))?
696        };
697        Ok(Self(tensor))
698    }
699
700    /// Divide two tensors.
701    /// &RETURNS&: Tensor
702    fn __truediv__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
703        let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
704            self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
705        } else if let Ok(rhs) = rhs.extract::<f64>() {
706            (&self.0 / rhs).map_err(wrap_err)?
707        } else {
708            Err(PyTypeError::new_err("unsupported rhs for div"))?
709        };
710        Ok(Self(tensor))
711    }
712    /// Rich-compare two tensors.
713    /// &RETURNS&: Tensor
714    fn __richcmp__(&self, rhs: &Bound<PyAny>, op: CompareOp) -> PyResult<Self> {
715        let compare = |lhs: &Tensor, rhs: &Tensor| {
716            let t = match op {
717                CompareOp::Eq => lhs.eq(rhs),
718                CompareOp::Ne => lhs.ne(rhs),
719                CompareOp::Lt => lhs.lt(rhs),
720                CompareOp::Le => lhs.le(rhs),
721                CompareOp::Gt => lhs.gt(rhs),
722                CompareOp::Ge => lhs.ge(rhs),
723            };
724            Ok(PyTensor(t.map_err(wrap_err)?))
725        };
726        if let Ok(rhs) = rhs.extract::<PyTensor>() {
727            if self.0.shape() == rhs.0.shape() {
728                compare(&self.0, &rhs.0)
729            } else {
730                // We broadcast manually here because `candle.cmp` does not support automatic broadcasting
731                let broadcast_shape = self
732                    .0
733                    .shape()
734                    .broadcast_shape_binary_op(rhs.0.shape(), "cmp")
735                    .map_err(wrap_err)?;
736                let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
737                let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
738
739                compare(&broadcasted_lhs, &broadcasted_rhs)
740            }
741        } else if let Ok(rhs) = rhs.extract::<f64>() {
742            let scalar_tensor = Tensor::new(rhs, self.0.device())
743                .map_err(wrap_err)?
744                .to_dtype(self.0.dtype())
745                .map_err(wrap_err)?
746                .broadcast_as(self.0.shape())
747                .map_err(wrap_err)?;
748
749            compare(&self.0, &scalar_tensor)
750        } else {
751            return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
752        }
753    }
754
755    fn __hash__(&self) -> u64 {
756        // we have overridden __richcmp__ => py03 wants us to also override __hash__
757        // we simply hash the address of the tensor
758        let mut hasher = DefaultHasher::new();
759        let pointer = &self.0 as *const Tensor;
760        let address = pointer as usize;
761        address.hash(&mut hasher);
762        hasher.finish()
763    }
764
765    #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
766    /// Reshapes the tensor to the given shape.
767    /// &RETURNS&: Tensor
768    fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
769        Ok(PyTensor(
770            self.0
771                .reshape(shape.to_absolute(&self.0)?)
772                .map_err(wrap_err)?,
773        ))
774    }
775
776    #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
777    /// Broadcasts the tensor to the given shape.
778    /// &RETURNS&: Tensor
779    fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
780        Ok(PyTensor(
781            self.0
782                .broadcast_as(shape.to_absolute(&self.0)?)
783                .map_err(wrap_err)?,
784        ))
785    }
786
787    #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
788    /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
789    /// &RETURNS&: Tensor
790    fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
791        Ok(PyTensor(
792            self.0
793                .broadcast_left(shape.to_absolute(&self.0)?)
794                .map_err(wrap_err)?,
795        ))
796    }
797
798    #[pyo3(text_signature = "(self, dim:int)")]
799    /// Creates a new tensor with the specified dimension removed if its size was one.
800    /// &RETURNS&: Tensor
801    fn squeeze(&self, dim: i64) -> PyResult<Self> {
802        let dim = actual_dim(self, dim).map_err(wrap_err)?;
803        Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
804    }
805
806    #[pyo3(text_signature = "(self, dim:int)")]
807    /// Creates a new tensor with a dimension of size one inserted at the specified position.
808    /// &RETURNS&: Tensor
809    fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
810        Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
811    }
812
813    #[pyo3(text_signature = "(self, index:int)")]
814    /// Gets the value at the specified index.
815    /// &RETURNS&: Tensor
816    fn get(&self, index: i64) -> PyResult<Self> {
817        let index = actual_index(self, 0, index).map_err(wrap_err)?;
818        Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
819    }
820
821    #[pyo3(text_signature = "(self, dim1:int, dim2:int)")]
822    /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
823    /// &RETURNS&: Tensor
824    fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
825        Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
826    }
827
828    #[pyo3(text_signature = "(self, dim:int, start:int, len:int)")]
829    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
830    /// ranges from `start` to `start + len`.
831    /// &RETURNS&: Tensor
832    fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
833        let dim = actual_dim(self, dim).map_err(wrap_err)?;
834        let start = actual_index(self, dim, start).map_err(wrap_err)?;
835        Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
836    }
837
838    #[pyo3(text_signature = "(self, dim:int)")]
839    /// Returns the indices of the maximum value(s) across the selected dimension.
840    /// &RETURNS&: Tensor
841    fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
842        let dim = actual_dim(self, dim).map_err(wrap_err)?;
843        Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
844    }
845
846    #[pyo3(text_signature = "(self, dim:int)")]
847    /// Returns the indices of the minimum value(s) across the selected dimension.
848    /// &RETURNS&: Tensor
849    fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
850        let dim = actual_dim(self, dim).map_err(wrap_err)?;
851        Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
852    }
853
854    #[pyo3(text_signature = "(self, dim:int)")]
855    /// Gathers the maximum value across the selected dimension.
856    /// &RETURNS&: Tensor
857    fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
858        let dim = actual_dim(self, dim).map_err(wrap_err)?;
859        Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
860    }
861
862    #[pyo3(text_signature = "(self, dim:int)")]
863    /// Gathers the minimum value across the selected dimension.
864    /// &RETURNS&: Tensor
865    fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
866        let dim = actual_dim(self, dim).map_err(wrap_err)?;
867        Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
868    }
869
870    #[pyo3(text_signature = "(self, dim:Union[int, List[int]])")]
871    /// Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
872    /// &RETURNS&: Tensor
873    fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
874        let dims = if let Ok(dim) = dims.extract::<usize>(py) {
875            vec![dim]
876        } else {
877            dims.extract::<Vec<usize>>(py)?
878        };
879        Ok(PyTensor(
880            self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
881        ))
882    }
883
884    /// Returns the sum of the tensor.
885    /// &RETURNS&: Tensor
886    fn sum_all(&self) -> PyResult<Self> {
887        Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
888    }
889
890    /// Returns the mean of the tensor.
891    /// &RETURNS&: Tensor
892    fn mean_all(&self) -> PyResult<Self> {
893        let elements = self.0.elem_count();
894        let sum = self.0.sum_all().map_err(wrap_err)?;
895        let mean = (sum / elements as f64).map_err(wrap_err)?;
896        Ok(PyTensor(mean))
897    }
898
899    #[pyo3(text_signature = "(self, dim:int)")]
900    /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
901    /// &RETURNS&: Tensor
902    fn flatten_from(&self, dim: i64) -> PyResult<Self> {
903        let dim = actual_dim(self, dim).map_err(wrap_err)?;
904        Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
905    }
906
907    #[pyo3(text_signature = "(self, dim:int)")]
908    ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
909    /// &RETURNS&: Tensor
910    fn flatten_to(&self, dim: i64) -> PyResult<Self> {
911        let dim = actual_dim(self, dim).map_err(wrap_err)?;
912        Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
913    }
914
915    /// Flattens the tensor into a 1D tensor.
916    /// &RETURNS&: Tensor
917    fn flatten_all(&self) -> PyResult<Self> {
918        Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
919    }
920
921    /// Transposes the tensor.
922    /// &RETURNS&: Tensor
923    fn t(&self) -> PyResult<Self> {
924        Ok(PyTensor(self.0.t().map_err(wrap_err)?))
925    }
926
927    /// Makes the tensor contiguous in memory.
928    /// &RETURNS&: Tensor
929    fn contiguous(&self) -> PyResult<Self> {
930        Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
931    }
932
933    /// Returns true if the tensor is contiguous in C order.
934    /// &RETURNS&: bool
935    fn is_contiguous(&self) -> bool {
936        self.0.is_contiguous()
937    }
938
939    /// Returns true if the tensor is contiguous in Fortran order.
940    /// &RETURNS&: bool
941    fn is_fortran_contiguous(&self) -> bool {
942        self.0.is_fortran_contiguous()
943    }
944
945    /// Detach the tensor from the computation graph.
946    /// &RETURNS&: Tensor
947    fn detach(&self) -> Self {
948        PyTensor(self.0.detach())
949    }
950
951    /// Returns a copy of the tensor.
952    /// &RETURNS&: Tensor
953    fn copy(&self) -> PyResult<Self> {
954        Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
955    }
956
957    #[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
958    /// Performs Tensor dtype and/or device conversion.
959    /// &RETURNS&: Tensor
960    fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
961        let mut device: Option<PyDevice> = None;
962        let mut dtype: Option<PyDType> = None;
963        let mut other: Option<PyTensor> = None;
964
965        fn handle_duplicates<T>(
966            opt: &mut Option<T>,
967            extraction_result: PyResult<T>,
968            err_msg: &'static str,
969        ) -> PyResult<()> {
970            if let Ok(successful_extraction) = extraction_result {
971                if opt.is_some() {
972                    return Err(PyValueError::new_err(err_msg));
973                }
974                *opt = Some(successful_extraction);
975            }
976            Ok(())
977        }
978
979        //handle args
980        for arg in args.iter() {
981            if arg.extract::<PyDevice>().is_ok() {
982                handle_duplicates(
983                    &mut device,
984                    arg.extract::<PyDevice>(),
985                    "cannot specify multiple devices",
986                )?;
987            } else if arg.extract::<PyDType>().is_ok() {
988                handle_duplicates(
989                    &mut dtype,
990                    arg.extract::<PyDType>(),
991                    "cannot specify multiple dtypes",
992                )?;
993            } else if arg.extract::<PyTensor>().is_ok() {
994                handle_duplicates(
995                    &mut other,
996                    arg.extract::<PyTensor>(),
997                    "cannot specify multiple output tensors",
998                )?;
999            } else {
1000                return Err(PyTypeError::new_err(format!(
1001                    "unsupported argument type `{:#?}`",
1002                    arg.get_type().name()
1003                )));
1004            }
1005        }
1006
1007        if let Some(kwargs) = kwargs {
1008            if let Ok(Some(any)) = kwargs.get_item("dtype") {
1009                handle_duplicates(
1010                    &mut dtype,
1011                    any.extract::<PyDType>(),
1012                    "cannot specify multiple dtypes",
1013                )?;
1014            }
1015            if let Ok(Some(any)) = kwargs.get_item("device") {
1016                handle_duplicates(
1017                    &mut device,
1018                    any.extract::<PyDevice>(),
1019                    "cannot specify multiple devices",
1020                )?;
1021            }
1022            if let Ok(Some(any)) = kwargs.get_item("other") {
1023                handle_duplicates(
1024                    &mut other,
1025                    any.extract::<PyTensor>(),
1026                    "cannot specify multiple output tensors",
1027                )?;
1028            }
1029        }
1030
1031        if let Some(other) = other {
1032            if device.is_some() {
1033                return Err(PyValueError::new_err(
1034                    "cannot specify both an output tensor and a device",
1035                ));
1036            }
1037            if dtype.is_some() {
1038                return Err(PyValueError::new_err(
1039                    "cannot specify both an output tensor and a dtype",
1040                ));
1041            }
1042            dtype = Some(other.dtype());
1043            device = Some(PyDevice::from_device(other.0.device()));
1044        }
1045
1046        let result = match (device, dtype) {
1047            (Some(device), Some(dtype)) => self
1048                .0
1049                .to_device(&device.as_device()?)
1050                .map_err(wrap_err)?
1051                .to_dtype(dtype.0)
1052                .map_err(wrap_err)?,
1053            (Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
1054            (None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
1055            (None, None) => return Err(PyTypeError::new_err("No valid dtype or device specified")),
1056        };
1057
1058        Ok(PyTensor(result))
1059    }
1060
1061    #[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
1062    /// Convert the tensor to a new dtype.
1063    /// &RETURNS&: Tensor
1064    fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
1065        let dtype = PyDType::from_pyobject(dtype, py)?;
1066        Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
1067    }
1068
1069    #[pyo3(text_signature = "(self, device:Union[str,Device])")]
1070    /// Move the tensor to a new device.
1071    /// &RETURNS&: Tensor
1072    fn to_device(&self, device: PyDevice) -> PyResult<Self> {
1073        let device = device.as_device()?;
1074        Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
1075    }
1076
1077    #[pyo3(text_signature = "(self, quantized_dtype:str)")]
1078    /// Quantize the tensor.
1079    /// &RETURNS&: QTensor
1080    fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
1081        use ::candle::quantized;
1082        let res = match quantized_dtype.to_lowercase().as_str() {
1083            "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
1084            "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),
1085            "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0),
1086            "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1),
1087            "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K),
1088            "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0),
1089            "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1),
1090            "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K),
1091            "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K),
1092            "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0),
1093            "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1),
1094            "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K),
1095            "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16),
1096            "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32),
1097            dt => {
1098                return Err(PyErr::new::<PyValueError, _>(format!(
1099                    "unknown quantized-dtype {dt}"
1100                )))
1101            }
1102        };
1103        Ok(PyQTensor(Arc::new(res.map_err(wrap_err)?)))
1104    }
1105}
1106
1107#[pyfunction]
1108#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
1109/// Concatenate the tensors across one axis.
1110/// &RETURNS&: Tensor
1111fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
1112    if tensors.is_empty() {
1113        return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
1114    }
1115    let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?;
1116    let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
1117    let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
1118    Ok(PyTensor(tensor))
1119}
1120
1121#[pyfunction]
1122#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
1123/// Stack the tensors along a new axis.
1124/// &RETURNS&: Tensor
1125fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
1126    let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
1127    let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
1128    Ok(PyTensor(tensor))
1129}
1130
1131#[pyfunction]
1132#[pyo3(text_signature = "(data:_ArrayLike)")]
1133/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
1134/// &RETURNS&: Tensor
1135fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
1136    PyTensor::new(py, data)
1137}
1138
1139#[pyfunction]
1140#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
1141/// Creates a new tensor with random values.
1142/// &RETURNS&: Tensor
1143fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
1144    let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1145    let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
1146    Ok(PyTensor(tensor))
1147}
1148
1149#[pyfunction]
1150#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
1151/// Creates a new tensor with random values from a normal distribution.
1152/// &RETURNS&: Tensor
1153fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
1154    let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1155    let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
1156    Ok(PyTensor(tensor))
1157}
1158
1159#[pyfunction]
1160#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
1161/// Creates a new tensor filled with ones.
1162/// &RETURNS&: Tensor
1163fn ones(
1164    py: Python<'_>,
1165    shape: PyShape,
1166    dtype: Option<PyObject>,
1167    device: Option<PyDevice>,
1168) -> PyResult<PyTensor> {
1169    let dtype = match dtype {
1170        None => DType::F32,
1171        Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
1172    };
1173    let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1174    let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
1175    Ok(PyTensor(tensor))
1176}
1177
1178#[pyfunction]
1179#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
1180/// Creates a new tensor filled with zeros.
1181/// &RETURNS&: Tensor
1182fn zeros(
1183    py: Python<'_>,
1184    shape: PyShape,
1185    dtype: Option<PyObject>,
1186    device: Option<PyDevice>,
1187) -> PyResult<PyTensor> {
1188    let dtype = match dtype {
1189        None => DType::F32,
1190        Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
1191    };
1192    let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1193    let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
1194    Ok(PyTensor(tensor))
1195}
1196
1197#[derive(Debug, Clone)]
1198#[pyclass(name = "QTensor")]
1199/// A quantized tensor.
1200struct PyQTensor(Arc<QTensor>);
1201
1202impl std::ops::Deref for PyQTensor {
1203    type Target = QTensor;
1204
1205    fn deref(&self) -> &Self::Target {
1206        self.0.as_ref()
1207    }
1208}
1209
1210#[pymethods]
1211impl PyQTensor {
1212    #[getter]
1213    ///Gets the tensors quantized dtype.
1214    /// &RETURNS&: str
1215    fn ggml_dtype(&self) -> String {
1216        format!("{:?}", self.0.dtype())
1217    }
1218
1219    #[getter]
1220    ///Gets the rank of the tensor.
1221    /// &RETURNS&: int
1222    fn rank(&self) -> usize {
1223        self.0.rank()
1224    }
1225
1226    #[getter]
1227    ///Gets the shape of the tensor.
1228    /// &RETURNS&: Tuple[int]
1229    fn shape(&self, py: Python<'_>) -> PyObject {
1230        PyTuple::new_bound(py, self.0.shape().dims()).to_object(py)
1231    }
1232
1233    fn __repr__(&self) -> String {
1234        format!("{:?}", self.0)
1235    }
1236
1237    fn __str__(&self) -> String {
1238        self.__repr__()
1239    }
1240
1241    /// Dequantizes the tensor.
1242    /// &RETURNS&: Tensor  
1243    fn dequantize(&self) -> PyResult<PyTensor> {
1244        let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
1245        Ok(PyTensor(tensor))
1246    }
1247
1248    #[pyo3(text_signature = "(self, lhs:Tensor)")]
1249    /// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
1250    /// &RETURNS&: Tensor
1251    fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
1252        let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()).map_err(wrap_err)?;
1253        let res = qmatmul.forward(lhs).map_err(wrap_err)?;
1254        Ok(PyTensor(res))
1255    }
1256}
1257
1258#[pyfunction]
1259#[pyo3(text_signature = "(path:Union[str,PathLike])")]
1260/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
1261/// &RETURNS&: Dict[str,Tensor]
1262fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
1263    let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
1264    let res = res
1265        .into_iter()
1266        .map(|(key, value)| (key, PyTensor(value).into_py(py)))
1267        .collect::<Vec<_>>();
1268    Ok(res.into_py_dict_bound(py).to_object(py))
1269}
1270
1271#[pyfunction]
1272#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
1273/// Saves a dictionary of tensors to a safetensors file.
1274/// &RETURNS&: None
1275fn save_safetensors(
1276    path: &str,
1277    tensors: std::collections::HashMap<String, PyTensor>,
1278) -> PyResult<()> {
1279    let tensors = tensors
1280        .into_iter()
1281        .map(|(s, t)| (s, t.0))
1282        .collect::<std::collections::HashMap<_, _>>();
1283    ::candle::safetensors::save(&tensors, path).map_err(wrap_err)
1284}
1285
1286#[pyfunction]
1287#[pyo3(signature = (path, device = None))]
1288/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
1289/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
1290/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
1291fn load_ggml(
1292    path: &str,
1293    device: Option<PyDevice>,
1294    py: Python<'_>,
1295) -> PyResult<(PyObject, PyObject, PyObject)> {
1296    let mut file = std::fs::File::open(path)?;
1297    let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1298    let ggml =
1299        ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
1300    let tensors = ggml
1301        .tensors
1302        .into_iter()
1303        .map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
1304        .collect::<::candle::Result<Vec<_>>>()
1305        .map_err(wrap_err)?;
1306    let tensors = tensors.into_py_dict_bound(py).to_object(py);
1307    let hparams = [
1308        ("n_vocab", ggml.hparams.n_vocab),
1309        ("n_embd", ggml.hparams.n_embd),
1310        ("n_mult", ggml.hparams.n_mult),
1311        ("n_head", ggml.hparams.n_head),
1312        ("n_layer", ggml.hparams.n_layer),
1313        ("n_rot", ggml.hparams.n_rot),
1314        ("ftype", ggml.hparams.ftype),
1315    ];
1316    let hparams = hparams.into_py_dict_bound(py).to_object(py);
1317    let vocab = ggml
1318        .vocab
1319        .token_score_pairs
1320        .iter()
1321        .map(|(bytes, _)| String::from_utf8_lossy(bytes.as_slice()).to_string())
1322        .collect::<Vec<String>>()
1323        .to_object(py);
1324    Ok((tensors, hparams, vocab))
1325}
1326
1327#[pyfunction]
1328#[pyo3(signature = (path, device = None))]
1329/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
1330/// and the second maps metadata keys to metadata values.
1331/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
1332fn load_gguf(
1333    path: &str,
1334    device: Option<PyDevice>,
1335    py: Python<'_>,
1336) -> PyResult<(PyObject, PyObject)> {
1337    let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1338    use ::candle::quantized::gguf_file;
1339    fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
1340        let v: PyObject = match v {
1341            gguf_file::Value::U8(x) => x.into_py(py),
1342            gguf_file::Value::I8(x) => x.into_py(py),
1343            gguf_file::Value::U16(x) => x.into_py(py),
1344            gguf_file::Value::I16(x) => x.into_py(py),
1345            gguf_file::Value::U32(x) => x.into_py(py),
1346            gguf_file::Value::I32(x) => x.into_py(py),
1347            gguf_file::Value::U64(x) => x.into_py(py),
1348            gguf_file::Value::I64(x) => x.into_py(py),
1349            gguf_file::Value::F32(x) => x.into_py(py),
1350            gguf_file::Value::F64(x) => x.into_py(py),
1351            gguf_file::Value::Bool(x) => x.into_py(py),
1352            gguf_file::Value::String(x) => x.into_py(py),
1353            gguf_file::Value::Array(x) => {
1354                let list = pyo3::types::PyList::empty_bound(py);
1355                for elem in x.iter() {
1356                    list.append(gguf_value_to_pyobject(elem, py)?)?;
1357                }
1358                list.into()
1359            }
1360        };
1361        Ok(v)
1362    }
1363    let mut file = std::fs::File::open(path)?;
1364    let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?;
1365    let tensors = gguf
1366        .tensor_infos
1367        .keys()
1368        .map(|key| {
1369            let qtensor = gguf.tensor(&mut file, key, &device)?;
1370            Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
1371        })
1372        .collect::<::candle::Result<Vec<_>>>()
1373        .map_err(wrap_err)?;
1374    let tensors = tensors.into_py_dict_bound(py).to_object(py);
1375    let metadata = gguf
1376        .metadata
1377        .iter()
1378        .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
1379        .collect::<PyResult<Vec<_>>>()?
1380        .into_py_dict_bound(py)
1381        .to_object(py);
1382    Ok((tensors, metadata))
1383}
1384
1385#[pyfunction]
1386#[pyo3(
1387    signature = (path, tensors, metadata)
1388)]
1389/// Save quanitzed tensors and metadata to a GGUF file.
1390fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
1391    use ::candle::quantized::gguf_file;
1392
1393    fn pyobject_to_gguf_value(v: &Bound<PyAny>, py: Python<'_>) -> PyResult<gguf_file::Value> {
1394        let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
1395            gguf_file::Value::U8(x)
1396        } else if let Ok(x) = v.extract::<i8>() {
1397            gguf_file::Value::I8(x)
1398        } else if let Ok(x) = v.extract::<u16>() {
1399            gguf_file::Value::U16(x)
1400        } else if let Ok(x) = v.extract::<i16>() {
1401            gguf_file::Value::I16(x)
1402        } else if let Ok(x) = v.extract::<u32>() {
1403            gguf_file::Value::U32(x)
1404        } else if let Ok(x) = v.extract::<i32>() {
1405            gguf_file::Value::I32(x)
1406        } else if let Ok(x) = v.extract::<u64>() {
1407            gguf_file::Value::U64(x)
1408        } else if let Ok(x) = v.extract::<i64>() {
1409            gguf_file::Value::I64(x)
1410        } else if let Ok(x) = v.extract::<f32>() {
1411            gguf_file::Value::F32(x)
1412        } else if let Ok(x) = v.extract::<f64>() {
1413            gguf_file::Value::F64(x)
1414        } else if let Ok(x) = v.extract::<bool>() {
1415            gguf_file::Value::Bool(x)
1416        } else if let Ok(x) = v.extract::<String>() {
1417            gguf_file::Value::String(x)
1418        } else if let Ok(x) = v.extract::<Vec<PyObject>>() {
1419            let x = x
1420                .into_iter()
1421                .map(|f| pyobject_to_gguf_value(f.bind(py), py))
1422                .collect::<PyResult<Vec<_>>>()?;
1423            gguf_file::Value::Array(x)
1424        } else {
1425            return Err(PyErr::new::<PyValueError, _>(format!(
1426                "unsupported type {:?}",
1427                v
1428            )));
1429        };
1430        Ok(v)
1431    }
1432    let tensors = tensors
1433        .downcast_bound::<PyDict>(py)
1434        .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
1435        .iter()
1436        .map(|(key, value)| {
1437            Ok((
1438                key.extract::<String>()
1439                    .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
1440                value.extract::<PyQTensor>()?.0,
1441            ))
1442        })
1443        .collect::<PyResult<Vec<_>>>()?;
1444
1445    let metadata = metadata
1446        .downcast_bound::<PyDict>(py)
1447        .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
1448        .iter()
1449        .map(|(key, value)| {
1450            Ok((
1451                key.extract::<String>()
1452                    .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
1453                pyobject_to_gguf_value(&value.as_borrowed(), py)?,
1454            ))
1455        })
1456        .collect::<PyResult<Vec<_>>>()?;
1457
1458    let converted_metadata: Vec<_> = metadata
1459        .iter()
1460        .map(|(name, value)| (name.as_str(), value))
1461        .collect();
1462
1463    let converted_tensors: Vec<_> = tensors
1464        .iter()
1465        .map(|(name, tensor)| (name.as_str(), tensor.as_ref()))
1466        .collect();
1467
1468    let mut file = std::fs::File::create(path)?;
1469
1470    gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err)
1471}
1472
1473#[pyfunction]
1474/// Returns true if the 'cuda' backend is available.
1475/// &RETURNS&: bool
1476fn cuda_is_available() -> bool {
1477    ::candle::utils::cuda_is_available()
1478}
1479
1480#[pyfunction]
1481/// Returns true if candle was compiled with 'accelerate' support.
1482/// &RETURNS&: bool
1483fn has_accelerate() -> bool {
1484    ::candle::utils::has_accelerate()
1485}
1486
1487#[pyfunction]
1488/// Returns true if candle was compiled with MKL support.
1489/// &RETURNS&: bool
1490fn has_mkl() -> bool {
1491    ::candle::utils::has_mkl()
1492}
1493
1494#[pyfunction]
1495/// Returns the number of threads used by the candle.
1496/// &RETURNS&: int
1497fn get_num_threads() -> usize {
1498    ::candle::utils::get_num_threads()
1499}
1500
1501fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1502    m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
1503    m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
1504    m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
1505    m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
1506    m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
1507    m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
1508    m.add_function(wrap_pyfunction!(save_gguf, m)?)?;
1509    m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
1510    m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
1511    Ok(())
1512}
1513
1514#[pyfunction]
1515#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
1516/// Applies the Softmax function to a given tensor.#
1517/// &RETURNS&: Tensor
1518fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
1519    let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
1520    let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
1521    Ok(PyTensor(sm))
1522}
1523
1524#[pyfunction]
1525#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
1526/// Applies the 2d avg-pool function to a given tensor.#
1527/// &RETURNS&: Tensor
1528fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
1529    let tensor = tensor
1530        .avg_pool2d_with_stride(ksize, stride)
1531        .map_err(wrap_err)?;
1532    Ok(PyTensor(tensor))
1533}
1534
1535#[pyfunction]
1536#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
1537/// Applies the 2d max-pool function to a given tensor.#
1538/// &RETURNS&: Tensor
1539fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
1540    let tensor = tensor
1541        .max_pool2d_with_stride(ksize, stride)
1542        .map_err(wrap_err)?;
1543    Ok(PyTensor(tensor))
1544}
1545
1546#[pyfunction]
1547#[pyo3(text_signature = "(tensor:Tensor)")]
1548/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
1549/// &RETURNS&: Tensor
1550fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
1551    let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
1552    Ok(PyTensor(s))
1553}
1554
1555#[pyfunction]
1556#[pyo3(text_signature = "(tensor:Tensor)")]
1557/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
1558/// &RETURNS&: Tensor
1559fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
1560    let s = tensor.0.gelu_erf().map_err(wrap_err)?;
1561    Ok(PyTensor(s))
1562}
1563
1564#[pyfunction]
1565#[pyo3(text_signature = "(tensor:Tensor)")]
1566/// Applies the Rectified Linear Unit (ReLU) function to a given tensor.
1567/// &RETURNS&: Tensor
1568fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
1569    let s = tensor.0.relu().map_err(wrap_err)?;
1570    Ok(PyTensor(s))
1571}
1572
1573#[pyfunction]
1574#[pyo3(text_signature = "(tensor:Tensor)")]
1575/// Applies the tanh function to a given tensor.
1576/// &RETURNS&: Tensor
1577fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
1578    let s = tensor.0.tanh().map_err(wrap_err)?;
1579    Ok(PyTensor(s))
1580}
1581
1582fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1583    m.add_function(wrap_pyfunction!(silu, m)?)?;
1584    m.add_function(wrap_pyfunction!(softmax, m)?)?;
1585    m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
1586    m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
1587    m.add_function(wrap_pyfunction!(gelu, m)?)?;
1588    m.add_function(wrap_pyfunction!(relu, m)?)?;
1589    m.add_function(wrap_pyfunction!(tanh, m)?)?;
1590    Ok(())
1591}
1592
1593#[cfg(feature = "onnx")]
1594fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1595    use onnx::{PyONNXModel, PyONNXTensorDescriptor};
1596    m.add_class::<PyONNXModel>()?;
1597    m.add_class::<PyONNXTensorDescriptor>()?;
1598    Ok(())
1599}
1600
1601#[pymodule]
1602fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1603    let utils = PyModule::new_bound(py, "utils")?;
1604    candle_utils(py, &utils)?;
1605    m.add_submodule(&utils)?;
1606    let nn = PyModule::new_bound(py, "functional")?;
1607    candle_functional_m(py, &nn)?;
1608    m.add_submodule(&nn)?;
1609    #[cfg(feature = "onnx")]
1610    {
1611        let onnx = PyModule::new_bound(py, "onnx")?;
1612        candle_onnx_m(py, &onnx)?;
1613        m.add_submodule(&onnx)?;
1614    }
1615    m.add_class::<PyTensor>()?;
1616    m.add_class::<PyQTensor>()?;
1617    m.add_class::<PyDType>()?;
1618    m.add("u8", PyDType(DType::U8))?;
1619    m.add("u32", PyDType(DType::U32))?;
1620    m.add("i64", PyDType(DType::I64))?;
1621    m.add("bf16", PyDType(DType::BF16))?;
1622    m.add("f16", PyDType(DType::F16))?;
1623    m.add("f32", PyDType(DType::F32))?;
1624    m.add("f64", PyDType(DType::F64))?;
1625    m.add_function(wrap_pyfunction!(cat, m)?)?;
1626    m.add_function(wrap_pyfunction!(ones, m)?)?;
1627    m.add_function(wrap_pyfunction!(rand, m)?)?;
1628    m.add_function(wrap_pyfunction!(randn, m)?)?;
1629    m.add_function(wrap_pyfunction!(tensor, m)?)?;
1630    m.add_function(wrap_pyfunction!(stack, m)?)?;
1631    m.add_function(wrap_pyfunction!(zeros, m)?)?;
1632    Ok(())
1633}