use rustyline::error::ReadlineError;
use rustyline::DefaultEditor;
use sqlx::PgPool;
pub struct TinkerSession {
pool: PgPool,
editor: DefaultEditor,
}
impl TinkerSession {
pub async fn new() -> anyhow::Result<Self> {
dotenvy::dotenv().ok();
let url =
std::env::var("DATABASE_URL").map_err(|_| anyhow::anyhow!("DATABASE_URL not set"))?;
let pool = PgPool::connect(&url)
.await
.map_err(|e| anyhow::anyhow!("Cannot connect to database: {e}"))?;
let editor = DefaultEditor::new()?;
Ok(Self { pool, editor })
}
pub async fn run(&mut self) -> anyhow::Result<()> {
println!(
"rok tinker {} — type 'help' for commands, Ctrl-D or 'quit' to exit.",
env!("CARGO_PKG_VERSION")
);
println!();
loop {
match self.editor.readline("rok> ") {
Ok(line) => {
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let _ = self.editor.add_history_entry(&line);
if let Err(e) = self.dispatch(&line).await {
eprintln!(" error: {e}");
}
}
Err(ReadlineError::Eof | ReadlineError::Interrupted) => {
println!("Bye!");
break;
}
Err(e) => return Err(e.into()),
}
}
Ok(())
}
async fn dispatch(&self, input: &str) -> anyhow::Result<()> {
let parts: Vec<&str> = input.splitn(3, ' ').collect();
let cmd = parts[0].to_lowercase();
match cmd.as_str() {
"help" | "?" => print_help(),
"quit" | "exit" => std::process::exit(0),
"tables" => self.cmd_tables().await?,
"schema" => {
let table = parts
.get(1)
.ok_or_else(|| anyhow::anyhow!("usage: schema <table>"))?;
self.cmd_schema(table).await?;
}
"count" => {
let table = parts
.get(1)
.ok_or_else(|| anyhow::anyhow!("usage: count <table>"))?;
self.cmd_count(table).await?;
}
"all" => {
let table = parts
.get(1)
.ok_or_else(|| anyhow::anyhow!("usage: all <table> [limit]"))?;
let limit: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(10);
self.cmd_rows(table, None, "id", limit, false).await?;
}
"last" => {
let table = parts
.get(1)
.ok_or_else(|| anyhow::anyhow!("usage: last <table> [limit]"))?;
let limit: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(5);
self.cmd_rows(table, None, "id", limit, true).await?;
}
"find" => {
let table = parts
.get(1)
.ok_or_else(|| anyhow::anyhow!("usage: find <table> <id>"))?;
let id = parts
.get(2)
.ok_or_else(|| anyhow::anyhow!("usage: find <table> <id>"))?;
self.cmd_rows(table, Some(id), "id", 1, false).await?;
}
"sql" => {
let query = input.trim_start_matches("sql").trim();
self.cmd_sql(query).await?;
}
_ => {
let upper = input.to_uppercase();
if [
"SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "EXPLAIN",
]
.iter()
.any(|kw| upper.starts_with(kw))
{
self.cmd_sql(input).await?;
} else {
println!(" Unknown command '{cmd}'. Type 'help' for a list of commands.");
}
}
}
Ok(())
}
async fn cmd_tables(&self) -> anyhow::Result<()> {
let tables: Vec<String> = sqlx::query_scalar(
"SELECT tablename FROM pg_catalog.pg_tables \
WHERE schemaname = 'public' ORDER BY tablename",
)
.fetch_all(&self.pool)
.await?;
if tables.is_empty() {
println!(" (no tables)");
} else {
for t in &tables {
println!(" {t}");
}
println!(" ({} table(s))", tables.len());
}
Ok(())
}
async fn cmd_schema(&self, table: &str) -> anyhow::Result<()> {
let table = validate_ident(table)?;
let cols: Vec<(String, String, String)> = sqlx::query_as(
"SELECT column_name, data_type, is_nullable \
FROM information_schema.columns \
WHERE table_schema = 'public' AND table_name = $1 \
ORDER BY ordinal_position",
)
.bind(table)
.fetch_all(&self.pool)
.await?;
if cols.is_empty() {
println!(" Table '{table}' not found or has no columns.");
} else {
println!(" {:<30} {:<20} nullable", "column", "type");
println!(" {}", "-".repeat(65));
for (col, typ, nullable) in &cols {
println!(" {:<30} {:<20} {}", col, typ, nullable);
}
}
Ok(())
}
async fn cmd_count(&self, table: &str) -> anyhow::Result<()> {
let table = validate_ident(table)?;
let n: i64 = sqlx::query_scalar(&format!("SELECT COUNT(*) FROM \"{table}\""))
.fetch_one(&self.pool)
.await?;
println!(" {n}");
Ok(())
}
async fn cmd_rows(
&self,
table: &str,
where_id: Option<&str>,
id_col: &str,
limit: i64,
desc: bool,
) -> anyhow::Result<()> {
let table = validate_ident(table)?;
let id_col = validate_ident(id_col)?;
let order = if desc { "DESC" } else { "ASC" };
let rows: Vec<serde_json::Value> = if let Some(id) = where_id {
sqlx::query_scalar(&format!(
"SELECT row_to_json(t) FROM \"{table}\" t WHERE \"{id_col}\" = $1"
))
.bind(id)
.fetch_all(&self.pool)
.await?
} else {
sqlx::query_scalar(&format!(
"SELECT row_to_json(t) FROM (SELECT * FROM \"{table}\" ORDER BY \"{id_col}\" {order} LIMIT $1) t"
))
.bind(limit)
.fetch_all(&self.pool)
.await?
};
if rows.is_empty() {
println!(" (no rows)");
} else {
for row in &rows {
println!(" {}", serde_json::to_string_pretty(row)?);
}
println!(" ({} row(s))", rows.len());
}
Ok(())
}
async fn cmd_sql(&self, query: &str) -> anyhow::Result<()> {
if query.is_empty() {
anyhow::bail!("empty query");
}
let upper = query.trim().to_uppercase();
if upper.starts_with("SELECT") || upper.starts_with("EXPLAIN") {
let rows: Vec<serde_json::Value> =
sqlx::query_scalar(&format!("SELECT row_to_json(t) FROM ({query}) t"))
.fetch_all(&self.pool)
.await?;
if rows.is_empty() {
println!(" (no rows)");
} else {
for row in &rows {
println!(" {}", serde_json::to_string_pretty(row)?);
}
println!(" ({} row(s))", rows.len());
}
} else {
let result = sqlx::query(query).execute(&self.pool).await?;
println!(" {} row(s) affected", result.rows_affected());
}
Ok(())
}
}
fn validate_ident(s: &str) -> anyhow::Result<&str> {
if s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() {
Ok(s)
} else {
anyhow::bail!("invalid identifier '{s}' — only letters, digits, and underscores allowed")
}
}
fn print_help() {
println!(
r#"
rok tinker commands
───────────────────
tables list all tables in the public schema
schema <table> describe columns of a table
count <table> count rows in a table
all <table> [n] show first n rows (default 10)
last <table> [n] show last n rows (default 5)
find <table> <id> find a row by primary key
sql <query> execute a raw SQL statement
<SELECT …> raw SELECT is auto-detected
help | ? show this help
quit | exit | Ctrl-D exit the REPL
"#
);
}