tally-cli 1.2.0

A persistent, inter-process counter for all your shell scripts
Documentation
use sqlite::{ConnectionThreadSafe, State};

#[derive(Debug)]
pub struct Counter {
    pub name: String,
    pub count: i64,
    pub step: i64,
    pub template: String,
}

impl Counter {
    pub fn new(name: &str) -> Counter {
        Counter {
            name: name.to_string(),
            count: 0,
            step: 1,
            template: String::from("{}"),
        }
    }

    pub fn set_default(&self, conn: &ConnectionThreadSafe) -> sqlite::Result<()> {
        conn.execute("DELETE FROM default_counter;")?;
        let mut stmt = conn.prepare(
            "INSERT INTO default_counter (name, timestamp) VALUES (?, CURRENT_TIMESTAMP);",
        )?;
        stmt.bind((1, self.name.as_str()))?;
        stmt.next()?;
        Ok(())
    }

    pub fn get_default(conn: &ConnectionThreadSafe) -> sqlite::Result<Option<String>> {
        let mut stmt =
            conn.prepare("SELECT name FROM default_counter ORDER BY timestamp DESC LIMIT 1;")?;
        if let State::Row = stmt.next()? {
            Ok(Some(stmt.read::<String, usize>(0)?))
        } else {
            Ok(None)
        }
    }

    pub fn insert(&self, conn: &ConnectionThreadSafe) -> sqlite::Result<()> {
        let mut stmt =
            conn.prepare("INSERT INTO counters (name, count, step, template) VALUES (?, ?, ?, ?)")?;
        stmt.bind((1, self.name.as_str()))?;
        stmt.bind((2, self.count))?;
        stmt.bind((3, self.step))?;
        stmt.bind((4, self.template.as_str()))?;
        stmt.next()?;

        Ok(())
    }

    pub fn delete(conn: &ConnectionThreadSafe, name: &str) -> sqlite::Result<()> {
        let mut stmt = conn.prepare("DELETE FROM counters WHERE name = ?")?;
        stmt.bind((1, name))?;
        stmt.next()?;
        Ok(())
    }

    pub fn get(conn: &ConnectionThreadSafe, name: &str) -> sqlite::Result<Option<Counter>> {
        let mut stmt =
            conn.prepare("SELECT name, count, step, template FROM counters WHERE name = ?")?;
        stmt.bind((1, name))?;

        if let State::Row = stmt.next()? {
            Ok(Some(Counter {
                name: stmt.read::<String, _>(0)?.to_string(),
                count: stmt.read::<i64, _>(1)?,
                step: stmt.read::<i64, _>(2)?,
                template: stmt.read::<String, _>(3)?.to_string(),
            }))
        } else {
            Ok(None)
        }
    }

    pub fn get_all(conn: &ConnectionThreadSafe) -> sqlite::Result<Vec<Counter>> {
        let mut stmt = conn.prepare("SELECT name, count, step, template FROM counters")?;
        let mut counters = Vec::new();

        while let State::Row = stmt.next()? {
            counters.push(Counter {
                name: stmt.read::<String, usize>(0)?,
                count: stmt.read::<i64, usize>(1)?,
                step: stmt.read::<i64, usize>(2)?,
                template: stmt.read::<String, usize>(3)?,
            });
        }

        Ok(counters)
    }

    pub fn update(&self, conn: &ConnectionThreadSafe) -> sqlite::Result<()> {
        let mut stmt =
            conn.prepare("UPDATE counters SET count = ?, step = ?, template = ? WHERE name = ?")?;
        stmt.bind((1, self.count))?;
        stmt.bind((2, self.step))?;
        stmt.bind((3, self.template.as_str()))?;
        stmt.bind((4, self.name.as_str()))?;
        stmt.next()?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::database::Connection;
    use tempfile::TempDir;

    fn fresh_db() -> (TempDir, Connection) {
        let dir = TempDir::new().unwrap();
        let path = dir.path().join("test.db");
        let conn = Connection::new(&path.to_string_lossy()).unwrap();
        (dir, conn)
    }

    #[test]
    fn new_has_sensible_defaults() {
        let c = Counter::new("foo");
        assert_eq!(c.name, "foo");
        assert_eq!(c.count, 0);
        assert_eq!(c.step, 1);
        assert_eq!(c.template, "{}");
    }

    #[test]
    fn insert_and_get_round_trip() {
        let (_dir, conn) = fresh_db();
        let c = Counter {
            name: "foo".into(),
            count: 7,
            step: 2,
            template: "x-{}".into(),
        };
        c.insert(conn.get()).unwrap();

        let loaded = Counter::get(conn.get(), "foo").unwrap().unwrap();
        assert_eq!(loaded.name, "foo");
        assert_eq!(loaded.count, 7);
        assert_eq!(loaded.step, 2);
        assert_eq!(loaded.template, "x-{}");
    }

    #[test]
    fn get_missing_returns_none() {
        let (_dir, conn) = fresh_db();
        assert!(Counter::get(conn.get(), "nope").unwrap().is_none());
    }

    #[test]
    fn update_persists_changes() {
        let (_dir, conn) = fresh_db();
        let mut c = Counter::new("foo");
        c.insert(conn.get()).unwrap();
        c.count = 42;
        c.step = 5;
        c.update(conn.get()).unwrap();

        let loaded = Counter::get(conn.get(), "foo").unwrap().unwrap();
        assert_eq!(loaded.count, 42);
        assert_eq!(loaded.step, 5);
    }

    #[test]
    fn delete_removes_row() {
        let (_dir, conn) = fresh_db();
        Counter::new("foo").insert(conn.get()).unwrap();
        Counter::delete(conn.get(), "foo").unwrap();
        assert!(Counter::get(conn.get(), "foo").unwrap().is_none());
    }

    #[test]
    fn get_all_returns_every_counter() {
        let (_dir, conn) = fresh_db();
        Counter::new("a").insert(conn.get()).unwrap();
        Counter::new("b").insert(conn.get()).unwrap();
        let mut names: Vec<String> = Counter::get_all(conn.get())
            .unwrap()
            .into_iter()
            .map(|c| c.name)
            .collect();
        names.sort();
        assert_eq!(names, vec!["a", "b", "tally"]);
    }

    #[test]
    fn set_default_keeps_default_counter_single_row() {
        let (_dir, conn) = fresh_db();
        Counter::new("a").insert(conn.get()).unwrap();
        Counter::new("b").insert(conn.get()).unwrap();

        Counter::new("a").set_default(conn.get()).unwrap();
        Counter::new("b").set_default(conn.get()).unwrap();
        Counter::new("a").set_default(conn.get()).unwrap();

        assert_eq!(
            Counter::get_default(conn.get()).unwrap().as_deref(),
            Some("a")
        );

        let mut stmt = conn
            .get()
            .prepare("SELECT COUNT(*) FROM default_counter;")
            .unwrap();
        stmt.next().unwrap();
        let count: i64 = stmt.read(0).unwrap();
        assert_eq!(count, 1);
    }
}