use super::device::CudaDevice;
use crate::histogram::Histogram;
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, LaunchConfig, PushKernelArg};
use cudarc::nvrtc::compile_ptx;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct CacheKey(u64);
impl CacheKey {
pub(crate) fn from_slice(data: &[u8]) -> Self {
let mut hasher = DefaultHasher::new();
data.len().hash(&mut hasher);
if !data.is_empty() {
let chunk_size = 1024.min(data.len());
data[..chunk_size].hash(&mut hasher);
if data.len() > chunk_size * 2 {
let mid = data.len() / 2;
data[mid..mid + chunk_size.min(data.len() - mid)].hash(&mut hasher);
}
if data.len() > chunk_size {
data[data.len() - chunk_size..].hash(&mut hasher);
}
}
CacheKey(hasher.finish())
}
}
const HISTOGRAM_KERNEL_SOURCE: &str = include_str!("kernels/histogram.cu");
const THREADS_PER_BLOCK: u32 = 256;
#[derive(Debug, Clone, Copy)]
pub struct NodeRange {
pub start: u32,
pub count: u32,
}
pub struct HistogramKernel {
device: Arc<CudaDevice>,
module: Option<Arc<CudaModule>>,
build_histogram_fn: Option<CudaFunction>,
build_histogram_batched_fn: Option<CudaFunction>,
build_histogram_era_fn: Option<CudaFunction>,
zero_histograms_fn: Option<CudaFunction>,
cached_bins: Option<CudaSlice<u8>>,
cached_bins_key: Option<CacheKey>,
cached_gradients: Option<CudaSlice<f32>>,
cached_hessians: Option<CudaSlice<f32>>,
cached_grad_hess_key: Option<CacheKey>,
cached_era_indices: Option<CudaSlice<u16>>,
cached_era_indices_key: Option<CacheKey>,
cached_grad_hist: Option<CudaSlice<f32>>,
cached_hess_hist: Option<CudaSlice<f32>>,
cached_count_hist: Option<CudaSlice<u32>>,
cached_output_len: usize,
cached_batched_grad_hist: Option<CudaSlice<f32>>,
cached_batched_hess_hist: Option<CudaSlice<f32>>,
cached_batched_count_hist: Option<CudaSlice<u32>>,
cached_batched_output_len: usize,
cached_era_grad_hist: Option<CudaSlice<f32>>,
cached_era_hess_hist: Option<CudaSlice<f32>>,
cached_era_count_hist: Option<CudaSlice<u32>>,
cached_era_output_len: usize,
}
impl HistogramKernel {
pub fn new(device: Arc<CudaDevice>) -> Self {
Self {
device,
module: None,
build_histogram_fn: None,
build_histogram_batched_fn: None,
build_histogram_era_fn: None,
zero_histograms_fn: None,
cached_bins: None,
cached_bins_key: None,
cached_gradients: None,
cached_hessians: None,
cached_grad_hess_key: None,
cached_era_indices: None,
cached_era_indices_key: None,
cached_grad_hist: None,
cached_hess_hist: None,
cached_count_hist: None,
cached_output_len: 0,
cached_batched_grad_hist: None,
cached_batched_hess_hist: None,
cached_batched_count_hist: None,
cached_batched_output_len: 0,
cached_era_grad_hist: None,
cached_era_hess_hist: None,
cached_era_count_hist: None,
cached_era_output_len: 0,
}
}
fn ensure_initialized(&mut self) {
if self.module.is_some() {
return;
}
let ptx = compile_ptx(HISTOGRAM_KERNEL_SOURCE).expect("Failed to compile CUDA kernel");
let module = self.device.load_module(ptx);
self.build_histogram_fn = Some(CudaDevice::load_function(&module, "build_histogram"));
self.build_histogram_batched_fn = Some(CudaDevice::load_function(
&module,
"build_histogram_batched",
));
self.build_histogram_era_fn =
Some(CudaDevice::load_function(&module, "build_histogram_era"));
self.zero_histograms_fn = Some(CudaDevice::load_function(&module, "zero_histograms"));
self.module = Some(module);
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn ensure_bins_cached(&mut self, bins: &[u8]) {
let bins_key = CacheKey::from_slice(bins);
if self.cached_bins_key != Some(bins_key) || self.cached_bins.is_none() {
self.cached_bins = Some(self.device.htod_copy(bins));
self.cached_bins_key = Some(bins_key);
}
}
pub fn ensure_grad_hess_cached(&mut self, gradients: &[f32], hessians: &[f32]) {
let grad_key = {
let mut hasher = DefaultHasher::new();
gradients.len().hash(&mut hasher);
if !gradients.is_empty() {
gradients[0].to_bits().hash(&mut hasher);
if gradients.len() > 1 {
gradients[gradients.len() / 2].to_bits().hash(&mut hasher);
}
}
CacheKey(hasher.finish())
};
if self.cached_grad_hess_key != Some(grad_key)
|| self.cached_gradients.is_none()
|| self.cached_hessians.is_none()
{
self.cached_gradients = Some(self.device.htod_copy(gradients));
self.cached_hessians = Some(self.device.htod_copy(hessians));
self.cached_grad_hess_key = Some(grad_key);
}
}
pub fn cached_bins(&self) -> Option<&CudaSlice<u8>> {
self.cached_bins.as_ref()
}
pub fn build_histograms(
&mut self,
bins: &[u8],
grad_hess: &[(f32, f32)],
row_indices: &[usize],
num_rows: usize,
num_features: usize,
) -> Vec<Histogram> {
self.ensure_initialized();
if row_indices.is_empty() {
return (0..num_features).map(|_| Histogram::new()).collect();
}
let num_indices = row_indices.len();
let gradients: Vec<f32> = grad_hess.iter().map(|(g, _)| *g).collect();
let hessians: Vec<f32> = grad_hess.iter().map(|(_, h)| *h).collect();
let indices_u32: Vec<u32> = row_indices.iter().map(|&i| i as u32).collect();
let bins_key = CacheKey::from_slice(bins);
if self.cached_bins_key != Some(bins_key) || self.cached_bins.is_none() {
self.cached_bins = Some(self.device.htod_copy(bins));
self.cached_bins_key = Some(bins_key);
}
let grad_key = {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(&gradients.len(), &mut hasher);
if !gradients.is_empty() {
std::hash::Hash::hash(&gradients[0].to_bits(), &mut hasher);
if gradients.len() > 1 {
std::hash::Hash::hash(&gradients[gradients.len() / 2].to_bits(), &mut hasher);
}
}
CacheKey(std::hash::Hasher::finish(&hasher))
};
if self.cached_grad_hess_key != Some(grad_key)
|| self.cached_gradients.is_none()
|| self.cached_hessians.is_none()
{
self.cached_gradients = Some(self.device.htod_copy(&gradients));
self.cached_hessians = Some(self.device.htod_copy(&hessians));
self.cached_grad_hess_key = Some(grad_key);
}
let d_indices = self.device.htod_copy(&indices_u32);
let output_bins = num_features * 256;
if self.cached_output_len != output_bins
|| self.cached_grad_hist.is_none()
|| self.cached_hess_hist.is_none()
|| self.cached_count_hist.is_none()
{
self.cached_grad_hist = Some(self.device.alloc_zeros(output_bins));
self.cached_hess_hist = Some(self.device.alloc_zeros(output_bins));
self.cached_count_hist = Some(self.device.alloc_zeros(output_bins));
self.cached_output_len = output_bins;
}
let stream = self.device.stream();
let zero_blocks = ((output_bins + 255) / 256) as u32;
let zero_config = LaunchConfig {
block_dim: (256, 1, 1),
grid_dim: (zero_blocks, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let d_grad_hist = self.cached_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_count_hist.as_mut().unwrap();
stream
.launch_builder(self.zero_histograms_fn.as_ref().unwrap())
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(output_bins as u32))
.launch(zero_config)
.expect("Failed to launch zero_histograms kernel");
}
let rows_per_tile = THREADS_PER_BLOCK;
let row_tiles = ((num_indices as u32) + rows_per_tile - 1) / rows_per_tile;
let shared_mem_bytes = 256 * (4 + 4 + 4);
let config = LaunchConfig {
block_dim: (THREADS_PER_BLOCK, 1, 1),
grid_dim: (num_features as u32, row_tiles, 1),
shared_mem_bytes,
};
unsafe {
let d_bins = self.cached_bins.as_ref().unwrap();
let d_gradients = self.cached_gradients.as_ref().unwrap();
let d_hessians = self.cached_hessians.as_ref().unwrap();
let d_grad_hist = self.cached_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_count_hist.as_mut().unwrap();
stream
.launch_builder(self.build_histogram_fn.as_ref().unwrap())
.arg(d_bins)
.arg(d_gradients)
.arg(d_hessians)
.arg(&d_indices)
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(num_rows as u32))
.arg(&(num_features as u32))
.arg(&(num_indices as u32))
.launch(config)
.expect("Failed to launch build_histogram kernel");
}
self.device.synchronize();
let grad_hist = self
.device
.dtoh_copy(self.cached_grad_hist.as_ref().unwrap());
let hess_hist = self
.device
.dtoh_copy(self.cached_hess_hist.as_ref().unwrap());
let count_hist = self
.device
.dtoh_copy(self.cached_count_hist.as_ref().unwrap());
(0..num_features)
.map(|f| {
let base = f * 256;
let mut grads = [0.0f32; 256];
let mut hess = [0.0f32; 256];
let mut counts = [0u32; 256];
for bin in 0..256 {
grads[bin] = grad_hist[base + bin];
hess[bin] = hess_hist[base + bin];
counts[bin] = count_hist[base + bin];
}
Histogram::from_raw_arrays(&grads, &hess, &counts)
})
.collect()
}
pub fn build_histograms_batched(
&mut self,
bins: &[u8],
grad_hess: &[(f32, f32)],
batches: &[&[usize]],
_num_rows: usize,
num_features: usize,
) -> Vec<Vec<Histogram>> {
self.ensure_initialized();
let num_batches = batches.len();
if num_batches == 0 {
return Vec::new();
}
if num_batches == 1 {
return vec![self.build_histograms(
bins,
grad_hess,
batches[0],
bins.len() / num_features,
num_features,
)];
}
let mut all_indices: Vec<u32> = Vec::new();
let mut node_starts: Vec<u32> = Vec::with_capacity(num_batches);
let mut node_counts: Vec<u32> = Vec::with_capacity(num_batches);
let mut max_count = 0usize;
for batch in batches {
node_starts.push(all_indices.len() as u32);
node_counts.push(batch.len() as u32);
max_count = max_count.max(batch.len());
all_indices.extend(batch.iter().map(|&i| i as u32));
}
if all_indices.is_empty() {
return (0..num_batches)
.map(|_| (0..num_features).map(|_| Histogram::new()).collect())
.collect();
}
let gradients: Vec<f32> = grad_hess.iter().map(|(g, _)| *g).collect();
let hessians: Vec<f32> = grad_hess.iter().map(|(_, h)| *h).collect();
let bins_key = CacheKey::from_slice(bins);
if self.cached_bins_key != Some(bins_key) || self.cached_bins.is_none() {
self.cached_bins = Some(self.device.htod_copy(bins));
self.cached_bins_key = Some(bins_key);
}
let grad_key = {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(&gradients.len(), &mut hasher);
if !gradients.is_empty() {
std::hash::Hash::hash(&gradients[0].to_bits(), &mut hasher);
}
CacheKey(std::hash::Hasher::finish(&hasher))
};
if self.cached_grad_hess_key != Some(grad_key)
|| self.cached_gradients.is_none()
|| self.cached_hessians.is_none()
{
self.cached_gradients = Some(self.device.htod_copy(&gradients));
self.cached_hessians = Some(self.device.htod_copy(&hessians));
self.cached_grad_hess_key = Some(grad_key);
}
let d_indices = self.device.htod_copy(&all_indices);
let d_node_starts = self.device.htod_copy(&node_starts);
let d_node_counts = self.device.htod_copy(&node_counts);
let output_size = num_batches * num_features * 256;
if self.cached_batched_output_len < output_size
|| self.cached_batched_grad_hist.is_none()
|| self.cached_batched_hess_hist.is_none()
|| self.cached_batched_count_hist.is_none()
{
self.cached_batched_grad_hist = Some(self.device.alloc_zeros(output_size));
self.cached_batched_hess_hist = Some(self.device.alloc_zeros(output_size));
self.cached_batched_count_hist = Some(self.device.alloc_zeros(output_size));
self.cached_batched_output_len = output_size;
}
let stream = self.device.stream();
let zero_blocks = ((output_size + 255) / 256) as u32;
let zero_config = LaunchConfig {
block_dim: (256, 1, 1),
grid_dim: (zero_blocks, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let d_grad_hist = self.cached_batched_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_batched_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_batched_count_hist.as_mut().unwrap();
stream
.launch_builder(self.zero_histograms_fn.as_ref().unwrap())
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(output_size as u32))
.launch(zero_config)
.expect("Failed to launch zero_histograms kernel");
}
let tiles_per_batch = ((max_count as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let shared_mem_bytes = 256 * (4 + 4 + 4);
let config = LaunchConfig {
block_dim: (THREADS_PER_BLOCK, 1, 1),
grid_dim: (
num_features as u32,
(num_batches as u32) * tiles_per_batch,
1,
),
shared_mem_bytes,
};
unsafe {
let d_bins = self.cached_bins.as_ref().unwrap();
let d_gradients = self.cached_gradients.as_ref().unwrap();
let d_hessians = self.cached_hessians.as_ref().unwrap();
let d_grad_hist = self.cached_batched_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_batched_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_batched_count_hist.as_mut().unwrap();
stream
.launch_builder(self.build_histogram_batched_fn.as_ref().unwrap())
.arg(d_bins)
.arg(d_gradients)
.arg(d_hessians)
.arg(&d_indices)
.arg(&d_node_starts)
.arg(&d_node_counts)
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(num_features as u32))
.arg(&(num_batches as u32))
.arg(&tiles_per_batch)
.launch(config)
.expect("Failed to launch build_histogram_batched kernel");
}
self.device.synchronize();
let grad_hist = self
.device
.dtoh_copy(self.cached_batched_grad_hist.as_ref().unwrap());
let hess_hist = self
.device
.dtoh_copy(self.cached_batched_hess_hist.as_ref().unwrap());
let count_hist = self
.device
.dtoh_copy(self.cached_batched_count_hist.as_ref().unwrap());
(0..num_batches)
.map(|batch_idx| {
(0..num_features)
.map(|f| {
let base = (batch_idx * num_features + f) * 256;
let mut grads = [0.0f32; 256];
let mut hess = [0.0f32; 256];
let mut counts = [0u32; 256];
for bin in 0..256 {
grads[bin] = grad_hist[base + bin];
hess[bin] = hess_hist[base + bin];
counts[bin] = count_hist[base + bin];
}
Histogram::from_raw_arrays(&grads, &hess, &counts)
})
.collect()
})
.collect()
}
pub fn build_histograms_gpu(
&mut self,
d_indices: &CudaSlice<u32>,
node_ranges: &[NodeRange],
num_features: usize,
) -> Vec<Vec<Histogram>> {
self.ensure_initialized();
if node_ranges.is_empty() {
return Vec::new();
}
let num_nodes = node_ranges.len();
let max_node_count = node_ranges
.iter()
.map(|n| n.count as usize)
.max()
.unwrap_or(0);
if max_node_count == 0 {
return (0..num_nodes)
.map(|_| (0..num_features).map(|_| Histogram::new()).collect())
.collect();
}
let node_starts: Vec<u32> = node_ranges.iter().map(|n| n.start).collect();
let node_counts: Vec<u32> = node_ranges.iter().map(|n| n.count).collect();
let d_node_starts = self.device.htod_copy(&node_starts);
let d_node_counts = self.device.htod_copy(&node_counts);
let output_size = num_nodes * num_features * 256;
if self.cached_batched_output_len < output_size
|| self.cached_batched_grad_hist.is_none()
|| self.cached_batched_hess_hist.is_none()
|| self.cached_batched_count_hist.is_none()
{
self.cached_batched_grad_hist = Some(self.device.alloc_zeros(output_size));
self.cached_batched_hess_hist = Some(self.device.alloc_zeros(output_size));
self.cached_batched_count_hist = Some(self.device.alloc_zeros(output_size));
self.cached_batched_output_len = output_size;
}
let stream = self.device.stream();
let zero_blocks = ((output_size + 255) / 256) as u32;
let zero_config = LaunchConfig {
block_dim: (256, 1, 1),
grid_dim: (zero_blocks, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let d_grad_hist = self.cached_batched_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_batched_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_batched_count_hist.as_mut().unwrap();
stream
.launch_builder(self.zero_histograms_fn.as_ref().unwrap())
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(output_size as u32))
.launch(zero_config)
.expect("Failed to launch zero_histograms kernel");
}
let tiles_per_node = ((max_node_count as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let shared_mem_bytes = 256 * (4 + 4 + 4);
let config = LaunchConfig {
block_dim: (THREADS_PER_BLOCK, 1, 1),
grid_dim: (num_features as u32, (num_nodes as u32) * tiles_per_node, 1),
shared_mem_bytes,
};
unsafe {
let d_bins = self
.cached_bins
.as_ref()
.expect("bins must be cached first");
let d_gradients = self
.cached_gradients
.as_ref()
.expect("gradients must be cached first");
let d_hessians = self
.cached_hessians
.as_ref()
.expect("hessians must be cached first");
let d_grad_hist = self.cached_batched_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_batched_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_batched_count_hist.as_mut().unwrap();
stream
.launch_builder(self.build_histogram_batched_fn.as_ref().unwrap())
.arg(d_bins)
.arg(d_gradients)
.arg(d_hessians)
.arg(d_indices)
.arg(&d_node_starts)
.arg(&d_node_counts)
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(num_features as u32))
.arg(&(num_nodes as u32))
.arg(&tiles_per_node)
.launch(config)
.expect("Failed to launch build_histogram_batched kernel");
}
self.device.synchronize();
let grad_hist = self
.device
.dtoh_copy(self.cached_batched_grad_hist.as_ref().unwrap());
let hess_hist = self
.device
.dtoh_copy(self.cached_batched_hess_hist.as_ref().unwrap());
let count_hist = self
.device
.dtoh_copy(self.cached_batched_count_hist.as_ref().unwrap());
(0..num_nodes)
.map(|node_idx| {
(0..num_features)
.map(|f| {
let base = (node_idx * num_features + f) * 256;
let mut grads = [0.0f32; 256];
let mut hess = [0.0f32; 256];
let mut counts = [0u32; 256];
for bin in 0..256 {
grads[bin] = grad_hist[base + bin];
hess[bin] = hess_hist[base + bin];
counts[bin] = count_hist[base + bin];
}
Histogram::from_raw_arrays(&grads, &hess, &counts)
})
.collect()
})
.collect()
}
pub fn build_era_histograms(
&mut self,
bins: &[u8],
grad_hess: &[(f32, f32)],
row_indices: &[usize],
era_indices: &[u16],
num_rows: usize,
num_features: usize,
num_eras: usize,
) -> Vec<Vec<Histogram>> {
self.ensure_initialized();
if row_indices.is_empty() || num_eras == 0 {
return (0..num_eras)
.map(|_| (0..num_features).map(|_| Histogram::new()).collect())
.collect();
}
let num_indices = row_indices.len();
let gradients: Vec<f32> = grad_hess.iter().map(|(g, _)| *g).collect();
let hessians: Vec<f32> = grad_hess.iter().map(|(_, h)| *h).collect();
let indices_u32: Vec<u32> = row_indices.iter().map(|&i| i as u32).collect();
let bins_key = CacheKey::from_slice(bins);
if self.cached_bins_key != Some(bins_key) || self.cached_bins.is_none() {
self.cached_bins = Some(self.device.htod_copy(bins));
self.cached_bins_key = Some(bins_key);
}
let grad_key = {
let mut hasher = DefaultHasher::new();
gradients.len().hash(&mut hasher);
if !gradients.is_empty() {
gradients[0].to_bits().hash(&mut hasher);
if gradients.len() > 1 {
gradients[gradients.len() / 2].to_bits().hash(&mut hasher);
}
}
CacheKey(hasher.finish())
};
if self.cached_grad_hess_key != Some(grad_key)
|| self.cached_gradients.is_none()
|| self.cached_hessians.is_none()
{
self.cached_gradients = Some(self.device.htod_copy(&gradients));
self.cached_hessians = Some(self.device.htod_copy(&hessians));
self.cached_grad_hess_key = Some(grad_key);
}
let era_key = {
let era_bytes: &[u8] = bytemuck::cast_slice(era_indices);
CacheKey::from_slice(era_bytes)
};
if self.cached_era_indices_key != Some(era_key) || self.cached_era_indices.is_none() {
self.cached_era_indices = Some(self.device.htod_copy(era_indices));
self.cached_era_indices_key = Some(era_key);
}
let d_indices = self.device.htod_copy(&indices_u32);
let output_size = num_eras * num_features * 256;
if self.cached_era_output_len < output_size
|| self.cached_era_grad_hist.is_none()
|| self.cached_era_hess_hist.is_none()
|| self.cached_era_count_hist.is_none()
{
self.cached_era_grad_hist = Some(self.device.alloc_zeros(output_size));
self.cached_era_hess_hist = Some(self.device.alloc_zeros(output_size));
self.cached_era_count_hist = Some(self.device.alloc_zeros(output_size));
self.cached_era_output_len = output_size;
}
let stream = self.device.stream();
let zero_blocks = ((output_size + 255) / 256) as u32;
let zero_config = LaunchConfig {
block_dim: (256, 1, 1),
grid_dim: (zero_blocks, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let d_grad_hist = self.cached_era_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_era_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_era_count_hist.as_mut().unwrap();
stream
.launch_builder(self.zero_histograms_fn.as_ref().unwrap())
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(output_size as u32))
.launch(zero_config)
.expect("Failed to launch zero_histograms kernel");
}
let shared_mem_bytes = 256 * (4 + 4 + 4);
let config = LaunchConfig {
block_dim: (THREADS_PER_BLOCK, 1, 1),
grid_dim: (num_features as u32, num_eras as u32, 1),
shared_mem_bytes,
};
unsafe {
let d_bins = self.cached_bins.as_ref().unwrap();
let d_gradients = self.cached_gradients.as_ref().unwrap();
let d_hessians = self.cached_hessians.as_ref().unwrap();
let d_era_indices = self.cached_era_indices.as_ref().unwrap();
let d_grad_hist = self.cached_era_grad_hist.as_mut().unwrap();
let d_hess_hist = self.cached_era_hess_hist.as_mut().unwrap();
let d_count_hist = self.cached_era_count_hist.as_mut().unwrap();
stream
.launch_builder(self.build_histogram_era_fn.as_ref().unwrap())
.arg(d_bins)
.arg(d_gradients)
.arg(d_hessians)
.arg(&d_indices)
.arg(d_era_indices)
.arg(d_grad_hist)
.arg(d_hess_hist)
.arg(d_count_hist)
.arg(&(num_rows as u32))
.arg(&(num_features as u32))
.arg(&(num_indices as u32))
.arg(&(num_eras as u32))
.launch(config)
.expect("Failed to launch build_histogram_era kernel");
}
self.device.synchronize();
let grad_hist = self
.device
.dtoh_copy(self.cached_era_grad_hist.as_ref().unwrap());
let hess_hist = self
.device
.dtoh_copy(self.cached_era_hess_hist.as_ref().unwrap());
let count_hist = self
.device
.dtoh_copy(self.cached_era_count_hist.as_ref().unwrap());
(0..num_eras)
.map(|era_idx| {
(0..num_features)
.map(|f| {
let base = (era_idx * num_features + f) * 256;
let mut grads = [0.0f32; 256];
let mut hess = [0.0f32; 256];
let mut counts = [0u32; 256];
for bin in 0..256 {
grads[bin] = grad_hist[base + bin];
hess[bin] = hess_hist[base + bin];
counts[bin] = count_hist[base + bin];
}
Histogram::from_raw_arrays(&grads, &hess, &counts)
})
.collect()
})
.collect()
}
}