Skip to main content

tl_gpu/
tensor.rs

1// GpuTensor — GPU-resident tensor with f32 storage
2
3use crate::device::GpuDevice;
4use std::sync::Arc;
5use tl_ai::TlTensor;
6use wgpu;
7
8/// Data type for GPU tensors.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum DType {
11    F32,
12    F64,
13}
14
15impl std::fmt::Display for DType {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            DType::F32 => write!(f, "f32"),
19            DType::F64 => write!(f, "f64"),
20        }
21    }
22}
23
24/// A tensor stored on the GPU as an f32 buffer.
25pub struct GpuTensor {
26    pub buffer: wgpu::Buffer,
27    pub shape: Vec<usize>,
28    pub dtype: DType,
29    pub numel: usize,
30    pub device: Arc<GpuDevice>,
31}
32
33impl GpuTensor {
34    /// Upload a CPU TlTensor (f64) to GPU as f32.
35    pub fn from_cpu(tensor: &TlTensor, device: Arc<GpuDevice>) -> Self {
36        let data_f32: Vec<f32> = tensor.data.iter().map(|&v| v as f32).collect();
37        Self::from_f32(&data_f32, tensor.data.shape().to_vec(), device)
38    }
39
40    /// Create a GpuTensor from f32 data.
41    pub fn from_f32(data: &[f32], shape: Vec<usize>, device: Arc<GpuDevice>) -> Self {
42        let bytes = bytemuck::cast_slice(data);
43        let buffer = device
44            .device
45            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
46                label: Some("gpu_tensor_data"),
47                contents: bytes,
48                usage: wgpu::BufferUsages::STORAGE
49                    | wgpu::BufferUsages::COPY_SRC
50                    | wgpu::BufferUsages::COPY_DST,
51            });
52
53        let numel = data.len();
54        GpuTensor {
55            buffer,
56            shape,
57            dtype: DType::F32,
58            numel,
59            device,
60        }
61    }
62
63    /// Download GPU tensor to CPU as TlTensor (f64).
64    pub fn to_cpu(&self) -> Result<TlTensor, String> {
65        let f32_data = self.read_f32()?;
66        let f64_data: Vec<f64> = f32_data.iter().map(|&v| v as f64).collect();
67        let shape = ndarray::IxDyn(&self.shape);
68        let array = ndarray::ArrayD::from_shape_vec(shape, f64_data)
69            .map_err(|e| format!("Shape mismatch: {e}"))?;
70        Ok(TlTensor {
71            data: array,
72            name: None,
73        })
74    }
75
76    /// Read raw f32 data from the GPU buffer.
77    pub fn read_f32(&self) -> Result<Vec<f32>, String> {
78        let size = (self.numel * std::mem::size_of::<f32>()) as u64;
79        let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
80            label: Some("staging_read"),
81            size,
82            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
83            mapped_at_creation: false,
84        });
85
86        let mut encoder =
87            self.device
88                .device
89                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
90                    label: Some("readback"),
91                });
92        encoder.copy_buffer_to_buffer(&self.buffer, 0, &staging, 0, size);
93        self.device.queue.submit(std::iter::once(encoder.finish()));
94
95        let slice = staging.slice(..);
96        let (tx, rx) = std::sync::mpsc::channel();
97        slice.map_async(wgpu::MapMode::Read, move |result| {
98            let _ = tx.send(result);
99        });
100        self.device.device.poll(wgpu::Maintain::Wait);
101        rx.recv()
102            .map_err(|e| format!("GPU readback channel error: {e}"))?
103            .map_err(|e| format!("GPU readback error: {e}"))?;
104
105        let data = slice.get_mapped_range();
106        let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
107        drop(data);
108        staging.unmap();
109
110        Ok(result)
111    }
112
113    /// Get the total byte size of the buffer.
114    pub fn byte_size(&self) -> u64 {
115        (self.numel * std::mem::size_of::<f32>()) as u64
116    }
117}
118
119impl Clone for GpuTensor {
120    fn clone(&self) -> Self {
121        let size = self.byte_size();
122        let new_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
123            label: Some("gpu_tensor_clone"),
124            size,
125            usage: wgpu::BufferUsages::STORAGE
126                | wgpu::BufferUsages::COPY_SRC
127                | wgpu::BufferUsages::COPY_DST,
128            mapped_at_creation: false,
129        });
130
131        let mut encoder =
132            self.device
133                .device
134                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
135                    label: Some("clone"),
136                });
137        encoder.copy_buffer_to_buffer(&self.buffer, 0, &new_buffer, 0, size);
138        self.device.queue.submit(std::iter::once(encoder.finish()));
139
140        GpuTensor {
141            buffer: new_buffer,
142            shape: self.shape.clone(),
143            dtype: self.dtype,
144            numel: self.numel,
145            device: self.device.clone(),
146        }
147    }
148}
149
150impl std::fmt::Debug for GpuTensor {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(
153            f,
154            "GpuTensor(shape={:?}, dtype={}, device={})",
155            self.shape, self.dtype, self.device.adapter_name
156        )
157    }
158}
159
160impl std::fmt::Display for GpuTensor {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        write!(
163            f,
164            "<gpu_tensor shape={:?} dtype={}>",
165            self.shape, self.dtype
166        )
167    }
168}
169
170// Bring in the BufferInitDescriptor from wgpu::util
171use wgpu::util::DeviceExt;
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_roundtrip_cpu_gpu_cpu() {
179        let Some(device) = GpuDevice::get() else {
180            return;
181        };
182
183        let cpu_tensor = TlTensor {
184            data: ndarray::arr1(&[1.0, 2.0, 3.0, 4.0]).into_dyn(),
185            name: None,
186        };
187
188        let gpu = GpuTensor::from_cpu(&cpu_tensor, device);
189        let back = gpu.to_cpu().unwrap();
190
191        // f32 precision: within 1e-6
192        for (a, b) in cpu_tensor.data.iter().zip(back.data.iter()) {
193            assert!((a - b).abs() < 1e-6, "mismatch: {a} vs {b}");
194        }
195    }
196}