use std::path::Path;
use rusqlite::{Connection, OpenFlags};
use crate::dataset_bridge::local_dataset::LocalSample;
use crate::dataset_bridge::{FEATURE_DIM, LABEL_DIM};
use crate::error::{Error, Result};
#[derive(Debug)]
struct ParseError(String);
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for ParseError {}
pub const CATEGORICAL_DIMS: usize = 74;
pub const NUMERICAL_DIMS: usize = 22;
pub const OUTCOME_KIND_DIMS: usize = 12;
pub const AUX_METRIC_DIMS: usize = 8;
#[derive(Debug, Clone)]
pub struct DecisionRow {
pub decision_group_id: String,
pub workflow_kind: i64,
pub stage: i64,
pub reason: i64,
pub host_key: String,
pub recorded_at: i64,
pub categorical: [u8; CATEGORICAL_DIMS],
pub numerical: [i16; NUMERICAL_DIMS],
}
#[derive(Debug, Clone)]
pub struct OutcomeRow {
pub decision_group_id: String,
pub outcome_kind: i64,
pub observed_at: i64,
pub metric_value: f64,
pub confidence: f64,
pub aux_metrics: [f32; AUX_METRIC_DIMS],
}
pub struct SqliteDatasetReader {
pub decisions: Connection,
pub outcomes: Connection,
}
impl SqliteDatasetReader {
pub fn open(decisions_path: &Path, outcomes_path: &Path) -> Result<Self> {
let decisions =
Connection::open_with_flags(decisions_path, OpenFlags::SQLITE_OPEN_READ_ONLY).map_err(
|e| Error::backend(format!("open decisions db {decisions_path:?}: {e}")),
)?;
let outcomes = Connection::open_with_flags(outcomes_path, OpenFlags::SQLITE_OPEN_READ_ONLY)
.map_err(|e| Error::backend(format!("open outcomes db {outcomes_path:?}: {e}")))?;
Ok(Self {
decisions,
outcomes,
})
}
pub fn decisions_tables(&self) -> Result<Vec<String>> {
list_tables(&self.decisions)
}
pub fn outcomes_tables(&self) -> Result<Vec<String>> {
list_tables(&self.outcomes)
}
pub fn iter_decisions(&self) -> Result<impl Iterator<Item = Result<DecisionRow>> + '_> {
let mut stmt = self
.decisions
.prepare(SELECT_DECISION_SQL)
.map_err(|e| Error::backend(format!("prepare decision stmt: {e}")))?;
let mut out: Vec<Result<DecisionRow>> = Vec::new();
let rows = stmt
.query_map([], map_decision_row)
.map_err(|e| Error::backend(format!("query decisions: {e}")))?;
for r in rows {
out.push(r.map_err(|e| Error::backend(format!("map decision row: {e}"))));
}
Ok(out.into_iter())
}
pub fn iter_outcomes(&self) -> Result<impl Iterator<Item = Result<OutcomeRow>> + '_> {
let mut stmt = self
.outcomes
.prepare(SELECT_OUTCOME_SQL)
.map_err(|e| Error::backend(format!("prepare outcome stmt: {e}")))?;
let mut out: Vec<Result<OutcomeRow>> = Vec::new();
let rows = stmt
.query_map([], map_outcome_row)
.map_err(|e| Error::backend(format!("query outcomes: {e}")))?;
for r in rows {
out.push(r.map_err(|e| Error::backend(format!("map outcome row: {e}"))));
}
Ok(out.into_iter())
}
pub fn iter_pairs(&self) -> Result<impl Iterator<Item = Result<LocalSample>> + '_> {
let mut outcomes_by_group: std::collections::HashMap<String, OutcomeRow> =
std::collections::HashMap::new();
for row in self.iter_outcomes()? {
let row = row?;
outcomes_by_group
.entry(row.decision_group_id.clone())
.or_insert(row);
}
let iter = self.iter_decisions()?;
Ok(iter.map(move |row| {
let row = row?;
let outcome = outcomes_by_group
.get(&row.decision_group_id)
.ok_or_else(|| {
Error::backend(format!(
"decision group {} has no outcome (this should not happen — caller should have filtered)",
row.decision_group_id
))
})?;
Ok(build_local_sample(&row, outcome))
}))
}
}
fn list_tables(conn: &Connection) -> Result<Vec<String>> {
let mut stmt = conn
.prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name ASC")
.map_err(|e| Error::backend(format!("list tables: {e}")))?;
let rows = stmt
.query_map([], |row| row.get::<_, String>(0))
.map_err(|e| Error::backend(format!("query tables: {e}")))?;
let mut out = Vec::new();
for r in rows {
out.push(r.map_err(|e| Error::backend(format!("map table row: {e}")))?);
}
Ok(out)
}
fn build_local_sample(d: &DecisionRow, o: &OutcomeRow) -> LocalSample {
let mut features: Vec<f32> = Vec::with_capacity(FEATURE_DIM);
for &b in d.categorical.iter() {
features.push(if b == 0 { 0.0 } else { 1.0 });
}
for &n in d.numerical.iter() {
let v = n.clamp(0, 1000) as f32;
features.push(v);
}
debug_assert_eq!(features.len(), FEATURE_DIM);
let mut labels: Vec<f32> = Vec::with_capacity(LABEL_DIM);
let mut outcome_oh = [0.0f32; OUTCOME_KIND_DIMS];
let kind = o.outcome_kind.clamp(0, (OUTCOME_KIND_DIMS as i64) - 1) as usize;
outcome_oh[kind] = 1.0;
labels.extend_from_slice(&outcome_oh);
for &a in o.aux_metrics.iter() {
labels.push(a.clamp(0.0, 1.0));
}
debug_assert_eq!(labels.len(), LABEL_DIM);
LocalSample { features, labels }
}
const SELECT_DECISION_SQL: &str = "SELECT decision_group_id, workflow_kind, stage, reason, host_key, recorded_at, categorical_json, numerical_json FROM quality_decisions";
const SELECT_OUTCOME_SQL: &str = "SELECT decision_group_id, outcome_kind, observed_at, metric_value, confidence, aux_metrics_json FROM quality_outcomes";
fn map_decision_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<DecisionRow> {
let decision_group_id: String = row.get(0)?;
let workflow_kind: i64 = row.get(1)?;
let stage: i64 = row.get(2)?;
let reason: i64 = row.get(3)?;
let host_key: String = row.get(4)?;
let recorded_at: i64 = row.get(5)?;
let cat_json: String = row.get(6)?;
let num_json: String = row.get(7)?;
let categorical: [u8; CATEGORICAL_DIMS] =
parse_u8_array(&cat_json, CATEGORICAL_DIMS).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
})?;
let numerical: [i16; NUMERICAL_DIMS] =
parse_i16_array(&num_json, NUMERICAL_DIMS).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
})?;
Ok(DecisionRow {
decision_group_id,
workflow_kind,
stage,
reason,
host_key,
recorded_at,
categorical,
numerical,
})
}
fn map_outcome_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<OutcomeRow> {
let decision_group_id: String = row.get(0)?;
let outcome_kind: i64 = row.get(1)?;
let observed_at: i64 = row.get(2)?;
let metric_value: f64 = row.get(3)?;
let confidence: f64 = row.get(4)?;
let aux_json: String = row.get(5)?;
let aux_metrics: [f32; AUX_METRIC_DIMS] =
parse_f32_array(&aux_json, AUX_METRIC_DIMS).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(e))
})?;
Ok(OutcomeRow {
decision_group_id,
outcome_kind,
observed_at,
metric_value,
confidence,
aux_metrics,
})
}
fn parse_u8_array(s: &str, n: usize) -> std::result::Result<[u8; 74], ParseError> {
let v: Vec<i64> =
serde_json::from_str(s).map_err(|e| ParseError(format!("parse u8 array: {e}")))?;
if v.len() != n {
return Err(ParseError(format!(
"u8 array length mismatch: got {}, expected {}",
v.len(),
n
)));
}
let mut out = [0u8; 74];
for (i, x) in v.iter().enumerate() {
out[i] = if *x == 0 { 0 } else { 1 };
}
Ok(out)
}
fn parse_i16_array(s: &str, n: usize) -> std::result::Result<[i16; 22], ParseError> {
let v: Vec<i64> =
serde_json::from_str(s).map_err(|e| ParseError(format!("parse i16 array: {e}")))?;
if v.len() != n {
return Err(ParseError(format!(
"i16 array length mismatch: got {}, expected {}",
v.len(),
n
)));
}
let mut out = [0i16; 22];
for (i, x) in v.iter().enumerate() {
let clamped = (*x).clamp(i16::MIN as i64, i16::MAX as i64);
out[i] = clamped as i16;
}
Ok(out)
}
fn parse_f32_array(s: &str, n: usize) -> std::result::Result<[f32; 8], ParseError> {
let v: Vec<f64> =
serde_json::from_str(s).map_err(|e| ParseError(format!("parse f32 array: {e}")))?;
if v.len() != n {
return Err(ParseError(format!(
"f32 array length mismatch: got {}, expected {}",
v.len(),
n
)));
}
let mut out = [0.0f32; 8];
for (i, x) in v.iter().enumerate() {
out[i] = *x as f32;
}
Ok(out)
}