spacetimedb_cli/subcommands/
sql.rs

1use std::fmt;
2use std::fmt::Write;
3use std::time::{Duration, Instant};
4
5use crate::api::{from_json_seed, ClientApi, Connection, SqlStmtResult, StmtStats};
6use crate::common_args;
7use crate::config::Config;
8use crate::util::{database_identity, get_auth_header, ResponseExt, UNSTABLE_WARNING};
9use anyhow::Context;
10use clap::{Arg, ArgAction, ArgMatches};
11use reqwest::RequestBuilder;
12use spacetimedb_lib::de::serde::SeedWrapper;
13use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace};
14
15pub fn cli() -> clap::Command {
16    clap::Command::new("sql")
17        .about(format!("Runs a SQL query on the database. {UNSTABLE_WARNING}"))
18        .arg(
19            Arg::new("database")
20                .required(true)
21                .help("The name or identity of the database you would like to query"),
22        )
23        .arg(
24            Arg::new("query")
25                .action(ArgAction::Set)
26                .required(true)
27                .conflicts_with("interactive")
28                .help("The SQL query to execute"),
29        )
30        .arg(
31            Arg::new("interactive")
32                .long("interactive")
33                .action(ArgAction::SetTrue)
34                .conflicts_with("query")
35                .help("Instead of using a query, run an interactive command prompt for `SQL` expressions"),
36        )
37        .arg(common_args::anonymous())
38        .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
39        .arg(common_args::yes())
40}
41
42pub(crate) async fn parse_req(mut config: Config, args: &ArgMatches) -> Result<Connection, anyhow::Error> {
43    let server = args.get_one::<String>("server").map(|s| s.as_ref());
44    let force = args.get_flag("force");
45    let database_name_or_identity = args.get_one::<String>("database").unwrap();
46    let anon_identity = args.get_flag("anon_identity");
47
48    Ok(Connection {
49        host: config.get_host_url(server)?,
50        auth_header: get_auth_header(&mut config, anon_identity, server, !force).await?,
51        database_identity: database_identity(&config, database_name_or_identity, server).await?,
52        database: database_name_or_identity.to_string(),
53    })
54}
55
56struct StmtResult {
57    table: tabled::Table,
58    stats: Option<StmtStats>,
59}
60
61impl fmt::Display for StmtResult {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        let has_table = !self.table.is_empty();
64        if has_table {
65            write!(f, "{}", self.table)?;
66        }
67
68        if let Some(stats) = &self.stats {
69            if has_table {
70                writeln!(f)?;
71            }
72            let txt = if stats.total_rows == 1 { "row" } else { "rows" };
73
74            let result = format!("({} {txt})", stats.total_rows);
75            let mut info = Vec::new();
76            if stats.rows_inserted != 0 {
77                info.push(format!("inserted: {}", stats.rows_inserted));
78            }
79            if stats.rows_deleted != 0 {
80                info.push(format!("deleted: {}", stats.rows_deleted));
81            }
82            if stats.rows_updated != 0 {
83                info.push(format!("updated: {}", stats.rows_updated));
84            }
85            info.push(format!(
86                "server: {:.2?}",
87                std::time::Duration::from_micros(stats.total_duration_micros)
88            ));
89
90            if !info.is_empty() {
91                write!(f, "{result} [{info}]", info = info.join(", "))?;
92            } else {
93                write!(f, "{result}")?;
94            };
95        };
96        Ok(())
97    }
98}
99
100fn print_stmt_result(
101    stmt_results: &[SqlStmtResult],
102    with_stats: Option<Duration>,
103    f: &mut String,
104) -> anyhow::Result<()> {
105    let if_empty: Option<anyhow::Result<StmtResult>> = stmt_results.is_empty().then_some(anyhow::Ok(StmtResult {
106        stats: with_stats.is_some().then_some(StmtStats::default()),
107        table: tabled::Table::new([""]),
108    }));
109    let total = stmt_results.len();
110    for (pos, result) in if_empty
111        .into_iter()
112        .chain(stmt_results.iter().map(|stmt_result| {
113            let (stats, table) = stmt_result_to_table(stmt_result)?;
114
115            anyhow::Ok(StmtResult {
116                stats: with_stats.is_some().then_some(stats),
117                table,
118            })
119        }))
120        .enumerate()
121    {
122        let result = result?;
123        f.write_str(&format!("{result}"))?;
124        if pos + 1 < total {
125            f.write_char('\n')?;
126            f.write_char('\n')?;
127        }
128    }
129
130    if let Some(with_stats) = with_stats {
131        f.write_char('\n')?;
132        f.write_str(&format!("Roundtrip time: {with_stats:.2?}"))?;
133        f.write_char('\n')?;
134    }
135    Ok(())
136}
137
138pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool) -> Result<(), anyhow::Error> {
139    let now = Instant::now();
140
141    let json = builder
142        .body(sql.to_owned())
143        .send()
144        .await?
145        .ensure_content_type("application/json")
146        .await?
147        .text()
148        .await?;
149
150    let stmt_result_json: Vec<SqlStmtResult> = serde_json::from_str(&json).context("malformed sql response")?;
151
152    let mut out = String::new();
153    print_stmt_result(&stmt_result_json, with_stats.then_some(now.elapsed()), &mut out)?;
154    println!("{out}");
155
156    Ok(())
157}
158
159fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> {
160    let stats = StmtStats::from(stmt_result);
161    let SqlStmtResult { schema, rows, .. } = stmt_result;
162    let ty = Typespace::EMPTY.with_type(schema);
163
164    let table = build_table(
165        schema,
166        rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))),
167    )?;
168
169    Ok((stats, table))
170}
171
172pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
173    eprintln!("{UNSTABLE_WARNING}\n");
174    let interactive = args.get_one::<bool>("interactive").unwrap_or(&false);
175    if *interactive {
176        let con = parse_req(config, args).await?;
177
178        crate::repl::exec(con).await?;
179    } else {
180        let query = args.get_one::<String>("query").unwrap();
181
182        let con = parse_req(config, args).await?;
183        let api = ClientApi::new(con);
184
185        run_sql(api.sql(), query, false).await?;
186    }
187    Ok(())
188}
189
190/// Generates a [`tabled::Table`] from a schema and rows, using the style of a psql table.
191fn build_table<E>(
192    schema: &ProductType,
193    rows: impl Iterator<Item = Result<ProductValue, E>>,
194) -> Result<tabled::Table, E> {
195    let mut builder = tabled::builder::Builder::default();
196    builder.set_header(
197        schema
198            .elements
199            .iter()
200            .enumerate()
201            .map(|(i, e)| e.name.clone().unwrap_or_else(|| format!("column {i}").into())),
202    );
203
204    let ty = Typespace::EMPTY.with_type(schema);
205    for row in rows {
206        let row = row?;
207        builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| {
208            let ty = satn::PsqlType {
209                tuple: ty.ty(),
210                field: &ty.ty().elements[idx],
211                idx,
212            };
213
214            satn::PsqlWrapper { ty, value }.to_string()
215        }));
216    }
217
218    let mut table = builder.build();
219    table.with(tabled::settings::Style::psql());
220
221    Ok(table)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use itertools::Itertools;
228    use serde_json::value::RawValue;
229    use spacetimedb_client_api_messages::http::SqlStmtStats;
230    use spacetimedb_lib::error::ResultTest;
231    use spacetimedb_lib::sats::time_duration::TimeDuration;
232    use spacetimedb_lib::sats::timestamp::Timestamp;
233    use spacetimedb_lib::sats::{product, GroundSpacetimeType, ProductType};
234    use spacetimedb_lib::{AlgebraicType, AlgebraicValue, ConnectionId, Identity};
235
236    fn make_row(row: &[AlgebraicValue]) -> Result<Box<RawValue>, serde_json::Error> {
237        let json = serde_json::json!(row);
238        RawValue::from_string(json.to_string())
239    }
240
241    fn check_outputs(
242        result: &[SqlStmtResult],
243        duration: Option<Duration>,
244        expect: &str,
245    ) -> Result<String, anyhow::Error> {
246        let mut out = String::new();
247        print_stmt_result(result, duration, &mut out)?;
248
249        // Need to trim the output to because rustfmt remove the `expect` spaces
250        let out = out.lines().map(|line| line.trim_end()).join("\n");
251        assert_eq!(out, expect,);
252
253        Ok(out)
254    }
255
256    fn check_output(
257        schema: ProductType,
258        rows: Vec<&RawValue>,
259        stats: SqlStmtStats,
260        duration: Option<Duration>,
261        expect: &str,
262    ) -> Result<String, anyhow::Error> {
263        let table = SqlStmtResult {
264            schema: schema.clone(),
265            rows,
266            total_duration_micros: 1000,
267            stats: stats.clone(),
268        };
269
270        let mut out = String::new();
271        print_stmt_result(&[table], duration, &mut out)?;
272
273        // Need to trim the output to because rustfmt remove the `expect` spaces
274        let out = out.lines().map(|line| line.trim_end()).join("\n");
275        assert_eq!(out, expect,);
276
277        Ok(out)
278    }
279
280    #[test]
281    fn test_output() -> Result<(), anyhow::Error> {
282        let duration = Duration::from_micros(1000);
283        let schema = ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::I64)]);
284        let row = make_row(&[AlgebraicValue::I32(1), AlgebraicValue::I64(2)])?;
285        // Verify with and without stats
286        check_output(
287            schema.clone(),
288            vec![&row],
289            SqlStmtStats {
290                rows_inserted: 1,
291                rows_deleted: 1,
292                rows_updated: 1,
293            },
294            None,
295            r#" a | b
296---+---
297 1 | 2"#,
298        )?;
299
300        check_output(
301            schema.clone(),
302            vec![&row],
303            SqlStmtStats {
304                rows_inserted: 1,
305                rows_deleted: 1,
306                rows_updated: 1,
307            },
308            Some(duration),
309            r#" a | b
310---+---
311 1 | 2
312(1 row) [inserted: 1, deleted: 1, updated: 1, server: 1.00ms]
313Roundtrip time: 1.00ms"#,
314        )?;
315
316        // Only a query result
317        check_output(
318            schema.clone(),
319            vec![&row],
320            SqlStmtStats {
321                rows_inserted: 0,
322                rows_deleted: 0,
323                rows_updated: 0,
324            },
325            Some(duration),
326            r#" a | b
327---+---
328 1 | 2
329(1 row) [server: 1.00ms]
330Roundtrip time: 1.00ms"#,
331        )?;
332
333        // Empty table
334        check_output(
335            schema.clone(),
336            vec![],
337            SqlStmtStats {
338                rows_inserted: 0,
339                rows_deleted: 0,
340                rows_updated: 0,
341            },
342            Some(duration),
343            r#" a | b
344---+---
345(0 rows) [server: 1.00ms]
346Roundtrip time: 1.00ms"#,
347        )?;
348
349        // DML
350        check_output(
351            schema.clone(),
352            vec![],
353            SqlStmtStats {
354                rows_inserted: 1,
355                rows_deleted: 0,
356                rows_updated: 0,
357            },
358            Some(duration),
359            r#" a | b
360---+---
361(0 rows) [inserted: 1, server: 1.00ms]
362Roundtrip time: 1.00ms"#,
363        )?;
364
365        check_output(
366            schema.clone(),
367            vec![],
368            SqlStmtStats {
369                rows_inserted: 0,
370                rows_deleted: 1,
371                rows_updated: 0,
372            },
373            Some(duration),
374            r#" a | b
375---+---
376(0 rows) [deleted: 1, server: 1.00ms]
377Roundtrip time: 1.00ms"#,
378        )?;
379
380        check_output(
381            schema.clone(),
382            vec![],
383            SqlStmtStats {
384                rows_inserted: 0,
385                rows_deleted: 0,
386                rows_updated: 1,
387            },
388            Some(duration),
389            r#" a | b
390---+---
391(0 rows) [updated: 1, server: 1.00ms]
392Roundtrip time: 1.00ms"#,
393        )?;
394
395        Ok(())
396    }
397
398    #[test]
399    fn test_multiple_output() -> Result<(), anyhow::Error> {
400        let duration = Duration::from_micros(1000);
401        let schema = ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::I64)]);
402        let row = make_row(&[AlgebraicValue::I32(1), AlgebraicValue::I64(2)])?;
403
404        // Verify with and without stats
405        check_outputs(
406            &[
407                SqlStmtResult {
408                    schema: schema.clone(),
409                    rows: vec![&row],
410                    total_duration_micros: 1000,
411                    stats: SqlStmtStats {
412                        rows_inserted: 1,
413                        rows_deleted: 1,
414                        rows_updated: 1,
415                    },
416                },
417                SqlStmtResult {
418                    schema: schema.clone(),
419                    rows: vec![&row],
420                    total_duration_micros: 1000,
421                    stats: SqlStmtStats {
422                        rows_inserted: 1,
423                        rows_deleted: 1,
424                        rows_updated: 1,
425                    },
426                },
427            ],
428            Some(duration),
429            r#" a | b
430---+---
431 1 | 2
432(1 row) [inserted: 1, deleted: 1, updated: 1, server: 1.00ms]
433
434 a | b
435---+---
436 1 | 2
437(1 row) [inserted: 1, deleted: 1, updated: 1, server: 1.00ms]
438Roundtrip time: 1.00ms"#,
439        )?;
440
441        Ok(())
442    }
443
444    fn expect_psql_table(ty: &ProductType, rows: Vec<ProductValue>, expected: &str) {
445        let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string();
446        let mut table = table.split('\n').map(|x| x.trim_end()).join("\n");
447        table.insert(0, '\n');
448        assert_eq!(expected, table);
449    }
450
451    // Verify the output of `sql` matches the inputs that return true for [`AlgebraicType::is_special()`]
452    #[test]
453    fn output_special_types() -> ResultTest<()> {
454        // Check tuples
455        let kind: ProductType = [
456            AlgebraicType::String,
457            AlgebraicType::U256,
458            Identity::get_type(),
459            ConnectionId::get_type(),
460            Timestamp::get_type(),
461            TimeDuration::get_type(),
462        ]
463        .into();
464        let value = product![
465            "a",
466            Identity::ZERO.to_u256(),
467            Identity::ZERO,
468            ConnectionId::ZERO,
469            Timestamp::UNIX_EPOCH,
470            TimeDuration::ZERO
471        ];
472
473        expect_psql_table(
474            &kind,
475            vec![value],
476            r#"
477 column 0 | column 1 | column 2                                                           | column 3                           | column 4                  | column 5
478----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------
479 "a"      | 0        | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
480        );
481
482        // Check struct
483        let kind: ProductType = [
484            ("bool", AlgebraicType::Bool),
485            ("str", AlgebraicType::String),
486            ("bytes", AlgebraicType::bytes()),
487            ("identity", Identity::get_type()),
488            ("connection_id", ConnectionId::get_type()),
489            ("timestamp", Timestamp::get_type()),
490            ("duration", TimeDuration::get_type()),
491        ]
492        .into();
493
494        let value = product![
495            true,
496            "This is spacetimedb".to_string(),
497            AlgebraicValue::Bytes([1, 2, 3, 4, 5, 6, 7].into()),
498            Identity::ZERO,
499            ConnectionId::ZERO,
500            Timestamp::UNIX_EPOCH,
501            TimeDuration::ZERO
502        ];
503
504        expect_psql_table(
505            &kind,
506            vec![value.clone()],
507            r#"
508 bool | str                   | bytes            | identity                                                           | connection_id                      | timestamp                 | duration
509------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------
510 true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
511        );
512
513        // Check nested struct, tuple...
514        let kind: ProductType = [(None, AlgebraicType::product(kind))].into();
515
516        let value = product![value.clone()];
517
518        expect_psql_table(
519            &kind,
520            vec![value.clone()],
521            r#"
522 column 0
523----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
524 (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#,
525        );
526
527        let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into();
528
529        let value = product![value];
530
531        expect_psql_table(
532            &kind,
533            vec![value],
534            r#"
535 tuple
536----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
537 (0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#,
538        );
539
540        Ok(())
541    }
542}