use std::path::Path;
use god_graph::graph::traits::GraphQuery;
use god_graph::tensor::differentiable::{EditOperation, NodeEditOp};
use god_graph::transformer::optimization::switch::{ModelSwitch, OperatorType, WeightTensor};
use tokitai_operator::backend::TensorStore;
use tokitai_operator::domain::DomainId;
use tokitai_operator::object::Tensor;
use tokitai_operator::object::shape::Shape;
use tokitai_operator::verify::witnesses::WitnessSample;
use crate::BridgeError;
#[derive(Debug, Clone, PartialEq, Default)]
pub struct TensorStoreLoadReport {
pub tensors_loaded: usize,
pub tensors_converted: usize,
pub warnings: Vec<TensorConversionWarning>,
pub fractional_truncations: usize,
pub non_finite_values: usize,
pub out_of_range_values: usize,
}
impl TensorStoreLoadReport {
fn record_warning(&mut self, warning: TensorConversionWarning) {
self.fractional_truncations += warning.fractional_truncations;
self.non_finite_values += warning.non_finite_values;
self.out_of_range_values += warning.out_of_range_values;
self.warnings.push(warning);
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TensorConversionWarning {
pub tensor_name: String,
pub node_index: usize,
pub dtype: String,
pub elements: usize,
pub fractional_truncations: usize,
pub non_finite_values: usize,
pub out_of_range_values: usize,
pub first_problem: Option<ConversionWarningSample>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ConversionWarningSample {
pub element_index: usize,
pub original: f64,
pub converted: i64,
pub reason: ConversionWarningKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConversionWarningKind {
FractionalTruncation,
NonFinite,
OutOfRange,
}
pub type TensorStoreLoadWithReport = (
TensorStore<i64>,
Vec<WitnessSample<i64>>,
TensorStoreLoadReport,
);
pub fn load_model_as_tensor_store(
path: &Path,
) -> Result<(TensorStore<i64>, Vec<WitnessSample<i64>>), BridgeError> {
let (store, witnesses, _report) = load_model_as_tensor_store_with_report(path)?;
Ok((store, witnesses))
}
pub fn load_model_as_tensor_store_with_report(
path: &Path,
) -> Result<TensorStoreLoadWithReport, BridgeError> {
#[cfg(feature = "tracing")]
tracing::trace!(
"llm_bridge::load_model_as_tensor_store_with_report: entry (path={})",
path.display()
);
#[cfg(all(feature = "observability", not(feature = "tracing")))]
log::trace!(
"llm_bridge::load_model_as_tensor_store_with_report: entry (path={})",
path.display()
);
let graph = ModelSwitch::load_from_safetensors(path).map_err(|e| {
BridgeError::Upstream(format!(
"ModelSwitch::load_from_safetensors({}) failed: {e}",
path.display()
))
})?;
let mut weight_by_node: std::collections::HashMap<usize, &WeightTensor> =
std::collections::HashMap::new();
for edge in graph.edges() {
if edge.source() == edge.target() {
weight_by_node.insert(edge.source().index(), edge.data());
}
}
let mut store = TensorStore::<i64>::new();
let mut witnesses: Vec<WitnessSample<i64>> = Vec::new();
let mut report = TensorStoreLoadReport::default();
for node in graph.nodes() {
let node_index = node.index().index();
let (tensor_name, i64_data, static_shape) = match weight_by_node.get(&node_index) {
Some(w) => {
let (data, warning) = weight_to_i64_vec_with_warning(w, node_index);
report.tensors_converted += 1;
if let Some(warning) = warning {
report.record_warning(warning);
}
let shape = w.shape().to_vec();
(w.name.clone(), data, shape)
}
None => (format!("node_{node_index}"), Vec::new(), Vec::new()),
};
report.tensors_loaded += 1;
let shape = Shape::from(static_shape);
let tensor = Tensor::<i64>::dense_cpu(DomainId::new("integer"), shape, i64_data);
store.insert(node_index, tensor);
let edit = EditOperation::NodeEdit(node_index, NodeEditOp::Add);
let witness = crate::witness_bridge::witness_of_edit_op(&edit)?;
let witness = WitnessSample::new(
format!("loaded_tensor:{node_index}:{tensor_name}"),
witness.observed,
witness.expected,
);
witnesses.push(witness);
}
#[cfg(feature = "tracing")]
tracing::debug!(
"llm_bridge::load_model_as_tensor_store_with_report: exit ({} tensors, {} witnesses, {} warnings)",
report.tensors_loaded,
witnesses.len(),
report.warnings.len()
);
#[cfg(all(feature = "observability", not(feature = "tracing")))]
log::debug!(
"llm_bridge::load_model_as_tensor_store_with_report: exit ({} tensors, {} witnesses, {} warnings)",
report.tensors_loaded,
witnesses.len(),
report.warnings.len()
);
Ok((store, witnesses, report))
}
fn weight_to_i64_vec_with_warning(
w: &WeightTensor,
node_index: usize,
) -> (Vec<i64>, Option<TensorConversionWarning>) {
let values = w.data.to_f64_vec();
let mut converted = Vec::with_capacity(values.len());
let mut warning = TensorConversionWarning {
tensor_name: w.name.clone(),
node_index,
dtype: format!("{:?}", w.dtype()),
elements: values.len(),
fractional_truncations: 0,
non_finite_values: 0,
out_of_range_values: 0,
first_problem: None,
};
for (element_index, original) in values.into_iter().enumerate() {
let cast = original as i64;
let reason = conversion_warning_kind(original);
match reason {
Some(ConversionWarningKind::FractionalTruncation) => {
warning.fractional_truncations += 1;
}
Some(ConversionWarningKind::NonFinite) => {
warning.non_finite_values += 1;
}
Some(ConversionWarningKind::OutOfRange) => {
warning.out_of_range_values += 1;
}
None => {}
}
if let Some(reason) = reason {
warning
.first_problem
.get_or_insert(ConversionWarningSample {
element_index,
original,
converted: cast,
reason,
});
}
converted.push(cast);
}
let has_warning = warning.fractional_truncations > 0
|| warning.non_finite_values > 0
|| warning.out_of_range_values > 0;
(converted, has_warning.then_some(warning))
}
fn conversion_warning_kind(value: f64) -> Option<ConversionWarningKind> {
if !value.is_finite() {
return Some(ConversionWarningKind::NonFinite);
}
let upper_exclusive = -(i64::MIN as f64);
if value < i64::MIN as f64 || value >= upper_exclusive {
return Some(ConversionWarningKind::OutOfRange);
}
if value.trunc() != value {
return Some(ConversionWarningKind::FractionalTruncation);
}
None
}
const _: fn() = || {
let _: fn(OperatorType) -> OperatorType = std::convert::identity;
};