#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
use cobre_io::output::policy::{PolicyBasisRecord, PolicyCutRecord, StageCutsPayload};
use crate::cut::FutureCostFunction;
use crate::training::TrainingResult;
#[must_use]
pub fn build_stage_cut_records(fcf: &FutureCostFunction) -> Vec<Vec<PolicyCutRecord<'_>>> {
fcf.pools
.iter()
.map(|pool| {
(0..pool.populated_count)
.filter(|&i| pool.active[i])
.map(|i| {
let meta = &pool.metadata[i];
PolicyCutRecord {
cut_id: meta.iteration_generated * u64::from(pool.forward_passes)
+ u64::from(meta.forward_pass_index),
slot_index: i as u32,
iteration: meta.iteration_generated as u32,
forward_pass_index: meta.forward_pass_index,
intercept: pool.intercepts[i],
coefficients: &pool.coefficients[i],
is_active: true,
domination_count: meta.active_count as u32,
}
})
.collect()
})
.collect()
}
#[must_use]
pub fn build_active_indices(stage_records: &[Vec<PolicyCutRecord<'_>>]) -> Vec<Vec<u32>> {
stage_records
.iter()
.map(|records| records.iter().map(|r| r.slot_index).collect())
.collect()
}
#[must_use]
pub fn build_stage_cuts_payloads<'a>(
fcf: &FutureCostFunction,
stage_records: &'a [Vec<PolicyCutRecord<'a>>],
stage_active_indices: &'a [Vec<u32>],
) -> Vec<StageCutsPayload<'a>> {
fcf.pools
.iter()
.enumerate()
.map(|(stage_idx, pool)| StageCutsPayload {
stage_id: stage_idx as u32,
state_dimension: fcf.state_dimension as u32,
capacity: pool.capacity as u32,
warm_start_count: pool.warm_start_count,
cuts: &stage_records[stage_idx],
active_cut_indices: &stage_active_indices[stage_idx],
populated_count: pool.populated_count as u32,
})
.collect()
}
#[must_use]
pub fn convert_basis_cache(training_result: &TrainingResult) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
let col = training_result
.basis_cache
.iter()
.map(|opt| {
opt.as_ref()
.map(|b| b.col_status.iter().map(|&v| v as u8).collect())
.unwrap_or_default()
})
.collect();
let row = training_result
.basis_cache
.iter()
.map(|opt| {
opt.as_ref()
.map(|b| b.row_status.iter().map(|&v| v as u8).collect())
.unwrap_or_default()
})
.collect();
(col, row)
}
#[must_use]
pub fn build_stage_basis_records<'a>(
fcf: &FutureCostFunction,
training_result: &TrainingResult,
basis_col_u8: &'a [Vec<u8>],
basis_row_u8: &'a [Vec<u8>],
) -> Vec<PolicyBasisRecord<'a>> {
training_result
.basis_cache
.iter()
.enumerate()
.filter_map(|(stage_idx, opt)| {
opt.as_ref().map(|_| {
let num_cut_rows = fcf
.pools
.get(stage_idx)
.map_or(0, |pool| pool.populated_count.min(pool.capacity) as u32);
PolicyBasisRecord {
stage_id: stage_idx as u32,
iteration: training_result.iterations as u32,
column_status: &basis_col_u8[stage_idx],
row_status: &basis_row_u8[stage_idx],
num_cut_rows,
}
})
})
.collect()
}