use crate::{Database, DbResultExt};
use roboticus_core::{Result, SurvivalTier};
use rusqlite::OptionalExtension;
#[derive(Debug, Clone)]
pub struct TreasuryStateRow {
pub usdc_balance: f64,
pub native_balance: f64,
pub survival_tier: SurvivalTier,
pub updated_at: String,
}
pub fn upsert_treasury_state(
db: &Database,
usdc_balance: f64,
native_balance: f64,
tier: SurvivalTier,
) -> Result<()> {
let conn = db.conn();
conn.execute(
"INSERT INTO treasury_state (id, usdc_balance, native_balance, survival_tier, updated_at)
VALUES ('singleton', ?1, ?2, ?3, datetime('now'))
ON CONFLICT(id) DO UPDATE SET
usdc_balance = excluded.usdc_balance,
native_balance = excluded.native_balance,
survival_tier = excluded.survival_tier,
updated_at = excluded.updated_at",
rusqlite::params![usdc_balance, native_balance, format!("{tier:?}")],
)
.db_err()?;
Ok(())
}
pub fn get_treasury_state(db: &Database) -> Result<Option<TreasuryStateRow>> {
let conn = db.conn();
let row = conn
.query_row(
"SELECT usdc_balance, native_balance, survival_tier, updated_at
FROM treasury_state WHERE id = 'singleton'",
[],
|row: &rusqlite::Row| {
let tier_str: String = row.get(2)?;
let tier = parse_survival_tier(&tier_str);
Ok(TreasuryStateRow {
usdc_balance: row.get(0)?,
native_balance: row.get(1)?,
survival_tier: tier,
updated_at: row.get(3)?,
})
},
)
.optional()
.db_err()?;
Ok(row)
}
fn parse_survival_tier(s: &str) -> SurvivalTier {
match s {
"High" => SurvivalTier::High,
"Normal" => SurvivalTier::Normal,
"LowCompute" => SurvivalTier::LowCompute,
"Critical" => SurvivalTier::Critical,
_ => SurvivalTier::Dead,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_db() -> Database {
let db = Database::new(":memory:").unwrap();
crate::schema::initialize_db(&db).unwrap();
db
}
#[test]
fn upsert_and_read_treasury_state() {
let db = test_db();
assert!(get_treasury_state(&db).unwrap().is_none());
upsert_treasury_state(&db, 42.5, 0.01, SurvivalTier::High).unwrap();
let state = get_treasury_state(&db).unwrap().unwrap();
assert!((state.usdc_balance - 42.5).abs() < f64::EPSILON);
assert!((state.native_balance - 0.01).abs() < f64::EPSILON);
assert!(matches!(state.survival_tier, SurvivalTier::High));
}
#[test]
fn upsert_overwrites_previous() {
let db = test_db();
upsert_treasury_state(&db, 10.0, 0.5, SurvivalTier::Normal).unwrap();
upsert_treasury_state(&db, 2.0, 0.01, SurvivalTier::Critical).unwrap();
let state = get_treasury_state(&db).unwrap().unwrap();
assert!((state.usdc_balance - 2.0).abs() < f64::EPSILON);
assert!(matches!(state.survival_tier, SurvivalTier::Critical));
}
}