1#![allow(clippy::redundant_closure_call)]
2#![allow(clippy::useless_conversion)]
3use pyo3::exceptions::{PyTypeError, PyValueError};
4use pyo3::prelude::*;
5use pyo3::pyclass::CompareOp;
6use pyo3::types::{IntoPyDict, PyDict, PyTuple};
7use pyo3::ToPyObject;
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use half::{bf16, f16};
13
14#[cfg(feature = "mkl")]
15extern crate intel_mkl_src;
16
17#[cfg(feature = "accelerate")]
18extern crate accelerate_src;
19
20use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
21
22mod utils;
23use utils::wrap_err;
24
25mod shape;
26use shape::{PyShape, PyShapeWithHole};
27
28#[cfg(feature = "onnx")]
29mod onnx;
30
31#[derive(Clone, Debug)]
32#[pyclass(name = "Tensor")]
33struct PyTensor(Tensor);
35
36impl std::ops::Deref for PyTensor {
37 type Target = Tensor;
38
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44#[derive(Clone, Copy, Debug, PartialEq, Eq)]
45#[pyclass(name = "DType")]
46struct PyDType(DType);
48
49#[pymethods]
50impl PyDType {
51 fn __repr__(&self) -> String {
52 format!("{:?}", self.0)
53 }
54
55 fn __str__(&self) -> String {
56 self.__repr__()
57 }
58}
59
60impl PyDType {
61 fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
62 use std::str::FromStr;
63 if let Ok(dtype) = ob.extract::<String>(py) {
64 let dtype = DType::from_str(&dtype)
65 .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
66 Ok(Self(dtype))
67 } else {
68 ob.extract(py)
69 }
70 }
71}
72
73static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
74static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq)]
77enum PyDevice {
78 Cpu,
79 Cuda,
80 Metal,
81}
82
83impl PyDevice {
84 fn from_device(device: &Device) -> Self {
85 match device {
86 Device::Cpu => Self::Cpu,
87 Device::Cuda(_) => Self::Cuda,
88 Device::Metal(_) => Self::Metal,
89 }
90 }
91
92 fn as_device(&self) -> PyResult<Device> {
93 match self {
94 Self::Cpu => Ok(Device::Cpu),
95 Self::Cuda => {
96 let mut device = CUDA_DEVICE.lock().unwrap();
97 if let Some(device) = device.as_ref() {
98 return Ok(device.clone());
99 };
100 let d = Device::new_cuda(0).map_err(wrap_err)?;
101 *device = Some(d.clone());
102 Ok(d)
103 }
104 Self::Metal => {
105 let mut device = METAL_DEVICE.lock().unwrap();
106 if let Some(device) = device.as_ref() {
107 return Ok(device.clone());
108 };
109 let d = Device::new_metal(0).map_err(wrap_err)?;
110 *device = Some(d.clone());
111 Ok(d)
112 }
113 }
114 }
115}
116
117impl<'source> FromPyObject<'source> for PyDevice {
118 fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
119 let device: String = ob.extract()?;
120 let device = match device.as_str() {
121 "cpu" => PyDevice::Cpu,
122 "cuda" => PyDevice::Cuda,
123 _ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
124 };
125 Ok(device)
126 }
127}
128
129impl ToPyObject for PyDevice {
130 fn to_object(&self, py: Python<'_>) -> PyObject {
131 let str = match self {
132 PyDevice::Cpu => "cpu",
133 PyDevice::Cuda => "cuda",
134 PyDevice::Metal => "metal",
135 };
136 str.to_object(py)
137 }
138}
139
140trait PyWithDType: WithDType {
141 fn to_py(&self, py: Python<'_>) -> PyObject;
142}
143
144macro_rules! pydtype {
145 ($ty:ty, $conv:expr) => {
146 impl PyWithDType for $ty {
147 fn to_py(&self, py: Python<'_>) -> PyObject {
148 $conv(*self).to_object(py)
149 }
150 }
151 };
152}
153
154pydtype!(i64, |v| v);
155pydtype!(u8, |v| v);
156pydtype!(u32, |v| v);
157pydtype!(f16, f32::from);
158pydtype!(bf16, f32::from);
159pydtype!(f32, |v| v);
160pydtype!(f64, |v| v);
161
162fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result<usize> {
163 let dim = t.dim(dim)?;
164 if 0 <= index {
165 let index = index as usize;
166 if dim <= index {
167 ::candle::bail!("index {index} is too large for tensor dimension {dim}")
168 }
169 Ok(index)
170 } else {
171 if (dim as i64) < -index {
172 ::candle::bail!("index {index} is too low for tensor dimension {dim}")
173 }
174 Ok((dim as i64 + index) as usize)
175 }
176}
177
178fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
179 let rank = t.rank();
180 if 0 <= dim {
181 let dim = dim as usize;
182 if rank <= dim {
183 ::candle::bail!("dimension index {dim} is too large for tensor rank {rank}")
184 }
185 Ok(dim)
186 } else {
187 if (rank as i64) < -dim {
188 ::candle::bail!("dimension index {dim} is too low for tensor rank {rank}")
189 }
190 Ok((rank as i64 + dim) as usize)
191 }
192}
193
194trait MapDType {
196 type Output;
197 fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
198
199 fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
200 match t.dtype() {
201 DType::U8 => self.f::<u8>(t),
202 DType::U32 => self.f::<u32>(t),
203 DType::I64 => self.f::<i64>(t),
204 DType::BF16 => self.f::<bf16>(t),
205 DType::F16 => self.f::<f16>(t),
206 DType::F32 => self.f::<f32>(t),
207 DType::F64 => self.f::<f64>(t),
208 }
209 }
210}
211
212enum Indexer {
213 Index(usize),
214 Slice(usize, usize),
215 Ellipsis,
216 Expand,
217 IndexSelect(Tensor),
218}
219
220#[derive(Debug)]
221struct TorchTensor(PyObject);
222
223impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
224 fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
225 let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
226 Ok(TorchTensor(numpy_value))
227 }
228}
229
230#[pymethods]
231impl PyTensor {
232 #[new]
233 #[pyo3(text_signature = "(self, data:_ArrayLike)")]
234 fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
237 use Device::Cpu;
238 let tensor = if let Ok(vs) = data.extract::<u32>(py) {
239 Tensor::new(vs, &Cpu).map_err(wrap_err)?
240 } else if let Ok(vs) = data.extract::<i64>(py) {
241 Tensor::new(vs, &Cpu).map_err(wrap_err)?
242 } else if let Ok(vs) = data.extract::<f32>(py) {
243 Tensor::new(vs, &Cpu).map_err(wrap_err)?
244 } else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
245 let len = vs.len();
246 Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
247 } else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
248 let len = vs.len();
249 Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
250 } else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
251 let len = vs.len();
252 Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
253 } else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
254 Tensor::new(vs, &Cpu).map_err(wrap_err)?
255 } else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
256 Tensor::new(vs, &Cpu).map_err(wrap_err)?
257 } else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
258 Tensor::new(vs, &Cpu).map_err(wrap_err)?
259 } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
260 Tensor::new(vs, &Cpu).map_err(wrap_err)?
261 } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
262 Tensor::new(vs, &Cpu).map_err(wrap_err)?
263 } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
264 Tensor::new(vs, &Cpu).map_err(wrap_err)?
265 } else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
266 return PyTensor::new(py, numpy);
267 } else {
268 let ty = data.bind(py).get_type();
269 Err(PyTypeError::new_err(format!(
270 "incorrect type {ty} for tensor"
271 )))?
272 };
273 Ok(Self(tensor))
274 }
275
276 fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
279 struct M<'a>(Python<'a>);
280 impl MapDType for M<'_> {
281 type Output = PyObject;
282 fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
283 match t.rank() {
284 0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
285 1 => {
286 let v = t.to_vec1::<T>().map_err(wrap_err)?;
287 let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
288 Ok(v.to_object(self.0))
289 }
290 2 => {
291 let v = t.to_vec2::<T>().map_err(wrap_err)?;
292 let v = v
293 .iter()
294 .map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
295 .collect::<Vec<Vec<_>>>();
296 Ok(v.to_object(self.0))
297 }
298 3 => {
299 let v = t.to_vec3::<T>().map_err(wrap_err)?;
300 let v = v
301 .iter()
302 .map(|v| {
303 v.iter()
304 .map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
305 .collect()
306 })
307 .collect::<Vec<Vec<Vec<_>>>>();
308 Ok(v.to_object(self.0))
309 }
310 n => Err(PyTypeError::new_err(format!(
311 "TODO: conversion to PyObject is not handled for rank {n}"
312 )))?,
313 }
314 }
315 }
316 M(py).map(self)
318 }
319
320 fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
323 let candle_values = self.values(py)?;
324 let torch_tensor: PyObject = py
325 .import_bound("torch")?
326 .getattr("tensor")?
327 .call1((candle_values,))?
328 .extract()?;
329 Ok(torch_tensor)
330 }
331
332 #[getter]
333 fn shape(&self, py: Python<'_>) -> PyObject {
336 PyTuple::new_bound(py, self.0.dims()).to_object(py)
337 }
338
339 #[getter]
340 fn nelement(&self) -> usize {
343 self.0.elem_count()
344 }
345
346 #[getter]
347 fn stride(&self, py: Python<'_>) -> PyObject {
350 PyTuple::new_bound(py, self.0.stride()).to_object(py)
351 }
352
353 #[getter]
354 fn dtype(&self) -> PyDType {
357 PyDType(self.0.dtype())
358 }
359
360 #[getter]
361 fn device(&self, py: Python<'_>) -> PyObject {
364 PyDevice::from_device(self.0.device()).to_object(py)
365 }
366
367 #[getter]
368 fn rank(&self) -> usize {
371 self.0.rank()
372 }
373
374 fn __repr__(&self) -> String {
375 format!("{}", self.0)
376 }
377
378 fn __str__(&self) -> String {
379 self.__repr__()
380 }
381
382 fn abs(&self) -> PyResult<Self> {
385 Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
386 }
387
388 fn sin(&self) -> PyResult<Self> {
391 Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
392 }
393
394 fn cos(&self) -> PyResult<Self> {
397 Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
398 }
399
400 fn log(&self) -> PyResult<Self> {
403 Ok(PyTensor(self.0.log().map_err(wrap_err)?))
404 }
405
406 fn sqr(&self) -> PyResult<Self> {
409 Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
410 }
411
412 fn sqrt(&self) -> PyResult<Self> {
415 Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
416 }
417
418 fn recip(&self) -> PyResult<Self> {
421 Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
422 }
423
424 fn exp(&self) -> PyResult<Self> {
427 Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
428 }
429
430 #[pyo3(text_signature = "(self, p:float)")]
431 fn powf(&self, p: f64) -> PyResult<Self> {
434 Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
435 }
436
437 #[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")]
438 fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
447 let dim = actual_dim(self, dim).map_err(wrap_err)?;
448 Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
449 }
450
451 fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> {
453 let dim = actual_dim(self, dim).map_err(wrap_err)?;
454 Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?))
455 }
456
457 #[pyo3(text_signature = "(self, rhs:Tensor)")]
458 fn matmul(&self, rhs: &Self) -> PyResult<Self> {
461 Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
462 }
463
464 #[pyo3(text_signature = "(self, rhs:Tensor)")]
465 fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
468 Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
469 }
470
471 #[pyo3(text_signature = "(self, rhs:Tensor)")]
472 fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
475 Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
476 }
477
478 #[pyo3(text_signature = "(self, rhs:Tensor)")]
479 fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
482 Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
483 }
484
485 #[pyo3(text_signature = "(self, rhs:Tensor)")]
486 fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
489 Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
490 }
491
492 #[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")]
493 fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
498 Ok(PyTensor(
499 self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
500 ))
501 }
502
503 #[getter]
504 fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
507 let mut indexers: Vec<Indexer> = vec![];
508 let dims = self.0.shape().dims();
509
510 fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult<usize> {
511 let actual_index = if index < 0 {
513 dims[current_dim] as isize + index
514 } else {
515 index
516 };
517
518 if actual_index < 0 || actual_index >= dims[current_dim] as isize {
520 return Err(PyValueError::new_err(format!(
521 "index out of range for dimension '{i}' with indexer '{value}'",
522 i = current_dim,
523 value = index
524 )));
525 }
526 Ok(actual_index as usize)
527 }
528
529 fn extract_indexer(
530 py_indexer: &Bound<PyAny>,
531 current_dim: usize,
532 dims: &[usize],
533 index_argument_count: usize,
534 ) -> PyResult<(Indexer, usize)> {
535 if let Ok(index) = py_indexer.extract() {
536 Ok((
538 Indexer::Index(to_absolute_index(index, current_dim, dims)?),
539 current_dim + 1,
540 ))
541 } else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
542 let index = slice.indices(dims[current_dim] as isize)?;
544 Ok((
545 Indexer::Slice(index.start as usize, index.stop as usize),
546 current_dim + 1,
547 ))
548 } else if let Ok(tensor) = py_indexer.extract::<PyTensor>() {
549 let t = tensor.0;
551 if t.rank() != 1 {
552 return Err(PyTypeError::new_err(
553 "multi-dimensional tensor indexing is not supported",
554 ));
555 }
556 Ok((Indexer::IndexSelect(t), current_dim + 1))
557 } else if let Ok(list) = py_indexer.downcast::<pyo3::types::PyList>() {
558 let mut indexes = vec![];
560 for item in list.iter() {
561 let index = item.extract::<i64>()?;
562 indexes.push(index);
563 }
564 Ok((
565 Indexer::IndexSelect(
566 Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?,
567 ),
568 current_dim + 1,
569 ))
570 } else if py_indexer.is(&py_indexer.py().Ellipsis()) {
571 if current_dim > 0 {
573 return Err(PyTypeError::new_err(
574 "Ellipsis ('...') can only be used at the start of an indexing operation",
575 ));
576 }
577 Ok((Indexer::Ellipsis, dims.len() - (index_argument_count - 1)))
578 } else if py_indexer.is_none() {
579 Ok((Indexer::Expand, current_dim))
581 } else {
582 Err(PyTypeError::new_err(format!(
583 "unsupported indexer {}",
584 py_indexer
585 )))
586 }
587 }
588
589 if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
590 let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
591
592 if not_none_count > dims.len() {
593 return Err(PyValueError::new_err("provided too many indices"));
594 }
595
596 let mut current_dim = 0;
597 for item in tuple.iter() {
598 let (indexer, new_current_dim) =
599 extract_indexer(&item, current_dim, dims, not_none_count)?;
600 current_dim = new_current_dim;
601 indexers.push(indexer);
602 }
603 } else {
604 let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
605 indexers.push(indexer);
606 }
607
608 let mut x = self.0.clone();
609 let mut current_dim = 0;
610 for indexer in indexers.iter() {
612 x = match indexer {
613 Indexer::Index(n) => x
614 .narrow(current_dim, *n, 1)
615 .map_err(wrap_err)?
616 .squeeze(current_dim)
617 .map_err(wrap_err)?,
618 Indexer::Slice(start, stop) => {
619 let out = x
620 .narrow(current_dim, *start, stop.saturating_sub(*start))
621 .map_err(wrap_err)?;
622 current_dim += 1;
623 out
624 }
625 Indexer::Ellipsis => {
626 current_dim += dims.len() - (indexers.len() - 1);
629 x
630 }
631 Indexer::Expand => {
632 let out = x.unsqueeze(current_dim).map_err(wrap_err)?;
634 current_dim += 1;
635 out
636 }
637 Indexer::IndexSelect(indexes) => {
638 let out = x
639 .index_select(
640 &indexes.to_device(x.device()).map_err(wrap_err)?,
641 current_dim,
642 )
643 .map_err(wrap_err)?;
644 current_dim += 1;
645 out
646 }
647 }
648 }
649
650 Ok(Self(x))
651 }
652
653 fn __add__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
656 let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
657 self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
658 } else if let Ok(rhs) = rhs.extract::<f64>() {
659 (&self.0 + rhs).map_err(wrap_err)?
660 } else {
661 Err(PyTypeError::new_err("unsupported rhs for add"))?
662 };
663 Ok(Self(tensor))
664 }
665
666 fn __radd__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
667 self.__add__(rhs)
668 }
669
670 fn __mul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
673 let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
674 self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
675 } else if let Ok(rhs) = rhs.extract::<f64>() {
676 (&self.0 * rhs).map_err(wrap_err)?
677 } else {
678 Err(PyTypeError::new_err("unsupported rhs for mul"))?
679 };
680 Ok(Self(tensor))
681 }
682
683 fn __rmul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
684 self.__mul__(rhs)
685 }
686
687 fn __sub__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
690 let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
691 self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
692 } else if let Ok(rhs) = rhs.extract::<f64>() {
693 (&self.0 - rhs).map_err(wrap_err)?
694 } else {
695 Err(PyTypeError::new_err("unsupported rhs for sub"))?
696 };
697 Ok(Self(tensor))
698 }
699
700 fn __truediv__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
703 let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
704 self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
705 } else if let Ok(rhs) = rhs.extract::<f64>() {
706 (&self.0 / rhs).map_err(wrap_err)?
707 } else {
708 Err(PyTypeError::new_err("unsupported rhs for div"))?
709 };
710 Ok(Self(tensor))
711 }
712 fn __richcmp__(&self, rhs: &Bound<PyAny>, op: CompareOp) -> PyResult<Self> {
715 let compare = |lhs: &Tensor, rhs: &Tensor| {
716 let t = match op {
717 CompareOp::Eq => lhs.eq(rhs),
718 CompareOp::Ne => lhs.ne(rhs),
719 CompareOp::Lt => lhs.lt(rhs),
720 CompareOp::Le => lhs.le(rhs),
721 CompareOp::Gt => lhs.gt(rhs),
722 CompareOp::Ge => lhs.ge(rhs),
723 };
724 Ok(PyTensor(t.map_err(wrap_err)?))
725 };
726 if let Ok(rhs) = rhs.extract::<PyTensor>() {
727 if self.0.shape() == rhs.0.shape() {
728 compare(&self.0, &rhs.0)
729 } else {
730 let broadcast_shape = self
732 .0
733 .shape()
734 .broadcast_shape_binary_op(rhs.0.shape(), "cmp")
735 .map_err(wrap_err)?;
736 let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
737 let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
738
739 compare(&broadcasted_lhs, &broadcasted_rhs)
740 }
741 } else if let Ok(rhs) = rhs.extract::<f64>() {
742 let scalar_tensor = Tensor::new(rhs, self.0.device())
743 .map_err(wrap_err)?
744 .to_dtype(self.0.dtype())
745 .map_err(wrap_err)?
746 .broadcast_as(self.0.shape())
747 .map_err(wrap_err)?;
748
749 compare(&self.0, &scalar_tensor)
750 } else {
751 return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
752 }
753 }
754
755 fn __hash__(&self) -> u64 {
756 let mut hasher = DefaultHasher::new();
759 let pointer = &self.0 as *const Tensor;
760 let address = pointer as usize;
761 address.hash(&mut hasher);
762 hasher.finish()
763 }
764
765 #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
766 fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
769 Ok(PyTensor(
770 self.0
771 .reshape(shape.to_absolute(&self.0)?)
772 .map_err(wrap_err)?,
773 ))
774 }
775
776 #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
777 fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
780 Ok(PyTensor(
781 self.0
782 .broadcast_as(shape.to_absolute(&self.0)?)
783 .map_err(wrap_err)?,
784 ))
785 }
786
787 #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
788 fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
791 Ok(PyTensor(
792 self.0
793 .broadcast_left(shape.to_absolute(&self.0)?)
794 .map_err(wrap_err)?,
795 ))
796 }
797
798 #[pyo3(text_signature = "(self, dim:int)")]
799 fn squeeze(&self, dim: i64) -> PyResult<Self> {
802 let dim = actual_dim(self, dim).map_err(wrap_err)?;
803 Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
804 }
805
806 #[pyo3(text_signature = "(self, dim:int)")]
807 fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
810 Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
811 }
812
813 #[pyo3(text_signature = "(self, index:int)")]
814 fn get(&self, index: i64) -> PyResult<Self> {
817 let index = actual_index(self, 0, index).map_err(wrap_err)?;
818 Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
819 }
820
821 #[pyo3(text_signature = "(self, dim1:int, dim2:int)")]
822 fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
825 Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
826 }
827
828 #[pyo3(text_signature = "(self, dim:int, start:int, len:int)")]
829 fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
833 let dim = actual_dim(self, dim).map_err(wrap_err)?;
834 let start = actual_index(self, dim, start).map_err(wrap_err)?;
835 Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
836 }
837
838 #[pyo3(text_signature = "(self, dim:int)")]
839 fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
842 let dim = actual_dim(self, dim).map_err(wrap_err)?;
843 Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
844 }
845
846 #[pyo3(text_signature = "(self, dim:int)")]
847 fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
850 let dim = actual_dim(self, dim).map_err(wrap_err)?;
851 Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
852 }
853
854 #[pyo3(text_signature = "(self, dim:int)")]
855 fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
858 let dim = actual_dim(self, dim).map_err(wrap_err)?;
859 Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
860 }
861
862 #[pyo3(text_signature = "(self, dim:int)")]
863 fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
866 let dim = actual_dim(self, dim).map_err(wrap_err)?;
867 Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
868 }
869
870 #[pyo3(text_signature = "(self, dim:Union[int, List[int]])")]
871 fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
874 let dims = if let Ok(dim) = dims.extract::<usize>(py) {
875 vec![dim]
876 } else {
877 dims.extract::<Vec<usize>>(py)?
878 };
879 Ok(PyTensor(
880 self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
881 ))
882 }
883
884 fn sum_all(&self) -> PyResult<Self> {
887 Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
888 }
889
890 fn mean_all(&self) -> PyResult<Self> {
893 let elements = self.0.elem_count();
894 let sum = self.0.sum_all().map_err(wrap_err)?;
895 let mean = (sum / elements as f64).map_err(wrap_err)?;
896 Ok(PyTensor(mean))
897 }
898
899 #[pyo3(text_signature = "(self, dim:int)")]
900 fn flatten_from(&self, dim: i64) -> PyResult<Self> {
903 let dim = actual_dim(self, dim).map_err(wrap_err)?;
904 Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
905 }
906
907 #[pyo3(text_signature = "(self, dim:int)")]
908 fn flatten_to(&self, dim: i64) -> PyResult<Self> {
911 let dim = actual_dim(self, dim).map_err(wrap_err)?;
912 Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
913 }
914
915 fn flatten_all(&self) -> PyResult<Self> {
918 Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
919 }
920
921 fn t(&self) -> PyResult<Self> {
924 Ok(PyTensor(self.0.t().map_err(wrap_err)?))
925 }
926
927 fn contiguous(&self) -> PyResult<Self> {
930 Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
931 }
932
933 fn is_contiguous(&self) -> bool {
936 self.0.is_contiguous()
937 }
938
939 fn is_fortran_contiguous(&self) -> bool {
942 self.0.is_fortran_contiguous()
943 }
944
945 fn detach(&self) -> Self {
948 PyTensor(self.0.detach())
949 }
950
951 fn copy(&self) -> PyResult<Self> {
954 Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
955 }
956
957 #[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
958 fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
961 let mut device: Option<PyDevice> = None;
962 let mut dtype: Option<PyDType> = None;
963 let mut other: Option<PyTensor> = None;
964
965 fn handle_duplicates<T>(
966 opt: &mut Option<T>,
967 extraction_result: PyResult<T>,
968 err_msg: &'static str,
969 ) -> PyResult<()> {
970 if let Ok(successful_extraction) = extraction_result {
971 if opt.is_some() {
972 return Err(PyValueError::new_err(err_msg));
973 }
974 *opt = Some(successful_extraction);
975 }
976 Ok(())
977 }
978
979 for arg in args.iter() {
981 if arg.extract::<PyDevice>().is_ok() {
982 handle_duplicates(
983 &mut device,
984 arg.extract::<PyDevice>(),
985 "cannot specify multiple devices",
986 )?;
987 } else if arg.extract::<PyDType>().is_ok() {
988 handle_duplicates(
989 &mut dtype,
990 arg.extract::<PyDType>(),
991 "cannot specify multiple dtypes",
992 )?;
993 } else if arg.extract::<PyTensor>().is_ok() {
994 handle_duplicates(
995 &mut other,
996 arg.extract::<PyTensor>(),
997 "cannot specify multiple output tensors",
998 )?;
999 } else {
1000 return Err(PyTypeError::new_err(format!(
1001 "unsupported argument type `{:#?}`",
1002 arg.get_type().name()
1003 )));
1004 }
1005 }
1006
1007 if let Some(kwargs) = kwargs {
1008 if let Ok(Some(any)) = kwargs.get_item("dtype") {
1009 handle_duplicates(
1010 &mut dtype,
1011 any.extract::<PyDType>(),
1012 "cannot specify multiple dtypes",
1013 )?;
1014 }
1015 if let Ok(Some(any)) = kwargs.get_item("device") {
1016 handle_duplicates(
1017 &mut device,
1018 any.extract::<PyDevice>(),
1019 "cannot specify multiple devices",
1020 )?;
1021 }
1022 if let Ok(Some(any)) = kwargs.get_item("other") {
1023 handle_duplicates(
1024 &mut other,
1025 any.extract::<PyTensor>(),
1026 "cannot specify multiple output tensors",
1027 )?;
1028 }
1029 }
1030
1031 if let Some(other) = other {
1032 if device.is_some() {
1033 return Err(PyValueError::new_err(
1034 "cannot specify both an output tensor and a device",
1035 ));
1036 }
1037 if dtype.is_some() {
1038 return Err(PyValueError::new_err(
1039 "cannot specify both an output tensor and a dtype",
1040 ));
1041 }
1042 dtype = Some(other.dtype());
1043 device = Some(PyDevice::from_device(other.0.device()));
1044 }
1045
1046 let result = match (device, dtype) {
1047 (Some(device), Some(dtype)) => self
1048 .0
1049 .to_device(&device.as_device()?)
1050 .map_err(wrap_err)?
1051 .to_dtype(dtype.0)
1052 .map_err(wrap_err)?,
1053 (Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
1054 (None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
1055 (None, None) => return Err(PyTypeError::new_err("No valid dtype or device specified")),
1056 };
1057
1058 Ok(PyTensor(result))
1059 }
1060
1061 #[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
1062 fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
1065 let dtype = PyDType::from_pyobject(dtype, py)?;
1066 Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
1067 }
1068
1069 #[pyo3(text_signature = "(self, device:Union[str,Device])")]
1070 fn to_device(&self, device: PyDevice) -> PyResult<Self> {
1073 let device = device.as_device()?;
1074 Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
1075 }
1076
1077 #[pyo3(text_signature = "(self, quantized_dtype:str)")]
1078 fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
1081 use ::candle::quantized;
1082 let res = match quantized_dtype.to_lowercase().as_str() {
1083 "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
1084 "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),
1085 "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0),
1086 "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1),
1087 "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K),
1088 "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0),
1089 "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1),
1090 "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K),
1091 "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K),
1092 "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0),
1093 "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1),
1094 "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K),
1095 "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16),
1096 "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32),
1097 dt => {
1098 return Err(PyErr::new::<PyValueError, _>(format!(
1099 "unknown quantized-dtype {dt}"
1100 )))
1101 }
1102 };
1103 Ok(PyQTensor(Arc::new(res.map_err(wrap_err)?)))
1104 }
1105}
1106
1107#[pyfunction]
1108#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
1109fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
1112 if tensors.is_empty() {
1113 return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
1114 }
1115 let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?;
1116 let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
1117 let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
1118 Ok(PyTensor(tensor))
1119}
1120
1121#[pyfunction]
1122#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
1123fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
1126 let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
1127 let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
1128 Ok(PyTensor(tensor))
1129}
1130
1131#[pyfunction]
1132#[pyo3(text_signature = "(data:_ArrayLike)")]
1133fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
1136 PyTensor::new(py, data)
1137}
1138
1139#[pyfunction]
1140#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
1141fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
1144 let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1145 let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
1146 Ok(PyTensor(tensor))
1147}
1148
1149#[pyfunction]
1150#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
1151fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
1154 let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1155 let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
1156 Ok(PyTensor(tensor))
1157}
1158
1159#[pyfunction]
1160#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
1161fn ones(
1164 py: Python<'_>,
1165 shape: PyShape,
1166 dtype: Option<PyObject>,
1167 device: Option<PyDevice>,
1168) -> PyResult<PyTensor> {
1169 let dtype = match dtype {
1170 None => DType::F32,
1171 Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
1172 };
1173 let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1174 let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
1175 Ok(PyTensor(tensor))
1176}
1177
1178#[pyfunction]
1179#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
1180fn zeros(
1183 py: Python<'_>,
1184 shape: PyShape,
1185 dtype: Option<PyObject>,
1186 device: Option<PyDevice>,
1187) -> PyResult<PyTensor> {
1188 let dtype = match dtype {
1189 None => DType::F32,
1190 Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
1191 };
1192 let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1193 let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
1194 Ok(PyTensor(tensor))
1195}
1196
1197#[derive(Debug, Clone)]
1198#[pyclass(name = "QTensor")]
1199struct PyQTensor(Arc<QTensor>);
1201
1202impl std::ops::Deref for PyQTensor {
1203 type Target = QTensor;
1204
1205 fn deref(&self) -> &Self::Target {
1206 self.0.as_ref()
1207 }
1208}
1209
1210#[pymethods]
1211impl PyQTensor {
1212 #[getter]
1213 fn ggml_dtype(&self) -> String {
1216 format!("{:?}", self.0.dtype())
1217 }
1218
1219 #[getter]
1220 fn rank(&self) -> usize {
1223 self.0.rank()
1224 }
1225
1226 #[getter]
1227 fn shape(&self, py: Python<'_>) -> PyObject {
1230 PyTuple::new_bound(py, self.0.shape().dims()).to_object(py)
1231 }
1232
1233 fn __repr__(&self) -> String {
1234 format!("{:?}", self.0)
1235 }
1236
1237 fn __str__(&self) -> String {
1238 self.__repr__()
1239 }
1240
1241 fn dequantize(&self) -> PyResult<PyTensor> {
1244 let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
1245 Ok(PyTensor(tensor))
1246 }
1247
1248 #[pyo3(text_signature = "(self, lhs:Tensor)")]
1249 fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
1252 let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()).map_err(wrap_err)?;
1253 let res = qmatmul.forward(lhs).map_err(wrap_err)?;
1254 Ok(PyTensor(res))
1255 }
1256}
1257
1258#[pyfunction]
1259#[pyo3(text_signature = "(path:Union[str,PathLike])")]
1260fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
1263 let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
1264 let res = res
1265 .into_iter()
1266 .map(|(key, value)| (key, PyTensor(value).into_py(py)))
1267 .collect::<Vec<_>>();
1268 Ok(res.into_py_dict_bound(py).to_object(py))
1269}
1270
1271#[pyfunction]
1272#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
1273fn save_safetensors(
1276 path: &str,
1277 tensors: std::collections::HashMap<String, PyTensor>,
1278) -> PyResult<()> {
1279 let tensors = tensors
1280 .into_iter()
1281 .map(|(s, t)| (s, t.0))
1282 .collect::<std::collections::HashMap<_, _>>();
1283 ::candle::safetensors::save(&tensors, path).map_err(wrap_err)
1284}
1285
1286#[pyfunction]
1287#[pyo3(signature = (path, device = None))]
1288fn load_ggml(
1292 path: &str,
1293 device: Option<PyDevice>,
1294 py: Python<'_>,
1295) -> PyResult<(PyObject, PyObject, PyObject)> {
1296 let mut file = std::fs::File::open(path)?;
1297 let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1298 let ggml =
1299 ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
1300 let tensors = ggml
1301 .tensors
1302 .into_iter()
1303 .map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
1304 .collect::<::candle::Result<Vec<_>>>()
1305 .map_err(wrap_err)?;
1306 let tensors = tensors.into_py_dict_bound(py).to_object(py);
1307 let hparams = [
1308 ("n_vocab", ggml.hparams.n_vocab),
1309 ("n_embd", ggml.hparams.n_embd),
1310 ("n_mult", ggml.hparams.n_mult),
1311 ("n_head", ggml.hparams.n_head),
1312 ("n_layer", ggml.hparams.n_layer),
1313 ("n_rot", ggml.hparams.n_rot),
1314 ("ftype", ggml.hparams.ftype),
1315 ];
1316 let hparams = hparams.into_py_dict_bound(py).to_object(py);
1317 let vocab = ggml
1318 .vocab
1319 .token_score_pairs
1320 .iter()
1321 .map(|(bytes, _)| String::from_utf8_lossy(bytes.as_slice()).to_string())
1322 .collect::<Vec<String>>()
1323 .to_object(py);
1324 Ok((tensors, hparams, vocab))
1325}
1326
1327#[pyfunction]
1328#[pyo3(signature = (path, device = None))]
1329fn load_gguf(
1333 path: &str,
1334 device: Option<PyDevice>,
1335 py: Python<'_>,
1336) -> PyResult<(PyObject, PyObject)> {
1337 let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
1338 use ::candle::quantized::gguf_file;
1339 fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
1340 let v: PyObject = match v {
1341 gguf_file::Value::U8(x) => x.into_py(py),
1342 gguf_file::Value::I8(x) => x.into_py(py),
1343 gguf_file::Value::U16(x) => x.into_py(py),
1344 gguf_file::Value::I16(x) => x.into_py(py),
1345 gguf_file::Value::U32(x) => x.into_py(py),
1346 gguf_file::Value::I32(x) => x.into_py(py),
1347 gguf_file::Value::U64(x) => x.into_py(py),
1348 gguf_file::Value::I64(x) => x.into_py(py),
1349 gguf_file::Value::F32(x) => x.into_py(py),
1350 gguf_file::Value::F64(x) => x.into_py(py),
1351 gguf_file::Value::Bool(x) => x.into_py(py),
1352 gguf_file::Value::String(x) => x.into_py(py),
1353 gguf_file::Value::Array(x) => {
1354 let list = pyo3::types::PyList::empty_bound(py);
1355 for elem in x.iter() {
1356 list.append(gguf_value_to_pyobject(elem, py)?)?;
1357 }
1358 list.into()
1359 }
1360 };
1361 Ok(v)
1362 }
1363 let mut file = std::fs::File::open(path)?;
1364 let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?;
1365 let tensors = gguf
1366 .tensor_infos
1367 .keys()
1368 .map(|key| {
1369 let qtensor = gguf.tensor(&mut file, key, &device)?;
1370 Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
1371 })
1372 .collect::<::candle::Result<Vec<_>>>()
1373 .map_err(wrap_err)?;
1374 let tensors = tensors.into_py_dict_bound(py).to_object(py);
1375 let metadata = gguf
1376 .metadata
1377 .iter()
1378 .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
1379 .collect::<PyResult<Vec<_>>>()?
1380 .into_py_dict_bound(py)
1381 .to_object(py);
1382 Ok((tensors, metadata))
1383}
1384
1385#[pyfunction]
1386#[pyo3(
1387 signature = (path, tensors, metadata)
1388)]
1389fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
1391 use ::candle::quantized::gguf_file;
1392
1393 fn pyobject_to_gguf_value(v: &Bound<PyAny>, py: Python<'_>) -> PyResult<gguf_file::Value> {
1394 let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
1395 gguf_file::Value::U8(x)
1396 } else if let Ok(x) = v.extract::<i8>() {
1397 gguf_file::Value::I8(x)
1398 } else if let Ok(x) = v.extract::<u16>() {
1399 gguf_file::Value::U16(x)
1400 } else if let Ok(x) = v.extract::<i16>() {
1401 gguf_file::Value::I16(x)
1402 } else if let Ok(x) = v.extract::<u32>() {
1403 gguf_file::Value::U32(x)
1404 } else if let Ok(x) = v.extract::<i32>() {
1405 gguf_file::Value::I32(x)
1406 } else if let Ok(x) = v.extract::<u64>() {
1407 gguf_file::Value::U64(x)
1408 } else if let Ok(x) = v.extract::<i64>() {
1409 gguf_file::Value::I64(x)
1410 } else if let Ok(x) = v.extract::<f32>() {
1411 gguf_file::Value::F32(x)
1412 } else if let Ok(x) = v.extract::<f64>() {
1413 gguf_file::Value::F64(x)
1414 } else if let Ok(x) = v.extract::<bool>() {
1415 gguf_file::Value::Bool(x)
1416 } else if let Ok(x) = v.extract::<String>() {
1417 gguf_file::Value::String(x)
1418 } else if let Ok(x) = v.extract::<Vec<PyObject>>() {
1419 let x = x
1420 .into_iter()
1421 .map(|f| pyobject_to_gguf_value(f.bind(py), py))
1422 .collect::<PyResult<Vec<_>>>()?;
1423 gguf_file::Value::Array(x)
1424 } else {
1425 return Err(PyErr::new::<PyValueError, _>(format!(
1426 "unsupported type {:?}",
1427 v
1428 )));
1429 };
1430 Ok(v)
1431 }
1432 let tensors = tensors
1433 .downcast_bound::<PyDict>(py)
1434 .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
1435 .iter()
1436 .map(|(key, value)| {
1437 Ok((
1438 key.extract::<String>()
1439 .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
1440 value.extract::<PyQTensor>()?.0,
1441 ))
1442 })
1443 .collect::<PyResult<Vec<_>>>()?;
1444
1445 let metadata = metadata
1446 .downcast_bound::<PyDict>(py)
1447 .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
1448 .iter()
1449 .map(|(key, value)| {
1450 Ok((
1451 key.extract::<String>()
1452 .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
1453 pyobject_to_gguf_value(&value.as_borrowed(), py)?,
1454 ))
1455 })
1456 .collect::<PyResult<Vec<_>>>()?;
1457
1458 let converted_metadata: Vec<_> = metadata
1459 .iter()
1460 .map(|(name, value)| (name.as_str(), value))
1461 .collect();
1462
1463 let converted_tensors: Vec<_> = tensors
1464 .iter()
1465 .map(|(name, tensor)| (name.as_str(), tensor.as_ref()))
1466 .collect();
1467
1468 let mut file = std::fs::File::create(path)?;
1469
1470 gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err)
1471}
1472
1473#[pyfunction]
1474fn cuda_is_available() -> bool {
1477 ::candle::utils::cuda_is_available()
1478}
1479
1480#[pyfunction]
1481fn has_accelerate() -> bool {
1484 ::candle::utils::has_accelerate()
1485}
1486
1487#[pyfunction]
1488fn has_mkl() -> bool {
1491 ::candle::utils::has_mkl()
1492}
1493
1494#[pyfunction]
1495fn get_num_threads() -> usize {
1498 ::candle::utils::get_num_threads()
1499}
1500
1501fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1502 m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
1503 m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
1504 m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
1505 m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
1506 m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
1507 m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
1508 m.add_function(wrap_pyfunction!(save_gguf, m)?)?;
1509 m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
1510 m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
1511 Ok(())
1512}
1513
1514#[pyfunction]
1515#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
1516fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
1519 let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
1520 let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
1521 Ok(PyTensor(sm))
1522}
1523
1524#[pyfunction]
1525#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
1526fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
1529 let tensor = tensor
1530 .avg_pool2d_with_stride(ksize, stride)
1531 .map_err(wrap_err)?;
1532 Ok(PyTensor(tensor))
1533}
1534
1535#[pyfunction]
1536#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
1537fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
1540 let tensor = tensor
1541 .max_pool2d_with_stride(ksize, stride)
1542 .map_err(wrap_err)?;
1543 Ok(PyTensor(tensor))
1544}
1545
1546#[pyfunction]
1547#[pyo3(text_signature = "(tensor:Tensor)")]
1548fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
1551 let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
1552 Ok(PyTensor(s))
1553}
1554
1555#[pyfunction]
1556#[pyo3(text_signature = "(tensor:Tensor)")]
1557fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
1560 let s = tensor.0.gelu_erf().map_err(wrap_err)?;
1561 Ok(PyTensor(s))
1562}
1563
1564#[pyfunction]
1565#[pyo3(text_signature = "(tensor:Tensor)")]
1566fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
1569 let s = tensor.0.relu().map_err(wrap_err)?;
1570 Ok(PyTensor(s))
1571}
1572
1573#[pyfunction]
1574#[pyo3(text_signature = "(tensor:Tensor)")]
1575fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
1578 let s = tensor.0.tanh().map_err(wrap_err)?;
1579 Ok(PyTensor(s))
1580}
1581
1582fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1583 m.add_function(wrap_pyfunction!(silu, m)?)?;
1584 m.add_function(wrap_pyfunction!(softmax, m)?)?;
1585 m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
1586 m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
1587 m.add_function(wrap_pyfunction!(gelu, m)?)?;
1588 m.add_function(wrap_pyfunction!(relu, m)?)?;
1589 m.add_function(wrap_pyfunction!(tanh, m)?)?;
1590 Ok(())
1591}
1592
1593#[cfg(feature = "onnx")]
1594fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1595 use onnx::{PyONNXModel, PyONNXTensorDescriptor};
1596 m.add_class::<PyONNXModel>()?;
1597 m.add_class::<PyONNXTensorDescriptor>()?;
1598 Ok(())
1599}
1600
1601#[pymodule]
1602fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1603 let utils = PyModule::new_bound(py, "utils")?;
1604 candle_utils(py, &utils)?;
1605 m.add_submodule(&utils)?;
1606 let nn = PyModule::new_bound(py, "functional")?;
1607 candle_functional_m(py, &nn)?;
1608 m.add_submodule(&nn)?;
1609 #[cfg(feature = "onnx")]
1610 {
1611 let onnx = PyModule::new_bound(py, "onnx")?;
1612 candle_onnx_m(py, &onnx)?;
1613 m.add_submodule(&onnx)?;
1614 }
1615 m.add_class::<PyTensor>()?;
1616 m.add_class::<PyQTensor>()?;
1617 m.add_class::<PyDType>()?;
1618 m.add("u8", PyDType(DType::U8))?;
1619 m.add("u32", PyDType(DType::U32))?;
1620 m.add("i64", PyDType(DType::I64))?;
1621 m.add("bf16", PyDType(DType::BF16))?;
1622 m.add("f16", PyDType(DType::F16))?;
1623 m.add("f32", PyDType(DType::F32))?;
1624 m.add("f64", PyDType(DType::F64))?;
1625 m.add_function(wrap_pyfunction!(cat, m)?)?;
1626 m.add_function(wrap_pyfunction!(ones, m)?)?;
1627 m.add_function(wrap_pyfunction!(rand, m)?)?;
1628 m.add_function(wrap_pyfunction!(randn, m)?)?;
1629 m.add_function(wrap_pyfunction!(tensor, m)?)?;
1630 m.add_function(wrap_pyfunction!(stack, m)?)?;
1631 m.add_function(wrap_pyfunction!(zeros, m)?)?;
1632 Ok(())
1633}