Skip to main content

agent_first_psql/
cli.rs

1use std::io::Write;
2
3use crate::types::{QueryOptions, SessionConfig};
4use agent_first_data::{cli_parse_log_filters, cli_parse_output, OutputFormat};
5use clap::{CommandFactory, Parser, ValueEnum};
6use serde_json::{json, Value};
7use std::collections::BTreeMap;
8
9pub enum Mode {
10    Cli(CliRequest),
11    Pipe(PipeInit),
12}
13
14pub struct PipeInit {
15    pub output: OutputFormat,
16    pub session: SessionConfig,
17    pub log: Vec<String>,
18    pub startup_argv: Vec<String>,
19    pub startup_args: Value,
20    pub startup_env: Value,
21    pub startup_requested: bool,
22}
23
24pub struct CliRequest {
25    pub sql: String,
26    pub params: Vec<Value>,
27    pub options: QueryOptions,
28    pub session: SessionConfig,
29    pub output: OutputFormat,
30    pub log: Vec<String>,
31    pub startup_argv: Vec<String>,
32    pub startup_args: Value,
33    pub startup_env: Value,
34    pub startup_requested: bool,
35    pub dry_run: bool,
36}
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
39enum RuntimeMode {
40    Cli,
41    Pipe,
42    #[value(name = "psql")]
43    Psql,
44}
45
46#[doc = r#"Agent-First PostgreSQL client.
47
48### Interface Policy
49
50- default mode is canonical agent-first CLI
51- `--mode psql` is argument translation only; runtime output stays JSONL
52- stdout carries protocol events; stderr is not a protocol channel
53
54### Query Sources and Parameters
55
56- use `--sql` for inline SQL or `--sql-file` for a file
57- use repeatable `--param N=value` for positional binds
58- placeholder count is validated from prepared-statement metadata, not by SQL text scanning
59
60### Connection Sources
61
62- `--dsn-secret` for a PostgreSQL URI
63- `--conninfo-secret` for libpq-style conninfo
64- or discrete `--host`, `--port`, `--user`, `--dbname`, `--password-secret`
65- agent-first environment fallbacks: `AFPSQL_*`
66- PostgreSQL environment fallbacks: `PGHOST`, `PGPORT`, `PGUSER`, `PGDATABASE`
67
68### Result Shaping
69
70- default mode buffers a bounded inline result
71- use `--stream-rows` for large result sets, with `--batch-rows` and `--batch-bytes` to tune chunk size
72- `--output json|yaml|plain` changes rendering only, not the runtime schema
73
74### Examples
75
76```text
77afpsql --sql "select now() as now_rfc3339"
78afpsql --sql-file ./query.sql
79afpsql --sql "select * from users where id = $1" --param 1=123
80afpsql --dsn-secret "postgresql://app:secret@127.0.0.1:5432/appdb" --sql "select 1"
81afpsql --mode psql -h 127.0.0.1 -p 5432 -U app -d appdb -c "select 1"
82afpsql --sql "select * from big_table" --stream-rows --batch-rows 1000
83afpsql --mode pipe
84```
85
86### Exit Codes
87
88- `0`: query completed successfully
89- `1`: SQL error or runtime error
90- `2`: invalid CLI arguments
91"#]
92#[derive(Parser)]
93#[command(name = "afpsql", version, verbatim_doc_comment)]
94pub struct AfdCli {
95    /// Inline SQL string to execute.
96    #[arg(long, help_heading = "Query")]
97    sql: Option<String>,
98    /// Read SQL from a file.
99    #[arg(long = "sql-file", help_heading = "Query")]
100    sql_file: Option<String>,
101    /// Positional bind parameter in `N=value` form. Repeat for additional parameters.
102    #[arg(long = "param", help_heading = "Query")]
103    param: Vec<String>,
104    /// Stream large result sets as `result_rows` batches instead of a single inline result.
105    #[arg(long = "stream-rows", help_heading = "Query")]
106    stream_rows: bool,
107    /// Maximum rows per streamed batch.
108    #[arg(long = "batch-rows", help_heading = "Query")]
109    batch_rows: Option<usize>,
110    /// Soft byte target per streamed batch.
111    #[arg(long = "batch-bytes", help_heading = "Query")]
112    batch_bytes: Option<usize>,
113    /// Per-query statement timeout in milliseconds.
114    #[arg(long = "statement-timeout-ms", help_heading = "Query")]
115    statement_timeout_ms: Option<u64>,
116    /// Per-query lock timeout in milliseconds.
117    #[arg(long = "lock-timeout-ms", help_heading = "Query")]
118    lock_timeout_ms: Option<u64>,
119    /// Maximum inline rows before returning `result_too_large`.
120    #[arg(long = "inline-max-rows", help_heading = "Query")]
121    inline_max_rows: Option<usize>,
122    /// Maximum inline payload bytes before returning `result_too_large`.
123    #[arg(long = "inline-max-bytes", help_heading = "Query")]
124    inline_max_bytes: Option<usize>,
125    /// Force the query to run in a read-only transaction.
126    #[arg(long = "read-only", help_heading = "Query")]
127    read_only: bool,
128    /// Preview the query without executing it
129    #[arg(long, help_heading = "Query")]
130    dry_run: bool,
131
132    /// PostgreSQL DSN URI. Redacted in structured output.
133    #[arg(long = "dsn-secret", help_heading = "Connection")]
134    dsn_secret: Option<String>,
135    /// libpq-style conninfo string. Redacted in structured output.
136    #[arg(long = "conninfo-secret", help_heading = "Connection")]
137    conninfo_secret: Option<String>,
138    /// PostgreSQL host.
139    #[arg(long, help_heading = "Connection")]
140    host: Option<String>,
141    /// PostgreSQL port.
142    #[arg(long, help_heading = "Connection")]
143    port: Option<u16>,
144    /// PostgreSQL user name.
145    #[arg(long, help_heading = "Connection")]
146    user: Option<String>,
147    /// PostgreSQL database name.
148    #[arg(long, help_heading = "Connection")]
149    dbname: Option<String>,
150    /// PostgreSQL password. Redacted in structured output.
151    #[arg(long = "password-secret", help_heading = "Connection")]
152    password_secret: Option<String>,
153
154    /// Output format: json (default), yaml, or plain.
155    #[arg(long, default_value = "json", help_heading = "Runtime")]
156    output: String,
157    /// Diagnostic log categories.
158    #[arg(long = "log", value_delimiter = ',', help_heading = "Runtime")]
159    log: Vec<String>,
160    /// Runtime mode: canonical cli, pipe, or `psql` translation mode.
161    #[arg(long, value_enum, default_value_t = RuntimeMode::Cli, help_heading = "Runtime")]
162    mode: RuntimeMode,
163}
164
165pub fn parse_args() -> Result<Mode, String> {
166    let raw: Vec<String> = std::env::args().collect();
167    if is_psql_mode_requested(&raw) {
168        return parse_psql_mode(&raw);
169    }
170    let startup_requested = startup_requested_from_raw(&raw);
171
172    // --help: recursive plain-text help (all subcommands expanded)
173    if raw.iter().any(|a| a == "--help" || a == "-h") {
174        let _ = writeln!(
175            std::io::stdout(),
176            "{}",
177            agent_first_data::cli_render_help(&AfdCli::command(), &[])
178        );
179        std::process::exit(0);
180    }
181    // --help-markdown: Markdown for doc generation
182    if raw.iter().any(|a| a == "--help-markdown") {
183        let _ = writeln!(
184            std::io::stdout(),
185            "{}",
186            agent_first_data::cli_render_help_markdown(&AfdCli::command(), &[])
187        );
188        std::process::exit(0);
189    }
190
191    let cli = match AfdCli::try_parse_from(&raw) {
192        Ok(c) => c,
193        Err(e) => {
194            use clap::error::ErrorKind;
195            if matches!(e.kind(), ErrorKind::DisplayVersion) {
196                let _ = writeln!(std::io::stdout(), "{e}");
197                std::process::exit(0);
198            }
199            return Err(e.to_string());
200        }
201    };
202    let output = parse_output(&cli.output)?;
203    let log = parse_log_categories(&cli.log);
204    let session = SessionConfig {
205        dsn_secret: cli.dsn_secret,
206        conninfo_secret: cli.conninfo_secret,
207        host: cli.host,
208        port: cli.port,
209        user: cli.user,
210        dbname: cli.dbname,
211        password_secret: cli.password_secret,
212    };
213    let mode_name = match cli.mode {
214        RuntimeMode::Cli => "cli",
215        RuntimeMode::Pipe => "pipe",
216        RuntimeMode::Psql => "psql",
217    };
218    let startup_args = json!({
219        "mode": mode_name,
220        "sql": &cli.sql,
221        "sql_file": &cli.sql_file,
222        "param": &cli.param,
223        "stream_rows": cli.stream_rows,
224        "batch_rows": cli.batch_rows,
225        "batch_bytes": cli.batch_bytes,
226        "statement_timeout_ms": cli.statement_timeout_ms,
227        "lock_timeout_ms": cli.lock_timeout_ms,
228        "inline_max_rows": cli.inline_max_rows,
229        "inline_max_bytes": cli.inline_max_bytes,
230        "read_only": cli.read_only,
231        "dsn_secret": &session.dsn_secret,
232        "conninfo_secret": &session.conninfo_secret,
233        "host": &session.host,
234        "port": session.port,
235        "user": &session.user,
236        "dbname": &session.dbname,
237        "password_secret": &session.password_secret,
238        "output": output_name(output),
239        "log": &log,
240    });
241    let startup_env = startup_env_snapshot();
242
243    match cli.mode {
244        RuntimeMode::Pipe => {
245            return Ok(Mode::Pipe(PipeInit {
246                output,
247                session,
248                log: log.clone(),
249                startup_argv: raw,
250                startup_args,
251                startup_env,
252                startup_requested,
253            }));
254        }
255        RuntimeMode::Cli | RuntimeMode::Psql => {}
256    }
257
258    let sql = load_sql(cli.sql, cli.sql_file)?;
259    let params = parse_params(&cli.param)?;
260
261    let options = QueryOptions {
262        stream_rows: cli.stream_rows,
263        batch_rows: cli.batch_rows,
264        batch_bytes: cli.batch_bytes,
265        statement_timeout_ms: cli.statement_timeout_ms,
266        lock_timeout_ms: cli.lock_timeout_ms,
267        read_only: if cli.read_only { Some(true) } else { None },
268        inline_max_rows: cli.inline_max_rows,
269        inline_max_bytes: cli.inline_max_bytes,
270    };
271
272    Ok(Mode::Cli(CliRequest {
273        sql,
274        params,
275        options,
276        session,
277        output,
278        log,
279        startup_argv: raw,
280        startup_args,
281        startup_env,
282        startup_requested,
283        dry_run: cli.dry_run,
284    }))
285}
286
287fn parse_psql_mode(raw: &[String]) -> Result<Mode, String> {
288    let startup_requested = startup_requested_from_raw(raw);
289    let mut sql: Option<String> = None;
290    let mut sql_file: Option<String> = None;
291    let mut host: Option<String> = None;
292    let mut port: Option<u16> = None;
293    let mut user: Option<String> = None;
294    let mut dbname: Option<String> = None;
295    let mut dsn_secret: Option<String> = None;
296    let mut conninfo_secret: Option<String> = None;
297    let mut params_kv: Vec<String> = vec![];
298    let mut output = OutputFormat::Json;
299    let mut log_entries: Vec<String> = vec![];
300
301    let mut i = 1usize;
302    while i < raw.len() {
303        match raw[i].as_str() {
304            "--mode" => {
305                i += 1;
306                let v = raw.get(i).ok_or("--mode requires value")?;
307                if v != "psql" {
308                    return Err(format!("unsupported psql-mode argument: --mode {v}; only --mode psql is allowed with psql translation"));
309                }
310                i += 1;
311            }
312            other if other.starts_with("--mode=") => {
313                let v = other.trim_start_matches("--mode=");
314                if v != "psql" {
315                    return Err(format!("unsupported psql-mode argument: {other}; only --mode=psql is allowed with psql translation"));
316                }
317                i += 1;
318            }
319            "-c" => {
320                i += 1;
321                let v = raw.get(i).ok_or("-c requires SQL")?;
322                sql = Some(v.clone());
323                i += 1;
324            }
325            "-f" => {
326                i += 1;
327                let v = raw.get(i).ok_or("-f requires file path")?;
328                sql_file = Some(v.clone());
329                i += 1;
330            }
331            "-h" => {
332                i += 1;
333                host = Some(raw.get(i).ok_or("-h requires value")?.clone());
334                i += 1;
335            }
336            "-p" => {
337                i += 1;
338                port = Some(
339                    raw.get(i)
340                        .ok_or("-p requires value")?
341                        .parse()
342                        .map_err(|_| "invalid -p port")?,
343                );
344                i += 1;
345            }
346            "-U" => {
347                i += 1;
348                user = Some(raw.get(i).ok_or("-U requires value")?.clone());
349                i += 1;
350            }
351            "-d" => {
352                i += 1;
353                dbname = Some(raw.get(i).ok_or("-d requires value")?.clone());
354                i += 1;
355            }
356            "--dsn-secret" => {
357                i += 1;
358                dsn_secret = Some(raw.get(i).ok_or("--dsn-secret requires value")?.clone());
359                i += 1;
360            }
361            "--conninfo-secret" => {
362                i += 1;
363                conninfo_secret = Some(
364                    raw.get(i)
365                        .ok_or("--conninfo-secret requires value")?
366                        .clone(),
367                );
368                i += 1;
369            }
370            "-v" => {
371                i += 1;
372                params_kv.push(raw.get(i).ok_or("-v requires N=value")?.clone());
373                i += 1;
374            }
375            "--output" => {
376                i += 1;
377                output = parse_output(raw.get(i).ok_or("--output requires value")?)?;
378                i += 1;
379            }
380            "--log" => {
381                i += 1;
382                let values = raw.get(i).ok_or("--log requires value")?;
383                for part in values.split(',') {
384                    let trimmed = part.trim();
385                    if !trimmed.is_empty() {
386                        log_entries.push(trimmed.to_string());
387                    }
388                }
389                i += 1;
390            }
391            other if other.starts_with("postgresql://") || other.starts_with("postgres://") => {
392                dsn_secret = Some(other.to_string());
393                i += 1;
394            }
395            unsupported => {
396                return Err(format!(
397                    "unsupported psql-mode argument: {unsupported}; only --mode psql, -c/-f/-h/-p/-U/-d/-v/--dsn-secret/--conninfo-secret/--output/--log are supported"
398                ));
399            }
400        }
401    }
402
403    let session = SessionConfig {
404        dsn_secret,
405        conninfo_secret,
406        host,
407        port,
408        user,
409        dbname,
410        password_secret: None,
411    };
412
413    let startup_sql = sql.clone();
414    let startup_sql_file = sql_file.clone();
415    let sql = load_sql(sql, sql_file)?;
416    let params = parse_params(&params_kv)?;
417    let startup_args = psql_startup_args(
418        "psql",
419        startup_sql.or_else(|| Some(sql.clone())),
420        startup_sql_file,
421        &params_kv,
422        &session,
423        output,
424        &log_entries,
425    );
426    Ok(Mode::Cli(CliRequest {
427        sql,
428        params,
429        options: QueryOptions::default(),
430        session,
431        output,
432        log: parse_log_categories(&log_entries),
433        startup_argv: raw.to_vec(),
434        startup_args,
435        startup_env: startup_env_snapshot(),
436        startup_requested,
437        dry_run: false,
438    }))
439}
440
441fn is_psql_mode_requested(raw: &[String]) -> bool {
442    let mut i = 1usize;
443    while i < raw.len() {
444        let arg = raw[i].as_str();
445        if arg == "--mode" {
446            if let Some(v) = raw.get(i + 1) {
447                return v == "psql";
448            }
449            return false;
450        }
451        if arg == "--mode=psql" {
452            return true;
453        }
454        i += 1;
455    }
456    false
457}
458
459fn load_sql(sql: Option<String>, sql_file: Option<String>) -> Result<String, String> {
460    match (sql, sql_file) {
461        (Some(s), None) => Ok(s),
462        (None, Some(path)) => {
463            std::fs::read_to_string(path).map_err(|e| format!("read --sql-file failed: {e}"))
464        }
465        (Some(_), Some(_)) => Err("--sql and --sql-file are mutually exclusive".to_string()),
466        (None, None) => Err("one of --sql or --sql-file is required".to_string()),
467    }
468}
469
470fn parse_output(v: &str) -> Result<OutputFormat, String> {
471    cli_parse_output(v)
472}
473
474fn parse_log_categories(entries: &[String]) -> Vec<String> {
475    cli_parse_log_filters(entries)
476}
477
478fn startup_requested_from_raw(raw: &[String]) -> bool {
479    let mut i = 1usize;
480    while i < raw.len() {
481        if raw[i] == "--log" {
482            if let Some(values) = raw.get(i + 1) {
483                for part in values.split(',') {
484                    let v = part.trim().to_ascii_lowercase();
485                    if matches!(v.as_str(), "startup" | "all" | "*") {
486                        return true;
487                    }
488                }
489            }
490            i += 2;
491            continue;
492        }
493        if let Some(values) = raw[i].strip_prefix("--log=") {
494            for part in values.split(',') {
495                let v = part.trim().to_ascii_lowercase();
496                if matches!(v.as_str(), "startup" | "all" | "*") {
497                    return true;
498                }
499            }
500        }
501        i += 1;
502    }
503    false
504}
505
506fn output_name(output: OutputFormat) -> &'static str {
507    match output {
508        OutputFormat::Json => "json",
509        OutputFormat::Yaml => "yaml",
510        OutputFormat::Plain => "plain",
511    }
512}
513
514fn startup_env_snapshot() -> Value {
515    json!({
516        "AFPSQL_DSN_SECRET": std::env::var("AFPSQL_DSN_SECRET").ok(),
517        "AFPSQL_CONNINFO_SECRET": std::env::var("AFPSQL_CONNINFO_SECRET").ok(),
518        "AFPSQL_HOST": std::env::var("AFPSQL_HOST").ok(),
519        "AFPSQL_PORT": std::env::var("AFPSQL_PORT").ok(),
520        "AFPSQL_USER": std::env::var("AFPSQL_USER").ok(),
521        "AFPSQL_DBNAME": std::env::var("AFPSQL_DBNAME").ok(),
522        "AFPSQL_PASSWORD_SECRET": std::env::var("AFPSQL_PASSWORD_SECRET").ok(),
523        "PGHOST": std::env::var("PGHOST").ok(),
524        "PGPORT": std::env::var("PGPORT").ok(),
525        "PGUSER": std::env::var("PGUSER").ok(),
526        "PGDATABASE": std::env::var("PGDATABASE").ok(),
527    })
528}
529
530fn psql_startup_args(
531    mode: &str,
532    sql: Option<String>,
533    sql_file: Option<String>,
534    params_kv: &[String],
535    session: &SessionConfig,
536    output: OutputFormat,
537    log_entries: &[String],
538) -> Value {
539    json!({
540        "mode": mode,
541        "sql": sql,
542        "sql_file": sql_file,
543        "param": params_kv,
544        "dsn_secret": session.dsn_secret,
545        "conninfo_secret": session.conninfo_secret,
546        "host": session.host,
547        "port": session.port,
548        "user": session.user,
549        "dbname": session.dbname,
550        "password_secret": session.password_secret,
551        "output": output_name(output),
552        "log": parse_log_categories(log_entries),
553    })
554}
555
556pub fn parse_params(entries: &[String]) -> Result<Vec<Value>, String> {
557    let mut by_index: BTreeMap<usize, Value> = BTreeMap::new();
558    for entry in entries {
559        let (idx, raw) = split_index_value(entry)?;
560        if idx == 0 {
561            return Err("param index must start at 1".to_string());
562        }
563        by_index.insert(idx, parse_param_value(raw));
564    }
565    if by_index.is_empty() {
566        return Ok(vec![]);
567    }
568    let max = by_index.keys().max().copied().unwrap_or(0);
569    let mut out = Vec::with_capacity(max);
570    for i in 1..=max {
571        let v = by_index
572            .remove(&i)
573            .ok_or_else(|| format!("missing parameter index {i}"))?;
574        out.push(v);
575    }
576    Ok(out)
577}
578
579fn split_index_value(entry: &str) -> Result<(usize, &str), String> {
580    let mut parts = entry.splitn(2, '=');
581    let left = parts.next().unwrap_or_default();
582    let right = parts
583        .next()
584        .ok_or_else(|| format!("invalid param '{entry}', expected N=value"))?;
585    let idx = left
586        .parse::<usize>()
587        .map_err(|_| format!("invalid param index in '{entry}'"))?;
588    Ok((idx, right))
589}
590
591fn parse_param_value(v: &str) -> Value {
592    if v == "null" {
593        return Value::Null;
594    }
595    if v == "true" {
596        return Value::Bool(true);
597    }
598    if v == "false" {
599        return Value::Bool(false);
600    }
601    if let Ok(i) = v.parse::<i64>() {
602        return Value::Number(i.into());
603    }
604    if let Ok(f) = v.parse::<f64>() {
605        if let Some(n) = serde_json::Number::from_f64(f) {
606            return Value::Number(n);
607        }
608    }
609    Value::String(v.to_string())
610}
611
612#[cfg(test)]
613#[path = "../tests/support/unit_cli.rs"]
614mod tests;