1#![allow(warnings)]
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::pin::Pin;
5
6use chrono::{DateTime, NaiveDate, Utc};
7use rust_decimal::Decimal;
8use std::sync::Arc;
9use teaql_core::{
10 BinaryOp, DataType, EntityDescriptor, Expr, InsertCommand, PropertyDescriptor, Record,
11 SelectQuery, UpdateCommand, Value,
12};
13use teaql_runtime::{GraphNode, InternalIdGenerator, RuntimeError, SchemaProvider, UserContext};
14use teaql_sql::{
15 CompiledQuery, DatabaseKind, SqlCompileError, SqlDialect, SqlTransport, quote_identifier_if_needed,
16};
17use tokio::sync::Mutex;
18use deadpool_postgres::Pool;
19
20pub const DEFAULT_ID_SPACE_TABLE: &str = "teaql_id_space";
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct PostgresDialect;
24
25impl SqlDialect for PostgresDialect {
26 fn kind(&self) -> DatabaseKind {
27 DatabaseKind::PostgreSql
28 }
29
30 fn quote_ident(&self, ident: &str) -> String {
31 quote_ident(ident)
32 }
33
34 fn placeholder(&self, index: usize) -> String {
35 format!("${index}")
36 }
37
38 fn schema_setup_sqls(&self) -> &'static [&'static str] {
39 &[CREATE_SOUNDEX_FUNCTION]
40 }
41
42 fn schema_type_sql(
43 &self,
44 data_type: DataType,
45 _property: &PropertyDescriptor,
46 ) -> Result<&'static str, SqlCompileError> {
47 match data_type {
48 DataType::Bool => Ok("BOOLEAN"),
49 DataType::I64 | DataType::U64 => Ok("BIGINT"),
50 DataType::F64 => Ok("DOUBLE PRECISION"),
51 DataType::Decimal => Ok("NUMERIC"),
52 DataType::Text => Ok("TEXT"),
53 DataType::Json => Ok("JSONB"),
54 DataType::Date => Ok("DATE"),
55 DataType::Timestamp => Ok("TIMESTAMPTZ"),
56 }
57 }
58
59 fn compile_in(
60 &self,
61 entity: &EntityDescriptor,
62 left: &Expr,
63 op: BinaryOp,
64 right: &Expr,
65 params: &mut Vec<Value>,
66 ) -> Result<String, SqlCompileError> {
67 match op {
68 BinaryOp::InLarge | BinaryOp::NotInLarge => {
69 let Expr::Value(Value::List(values)) = right else {
70 let lhs = self.compile_expr(entity, left, params)?;
71 let rhs = self.compile_expr(entity, right, params)?;
72 let operator = match op {
73 BinaryOp::InLarge => "= ANY",
74 BinaryOp::NotInLarge => "<> ALL",
75 _ => unreachable!(),
76 };
77 return Ok(format!("({lhs} {operator} ({rhs}))"));
78 };
79 if values.is_empty() {
80 return Err(SqlCompileError::EmptyInList);
81 }
82 let lhs = self.compile_expr(entity, left, params)?;
83 params.push(Value::List(values.clone()));
84 let placeholder = self.placeholder(params.len());
85 let operator = match op {
86 BinaryOp::InLarge => "= ANY",
87 BinaryOp::NotInLarge => "<> ALL",
88 _ => unreachable!(),
89 };
90 Ok(format!("({lhs} {operator}({placeholder}))"))
91 }
92 _ => {
93 let lhs = self.compile_expr(entity, left, params)?;
94 let operator = match op {
95 BinaryOp::In => "IN",
96 BinaryOp::NotIn => "NOT IN",
97 _ => unreachable!(),
98 };
99 match right {
100 Expr::Value(Value::List(values)) => {
101 if values.is_empty() {
102 return Err(SqlCompileError::EmptyInList);
103 }
104 let mut placeholders = Vec::with_capacity(values.len());
105 for value in values {
106 params.push(value.clone());
107 placeholders.push(self.placeholder(params.len()));
108 }
109 Ok(format!("({lhs} {operator} ({}))", placeholders.join(", ")))
110 }
111 _ => {
112 let rhs = self.compile_expr(entity, right, params)?;
113 Ok(format!("({lhs} {operator} ({rhs}))"))
114 }
115 }
116 }
117 }
118 }
119}
120
121const CREATE_SOUNDEX_FUNCTION: &str = r#"
122CREATE OR REPLACE FUNCTION soundex(input text)
123RETURNS text
124LANGUAGE plpgsql
125IMMUTABLE
126STRICT
127AS $$
128DECLARE
129 normalized text := upper(regexp_replace(input, '[^A-Za-z]', '', 'g'));
130 first_char text;
131 output text;
132 previous_code text;
133 code text;
134 ch text;
135 i integer;
136BEGIN
137 IF normalized = '' THEN
138 RETURN '0000';
139 END IF;
140
141 first_char := substr(normalized, 1, 1);
142 output := first_char;
143 previous_code := CASE
144 WHEN first_char IN ('B', 'F', 'P', 'V') THEN '1'
145 WHEN first_char IN ('C', 'G', 'J', 'K', 'Q', 'S', 'X', 'Z') THEN '2'
146 WHEN first_char IN ('D', 'T') THEN '3'
147 WHEN first_char = 'L' THEN '4'
148 WHEN first_char IN ('M', 'N') THEN '5'
149 WHEN first_char = 'R' THEN '6'
150 ELSE '0'
151 END;
152
153 FOR i IN 2..char_length(normalized) LOOP
154 ch := substr(normalized, i, 1);
155 code := CASE
156 WHEN ch IN ('B', 'F', 'P', 'V') THEN '1'
157 WHEN ch IN ('C', 'G', 'J', 'K', 'Q', 'S', 'X', 'Z') THEN '2'
158 WHEN ch IN ('D', 'T') THEN '3'
159 WHEN ch = 'L' THEN '4'
160 WHEN ch IN ('M', 'N') THEN '5'
161 WHEN ch = 'R' THEN '6'
162 ELSE '0'
163 END;
164
165 IF code <> '0' AND code <> previous_code THEN
166 output := output || code;
167 IF char_length(output) = 4 THEN
168 RETURN output;
169 END IF;
170 END IF;
171 previous_code := code;
172 END LOOP;
173
174 RETURN rpad(output, 4, '0');
175END;
176$$
177"#;
178
179#[derive(Debug)]
180pub enum MutationExecutorError {
181 Driver(tokio_postgres::Error),
182 Pool(String),
183 SqlCompile(SqlCompileError),
184 UnsupportedValue(&'static str),
185 UnsupportedColumnType(String),
186 Bind(String),
187}
188
189impl std::fmt::Display for MutationExecutorError {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 match self {
192 Self::Driver(err) => err.fmt(f),
193 Self::Pool(err) => write!(f, "postgres pool error: {err}"),
194 Self::SqlCompile(err) => err.fmt(f),
195 Self::UnsupportedValue(kind) => {
196 write!(
197 f,
198 "unsupported bind value for mutation executor: {kind}"
199 )
200 }
201 Self::UnsupportedColumnType(kind) => {
202 write!(
203 f,
204 "unsupported column type for record decoding: {kind}"
205 )
206 }
207 Self::Bind(message) => write!(f, "bind error: {message}"),
208 }
209 }
210}
211
212impl std::error::Error for MutationExecutorError {}
213
214impl From<tokio_postgres::Error> for MutationExecutorError {
215 fn from(value: tokio_postgres::Error) -> Self {
216 Self::Driver(value)
217 }
218}
219
220impl From<SqlCompileError> for MutationExecutorError {
221 fn from(value: SqlCompileError) -> Self {
222 Self::SqlCompile(value)
223 }
224}
225
226#[derive(Clone)]
227pub struct PgMutationExecutor {
228 pool: Pool,
229}
230
231impl SqlTransport for PgMutationExecutor {
232 type Error = MutationExecutorError;
233
234 async fn fetch_all_sql(&self, query: &CompiledQuery) -> Result<Vec<Record>, Self::Error> {
235 self.fetch_all(query).await
236 }
237
238 async fn execute_sql(&self, query: &CompiledQuery) -> Result<u64, Self::Error> {
239 self.execute(query).await
240 }
241}
242
243impl teaql_sql::SqlTransaction for PgMutationExecutor {
244 type Error = MutationExecutorError;
245
246 async fn commit_sql(self) -> Result<(), Self::Error> {
247 Err(MutationExecutorError::Bind("Transactions not supported yet".to_string()))
248 }
249
250 async fn rollback_sql(self) -> Result<(), Self::Error> {
251 Err(MutationExecutorError::Bind("Transactions not supported yet".to_string()))
252 }
253}
254
255impl teaql_sql::SqlTransactionTransport for PgMutationExecutor {
256 type Tx<'a> = Self where Self: 'a;
257
258 async fn begin_sql(&self) -> Result<Self::Tx<'_>, Self::Error> {
259 Err(MutationExecutorError::Bind("Transactions not supported yet".to_string()))
260 }
261}
262
263impl PgMutationExecutor {
264 pub fn new(pool: Pool) -> Self {
265 Self { pool }
266 }
267
268 pub fn pool(&self) -> Pool {
269 self.pool.clone()
270 }
271
272 pub async fn ensure_schema(
273 &self,
274 dialect: &PostgresDialect,
275 entities: &[&EntityDescriptor],
276 ) -> Result<(), MutationExecutorError> {
277 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
278 for sql in dialect.schema_setup_sqls() {
279 client.execute(*sql, &[]).await?;
280 }
281 self.ensure_id_space_table(DEFAULT_ID_SPACE_TABLE).await?;
282
283 for entity in entities {
284 if !self.table_exists(&entity.table_name).await? {
285 let sql = dialect.compile_create_table(entity)?;
286 client.execute(&sql, &[]).await?;
287 continue;
288 }
289
290 let existing_columns = self.table_columns(&entity.table_name).await?;
291 for property in &entity.properties {
292 let bare_column = strip_identifier_quotes(&property.column_name).to_lowercase();
293 if existing_columns.contains(&bare_column) {
294 continue;
295 }
296 let sql = dialect.compile_add_column(entity, property)?;
297 client.execute(&sql, &[]).await?;
298 }
299 }
300 Ok(())
301 }
302
303 pub async fn ensure_id_space_table(
304 &self,
305 table_name: &str,
306 ) -> Result<(), MutationExecutorError> {
307 let sql = format!(
308 "CREATE TABLE IF NOT EXISTS {} (type_name VARCHAR(100) PRIMARY KEY, current_level BIGINT NOT NULL)",
309 quote_ident(table_name)
310 );
311 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
312 client.execute(&sql, &[]).await?;
313 Ok(())
314 }
315
316 pub async fn execute(&self, query: &CompiledQuery) -> Result<u64, MutationExecutorError> {
317 let mut args = PgArgs { values: Vec::new() };
318 for value in &query.params {
319 bind_pg(&mut args, value)?;
320 }
321 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
322 let result = client.execute(&query.sql, &args.as_refs()).await?;
323 Ok(result)
324 }
325
326 pub async fn fetch_all(
327 &self,
328 query: &CompiledQuery,
329 ) -> Result<Vec<Record>, MutationExecutorError> {
330 let mut args = PgArgs { values: Vec::new() };
331 for value in &query.params {
332 bind_pg(&mut args, value)?;
333 }
334 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
335 let rows = client.query(&query.sql, &args.as_refs()).await?;
336 rows.iter().map(decode_pg_row).collect()
337 }
338
339 async fn table_exists(&self, table_name: &str) -> Result<bool, MutationExecutorError> {
340 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
341 let row = client.query_one(
342 "SELECT COUNT(1)
343 FROM information_schema.tables
344 WHERE table_schema = current_schema()
345 AND table_name = $1",
346 &[&table_name],
347 ).await?;
348 let exists: i64 = row.try_get(0)?;
349 Ok(exists > 0)
350 }
351
352 async fn table_columns(
353 &self,
354 table_name: &str,
355 ) -> Result<std::collections::BTreeSet<String>, MutationExecutorError> {
356 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
357 let rows = client.query(
358 "SELECT column_name
359 FROM information_schema.columns
360 WHERE table_schema = current_schema()
361 AND table_name = $1",
362 &[&table_name],
363 ).await?;
364 let mut columns = std::collections::BTreeSet::new();
365 for row in rows {
366 let name: String = row.try_get("column_name")?;
367 columns.insert(name.to_lowercase());
368 }
369 Ok(columns)
370 }
371}
372
373async fn ensure_initial_graphs_postgres(
374 executor: &PgMutationExecutor,
375 dialect: &PostgresDialect,
376 ctx: &UserContext,
377) -> Result<(), MutationExecutorError> {
378 for graph in ctx.initial_graphs() {
379 let entity = ctx.entity(&graph.entity).ok_or_else(|| {
380 MutationExecutorError::Bind(format!("missing entity: {}", graph.entity))
381 })?;
382 if initial_graph_exists_postgres(executor, dialect, entity, graph).await? {
383 if let Some(query) = compile_initial_graph_update(dialect, entity, graph)? {
384 executor.execute(&query).await?;
385 }
386 } else {
387 let query = compile_initial_graph_insert(dialect, entity, graph)?;
388 executor.execute(&query).await?;
389 }
390 }
391 Ok(())
392}
393
394async fn initial_graph_exists_postgres(
395 executor: &PgMutationExecutor,
396 dialect: &PostgresDialect,
397 entity: &EntityDescriptor,
398 graph: &GraphNode,
399) -> Result<bool, MutationExecutorError> {
400 let Some(id) = graph.values.get("id") else {
401 return Ok(false);
402 };
403 let query = dialect.compile_select(
404 entity,
405 &SelectQuery::new(&graph.entity)
406 .project("id")
407 .filter(Expr::eq("id", id.clone()))
408 .limit(1),
409 )?;
410 Ok(!executor.fetch_all(&query).await?.is_empty())
411}
412
413fn compile_initial_graph_insert(
414 dialect: &impl SqlDialect,
415 entity: &EntityDescriptor,
416 graph: &GraphNode,
417) -> Result<CompiledQuery, MutationExecutorError> {
418 let mut command = InsertCommand::new(&graph.entity);
419 for (field, value) in &graph.values {
420 command = command.value(field.clone(), value.clone());
421 }
422 dialect.compile_insert(entity, &command).map_err(Into::into)
423}
424
425fn compile_initial_graph_update(
426 dialect: &impl SqlDialect,
427 entity: &EntityDescriptor,
428 graph: &crate::GraphNode,
429) -> Result<Option<CompiledQuery>, MutationExecutorError> {
430 let Some(id) = graph.values.get("id") else {
431 return Ok(None);
432 };
433 let mut command = UpdateCommand::new(&graph.entity, id.clone());
434 for (field, value) in &graph.values {
435 if field == "id" {
436 continue;
437 }
438 command = command.value(field.clone(), value.clone());
439 }
440 match dialect.compile_update(entity, &command) {
441 Ok(query) => Ok(Some(query)),
442 Err(SqlCompileError::EmptyMutation(_)) => Ok(None),
443 Err(err) => Err(err.into()),
444 }
445}
446
447pub trait PostgresSchemaExt {
448 fn ensure_postgres_schema(
449 &self,
450 ) -> Pin<Box<dyn Future<Output = Result<(), MutationExecutorError>> + '_>>;
451}
452
453pub async fn ensure_postgres_schema_for(ctx: &UserContext) -> Result<(), MutationExecutorError> {
454 let dialect = ctx.get_resource::<PostgresDialect>().ok_or_else(|| {
455 MutationExecutorError::Bind("missing typed resource: PostgresDialect".to_owned())
456 })?;
457 let executor = ctx.get_resource::<PgMutationExecutor>().ok_or_else(|| {
458 MutationExecutorError::Bind("missing typed resource: PgMutationExecutor".to_owned())
459 })?;
460
461 let entities = ctx.all_entities();
462
463 executor.ensure_schema(dialect, &entities).await?;
464 ensure_initial_graphs_postgres(executor, dialect, ctx).await
465}
466
467impl PostgresSchemaExt for UserContext {
468 fn ensure_postgres_schema(
469 &self,
470 ) -> Pin<Box<dyn Future<Output = Result<(), MutationExecutorError>> + '_>> {
471 Box::pin(ensure_postgres_schema_for(self))
472 }
473}
474
475#[derive(Debug, Default, Clone, Copy)]
476pub struct PostgresSchemaProvider;
477
478impl SchemaProvider for PostgresSchemaProvider {
479 fn ensure_schema<'a>(
480 &'a self,
481 ctx: &'a UserContext,
482 ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>> {
483 Box::pin(async move {
484 ensure_postgres_schema_for(ctx)
485 .await
486 .map_err(|err| RuntimeError::Schema(err.to_string()))
487 })
488 }
489}
490
491pub trait PostgresProviderExt {
492 fn use_postgres_provider(&mut self, executor: PgMutationExecutor) -> &mut Self;
493}
494
495impl PostgresProviderExt for UserContext {
496 fn use_postgres_provider(&mut self, executor: PgMutationExecutor) -> &mut Self {
497 self.insert_resource(PostgresDialect);
498 self.insert_resource(executor);
499 self.set_schema_provider(PostgresSchemaProvider);
500 self
501 }
502}
503
504#[derive(Clone)]
505pub struct PgIdSpaceGenerator {
506 pool: Pool,
507 table_name: String,
508}
509
510impl PgIdSpaceGenerator {
511 pub fn new(pool: Pool) -> Self {
512 Self {
513 pool,
514 table_name: DEFAULT_ID_SPACE_TABLE.to_owned(),
515 }
516 }
517
518 pub fn from_executor(executor: PgMutationExecutor) -> Self {
519 Self::new(executor.pool())
520 }
521
522 pub fn with_table_name(mut self, table_name: impl Into<String>) -> Self {
523 self.table_name = table_name.into();
524 self
525 }
526
527 pub async fn ensure_table(&self) -> Result<(), MutationExecutorError> {
528 PgMutationExecutor::new(self.pool.clone())
529 .ensure_id_space_table(&self.table_name)
530 .await
531 }
532
533 pub async fn next_id(&self, entity: &str) -> Result<u64, MutationExecutorError> {
534 self.ensure_table().await?;
535 let update_sql = format!(
536 "UPDATE {} SET current_level = current_level + 1 WHERE type_name = $1 RETURNING current_level",
537 quote_ident(&self.table_name)
538 );
539 let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
540 let row = client.query_opt(&update_sql, &[&entity]).await?;
541
542 let id = match row {
543 Some(r) => {
544 let level: i64 = r.try_get(0)?;
545 level
546 },
547 None => {
548 let insert_sql = format!(
549 "INSERT INTO {} (type_name, current_level) VALUES ($1, 1) RETURNING current_level",
550 quote_ident(&self.table_name)
551 );
552 let insert_res = client.query_one(&insert_sql, &[&entity]).await;
553 match insert_res {
554 Ok(r) => {
555 let level: i64 = r.try_get(0)?;
556 level
557 },
558 Err(_) => {
559 let row = client.query_one(&update_sql, &[&entity]).await?;
560 let level: i64 = row.try_get(0)?;
561 level
562 }
563 }
564 }
565 };
566
567 u64::try_from(id).map_err(|_| {
568 MutationExecutorError::Bind(format!("generated id {id} cannot be represented as u64"))
569 })
570 }
571}
572
573impl InternalIdGenerator for PgIdSpaceGenerator {
574 fn generate_id(&self, entity: &str) -> Result<u64, RuntimeError> {
575 let generator = self.clone();
576 let entity = entity.to_owned();
577 block_on_id_generation(async move { generator.next_id(&entity).await })
578 }
579}
580
581fn block_on_id_generation<F>(future: F) -> Result<u64, RuntimeError>
582where
583 F: Future<Output = Result<u64, MutationExecutorError>> + Send + 'static,
584{
585 let result = if tokio::runtime::Handle::try_current().is_ok() {
586 let handle = tokio::runtime::Handle::current();
587 tokio::task::block_in_place(|| handle.block_on(future))
588 } else {
589 tokio::runtime::Builder::new_current_thread()
590 .enable_all()
591 .build()
592 .map_err(|err| RuntimeError::IdGeneration(err.to_string()))?
593 .block_on(future)
594 };
595 result.map_err(|err| RuntimeError::IdGeneration(err.to_string()))
596}
597
598fn quote_ident(ident: &str) -> String {
599 quote_identifier_if_needed(ident, '"')
600}
601
602fn strip_identifier_quotes(ident: &str) -> &str {
606 let bytes = ident.as_bytes();
607 if bytes.len() >= 2 {
608 let (first, last) = (bytes[0], bytes[bytes.len() - 1]);
609 if (first == b'"' && last == b'"')
610 || (first == b'`' && last == b'`')
611 || (first == b'[' && last == b']')
612 {
613 return &ident[1..ident.len() - 1];
614 }
615 }
616 ident
617}
618
619fn try_parse_datetime_from_str(s: &str) -> Option<chrono::DateTime<chrono::Utc>> {
620 if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(s) {
621 return Some(dt.with_timezone(&chrono::Utc));
622 }
623 if let Ok(ndt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
624 return Some(chrono::DateTime::from_naive_utc_and_offset(ndt, chrono::Utc));
625 }
626 if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
627 let ndt = nd.and_hms_opt(0, 0, 0)?;
628 return Some(chrono::DateTime::from_naive_utc_and_offset(ndt, chrono::Utc));
629 }
630 None
631}
632
633struct PgArgs {
634 values: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>>,
635}
636impl PgArgs {
637 fn add<T: tokio_postgres::types::ToSql + Sync + Send + 'static>(&mut self, v: T) {
638 self.values.push(Box::new(v));
639 }
640 fn as_refs(&self) -> Vec<&(dyn tokio_postgres::types::ToSql + Sync)> {
641 self.values.iter().map(|b| b.as_ref() as _).collect()
642 }
643}
644
645fn bind_pg(args: &mut PgArgs, value: &Value) -> Result<(), MutationExecutorError> {
646 match value {
647 Value::Null => {
648 args.add(Option::<i32>::None);
649 }
650 Value::Bool(v) => args.add(*v),
651 Value::I64(v) => args.add(*v),
652 Value::U64(v) => {
653 let v = i64::try_from(*v).map_err(|_| {
654 MutationExecutorError::Bind(format!("u64 value {v} exceeds i64 range"))
655 })?;
656 args.add(v);
657 }
658 Value::F64(v) => args.add(*v),
659 Value::Decimal(v) => args.add(*v),
660 Value::Text(v) => {
661 if let Some(dt) = try_parse_datetime_from_str(v) {
662 args.add(dt);
663 } else {
664 args.add(v.clone());
665 }
666 }
667 Value::Json(v) => {
668 let j_val: serde_json::Value = serde_json::to_value(v).map_err(|e| MutationExecutorError::Bind(e.to_string()))?;
669 args.add(j_val);
670 }
671 Value::Date(v) => args.add(*v),
672 Value::Timestamp(v) => args.add(*v),
673 Value::Object(_) => return Err(MutationExecutorError::UnsupportedValue("object")),
674 Value::List(values) => bind_pg_list(args, values)?,
675 }
676 Ok(())
677}
678
679fn bind_pg_list(args: &mut PgArgs, values: &[Value]) -> Result<(), MutationExecutorError> {
680 let Some(first) = values.first() else {
681 return Err(MutationExecutorError::UnsupportedValue("empty list"));
682 };
683 match first {
684 Value::Bool(_) => {
685 let values = values
686 .iter()
687 .map(|value| match value {
688 Value::Bool(value) => Ok(*value),
689 _ => Err(MutationExecutorError::UnsupportedValue("mixed bool list")),
690 })
691 .collect::<Result<Vec<_>, _>>()?;
692 args.add(values);
693 }
694 Value::I64(_) => {
695 let values = values
696 .iter()
697 .map(|value| match value {
698 Value::I64(value) => Ok(*value),
699 _ => Err(MutationExecutorError::UnsupportedValue("mixed i64 list")),
700 })
701 .collect::<Result<Vec<_>, _>>()?;
702 args.add(values);
703 }
704 Value::U64(_) => {
705 let values = values
706 .iter()
707 .map(|value| match value {
708 Value::U64(value) => i64::try_from(*value).map_err(|_| {
709 MutationExecutorError::Bind(format!("u64 value {value} exceeds i64 range"))
710 }),
711 _ => Err(MutationExecutorError::UnsupportedValue("mixed u64 list")),
712 })
713 .collect::<Result<Vec<_>, _>>()?;
714 args.add(values);
715 }
716 Value::F64(_) => {
717 let values = values
718 .iter()
719 .map(|value| match value {
720 Value::F64(value) => Ok(*value),
721 _ => Err(MutationExecutorError::UnsupportedValue("mixed f64 list")),
722 })
723 .collect::<Result<Vec<_>, _>>()?;
724 args.add(values);
725 }
726 Value::Decimal(_) => {
727 let values = values
728 .iter()
729 .map(|value| match value {
730 Value::Decimal(value) => Ok(*value),
731 _ => Err(MutationExecutorError::UnsupportedValue(
732 "mixed decimal list",
733 )),
734 })
735 .collect::<Result<Vec<_>, _>>()?;
736 args.add(values);
737 }
738 Value::Text(_) => {
739 let values = values
740 .iter()
741 .map(|value| match value {
742 Value::Text(value) => Ok(value.clone()),
743 _ => Err(MutationExecutorError::UnsupportedValue("mixed text list")),
744 })
745 .collect::<Result<Vec<_>, _>>()?;
746 args.add(values);
747 }
748 Value::Date(_) => {
749 let values = values
750 .iter()
751 .map(|value| match value {
752 Value::Date(value) => Ok(*value),
753 _ => Err(MutationExecutorError::UnsupportedValue("mixed date list")),
754 })
755 .collect::<Result<Vec<_>, _>>()?;
756 args.add(values);
757 }
758 Value::Timestamp(_) => {
759 let values = values
760 .iter()
761 .map(|value| match value {
762 Value::Timestamp(value) => Ok(*value),
763 _ => Err(MutationExecutorError::UnsupportedValue(
764 "mixed timestamp list",
765 )),
766 })
767 .collect::<Result<Vec<_>, _>>()?;
768 args.add(values);
769 }
770 Value::Null => return Err(MutationExecutorError::UnsupportedValue("null list")),
771 Value::Json(_) => return Err(MutationExecutorError::UnsupportedValue("json list")),
772 Value::Object(_) => return Err(MutationExecutorError::UnsupportedValue("object list")),
773 Value::List(_) => return Err(MutationExecutorError::UnsupportedValue("nested list")),
774 }
775 Ok(())
776}
777
778fn decode_pg_row(row: &tokio_postgres::Row) -> Result<Record, MutationExecutorError> {
779 let mut record = BTreeMap::new();
780 for (index, column) in row.columns().iter().enumerate() {
781 let name = column.name().to_owned();
782 let type_name = column.type_().name().to_ascii_uppercase();
783
784 let value = match type_name.as_str() {
785 "BOOL" | "BOOLEAN" => {
786 let v: Option<bool> = row.try_get(index)?;
787 match v {
788 Some(v) => Value::Bool(v),
789 None => Value::Null,
790 }
791 }
792 "INT2" => {
793 let v: Option<i16> = row.try_get(index)?;
794 match v {
795 Some(v) => Value::I64(v as i64),
796 None => Value::Null,
797 }
798 }
799 "INT4" => {
800 let v: Option<i32> = row.try_get(index)?;
801 match v {
802 Some(v) => Value::I64(v as i64),
803 None => Value::Null,
804 }
805 }
806 "INT8" => {
807 let v: Option<i64> = row.try_get(index)?;
808 match v {
809 Some(v) => Value::I64(v),
810 None => Value::Null,
811 }
812 }
813 "FLOAT4" => {
814 let v: Option<f32> = row.try_get(index)?;
815 match v {
816 Some(v) => Value::F64(v as f64),
817 None => Value::Null,
818 }
819 }
820 "FLOAT8" => {
821 let v: Option<f64> = row.try_get(index)?;
822 match v {
823 Some(v) => Value::F64(v),
824 None => Value::Null,
825 }
826 }
827 "NUMERIC" => {
828 let v: Option<Decimal> = row.try_get(index)?;
829 match v {
830 Some(v) => Value::Decimal(v),
831 None => Value::Null,
832 }
833 }
834 "JSON" | "JSONB" => {
835 let v: Option<serde_json::Value> = row.try_get(index)?;
836 match v {
837 Some(j) => Value::Json(j.into()),
838 None => Value::Null,
839 }
840 }
841 "DATE" => {
842 let v: Option<NaiveDate> = row.try_get(index)?;
843 match v {
844 Some(v) => Value::Date(v),
845 None => Value::Null,
846 }
847 }
848 "TIMESTAMP" | "TIMESTAMPTZ" => {
849 let v: Option<DateTime<Utc>> = row.try_get(index)?;
850 match v {
851 Some(v) => Value::Timestamp(v),
852 None => Value::Null,
853 }
854 }
855 "TEXT" | "VARCHAR" | "BPCHAR" | "NAME" | "UUID" => {
856 let v: Option<String> = row.try_get(index)?;
857 match v {
858 Some(v) => Value::Text(v),
859 None => Value::Null,
860 }
861 }
862 other => {
863 return Err(MutationExecutorError::UnsupportedColumnType(
864 other.to_owned(),
865 ));
866 }
867 };
868 record.insert(name, value);
869 }
870 Ok(record)
871}
872
873#[cfg(test)]
874mod tests {
875 use super::*;
876 use teaql_core::{DeleteCommand, RecoverCommand};
877
878 fn entity() -> EntityDescriptor {
879 EntityDescriptor::new("Order")
880 .table_name("orders")
881 .property(
882 PropertyDescriptor::new("id", DataType::U64)
883 .column_name("id")
884 .id()
885 .not_null(),
886 )
887 .property(
888 PropertyDescriptor::new("version", DataType::I64)
889 .column_name("version")
890 .version()
891 .not_null(),
892 )
893 .property(PropertyDescriptor::new("name", DataType::Text).column_name("name"))
894 }
895
896 #[test]
897 fn postgres_dialect_compiles_mutations_with_numbered_placeholders() {
898 let insert = PostgresDialect
899 .compile_insert(
900 &entity(),
901 &InsertCommand::new("Order")
902 .value("id", 1_u64)
903 .value("name", "A"),
904 )
905 .unwrap();
906 assert_eq!(
907 insert.sql,
908 "INSERT INTO orders (id, name) VALUES ($1, $2)"
909 );
910
911 let update = PostgresDialect
912 .compile_update(
913 &entity(),
914 &UpdateCommand::new("Order", 1_u64)
915 .expected_version(3)
916 .value("name", "B"),
917 )
918 .unwrap();
919 assert_eq!(
920 update.sql,
921 "UPDATE orders SET name = $1, version = $2 WHERE id = $3 AND version = $4"
922 );
923
924 let delete = PostgresDialect
925 .compile_delete(
926 &entity(),
927 &DeleteCommand::new("Order", 1_u64).expected_version(3),
928 )
929 .unwrap();
930 let recover = PostgresDialect
931 .compile_recover(&entity(), &RecoverCommand::new("Order", 1_u64, -4))
932 .unwrap();
933 assert_eq!(
934 delete.sql,
935 "UPDATE orders SET version = $1 WHERE id = $2 AND version = $3"
936 );
937 assert_eq!(
938 recover.sql,
939 "UPDATE orders SET version = $1 WHERE id = $2 AND version = $3"
940 );
941 }
942
943 #[test]
944 fn postgres_dialect_compiles_schema_and_large_in_array_binds() {
945 let create = PostgresDialect.compile_create_table(&entity()).unwrap();
946 assert_eq!(
947 create,
948 "CREATE TABLE IF NOT EXISTS orders (id BIGINT PRIMARY KEY NOT NULL, version BIGINT NOT NULL, name TEXT)"
949 );
950 assert!(
951 PostgresDialect
952 .schema_setup_sqls()
953 .iter()
954 .any(|sql| sql.contains("CREATE OR REPLACE FUNCTION soundex"))
955 );
956
957 let query = PostgresDialect
958 .compile_select(
959 &entity(),
960 &SelectQuery::new("Order")
961 .filter(Expr::in_large(
962 "id",
963 vec![Value::from(1_u64), Value::from(2_u64)],
964 ))
965 .order_asc("id"),
966 )
967 .unwrap();
968 assert_eq!(
969 query.sql,
970 "SELECT id, version, name FROM orders WHERE (id = ANY($1)) ORDER BY id ASC"
971 );
972 assert_eq!(
973 query.params,
974 vec![Value::List(vec![Value::U64(1), Value::U64(2)])]
975 );
976 }
977}