rok-repl 0.3.5

Interactive REPL (tinker) engine for the rok ecosystem
Documentation
use rustyline::error::ReadlineError;
use rustyline::DefaultEditor;
use sqlx::PgPool;

pub struct TinkerSession {
    pool: PgPool,
    editor: DefaultEditor,
}

impl TinkerSession {
    pub async fn new() -> anyhow::Result<Self> {
        dotenvy::dotenv().ok();

        let url =
            std::env::var("DATABASE_URL").map_err(|_| anyhow::anyhow!("DATABASE_URL not set"))?;

        let pool = PgPool::connect(&url)
            .await
            .map_err(|e| anyhow::anyhow!("Cannot connect to database: {e}"))?;

        let editor = DefaultEditor::new()?;

        Ok(Self { pool, editor })
    }

    pub async fn run(&mut self) -> anyhow::Result<()> {
        println!(
            "rok tinker {}  — type 'help' for commands, Ctrl-D or 'quit' to exit.",
            env!("CARGO_PKG_VERSION")
        );
        println!();

        loop {
            match self.editor.readline("rok> ") {
                Ok(line) => {
                    let line = line.trim().to_string();
                    if line.is_empty() {
                        continue;
                    }
                    let _ = self.editor.add_history_entry(&line);
                    if let Err(e) = self.dispatch(&line).await {
                        eprintln!("  error: {e}");
                    }
                }
                Err(ReadlineError::Eof | ReadlineError::Interrupted) => {
                    println!("Bye!");
                    break;
                }
                Err(e) => return Err(e.into()),
            }
        }
        Ok(())
    }

    async fn dispatch(&self, input: &str) -> anyhow::Result<()> {
        let parts: Vec<&str> = input.splitn(3, ' ').collect();
        let cmd = parts[0].to_lowercase();

        match cmd.as_str() {
            "help" | "?" => print_help(),
            "quit" | "exit" => std::process::exit(0),
            "tables" => self.cmd_tables().await?,
            "schema" => {
                let table = parts
                    .get(1)
                    .ok_or_else(|| anyhow::anyhow!("usage: schema <table>"))?;
                self.cmd_schema(table).await?;
            }
            "count" => {
                let table = parts
                    .get(1)
                    .ok_or_else(|| anyhow::anyhow!("usage: count <table>"))?;
                self.cmd_count(table).await?;
            }
            "all" => {
                let table = parts
                    .get(1)
                    .ok_or_else(|| anyhow::anyhow!("usage: all <table> [limit]"))?;
                let limit: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(10);
                self.cmd_rows(table, None, "id", limit, false).await?;
            }
            "last" => {
                let table = parts
                    .get(1)
                    .ok_or_else(|| anyhow::anyhow!("usage: last <table> [limit]"))?;
                let limit: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(5);
                self.cmd_rows(table, None, "id", limit, true).await?;
            }
            "find" => {
                let table = parts
                    .get(1)
                    .ok_or_else(|| anyhow::anyhow!("usage: find <table> <id>"))?;
                let id = parts
                    .get(2)
                    .ok_or_else(|| anyhow::anyhow!("usage: find <table> <id>"))?;
                self.cmd_rows(table, Some(id), "id", 1, false).await?;
            }
            "sql" => {
                let query = input.trim_start_matches("sql").trim();
                self.cmd_sql(query).await?;
            }
            _ => {
                // Accept raw SQL lines that start with a known keyword
                let upper = input.to_uppercase();
                if [
                    "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "EXPLAIN",
                ]
                .iter()
                .any(|kw| upper.starts_with(kw))
                {
                    self.cmd_sql(input).await?;
                } else {
                    println!("  Unknown command '{cmd}'. Type 'help' for a list of commands.");
                }
            }
        }
        Ok(())
    }

    // ── command implementations ─────────────────────────────────────────────

    async fn cmd_tables(&self) -> anyhow::Result<()> {
        let tables: Vec<String> = sqlx::query_scalar(
            "SELECT tablename FROM pg_catalog.pg_tables \
             WHERE schemaname = 'public' ORDER BY tablename",
        )
        .fetch_all(&self.pool)
        .await?;

        if tables.is_empty() {
            println!("  (no tables)");
        } else {
            for t in &tables {
                println!("  {t}");
            }
            println!("  ({} table(s))", tables.len());
        }
        Ok(())
    }

    async fn cmd_schema(&self, table: &str) -> anyhow::Result<()> {
        let table = validate_ident(table)?;

        let cols: Vec<(String, String, String)> = sqlx::query_as(
            "SELECT column_name, data_type, is_nullable \
             FROM information_schema.columns \
             WHERE table_schema = 'public' AND table_name = $1 \
             ORDER BY ordinal_position",
        )
        .bind(table)
        .fetch_all(&self.pool)
        .await?;

        if cols.is_empty() {
            println!("  Table '{table}' not found or has no columns.");
        } else {
            println!("  {:<30} {:<20} nullable", "column", "type");
            println!("  {}", "-".repeat(65));
            for (col, typ, nullable) in &cols {
                println!("  {:<30} {:<20} {}", col, typ, nullable);
            }
        }
        Ok(())
    }

    async fn cmd_count(&self, table: &str) -> anyhow::Result<()> {
        let table = validate_ident(table)?;
        let n: i64 = sqlx::query_scalar(&format!("SELECT COUNT(*) FROM \"{table}\""))
            .fetch_one(&self.pool)
            .await?;
        println!("  {n}");
        Ok(())
    }

    async fn cmd_rows(
        &self,
        table: &str,
        where_id: Option<&str>,
        id_col: &str,
        limit: i64,
        desc: bool,
    ) -> anyhow::Result<()> {
        let table = validate_ident(table)?;
        let id_col = validate_ident(id_col)?;
        let order = if desc { "DESC" } else { "ASC" };

        let rows: Vec<serde_json::Value> = if let Some(id) = where_id {
            sqlx::query_scalar(&format!(
                "SELECT row_to_json(t) FROM \"{table}\" t WHERE \"{id_col}\" = $1"
            ))
            .bind(id)
            .fetch_all(&self.pool)
            .await?
        } else {
            sqlx::query_scalar(&format!(
                "SELECT row_to_json(t) FROM (SELECT * FROM \"{table}\" ORDER BY \"{id_col}\" {order} LIMIT $1) t"
            ))
            .bind(limit)
            .fetch_all(&self.pool)
            .await?
        };

        if rows.is_empty() {
            println!("  (no rows)");
        } else {
            for row in &rows {
                println!("  {}", serde_json::to_string_pretty(row)?);
            }
            println!("  ({} row(s))", rows.len());
        }
        Ok(())
    }

    async fn cmd_sql(&self, query: &str) -> anyhow::Result<()> {
        if query.is_empty() {
            anyhow::bail!("empty query");
        }

        let upper = query.trim().to_uppercase();

        if upper.starts_with("SELECT") || upper.starts_with("EXPLAIN") {
            let rows: Vec<serde_json::Value> =
                sqlx::query_scalar(&format!("SELECT row_to_json(t) FROM ({query}) t"))
                    .fetch_all(&self.pool)
                    .await?;

            if rows.is_empty() {
                println!("  (no rows)");
            } else {
                for row in &rows {
                    println!("  {}", serde_json::to_string_pretty(row)?);
                }
                println!("  ({} row(s))", rows.len());
            }
        } else {
            let result = sqlx::query(query).execute(&self.pool).await?;
            println!("  {} row(s) affected", result.rows_affected());
        }
        Ok(())
    }
}

// ── helpers ─────────────────────────────────────────────────────────────────

/// Accept only `[a-zA-Z0-9_]` identifiers to prevent SQL injection.
fn validate_ident(s: &str) -> anyhow::Result<&str> {
    if s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() {
        Ok(s)
    } else {
        anyhow::bail!("invalid identifier '{s}' — only letters, digits, and underscores allowed")
    }
}

fn print_help() {
    println!(
        r#"
  rok tinker commands
  ───────────────────
  tables                  list all tables in the public schema
  schema  <table>         describe columns of a table
  count   <table>         count rows in a table
  all     <table> [n]     show first n rows (default 10)
  last    <table> [n]     show last  n rows (default  5)
  find    <table> <id>    find a row by primary key
  sql     <query>         execute a raw SQL statement
  <SELECT …>              raw SELECT is auto-detected
  help | ?                show this help
  quit | exit | Ctrl-D    exit the REPL
"#
    );
}