tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! SQLite reader for the two `tokitai-search` quality ledgers.
//!
//! This is the lowest layer of the dataset bridge: it opens the two
//! SQLite files (decisions + outcomes) with `rusqlite`, and exposes
//! raw row types ([`DecisionRow`], [`OutcomeRow`]) and a paired
//! iterator ([`SqliteDatasetReader::iter_pairs`]) that performs the
//! inner join on `decision_group_id`.
//!
//! The reader is **read-only** and does not depend on
//! `tokitai-search-core` — it is a deliberately thin wrapper around
//! the SQLite schema so that the training loop in tokitai-operator
//! can ingest real `tokitai-search` data without pulling in the
//! `SearchApplication` crate.

use std::path::Path;

use rusqlite::{Connection, OpenFlags};

use crate::dataset_bridge::local_dataset::LocalSample;
use crate::dataset_bridge::{FEATURE_DIM, LABEL_DIM};
use crate::error::{Error, Result};

/// Tiny error wrapper used by the parse helpers. Implements
/// `std::error::Error` so it can be boxed into the
/// `Box<dyn StdError + Send + Sync>` slot that
/// `rusqlite::Error::FromSqlConversionFailure` requires. We can't
/// use `String` directly because `String` does not implement
/// `std::error::Error`.
#[derive(Debug)]
struct ParseError(String);

impl std::fmt::Display for ParseError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(&self.0)
    }
}

impl std::error::Error for ParseError {}

/// Cardinality of the 74-dim categorical one-hot (matches
/// `tokitai-search::crates::training::CATEGORICAL_DIMS`).
pub const CATEGORICAL_DIMS: usize = 74;
/// Cardinality of the 22-dim numerical feature vector (matches
/// `tokitai-search::crates::training::NUMERICAL_DIMS`).
pub const NUMERICAL_DIMS: usize = 22;
/// Cardinality of the 12-way outcome one-hot label.
pub const OUTCOME_KIND_DIMS: usize = 12;
/// Cardinality of the 8 aux metric regression heads.
pub const AUX_METRIC_DIMS: usize = 8;

/// A single row from the `quality_decisions` ledger, in the simplified
/// column layout that the bridge understands. The real tokitai-search
/// schema has more columns; the bridge only needs `decision_group_id`
/// (the join key) and enough columns to populate the 96-dim feature
/// vector. Unrecognized columns are ignored.
#[derive(Debug, Clone)]
pub struct DecisionRow {
    /// Foreign key into `quality_outcomes.decision_group_id`.
    pub decision_group_id: String,
    /// 7-way workflow kind enum (clamped to `0..7`).
    pub workflow_kind: i64,
    /// 9-way decision stage enum (clamped to `0..9`).
    pub stage: i64,
    /// 14-way decision reason enum (clamped to `0..14`).
    pub reason: i64,
    /// Host key (typically the URL host). Stored as TEXT.
    pub host_key: String,
    /// Recorded-at timestamp (ms since epoch, integer).
    pub recorded_at: i64,
    /// 74-dim categorical one-hot, materialized as 0/1 `u8` per the
    /// real `tokitai-search::crates::training::FeatureVec` schema.
    pub categorical: [u8; CATEGORICAL_DIMS],
    /// 22-dim numerical vector in milli-units, materialized as `i16`.
    pub numerical: [i16; NUMERICAL_DIMS],
}

/// A single row from the `quality_outcomes` ledger, in the simplified
/// column layout that the bridge understands.
#[derive(Debug, Clone)]
pub struct OutcomeRow {
    /// Foreign key into `quality_decisions.decision_group_id`.
    pub decision_group_id: String,
    /// 12-way outcome kind enum (clamped to `0..12`).
    pub outcome_kind: i64,
    /// Observed-at timestamp (ms since epoch, integer).
    pub observed_at: i64,
    /// Primary metric value (in native units, typically `[0, 1]`).
    pub metric_value: f64,
    /// Confidence in `[0, 1]` (the tokitai-search milli-unit value is
    /// normalized to a `f64` in `[0, 1]` at read time).
    pub confidence: f64,
    /// 8-dim aux metric vector, normalized to `[0, 1]` (milli-units
    /// are divided by 1000.0 at read time).
    pub aux_metrics: [f32; AUX_METRIC_DIMS],
}

/// The two connections that the bridge needs. Both are opened in
/// read-only mode (no journal writes, no locking).
pub struct SqliteDatasetReader {
    /// Connection to `quality_decisions.db`.
    pub decisions: Connection,
    /// Connection to `quality_outcomes.db`.
    pub outcomes: Connection,
}

impl SqliteDatasetReader {
    /// Open both SQLite files in read-only mode. Either path can be a
    /// file or a directory; if a directory is given, the bridge looks
    /// for `quality_decisions.db` and `quality_outcomes.db` inside.
    pub fn open(decisions_path: &Path, outcomes_path: &Path) -> Result<Self> {
        let decisions =
            Connection::open_with_flags(decisions_path, OpenFlags::SQLITE_OPEN_READ_ONLY).map_err(
                |e| Error::backend(format!("open decisions db {decisions_path:?}: {e}")),
            )?;
        let outcomes = Connection::open_with_flags(outcomes_path, OpenFlags::SQLITE_OPEN_READ_ONLY)
            .map_err(|e| Error::backend(format!("open outcomes db {outcomes_path:?}: {e}")))?;
        Ok(Self {
            decisions,
            outcomes,
        })
    }

    /// List the tables present in the decisions DB (used by the smoke
    /// test to verify the schema is what we expect).
    pub fn decisions_tables(&self) -> Result<Vec<String>> {
        list_tables(&self.decisions)
    }

    /// List the tables present in the outcomes DB.
    pub fn outcomes_tables(&self) -> Result<Vec<String>> {
        list_tables(&self.outcomes)
    }

    /// Yield [`DecisionRow`]s in row-order (no ordering beyond what the
    /// SQLite storage engine happens to return — typically insertion
    /// order, which is good enough for training).
    pub fn iter_decisions(&self) -> Result<impl Iterator<Item = Result<DecisionRow>> + '_> {
        let mut stmt = self
            .decisions
            .prepare(SELECT_DECISION_SQL)
            .map_err(|e| Error::backend(format!("prepare decision stmt: {e}")))?;
        // `Rows` borrows `stmt`, so we can't return the iterator
        // directly. Collect into a Vec of mapped results and return
        // `into_iter()` to release the borrow.
        let mut out: Vec<Result<DecisionRow>> = Vec::new();
        let rows = stmt
            .query_map([], map_decision_row)
            .map_err(|e| Error::backend(format!("query decisions: {e}")))?;
        for r in rows {
            out.push(r.map_err(|e| Error::backend(format!("map decision row: {e}"))));
        }
        Ok(out.into_iter())
    }

    /// Yield [`OutcomeRow`]s in row-order.
    pub fn iter_outcomes(&self) -> Result<impl Iterator<Item = Result<OutcomeRow>> + '_> {
        let mut stmt = self
            .outcomes
            .prepare(SELECT_OUTCOME_SQL)
            .map_err(|e| Error::backend(format!("prepare outcome stmt: {e}")))?;
        let mut out: Vec<Result<OutcomeRow>> = Vec::new();
        let rows = stmt
            .query_map([], map_outcome_row)
            .map_err(|e| Error::backend(format!("query outcomes: {e}")))?;
        for r in rows {
            out.push(r.map_err(|e| Error::backend(format!("map outcome row: {e}"))));
        }
        Ok(out.into_iter())
    }

    /// Build an iterator that yields a [`LocalSample`] for every
    /// decision row whose `decision_group_id` matches at least one
    /// outcome row. Performs the inner join in application code (rather
    /// than via a SQL `INNER JOIN`) so we can use SQLite's index on
    /// `decision_group_id` for the right-hand side lookup.
    ///
    /// **Join semantics:** inner join. A decision with no matching
    /// outcome is dropped. The first matching outcome per group wins
    /// (the tokitai-search MoE design specifies a 1:1 join).
    pub fn iter_pairs(&self) -> Result<impl Iterator<Item = Result<LocalSample>> + '_> {
        // Materialize outcomes into a hash map keyed by group id. The
        // expected cardinality in real tokitai-search data is small
        // (one outcome per group), so this is fine.
        let mut outcomes_by_group: std::collections::HashMap<String, OutcomeRow> =
            std::collections::HashMap::new();
        for row in self.iter_outcomes()? {
            let row = row?;
            // First-write-wins: keeps the join deterministic even if
            // the source DB has multiple outcomes per group.
            outcomes_by_group
                .entry(row.decision_group_id.clone())
                .or_insert(row);
        }

        let iter = self.iter_decisions()?;
        Ok(iter.map(move |row| {
            let row = row?;
            let outcome = outcomes_by_group
                .get(&row.decision_group_id)
                .ok_or_else(|| {
                    Error::backend(format!(
                        "decision group {} has no outcome (this should not happen — caller should have filtered)",
                        row.decision_group_id
                    ))
                })?;
            Ok(build_local_sample(&row, outcome))
        }))
    }
}

/// Read the list of user table names from a SQLite connection. Used by
/// the smoke tests to assert that the fixture DBs contain the expected
/// schema.
fn list_tables(conn: &Connection) -> Result<Vec<String>> {
    let mut stmt = conn
        .prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name ASC")
        .map_err(|e| Error::backend(format!("list tables: {e}")))?;
    let rows = stmt
        .query_map([], |row| row.get::<_, String>(0))
        .map_err(|e| Error::backend(format!("query tables: {e}")))?;
    let mut out = Vec::new();
    for r in rows {
        out.push(r.map_err(|e| Error::backend(format!("map table row: {e}")))?);
    }
    Ok(out)
}

/// Materialize a [`LocalSample`] from a paired (decision, outcome)
/// row. The 96-dim feature vector is the concatenation of the 74
/// categorical one-hot bits (cast to `f32` 0.0/1.0) and the 22
/// numerical milli-units (clamped to `[0, 1000]`, cast to `f32`).
/// The 20-dim label is the 12-way outcome one-hot followed by the 8
/// aux metric scalars (each clipped to `[0, 1]`).
fn build_local_sample(d: &DecisionRow, o: &OutcomeRow) -> LocalSample {
    // Features: 96-dim. categorical[0..74] then numerical[0..22].
    let mut features: Vec<f32> = Vec::with_capacity(FEATURE_DIM);
    for &b in d.categorical.iter() {
        features.push(if b == 0 { 0.0 } else { 1.0 });
    }
    for &n in d.numerical.iter() {
        // Clamp to [0, 1000] milli-units and cast to f32.
        let v = n.clamp(0, 1000) as f32;
        features.push(v);
    }
    debug_assert_eq!(features.len(), FEATURE_DIM);

    // Labels: 20-dim. outcome_kind one-hot (12) then aux_metrics (8).
    let mut labels: Vec<f32> = Vec::with_capacity(LABEL_DIM);
    let mut outcome_oh = [0.0f32; OUTCOME_KIND_DIMS];
    let kind = o.outcome_kind.clamp(0, (OUTCOME_KIND_DIMS as i64) - 1) as usize;
    outcome_oh[kind] = 1.0;
    labels.extend_from_slice(&outcome_oh);
    for &a in o.aux_metrics.iter() {
        labels.push(a.clamp(0.0, 1.0));
    }
    debug_assert_eq!(labels.len(), LABEL_DIM);

    LocalSample { features, labels }
}

/// SELECT statement for the decision row. The bridge uses
/// `decision_group_id` (idx 0) as the join key, and the 74 + 22 = 96
/// feature columns as a packed BLOB or as a JSON-encoded TEXT column
/// (the simplified bridge fixture uses JSON; the real tokitai-search
/// schema uses dedicated columns).
const SELECT_DECISION_SQL: &str = "SELECT decision_group_id, workflow_kind, stage, reason, host_key, recorded_at, categorical_json, numerical_json FROM quality_decisions";

/// SELECT statement for the outcome row. We read outcome_kind (idx 0)
/// + 8 aux metric fields as a packed JSON for the simplified bridge
/// schema.
const SELECT_OUTCOME_SQL: &str = "SELECT decision_group_id, outcome_kind, observed_at, metric_value, confidence, aux_metrics_json FROM quality_outcomes";

/// Parse a `DecisionRow` from a SQLite row produced by
/// [`SELECT_DECISION_SQL`]. The categorical and numerical vectors are
/// stored as JSON arrays in TEXT columns for portability (the real
/// tokitai-search schema uses a packed BLOB; the bridge uses JSON so
/// the fixture DBs can be inspected with any SQLite browser).
fn map_decision_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<DecisionRow> {
    let decision_group_id: String = row.get(0)?;
    let workflow_kind: i64 = row.get(1)?;
    let stage: i64 = row.get(2)?;
    let reason: i64 = row.get(3)?;
    let host_key: String = row.get(4)?;
    let recorded_at: i64 = row.get(5)?;
    let cat_json: String = row.get(6)?;
    let num_json: String = row.get(7)?;

    let categorical: [u8; CATEGORICAL_DIMS] =
        parse_u8_array(&cat_json, CATEGORICAL_DIMS).map_err(|e| {
            rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
        })?;
    let numerical: [i16; NUMERICAL_DIMS] =
        parse_i16_array(&num_json, NUMERICAL_DIMS).map_err(|e| {
            rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
        })?;

    Ok(DecisionRow {
        decision_group_id,
        workflow_kind,
        stage,
        reason,
        host_key,
        recorded_at,
        categorical,
        numerical,
    })
}

/// Parse an `OutcomeRow` from a SQLite row produced by
/// [`SELECT_OUTCOME_SQL`]. The 8 aux metrics are stored as a JSON
/// array in `aux_metrics_json`.
fn map_outcome_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<OutcomeRow> {
    let decision_group_id: String = row.get(0)?;
    let outcome_kind: i64 = row.get(1)?;
    let observed_at: i64 = row.get(2)?;
    let metric_value: f64 = row.get(3)?;
    let confidence: f64 = row.get(4)?;
    let aux_json: String = row.get(5)?;

    let aux_metrics: [f32; AUX_METRIC_DIMS] =
        parse_f32_array(&aux_json, AUX_METRIC_DIMS).map_err(|e| {
            rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(e))
        })?;

    Ok(OutcomeRow {
        decision_group_id,
        outcome_kind,
        observed_at,
        metric_value,
        confidence,
        aux_metrics,
    })
}

/// Parse a JSON array of length `n` into a fixed-size array of `u8`.
/// Each value is clamped to `0..=1` so a noisy one-hot (e.g., a "soft
/// one-hot" with values in `{0, 1}`) round-trips losslessly.
fn parse_u8_array(s: &str, n: usize) -> std::result::Result<[u8; 74], ParseError> {
    let v: Vec<i64> =
        serde_json::from_str(s).map_err(|e| ParseError(format!("parse u8 array: {e}")))?;
    if v.len() != n {
        return Err(ParseError(format!(
            "u8 array length mismatch: got {}, expected {}",
            v.len(),
            n
        )));
    }
    let mut out = [0u8; 74];
    for (i, x) in v.iter().enumerate() {
        out[i] = if *x == 0 { 0 } else { 1 };
    }
    Ok(out)
}

/// Parse a JSON array of length `n` into a fixed-size array of `i16`.
/// Each value is clamped to `[-32768, 32767]`.
fn parse_i16_array(s: &str, n: usize) -> std::result::Result<[i16; 22], ParseError> {
    let v: Vec<i64> =
        serde_json::from_str(s).map_err(|e| ParseError(format!("parse i16 array: {e}")))?;
    if v.len() != n {
        return Err(ParseError(format!(
            "i16 array length mismatch: got {}, expected {}",
            v.len(),
            n
        )));
    }
    let mut out = [0i16; 22];
    for (i, x) in v.iter().enumerate() {
        // `x` is `&i64` — dereference and clamp on the value, then
        // cast to i16. This is safe because we already clamped to
        // the i16 range.
        let clamped = (*x).clamp(i16::MIN as i64, i16::MAX as i64);
        out[i] = clamped as i16;
    }
    Ok(out)
}

/// Parse a JSON array of length `n` into a fixed-size array of `f32`.
/// Each value is cast to `f32` (the bridge does not preserve the full
/// `f64` precision of the SQLite REAL column — fp32 is the model's
/// input dtype anyway).
fn parse_f32_array(s: &str, n: usize) -> std::result::Result<[f32; 8], ParseError> {
    let v: Vec<f64> =
        serde_json::from_str(s).map_err(|e| ParseError(format!("parse f32 array: {e}")))?;
    if v.len() != n {
        return Err(ParseError(format!(
            "f32 array length mismatch: got {}, expected {}",
            v.len(),
            n
        )));
    }
    let mut out = [0.0f32; 8];
    for (i, x) in v.iter().enumerate() {
        out[i] = *x as f32;
    }
    Ok(out)
}