#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
use cobre_io::output::policy::{
PolicyBasisRecord, PolicyCutRecord, StageCutsPayload, StageStatesPayload,
};
use crate::cut::FutureCostFunction;
use crate::training::TrainingResult;
#[must_use]
pub fn build_stage_cut_records(fcf: &FutureCostFunction) -> Vec<Vec<PolicyCutRecord<'_>>> {
fcf.pools
.iter()
.map(|pool| {
(0..pool.populated_count)
.map(|i| {
let meta = &pool.metadata[i];
PolicyCutRecord {
cut_id: meta.iteration_generated * u64::from(pool.forward_passes)
+ u64::from(meta.forward_pass_index),
slot_index: i as u32,
iteration: meta.iteration_generated as u32,
forward_pass_index: meta.forward_pass_index,
intercept: pool.intercepts[i],
coefficients: &pool.coefficients
[i * pool.state_dimension..(i + 1) * pool.state_dimension],
is_active: pool.active[i],
}
})
.collect()
})
.collect()
}
#[must_use]
pub fn build_active_indices(stage_records: &[Vec<PolicyCutRecord<'_>>]) -> Vec<Vec<u32>> {
stage_records
.iter()
.map(|records| {
records
.iter()
.filter(|r| r.is_active)
.map(|r| r.slot_index)
.collect()
})
.collect()
}
#[must_use]
pub fn build_stage_cuts_payloads<'a>(
fcf: &FutureCostFunction,
stage_records: &'a [Vec<PolicyCutRecord<'a>>],
stage_active_indices: &'a [Vec<u32>],
) -> Vec<StageCutsPayload<'a>> {
fcf.pools
.iter()
.enumerate()
.map(|(stage_idx, pool)| StageCutsPayload {
stage_id: stage_idx as u32,
state_dimension: fcf.state_dimension as u32,
capacity: pool.capacity as u32,
warm_start_count: pool.warm_start_count,
cuts: &stage_records[stage_idx],
active_cut_indices: &stage_active_indices[stage_idx],
populated_count: pool.populated_count as u32,
})
.collect()
}
#[must_use]
pub fn convert_basis_cache(training_result: &TrainingResult) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
let col = training_result
.basis_cache
.iter()
.map(|opt| {
opt.as_ref()
.map(|cb| cb.basis.col_status.iter().map(|&v| v as u8).collect())
.unwrap_or_default()
})
.collect();
let row = training_result
.basis_cache
.iter()
.map(|opt| {
opt.as_ref()
.map(|cb| cb.basis.row_status.iter().map(|&v| v as u8).collect())
.unwrap_or_default()
})
.collect();
(col, row)
}
#[must_use]
pub fn build_stage_basis_records<'a>(
fcf: &FutureCostFunction,
training_result: &TrainingResult,
basis_col_u8: &'a [Vec<u8>],
basis_row_u8: &'a [Vec<u8>],
) -> Vec<PolicyBasisRecord<'a>> {
training_result
.basis_cache
.iter()
.enumerate()
.filter_map(|(stage_idx, opt)| {
opt.as_ref().map(|_| {
let num_cut_rows = fcf
.pools
.get(stage_idx)
.map_or(0, |pool| pool.populated_count.min(pool.capacity) as u32);
PolicyBasisRecord {
stage_id: stage_idx as u32,
iteration: training_result.iterations as u32,
column_status: &basis_col_u8[stage_idx],
row_status: &basis_row_u8[stage_idx],
num_cut_rows,
}
})
})
.collect()
}
#[must_use]
pub fn build_stage_states_payloads(
archive: Option<&crate::visited_states::VisitedStatesArchive>,
) -> Vec<StageStatesPayload<'_>> {
let Some(archive) = archive else {
return Vec::new();
};
(0..archive.num_stages())
.map(|t| {
let stage = archive.stage(t);
StageStatesPayload {
stage_id: t as u32,
state_dimension: stage.state_dimension() as u32,
count: stage.count() as u32,
data: stage.states(),
}
})
.collect()
}