use candle_core::Tensor;
use ferrum_interfaces::TensorLike;
use ferrum_types::{DataType, Device, FerrumError, Result};
use std::any::Any;
#[derive(Debug, Clone)]
pub struct CandleTensorWrapper {
tensor: Tensor,
}
impl CandleTensorWrapper {
pub fn new(tensor: Tensor) -> Self {
Self { tensor }
}
pub fn inner(&self) -> &Tensor {
&self.tensor
}
pub fn into_inner(self) -> Tensor {
self.tensor
}
pub fn from_tensorref(tensor_ref: &ferrum_interfaces::TensorRef) -> Option<Tensor> {
let _ = tensor_ref;
None
}
}
impl TensorLike for CandleTensorWrapper {
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
self.tensor.dims()
}
fn dtype(&self) -> DataType {
match self.tensor.dtype() {
candle_core::DType::F32 => DataType::FP32,
candle_core::DType::F16 => DataType::FP16,
candle_core::DType::BF16 => DataType::BF16,
_ => DataType::FP32,
}
}
fn device(&self) -> Device {
match self.tensor.device() {
candle_core::Device::Cpu => Device::CPU,
candle_core::Device::Cuda(_) => Device::CUDA(0),
candle_core::Device::Metal(_) => {
#[cfg(any(target_os = "macos", target_os = "ios"))]
return Device::Metal;
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
Device::CPU
}
}
}
fn is_contiguous(&self) -> bool {
self.tensor.is_contiguous()
}
fn view(&self, start: &[usize], end: &[usize]) -> Result<ferrum_interfaces::TensorRef> {
if start.len() != end.len() || start.len() != self.tensor.dims().len() {
return Err(FerrumError::model(format!(
"Invalid view dimensions: start={:?}, end={:?}, shape={:?}",
start,
end,
self.tensor.dims()
)));
}
let mut view = self.tensor.clone();
for (dim, (&start_idx, &end_idx)) in start.iter().zip(end.iter()).enumerate() {
if end_idx < start_idx {
return Err(FerrumError::model(format!(
"Invalid view range on dim {}: {}..{}",
dim, start_idx, end_idx
)));
}
let current_dim = view
.dims()
.get(dim)
.copied()
.ok_or_else(|| FerrumError::model("View dimension out of bounds"))?;
if end_idx > current_dim {
return Err(FerrumError::model(format!(
"View end out of bounds on dim {}: {} > {}",
dim, end_idx, current_dim
)));
}
let length = end_idx - start_idx;
if start_idx != 0 || length != current_dim {
view = view.narrow(dim, start_idx, length).map_err(|e| {
FerrumError::model(format!("View narrow failed on dim {}: {}", dim, e))
})?;
}
}
Ok(std::sync::Arc::new(CandleTensorWrapper::new(view)))
}
fn reshape(&self, shape: &[usize]) -> Result<ferrum_interfaces::TensorRef> {
let reshaped = self
.tensor
.reshape(shape)
.map_err(|e| FerrumError::model(format!("Reshape failed: {}", e)))?;
Ok(std::sync::Arc::new(CandleTensorWrapper::new(reshaped)))
}
fn to_cpu(&self) -> Result<ferrum_interfaces::TensorRef> {
if matches!(self.tensor.device(), candle_core::Device::Cpu) {
return Ok(std::sync::Arc::new(self.clone()));
}
let cpu_tensor = self
.tensor
.to_device(&candle_core::Device::Cpu)
.map_err(|e| FerrumError::model(format!("to_cpu failed: {}", e)))?;
Ok(std::sync::Arc::new(CandleTensorWrapper::new(cpu_tensor)))
}
fn to_device(&self, device: &Device) -> Result<ferrum_interfaces::TensorRef> {
let candle_device = match device {
Device::CPU => candle_core::Device::Cpu,
Device::CUDA(id) => candle_core::Device::new_cuda(*id)
.map_err(|e| FerrumError::device(format!("CUDA device error: {}", e)))?,
#[cfg(any(target_os = "macos", target_os = "ios"))]
Device::Metal => candle_core::Device::new_metal(0)
.map_err(|e| FerrumError::device(format!("Metal device error: {}", e)))?,
Device::ROCm(_) => {
return Err(FerrumError::device("ROCm not supported yet"));
}
};
let device_tensor = self
.tensor
.to_device(&candle_device)
.map_err(|e| FerrumError::model(format!("to_device failed: {}", e)))?;
Ok(std::sync::Arc::new(CandleTensorWrapper::new(device_tensor)))
}
fn to_dtype(&self, dtype: DataType) -> Result<ferrum_interfaces::TensorRef> {
let candle_dtype = match &dtype {
DataType::FP32 => candle_core::DType::F32,
DataType::FP16 => candle_core::DType::F16,
DataType::BF16 => candle_core::DType::BF16,
_ => {
return Err(FerrumError::model(format!(
"Unsupported dtype: {:?}",
dtype
)))
}
};
let converted = self
.tensor
.to_dtype(candle_dtype)
.map_err(|e| FerrumError::model(format!("to_dtype failed: {}", e)))?;
Ok(std::sync::Arc::new(CandleTensorWrapper::new(converted)))
}
fn to_vec_f32(&self) -> Result<Vec<f32>> {
let tensor = if self.tensor.dtype() != candle_core::DType::F32 {
self.tensor
.to_dtype(candle_core::DType::F32)
.map_err(|e| FerrumError::model(format!("Cast to f32 failed: {}", e)))?
} else {
self.tensor.clone()
};
match tensor.dims().len() {
1 => tensor
.to_vec1::<f32>()
.map_err(|e| FerrumError::model(format!("to_vec1 failed: {}", e))),
2 => {
let batch = tensor
.to_vec2::<f32>()
.map_err(|e| FerrumError::model(format!("to_vec2 failed: {}", e)))?;
Ok(batch.into_iter().next().unwrap_or_default())
}
3 => {
let all = tensor
.to_vec3::<f32>()
.map_err(|e| FerrumError::model(format!("to_vec3 failed: {}", e)))?;
Ok(all
.into_iter()
.next()
.and_then(|seq| seq.into_iter().last())
.unwrap_or_default())
}
4 => {
let squeezed = tensor
.squeeze(2)
.map_err(|e| FerrumError::model(format!("Squeeze dim 2 failed: {}", e)))?;
let all = squeezed
.to_vec3::<f32>()
.map_err(|e| FerrumError::model(format!("to_vec3 (from 4D) failed: {}", e)))?;
Ok(all
.into_iter()
.next()
.and_then(|seq| seq.into_iter().last())
.unwrap_or_default())
}
_ => Err(FerrumError::model(format!(
"Unsupported dims: {:?}",
self.tensor.dims()
))),
}
}
fn to_vec_u32(&self) -> Result<Vec<u32>> {
match self.tensor.dims().len() {
1 => self
.tensor
.to_vec1::<u32>()
.map_err(|e| FerrumError::model(format!("to_vec1<u32> failed: {}", e))),
2 => {
let batch = self
.tensor
.to_vec2::<u32>()
.map_err(|e| FerrumError::model(format!("to_vec2<u32> failed: {}", e)))?;
Ok(batch.into_iter().next().unwrap_or_default())
}
_ => Err(FerrumError::model(format!(
"Unsupported dims for token extraction: {:?}",
self.tensor.dims()
))),
}
}
fn argmax_last_dim_u32(&self) -> Result<u32> {
use candle_core::{IndexOp, D};
let dims = self.tensor.dims();
let logits_1d = match dims.len() {
1 => self.tensor.clone(),
2 => self
.tensor
.i(0)
.map_err(|e| FerrumError::model(format!("Index batch failed: {}", e)))?,
3 => {
let seq_len = dims[1];
self.tensor
.i((0, seq_len.saturating_sub(1)))
.map_err(|e| FerrumError::model(format!("Index last token failed: {}", e)))?
}
4 => {
let seq_len = dims[1];
self.tensor
.i((0, seq_len.saturating_sub(1), 0))
.map_err(|e| {
FerrumError::model(format!("Index last token (4D) failed: {}", e))
})?
}
_ => {
return Err(FerrumError::model(format!(
"argmax_last_dim_u32 unsupported dims: {:?}",
dims
)))
}
};
let idx = logits_1d
.argmax(D::Minus1)
.map_err(|e| FerrumError::model(format!("Argmax failed: {}", e)))?
.to_device(&candle_core::Device::Cpu)
.map_err(|e| FerrumError::model(format!("Argmax to CPU failed: {}", e)))?
.to_vec0::<u32>()
.map_err(|e| FerrumError::model(format!("Argmax readback failed: {}", e)))?;
Ok(idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn view_extracts_last_sequence_slice() {
let tensor = Tensor::from_vec(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
(1, 2, 3),
&candle_core::Device::Cpu,
)
.expect("create tensor");
let wrapper = CandleTensorWrapper::new(tensor);
let view = wrapper.view(&[0, 1, 0], &[1, 2, 3]).expect("slice view");
assert_eq!(view.shape(), &[1, 1, 3]);
assert_eq!(view.to_vec_f32().expect("to_vec_f32"), vec![4.0, 5.0, 6.0]);
}
}