1use std::fmt;
2use std::fmt::Write;
3use std::time::{Duration, Instant};
4
5use crate::api::{from_json_seed, ClientApi, Connection, SqlStmtResult, StmtStats};
6use crate::common_args;
7use crate::config::Config;
8use crate::util::{database_identity, get_auth_header, ResponseExt, UNSTABLE_WARNING};
9use anyhow::Context;
10use clap::{Arg, ArgAction, ArgMatches};
11use reqwest::RequestBuilder;
12use spacetimedb_lib::de::serde::SeedWrapper;
13use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace};
14
15pub fn cli() -> clap::Command {
16 clap::Command::new("sql")
17 .about(format!("Runs a SQL query on the database. {UNSTABLE_WARNING}"))
18 .arg(
19 Arg::new("database")
20 .required(true)
21 .help("The name or identity of the database you would like to query"),
22 )
23 .arg(
24 Arg::new("query")
25 .action(ArgAction::Set)
26 .required(true)
27 .conflicts_with("interactive")
28 .help("The SQL query to execute"),
29 )
30 .arg(
31 Arg::new("interactive")
32 .long("interactive")
33 .action(ArgAction::SetTrue)
34 .conflicts_with("query")
35 .help("Instead of using a query, run an interactive command prompt for `SQL` expressions"),
36 )
37 .arg(common_args::anonymous())
38 .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
39 .arg(common_args::yes())
40}
41
42pub(crate) async fn parse_req(mut config: Config, args: &ArgMatches) -> Result<Connection, anyhow::Error> {
43 let server = args.get_one::<String>("server").map(|s| s.as_ref());
44 let force = args.get_flag("force");
45 let database_name_or_identity = args.get_one::<String>("database").unwrap();
46 let anon_identity = args.get_flag("anon_identity");
47
48 Ok(Connection {
49 host: config.get_host_url(server)?,
50 auth_header: get_auth_header(&mut config, anon_identity, server, !force).await?,
51 database_identity: database_identity(&config, database_name_or_identity, server).await?,
52 database: database_name_or_identity.to_string(),
53 })
54}
55
56struct StmtResult {
57 table: tabled::Table,
58 stats: Option<StmtStats>,
59}
60
61impl fmt::Display for StmtResult {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 let has_table = !self.table.is_empty();
64 if has_table {
65 write!(f, "{}", self.table)?;
66 }
67
68 if let Some(stats) = &self.stats {
69 if has_table {
70 writeln!(f)?;
71 }
72 let txt = if stats.total_rows == 1 { "row" } else { "rows" };
73
74 let result = format!("({} {txt})", stats.total_rows);
75 let mut info = Vec::new();
76 if stats.rows_inserted != 0 {
77 info.push(format!("inserted: {}", stats.rows_inserted));
78 }
79 if stats.rows_deleted != 0 {
80 info.push(format!("deleted: {}", stats.rows_deleted));
81 }
82 if stats.rows_updated != 0 {
83 info.push(format!("updated: {}", stats.rows_updated));
84 }
85 info.push(format!(
86 "server: {:.2?}",
87 std::time::Duration::from_micros(stats.total_duration_micros)
88 ));
89
90 if !info.is_empty() {
91 write!(f, "{result} [{info}]", info = info.join(", "))?;
92 } else {
93 write!(f, "{result}")?;
94 };
95 };
96 Ok(())
97 }
98}
99
100fn print_stmt_result(
101 stmt_results: &[SqlStmtResult],
102 with_stats: Option<Duration>,
103 f: &mut String,
104) -> anyhow::Result<()> {
105 let if_empty: Option<anyhow::Result<StmtResult>> = stmt_results.is_empty().then_some(anyhow::Ok(StmtResult {
106 stats: with_stats.is_some().then_some(StmtStats::default()),
107 table: tabled::Table::new([""]),
108 }));
109 let total = stmt_results.len();
110 for (pos, result) in if_empty
111 .into_iter()
112 .chain(stmt_results.iter().map(|stmt_result| {
113 let (stats, table) = stmt_result_to_table(stmt_result)?;
114
115 anyhow::Ok(StmtResult {
116 stats: with_stats.is_some().then_some(stats),
117 table,
118 })
119 }))
120 .enumerate()
121 {
122 let result = result?;
123 f.write_str(&format!("{result}"))?;
124 if pos + 1 < total {
125 f.write_char('\n')?;
126 f.write_char('\n')?;
127 }
128 }
129
130 if let Some(with_stats) = with_stats {
131 f.write_char('\n')?;
132 f.write_str(&format!("Roundtrip time: {with_stats:.2?}"))?;
133 f.write_char('\n')?;
134 }
135 Ok(())
136}
137
138pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool) -> Result<(), anyhow::Error> {
139 let now = Instant::now();
140
141 let json = builder
142 .body(sql.to_owned())
143 .send()
144 .await?
145 .ensure_content_type("application/json")
146 .await?
147 .text()
148 .await?;
149
150 let stmt_result_json: Vec<SqlStmtResult> = serde_json::from_str(&json).context("malformed sql response")?;
151
152 let mut out = String::new();
153 print_stmt_result(&stmt_result_json, with_stats.then_some(now.elapsed()), &mut out)?;
154 println!("{out}");
155
156 Ok(())
157}
158
159fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> {
160 let stats = StmtStats::from(stmt_result);
161 let SqlStmtResult { schema, rows, .. } = stmt_result;
162 let ty = Typespace::EMPTY.with_type(schema);
163
164 let table = build_table(
165 schema,
166 rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))),
167 )?;
168
169 Ok((stats, table))
170}
171
172pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
173 eprintln!("{UNSTABLE_WARNING}\n");
174 let interactive = args.get_one::<bool>("interactive").unwrap_or(&false);
175 if *interactive {
176 let con = parse_req(config, args).await?;
177
178 crate::repl::exec(con).await?;
179 } else {
180 let query = args.get_one::<String>("query").unwrap();
181
182 let con = parse_req(config, args).await?;
183 let api = ClientApi::new(con);
184
185 run_sql(api.sql(), query, false).await?;
186 }
187 Ok(())
188}
189
190fn build_table<E>(
192 schema: &ProductType,
193 rows: impl Iterator<Item = Result<ProductValue, E>>,
194) -> Result<tabled::Table, E> {
195 let mut builder = tabled::builder::Builder::default();
196 builder.set_header(
197 schema
198 .elements
199 .iter()
200 .enumerate()
201 .map(|(i, e)| e.name.clone().unwrap_or_else(|| format!("column {i}").into())),
202 );
203
204 let ty = Typespace::EMPTY.with_type(schema);
205 for row in rows {
206 let row = row?;
207 builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| {
208 let ty = satn::PsqlType {
209 tuple: ty.ty(),
210 field: &ty.ty().elements[idx],
211 idx,
212 };
213
214 satn::PsqlWrapper { ty, value }.to_string()
215 }));
216 }
217
218 let mut table = builder.build();
219 table.with(tabled::settings::Style::psql());
220
221 Ok(table)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use itertools::Itertools;
228 use serde_json::value::RawValue;
229 use spacetimedb_client_api_messages::http::SqlStmtStats;
230 use spacetimedb_lib::error::ResultTest;
231 use spacetimedb_lib::sats::time_duration::TimeDuration;
232 use spacetimedb_lib::sats::timestamp::Timestamp;
233 use spacetimedb_lib::sats::{product, GroundSpacetimeType, ProductType};
234 use spacetimedb_lib::{AlgebraicType, AlgebraicValue, ConnectionId, Identity};
235
236 fn make_row(row: &[AlgebraicValue]) -> Result<Box<RawValue>, serde_json::Error> {
237 let json = serde_json::json!(row);
238 RawValue::from_string(json.to_string())
239 }
240
241 fn check_outputs(
242 result: &[SqlStmtResult],
243 duration: Option<Duration>,
244 expect: &str,
245 ) -> Result<String, anyhow::Error> {
246 let mut out = String::new();
247 print_stmt_result(result, duration, &mut out)?;
248
249 let out = out.lines().map(|line| line.trim_end()).join("\n");
251 assert_eq!(out, expect,);
252
253 Ok(out)
254 }
255
256 fn check_output(
257 schema: ProductType,
258 rows: Vec<&RawValue>,
259 stats: SqlStmtStats,
260 duration: Option<Duration>,
261 expect: &str,
262 ) -> Result<String, anyhow::Error> {
263 let table = SqlStmtResult {
264 schema: schema.clone(),
265 rows,
266 total_duration_micros: 1000,
267 stats: stats.clone(),
268 };
269
270 let mut out = String::new();
271 print_stmt_result(&[table], duration, &mut out)?;
272
273 let out = out.lines().map(|line| line.trim_end()).join("\n");
275 assert_eq!(out, expect,);
276
277 Ok(out)
278 }
279
280 #[test]
281 fn test_output() -> Result<(), anyhow::Error> {
282 let duration = Duration::from_micros(1000);
283 let schema = ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::I64)]);
284 let row = make_row(&[AlgebraicValue::I32(1), AlgebraicValue::I64(2)])?;
285 check_output(
287 schema.clone(),
288 vec![&row],
289 SqlStmtStats {
290 rows_inserted: 1,
291 rows_deleted: 1,
292 rows_updated: 1,
293 },
294 None,
295 r#" a | b
296---+---
297 1 | 2"#,
298 )?;
299
300 check_output(
301 schema.clone(),
302 vec![&row],
303 SqlStmtStats {
304 rows_inserted: 1,
305 rows_deleted: 1,
306 rows_updated: 1,
307 },
308 Some(duration),
309 r#" a | b
310---+---
311 1 | 2
312(1 row) [inserted: 1, deleted: 1, updated: 1, server: 1.00ms]
313Roundtrip time: 1.00ms"#,
314 )?;
315
316 check_output(
318 schema.clone(),
319 vec![&row],
320 SqlStmtStats {
321 rows_inserted: 0,
322 rows_deleted: 0,
323 rows_updated: 0,
324 },
325 Some(duration),
326 r#" a | b
327---+---
328 1 | 2
329(1 row) [server: 1.00ms]
330Roundtrip time: 1.00ms"#,
331 )?;
332
333 check_output(
335 schema.clone(),
336 vec![],
337 SqlStmtStats {
338 rows_inserted: 0,
339 rows_deleted: 0,
340 rows_updated: 0,
341 },
342 Some(duration),
343 r#" a | b
344---+---
345(0 rows) [server: 1.00ms]
346Roundtrip time: 1.00ms"#,
347 )?;
348
349 check_output(
351 schema.clone(),
352 vec![],
353 SqlStmtStats {
354 rows_inserted: 1,
355 rows_deleted: 0,
356 rows_updated: 0,
357 },
358 Some(duration),
359 r#" a | b
360---+---
361(0 rows) [inserted: 1, server: 1.00ms]
362Roundtrip time: 1.00ms"#,
363 )?;
364
365 check_output(
366 schema.clone(),
367 vec![],
368 SqlStmtStats {
369 rows_inserted: 0,
370 rows_deleted: 1,
371 rows_updated: 0,
372 },
373 Some(duration),
374 r#" a | b
375---+---
376(0 rows) [deleted: 1, server: 1.00ms]
377Roundtrip time: 1.00ms"#,
378 )?;
379
380 check_output(
381 schema.clone(),
382 vec![],
383 SqlStmtStats {
384 rows_inserted: 0,
385 rows_deleted: 0,
386 rows_updated: 1,
387 },
388 Some(duration),
389 r#" a | b
390---+---
391(0 rows) [updated: 1, server: 1.00ms]
392Roundtrip time: 1.00ms"#,
393 )?;
394
395 Ok(())
396 }
397
398 #[test]
399 fn test_multiple_output() -> Result<(), anyhow::Error> {
400 let duration = Duration::from_micros(1000);
401 let schema = ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::I64)]);
402 let row = make_row(&[AlgebraicValue::I32(1), AlgebraicValue::I64(2)])?;
403
404 check_outputs(
406 &[
407 SqlStmtResult {
408 schema: schema.clone(),
409 rows: vec![&row],
410 total_duration_micros: 1000,
411 stats: SqlStmtStats {
412 rows_inserted: 1,
413 rows_deleted: 1,
414 rows_updated: 1,
415 },
416 },
417 SqlStmtResult {
418 schema: schema.clone(),
419 rows: vec![&row],
420 total_duration_micros: 1000,
421 stats: SqlStmtStats {
422 rows_inserted: 1,
423 rows_deleted: 1,
424 rows_updated: 1,
425 },
426 },
427 ],
428 Some(duration),
429 r#" a | b
430---+---
431 1 | 2
432(1 row) [inserted: 1, deleted: 1, updated: 1, server: 1.00ms]
433
434 a | b
435---+---
436 1 | 2
437(1 row) [inserted: 1, deleted: 1, updated: 1, server: 1.00ms]
438Roundtrip time: 1.00ms"#,
439 )?;
440
441 Ok(())
442 }
443
444 fn expect_psql_table(ty: &ProductType, rows: Vec<ProductValue>, expected: &str) {
445 let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string();
446 let mut table = table.split('\n').map(|x| x.trim_end()).join("\n");
447 table.insert(0, '\n');
448 assert_eq!(expected, table);
449 }
450
451 #[test]
453 fn output_special_types() -> ResultTest<()> {
454 let kind: ProductType = [
456 AlgebraicType::String,
457 AlgebraicType::U256,
458 Identity::get_type(),
459 ConnectionId::get_type(),
460 Timestamp::get_type(),
461 TimeDuration::get_type(),
462 ]
463 .into();
464 let value = product![
465 "a",
466 Identity::ZERO.to_u256(),
467 Identity::ZERO,
468 ConnectionId::ZERO,
469 Timestamp::UNIX_EPOCH,
470 TimeDuration::ZERO
471 ];
472
473 expect_psql_table(
474 &kind,
475 vec![value],
476 r#"
477 column 0 | column 1 | column 2 | column 3 | column 4 | column 5
478----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------
479 "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
480 );
481
482 let kind: ProductType = [
484 ("bool", AlgebraicType::Bool),
485 ("str", AlgebraicType::String),
486 ("bytes", AlgebraicType::bytes()),
487 ("identity", Identity::get_type()),
488 ("connection_id", ConnectionId::get_type()),
489 ("timestamp", Timestamp::get_type()),
490 ("duration", TimeDuration::get_type()),
491 ]
492 .into();
493
494 let value = product![
495 true,
496 "This is spacetimedb".to_string(),
497 AlgebraicValue::Bytes([1, 2, 3, 4, 5, 6, 7].into()),
498 Identity::ZERO,
499 ConnectionId::ZERO,
500 Timestamp::UNIX_EPOCH,
501 TimeDuration::ZERO
502 ];
503
504 expect_psql_table(
505 &kind,
506 vec![value.clone()],
507 r#"
508 bool | str | bytes | identity | connection_id | timestamp | duration
509------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+-----------
510 true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#,
511 );
512
513 let kind: ProductType = [(None, AlgebraicType::product(kind))].into();
515
516 let value = product![value.clone()];
517
518 expect_psql_table(
519 &kind,
520 vec![value.clone()],
521 r#"
522 column 0
523----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
524 (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#,
525 );
526
527 let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into();
528
529 let value = product![value];
530
531 expect_psql_table(
532 &kind,
533 vec![value],
534 r#"
535 tuple
536----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
537 (0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#,
538 );
539
540 Ok(())
541 }
542}