tokitai-dl 0.1.2

Vendored, additive bridge between tokitai-operator (verified DL kernels) and god-graph (DL model analysis). 100% non-invasive position was deliberately broken in 0.1.1 — see ADR-0004 for the rationale.
Documentation
//! Bridge between god-graph's [`ModelSwitch`] (Safetensors -> GodGraph
//! loader) and tokitai's [`TensorStore<i64>`].
//!
//! god-graph's `transformer::optimization::switch::ModelSwitch` is the
//! canonical HuggingFace-Safetensors loader. It produces a
//! `Graph<OperatorType, WeightTensor>` where every node carries the
//! weight tensor that was on disk. tokitai's
//! `backend::TensorStore<T>` is a typed `BTreeMap<usize, Tensor<T>>`
//! keyed by an integer id and indexed by tokitai's kernel planner/executor.
//!
//! [`load_model_as_tensor_store`]
//! walks the loaded god-graph, copies the per-tensor data into a
//! tokitai `TensorStore<i64>` (quantising F32/F64/BF16 to i64 via
//! the existing `to_f64_vec` -> `as i64` chain), and emits a
//! per-tensor traceability sample (`WitnessSample<i64>`) via
//! [`witness_bridge::witness_of_edit_op`](crate::witness_bridge::witness_of_edit_op)
//! using a synthetic `NodeEdit` for each tensor name. The result is
//! a `(TensorStore<i64>, Vec<WitnessSample<i64>>)` pair that the rest
//! of the tokitai-dl pipeline can consume.
//! [`load_model_as_tensor_store_with_report`] returns the same data
//! plus conversion diagnostics for lossy i64 casts.
//! The report is an accounting aid, not a proof that the loaded model
//! or the quantised weights preserve floating-point inference behavior.
//!
//! This module is gated on both the `llm` and `operator` features. The
//! `llm` feature transitively activates the `graph` feature (so god-graph is
//! present), the `tensor` and `transformer` god-graph features (so the
//! `ModelSwitch` loader and `WeightTensor` storage are available), and the
//! `safetensors` god-graph feature (the actual file format reader). The
//! `operator` feature is required for `TensorStore<i64>`, `Tensor<i64>`, and
//! `WitnessSample<i64>`.

use std::path::Path;

use god_graph::graph::traits::GraphQuery;
use god_graph::tensor::differentiable::{EditOperation, NodeEditOp};
use god_graph::transformer::optimization::switch::{ModelSwitch, OperatorType, WeightTensor};
use tokitai_operator::backend::TensorStore;
use tokitai_operator::domain::DomainId;
use tokitai_operator::object::Tensor;
use tokitai_operator::object::shape::Shape;
use tokitai_operator::verify::witnesses::WitnessSample;

use crate::BridgeError;

/// Diagnostic report for loading a Safetensors model into a
/// `TensorStore<i64>`.
///
/// This report is emitted by
/// [`load_model_as_tensor_store_with_report`]. The legacy
/// [`load_model_as_tensor_store`] function intentionally discards it
/// to preserve the original API. The counts are diagnostics for the
/// adapter's lossy `f64 as i64` conversion; they are not a validation report
/// for model quality, numerical equivalence, or inference correctness.
#[derive(Debug, Clone, PartialEq, Default)]
pub struct TensorStoreLoadReport {
    /// Number of graph nodes copied into the returned tensor store.
    pub tensors_loaded: usize,
    /// Number of nodes that carried a god-graph [`WeightTensor`] and
    /// therefore went through f64-to-i64 conversion.
    pub tensors_converted: usize,
    /// Per-tensor warnings for lossy or saturating conversions.
    pub warnings: Vec<TensorConversionWarning>,
    /// Total finite in-range values whose fractional part was removed
    /// by the `as i64` cast.
    pub fractional_truncations: usize,
    /// Total non-finite values (`NaN`, `+inf`, `-inf`) encountered
    /// before the `as i64` cast.
    pub non_finite_values: usize,
    /// Total finite values outside the i64 representable range.
    pub out_of_range_values: usize,
}

impl TensorStoreLoadReport {
    fn record_warning(&mut self, warning: TensorConversionWarning) {
        self.fractional_truncations += warning.fractional_truncations;
        self.non_finite_values += warning.non_finite_values;
        self.out_of_range_values += warning.out_of_range_values;
        self.warnings.push(warning);
    }
}

/// Per-tensor diagnostic for lossy conversion into `i64`.
#[derive(Debug, Clone, PartialEq)]
pub struct TensorConversionWarning {
    /// Tensor name from the Safetensors file / god-graph
    /// [`WeightTensor`].
    pub tensor_name: String,
    /// god-graph node index used as the `TensorStore<i64>` key.
    pub node_index: usize,
    /// Source dtype reported by god-graph, formatted for diagnostics.
    pub dtype: String,
    /// Number of scalar elements examined.
    pub elements: usize,
    /// Finite in-range values whose fractional part was removed.
    pub fractional_truncations: usize,
    /// Non-finite values (`NaN`, `+inf`, `-inf`) encountered.
    pub non_finite_values: usize,
    /// Finite values outside the i64 representable range.
    pub out_of_range_values: usize,
    /// First problematic element, if any, for quick debugging.
    pub first_problem: Option<ConversionWarningSample>,
}

/// Example element that triggered a conversion warning.
#[derive(Debug, Clone, PartialEq)]
pub struct ConversionWarningSample {
    /// Element offset in row-major storage.
    pub element_index: usize,
    /// Original value after promotion to f64.
    pub original: f64,
    /// Value produced by Rust's `as i64` cast.
    pub converted: i64,
    /// Reason this element was included in the warning.
    pub reason: ConversionWarningKind,
}

/// Reason a single value was flagged during i64 conversion.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConversionWarningKind {
    /// The value was finite and in range, but had a fractional part
    /// that was truncated toward zero.
    FractionalTruncation,
    /// The value was `NaN`, `+inf`, or `-inf`.
    NonFinite,
    /// The finite value was outside the i64 representable range and
    /// the cast saturated.
    OutOfRange,
}

/// Successful output from [`load_model_as_tensor_store_with_report`].
///
/// The tuple contains the loaded `TensorStore<i64>`, one traceability
/// sample per tensor, and the i64 conversion diagnostics report.
pub type TensorStoreLoadWithReport = (
    TensorStore<i64>,
    Vec<WitnessSample<i64>>,
    TensorStoreLoadReport,
);

/// Load a model from a Safetensors file into a tokitai
/// [`TensorStore<i64>`] and emit one witness sample per tensor.
///
/// This compatibility API discards the conversion diagnostics
/// returned by [`load_model_as_tensor_store_with_report`]. Prefer the
/// report-returning variant when loading non-integer model weights.
///
/// The returned `TensorStore<i64>` is keyed by the god-graph node
/// index (a small `usize`) so that downstream code can correlate
/// tensor ids back to `OperatorType` payloads if it still has a
/// reference to the source graph. Each entry's `Tensor<i64>` has:
///
/// - `meta.domain = DomainId::new("integer")` (we quantise F32/F64
///   to i64 by truncation; precision-sensitive callers should
///   consult the original [`WeightTensor`] directly),
/// - `meta.shape` rebuilt from the safetensors shape (all-`Static`),
/// - `meta.representation` = the default dense-CPU representation.
///
/// The returned `Vec<WitnessSample<i64>>` has one traceability sample per tensor
/// (in node-iteration order). The label is
/// `loaded_tensor:<index>:<name>` and the `observed` / `expected`
/// fields are both a deterministic 64-bit hash of the tensor
/// name. This makes the witness self-consistent (`is_satisfied()`
/// returns `true`) and traceable; it is not evidence that the model
/// weights or an optimization edit are correct.
///
/// # Errors
///
/// Returns [`BridgeError::Upstream`] when god-graph fails to open or
/// parse the file (e.g. missing path, malformed safetensors
/// header, unsupported dtype).
///
/// # Feature gating
///
/// Requires both the `llm` and `operator` features. The `llm` feature
/// transitively pulls in `graph`, `god-graph/tensor`,
/// `god-graph/transformer`, and `god-graph/safetensors`; `operator` supplies
/// `TensorStore<i64>` and `WitnessSample<i64>`.
///
/// # Example
///
/// ```no_run
/// # #[cfg(all(feature = "llm", feature = "operator"))]
/// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
/// use std::path::Path;
/// let (store, witnesses) =
///     tokitai_dl::llm_bridge::load_model_as_tensor_store(Path::new("model.safetensors"))?;
/// // TensorStore is keyed by `usize`; use `contains` to check membership.
/// let tensors_loaded = (0..witnesses.len()).filter(|i| store.contains(*i)).count();
/// println!("loaded {} tensors and {} witnesses", tensors_loaded, witnesses.len());
/// # Ok(()) }
/// ```
pub fn load_model_as_tensor_store(
    path: &Path,
) -> Result<(TensorStore<i64>, Vec<WitnessSample<i64>>), BridgeError> {
    let (store, witnesses, _report) = load_model_as_tensor_store_with_report(path)?;
    Ok((store, witnesses))
}

/// Load a model from a Safetensors file and include i64 conversion
/// diagnostics.
///
/// The returned [`TensorStoreLoadReport`] records finite fractional
/// values truncated toward zero, non-finite values, and finite values
/// that saturate because they are outside the i64 representable range.
/// The tensor store and witness outputs are identical to
/// [`load_model_as_tensor_store`].
///
/// # Errors
///
/// Returns [`BridgeError::Upstream`] when god-graph fails to open or
/// parse the file (e.g. missing path, malformed safetensors header,
/// unsupported dtype).
pub fn load_model_as_tensor_store_with_report(
    path: &Path,
) -> Result<TensorStoreLoadWithReport, BridgeError> {
    #[cfg(feature = "tracing")]
    tracing::trace!(
        "llm_bridge::load_model_as_tensor_store_with_report: entry (path={})",
        path.display()
    );
    #[cfg(all(feature = "observability", not(feature = "tracing")))]
    log::trace!(
        "llm_bridge::load_model_as_tensor_store_with_report: entry (path={})",
        path.display()
    );
    // 1. Delegate the actual file parsing to god-graph. The
    //    return type is `Graph<OperatorType, WeightTensor>` and is
    //    documented to be one node per safetensors tensor (with a
    //    self-loop edge carrying the WeightTensor).
    let graph = ModelSwitch::load_from_safetensors(path).map_err(|e| {
        BridgeError::Upstream(format!(
            "ModelSwitch::load_from_safetensors({}) failed: {e}",
            path.display()
        ))
    })?;

    // 2. Build a node_index -> weight map by walking every edge
    //    once. god-graph's `ModelSwitch::load_from_safetensors`
    //    stores each tensor as a self-loop (node_i -> node_i) whose
    //    payload is the original `WeightTensor`. We use the
    //    iteration over `edges()` to find these self-loops in a
    //    single pass.
    let mut weight_by_node: std::collections::HashMap<usize, &WeightTensor> =
        std::collections::HashMap::new();
    for edge in graph.edges() {
        if edge.source() == edge.target() {
            weight_by_node.insert(edge.source().index(), edge.data());
        }
    }

    // 3. Walk every node, build a tokitai `Tensor<i64>` from the
    //    weight (or a stub if no self-loop is present), insert it
    //    into the TensorStore, and emit a witness sample.
    let mut store = TensorStore::<i64>::new();
    let mut witnesses: Vec<WitnessSample<i64>> = Vec::new();
    let mut report = TensorStoreLoadReport::default();

    for node in graph.nodes() {
        let node_index = node.index().index();
        let (tensor_name, i64_data, static_shape) = match weight_by_node.get(&node_index) {
            Some(w) => {
                let (data, warning) = weight_to_i64_vec_with_warning(w, node_index);
                report.tensors_converted += 1;
                if let Some(warning) = warning {
                    report.record_warning(warning);
                }
                let shape = w.shape().to_vec();
                (w.name.clone(), data, shape)
            }
            None => (format!("node_{node_index}"), Vec::new(), Vec::new()),
        };
        report.tensors_loaded += 1;

        // Build a tokitai `Tensor<i64>` from the data. We use the
        // unchecked constructor `dense_cpu` because the data length
        // is already trusted (it came from a Safetensors file). The
        // shape is converted from `Vec<usize>` into a `Shape` via
        // its `From` impl.
        let shape = Shape::from(static_shape);
        let tensor = Tensor::<i64>::dense_cpu(DomainId::new("integer"), shape, i64_data);
        store.insert(node_index, tensor);

        // Emit a per-tensor witness via the existing
        // `witness_of_edit_op` bridge, using a synthetic
        // `NodeEdit` so the wiring is the same as the rest of the
        // pipeline. We override the label with the loaded tensor
        // name so the witness is traceable back to the
        // safetensors file.
        let edit = EditOperation::NodeEdit(node_index, NodeEditOp::Add);
        let witness = crate::witness_bridge::witness_of_edit_op(&edit)?;
        let witness = WitnessSample::new(
            format!("loaded_tensor:{node_index}:{tensor_name}"),
            witness.observed,
            witness.expected,
        );
        witnesses.push(witness);
    }

    #[cfg(feature = "tracing")]
    tracing::debug!(
        "llm_bridge::load_model_as_tensor_store_with_report: exit ({} tensors, {} witnesses, {} warnings)",
        report.tensors_loaded,
        witnesses.len(),
        report.warnings.len()
    );
    #[cfg(all(feature = "observability", not(feature = "tracing")))]
    log::debug!(
        "llm_bridge::load_model_as_tensor_store_with_report: exit ({} tensors, {} witnesses, {} warnings)",
        report.tensors_loaded,
        witnesses.len(),
        report.warnings.len()
    );

    Ok((store, witnesses, report))
}

fn weight_to_i64_vec_with_warning(
    w: &WeightTensor,
    node_index: usize,
) -> (Vec<i64>, Option<TensorConversionWarning>) {
    let values = w.data.to_f64_vec();
    let mut converted = Vec::with_capacity(values.len());
    let mut warning = TensorConversionWarning {
        tensor_name: w.name.clone(),
        node_index,
        dtype: format!("{:?}", w.dtype()),
        elements: values.len(),
        fractional_truncations: 0,
        non_finite_values: 0,
        out_of_range_values: 0,
        first_problem: None,
    };

    for (element_index, original) in values.into_iter().enumerate() {
        let cast = original as i64;
        let reason = conversion_warning_kind(original);
        match reason {
            Some(ConversionWarningKind::FractionalTruncation) => {
                warning.fractional_truncations += 1;
            }
            Some(ConversionWarningKind::NonFinite) => {
                warning.non_finite_values += 1;
            }
            Some(ConversionWarningKind::OutOfRange) => {
                warning.out_of_range_values += 1;
            }
            None => {}
        }

        if let Some(reason) = reason {
            warning
                .first_problem
                .get_or_insert(ConversionWarningSample {
                    element_index,
                    original,
                    converted: cast,
                    reason,
                });
        }

        converted.push(cast);
    }

    let has_warning = warning.fractional_truncations > 0
        || warning.non_finite_values > 0
        || warning.out_of_range_values > 0;
    (converted, has_warning.then_some(warning))
}

fn conversion_warning_kind(value: f64) -> Option<ConversionWarningKind> {
    if !value.is_finite() {
        return Some(ConversionWarningKind::NonFinite);
    }

    let upper_exclusive = -(i64::MIN as f64);
    if value < i64::MIN as f64 || value >= upper_exclusive {
        return Some(ConversionWarningKind::OutOfRange);
    }

    if value.trunc() != value {
        return Some(ConversionWarningKind::FractionalTruncation);
    }

    None
}

// Reference to OperatorType so the import is not flagged as unused
// when the trait methods on `graph.nodes()` are the only consumer.
const _: fn() = || {
    let _: fn(OperatorType) -> OperatorType = std::convert::identity;
};