wgpu_burn_interop/
interop.rs1use core::panic;
2
3use burn::tensor::{Int, Tensor, TensorMetadata};
4
5use burn_cubecl::tensor::CubeTensor;
6use cubecl::wgpu::WgpuRuntime;
7use gloss_burn_multibackend::{backend::MultiBackend, tensor::MultiFloatTensor, tensor::MultiIntTensor};
8
9pub fn tensor_float2wgpu_buffer(
10 tensor: Tensor<MultiBackend, 2>,
11 usages: wgpu::BufferUsages,
12 device: &wgpu::Device,
13 queue: &wgpu::Queue,
14) -> wgpu::Buffer {
15 let cube_tensor = tensor.into_primitive().tensor();
17 let MultiFloatTensor::Wgpu(cube_tensor) = cube_tensor else {
18 panic!("Expected wgpu tensor got {:?}", cube_tensor.dtype())
19 };
20
21 cubewgpu_tensor2wgpu_buffer(cube_tensor, usages, device, queue)
22}
23
24pub fn tensor_int2wgpu_buffer(
25 tensor: Tensor<MultiBackend, 2, Int>,
26 usages: wgpu::BufferUsages,
27 device: &wgpu::Device,
28 queue: &wgpu::Queue,
29) -> wgpu::Buffer {
30 let cube_tensor = tensor.into_primitive();
32 let MultiIntTensor::Wgpu(cube_tensor) = cube_tensor else {
33 panic!("Expected wgpu tensor got {:?}", cube_tensor.dtype())
34 };
35
36 cubewgpu_tensor2wgpu_buffer(cube_tensor, usages, device, queue)
37}
38
39fn cubewgpu_tensor2wgpu_buffer(
40 tensor: CubeTensor<WgpuRuntime>,
41 usages: wgpu::BufferUsages,
42 device: &wgpu::Device,
43 queue: &wgpu::Queue,
44) -> wgpu::Buffer {
45 let client = tensor.client;
47 let binding = client.get_resource(tensor.handle.clone().binding());
48 let resource = binding.resource();
49
50 let buffer = resource.buffer();
52
53 let offset = resource.offset();
55 let size = resource.size();
56
57 client.flush();
60 let dst_buffer = device.create_buffer(&wgpu::BufferDescriptor {
64 label: Some("tensor2wgpu_buffer_dst"),
65 size,
66 usage: wgpu::BufferUsages::COPY_DST | usages,
67 mapped_at_creation: false,
68 });
69
70 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
72 label: Some("tensor2wgpu_buffer_copy_encoder"),
73 });
74
75 encoder.copy_buffer_to_buffer(buffer, offset, &dst_buffer, 0, size);
76
77 queue.submit(Some(encoder.finish()));
79
80 dst_buffer
81}