Skip to main content

teaql_provider_postgres/
lib.rs

1#![allow(warnings)]
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::pin::Pin;
5
6use chrono::{DateTime, NaiveDate, Utc};
7use rust_decimal::Decimal;
8use std::sync::Arc;
9use teaql_core::{
10    BinaryOp, DataType, EntityDescriptor, Expr, InsertCommand, PropertyDescriptor, Record,
11    SelectQuery, UpdateCommand, Value,
12};
13use teaql_runtime::{GraphNode, InternalIdGenerator, RuntimeError, SchemaProvider, UserContext};
14use teaql_sql::{
15    CompiledQuery, DatabaseKind, SqlCompileError, SqlDialect, SqlTransport, quote_identifier_if_needed,
16};
17use tokio::sync::Mutex;
18use deadpool_postgres::Pool;
19
20pub const DEFAULT_ID_SPACE_TABLE: &str = "teaql_id_space";
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct PostgresDialect;
24
25impl SqlDialect for PostgresDialect {
26    fn kind(&self) -> DatabaseKind {
27        DatabaseKind::PostgreSql
28    }
29
30    fn quote_ident(&self, ident: &str) -> String {
31        quote_ident(ident)
32    }
33
34    fn placeholder(&self, index: usize) -> String {
35        format!("${index}")
36    }
37
38    fn schema_setup_sqls(&self) -> &'static [&'static str] {
39        &[CREATE_SOUNDEX_FUNCTION]
40    }
41
42    fn schema_type_sql(
43        &self,
44        data_type: DataType,
45        _property: &PropertyDescriptor,
46    ) -> Result<&'static str, SqlCompileError> {
47        match data_type {
48            DataType::Bool => Ok("BOOLEAN"),
49            DataType::I64 | DataType::U64 => Ok("BIGINT"),
50            DataType::F64 => Ok("DOUBLE PRECISION"),
51            DataType::Decimal => Ok("NUMERIC"),
52            DataType::Text => Ok("TEXT"),
53            DataType::Json => Ok("JSONB"),
54            DataType::Date => Ok("DATE"),
55            DataType::Timestamp => Ok("TIMESTAMPTZ"),
56        }
57    }
58
59    fn compile_in(
60        &self,
61        entity: &EntityDescriptor,
62        left: &Expr,
63        op: BinaryOp,
64        right: &Expr,
65        params: &mut Vec<Value>,
66    ) -> Result<String, SqlCompileError> {
67        match op {
68            BinaryOp::InLarge | BinaryOp::NotInLarge => {
69                let Expr::Value(Value::List(values)) = right else {
70                    let lhs = self.compile_expr(entity, left, params)?;
71                    let rhs = self.compile_expr(entity, right, params)?;
72                    let operator = match op {
73                        BinaryOp::InLarge => "= ANY",
74                        BinaryOp::NotInLarge => "<> ALL",
75                        _ => unreachable!(),
76                    };
77                    return Ok(format!("({lhs} {operator} ({rhs}))"));
78                };
79                if values.is_empty() {
80                    return Err(SqlCompileError::EmptyInList);
81                }
82                let lhs = self.compile_expr(entity, left, params)?;
83                params.push(Value::List(values.clone()));
84                let placeholder = self.placeholder(params.len());
85                let operator = match op {
86                    BinaryOp::InLarge => "= ANY",
87                    BinaryOp::NotInLarge => "<> ALL",
88                    _ => unreachable!(),
89                };
90                Ok(format!("({lhs} {operator}({placeholder}))"))
91            }
92            _ => {
93                let lhs = self.compile_expr(entity, left, params)?;
94                let operator = match op {
95                    BinaryOp::In => "IN",
96                    BinaryOp::NotIn => "NOT IN",
97                    _ => unreachable!(),
98                };
99                match right {
100                    Expr::Value(Value::List(values)) => {
101                        if values.is_empty() {
102                            return Err(SqlCompileError::EmptyInList);
103                        }
104                        let mut placeholders = Vec::with_capacity(values.len());
105                        for value in values {
106                            params.push(value.clone());
107                            placeholders.push(self.placeholder(params.len()));
108                        }
109                        Ok(format!("({lhs} {operator} ({}))", placeholders.join(", ")))
110                    }
111                    _ => {
112                        let rhs = self.compile_expr(entity, right, params)?;
113                        Ok(format!("({lhs} {operator} ({rhs}))"))
114                    }
115                }
116            }
117        }
118    }
119}
120
121const CREATE_SOUNDEX_FUNCTION: &str = r#"
122CREATE OR REPLACE FUNCTION soundex(input text)
123RETURNS text
124LANGUAGE plpgsql
125IMMUTABLE
126STRICT
127AS $$
128DECLARE
129    normalized text := upper(regexp_replace(input, '[^A-Za-z]', '', 'g'));
130    first_char text;
131    output text;
132    previous_code text;
133    code text;
134    ch text;
135    i integer;
136BEGIN
137    IF normalized = '' THEN
138        RETURN '0000';
139    END IF;
140
141    first_char := substr(normalized, 1, 1);
142    output := first_char;
143    previous_code := CASE
144        WHEN first_char IN ('B', 'F', 'P', 'V') THEN '1'
145        WHEN first_char IN ('C', 'G', 'J', 'K', 'Q', 'S', 'X', 'Z') THEN '2'
146        WHEN first_char IN ('D', 'T') THEN '3'
147        WHEN first_char = 'L' THEN '4'
148        WHEN first_char IN ('M', 'N') THEN '5'
149        WHEN first_char = 'R' THEN '6'
150        ELSE '0'
151    END;
152
153    FOR i IN 2..char_length(normalized) LOOP
154        ch := substr(normalized, i, 1);
155        code := CASE
156            WHEN ch IN ('B', 'F', 'P', 'V') THEN '1'
157            WHEN ch IN ('C', 'G', 'J', 'K', 'Q', 'S', 'X', 'Z') THEN '2'
158            WHEN ch IN ('D', 'T') THEN '3'
159            WHEN ch = 'L' THEN '4'
160            WHEN ch IN ('M', 'N') THEN '5'
161            WHEN ch = 'R' THEN '6'
162            ELSE '0'
163        END;
164
165        IF code <> '0' AND code <> previous_code THEN
166            output := output || code;
167            IF char_length(output) = 4 THEN
168                RETURN output;
169            END IF;
170        END IF;
171        previous_code := code;
172    END LOOP;
173
174    RETURN rpad(output, 4, '0');
175END;
176$$
177"#;
178
179#[derive(Debug)]
180pub enum MutationExecutorError {
181    Driver(tokio_postgres::Error),
182    Pool(String),
183    SqlCompile(SqlCompileError),
184    UnsupportedValue(&'static str),
185    UnsupportedColumnType(String),
186    Bind(String),
187}
188
189impl std::fmt::Display for MutationExecutorError {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        match self {
192            Self::Driver(err) => err.fmt(f),
193            Self::Pool(err) => write!(f, "postgres pool error: {err}"),
194            Self::SqlCompile(err) => err.fmt(f),
195            Self::UnsupportedValue(kind) => {
196                write!(
197                    f,
198                    "unsupported bind value for mutation executor: {kind}"
199                )
200            }
201            Self::UnsupportedColumnType(kind) => {
202                write!(
203                    f,
204                    "unsupported column type for record decoding: {kind}"
205                )
206            }
207            Self::Bind(message) => write!(f, "bind error: {message}"),
208        }
209    }
210}
211
212impl std::error::Error for MutationExecutorError {}
213
214impl From<tokio_postgres::Error> for MutationExecutorError {
215    fn from(value: tokio_postgres::Error) -> Self {
216        Self::Driver(value)
217    }
218}
219
220impl From<SqlCompileError> for MutationExecutorError {
221    fn from(value: SqlCompileError) -> Self {
222        Self::SqlCompile(value)
223    }
224}
225
226#[derive(Clone)]
227pub struct PgMutationExecutor {
228    pool: Pool,
229}
230
231impl SqlTransport for PgMutationExecutor {
232    type Error = MutationExecutorError;
233
234    async fn fetch_all_sql(&self, query: &CompiledQuery) -> Result<Vec<Record>, Self::Error> {
235        self.fetch_all(query).await
236    }
237
238    async fn execute_sql(&self, query: &CompiledQuery) -> Result<u64, Self::Error> {
239        self.execute(query).await
240    }
241}
242
243impl teaql_sql::SqlTransaction for PgMutationExecutor {
244    type Error = MutationExecutorError;
245
246    async fn commit_sql(self) -> Result<(), Self::Error> {
247        Err(MutationExecutorError::Bind("Transactions not supported yet".to_string()))
248    }
249
250    async fn rollback_sql(self) -> Result<(), Self::Error> {
251        Err(MutationExecutorError::Bind("Transactions not supported yet".to_string()))
252    }
253}
254
255impl teaql_sql::SqlTransactionTransport for PgMutationExecutor {
256    type Tx<'a> = Self where Self: 'a;
257
258    async fn begin_sql(&self) -> Result<Self::Tx<'_>, Self::Error> {
259        Err(MutationExecutorError::Bind("Transactions not supported yet".to_string()))
260    }
261}
262
263impl PgMutationExecutor {
264    pub fn new(pool: Pool) -> Self {
265        Self { pool }
266    }
267
268    pub fn pool(&self) -> Pool {
269        self.pool.clone()
270    }
271
272    pub async fn ensure_schema(
273        &self,
274        dialect: &PostgresDialect,
275        entities: &[&EntityDescriptor],
276    ) -> Result<(), MutationExecutorError> {
277        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
278        for sql in dialect.schema_setup_sqls() {
279            client.execute(*sql, &[]).await?;
280        }
281        self.ensure_id_space_table(DEFAULT_ID_SPACE_TABLE).await?;
282
283        for entity in entities {
284            if !self.table_exists(&entity.table_name).await? {
285                let sql = dialect.compile_create_table(entity)?;
286                client.execute(&sql, &[]).await?;
287                continue;
288            }
289
290            let existing_columns = self.table_columns(&entity.table_name).await?;
291            for property in &entity.properties {
292                let bare_column = strip_identifier_quotes(&property.column_name).to_lowercase();
293                if existing_columns.contains(&bare_column) {
294                    continue;
295                }
296                let sql = dialect.compile_add_column(entity, property)?;
297                client.execute(&sql, &[]).await?;
298            }
299        }
300        Ok(())
301    }
302
303    pub async fn ensure_id_space_table(
304        &self,
305        table_name: &str,
306    ) -> Result<(), MutationExecutorError> {
307        let sql = format!(
308            "CREATE TABLE IF NOT EXISTS {} (type_name VARCHAR(100) PRIMARY KEY, current_level BIGINT NOT NULL)",
309            quote_ident(table_name)
310        );
311        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
312        client.execute(&sql, &[]).await?;
313        Ok(())
314    }
315
316    pub async fn execute(&self, query: &CompiledQuery) -> Result<u64, MutationExecutorError> {
317        let mut args = PgArgs { values: Vec::new() };
318        for value in &query.params {
319            bind_pg(&mut args, value)?;
320        }
321        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
322        let result = client.execute(&query.sql, &args.as_refs()).await?;
323        Ok(result)
324    }
325
326    pub async fn fetch_all(
327        &self,
328        query: &CompiledQuery,
329    ) -> Result<Vec<Record>, MutationExecutorError> {
330        let mut args = PgArgs { values: Vec::new() };
331        for value in &query.params {
332            bind_pg(&mut args, value)?;
333        }
334        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
335        let rows = client.query(&query.sql, &args.as_refs()).await?;
336        rows.iter().map(decode_pg_row).collect()
337    }
338
339    async fn table_exists(&self, table_name: &str) -> Result<bool, MutationExecutorError> {
340        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
341        let row = client.query_one(
342            "SELECT COUNT(1)
343             FROM information_schema.tables
344             WHERE table_schema = current_schema()
345               AND table_name = $1",
346            &[&table_name],
347        ).await?;
348        let exists: i64 = row.try_get(0)?;
349        Ok(exists > 0)
350    }
351
352    async fn table_columns(
353        &self,
354        table_name: &str,
355    ) -> Result<std::collections::BTreeSet<String>, MutationExecutorError> {
356        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
357        let rows = client.query(
358            "SELECT column_name
359             FROM information_schema.columns
360             WHERE table_schema = current_schema()
361               AND table_name = $1",
362            &[&table_name],
363        ).await?;
364        let mut columns = std::collections::BTreeSet::new();
365        for row in rows {
366            let name: String = row.try_get("column_name")?;
367            columns.insert(name.to_lowercase());
368        }
369        Ok(columns)
370    }
371}
372
373async fn ensure_initial_graphs_postgres(
374    executor: &PgMutationExecutor,
375    dialect: &PostgresDialect,
376    ctx: &UserContext,
377) -> Result<(), MutationExecutorError> {
378    for graph in ctx.initial_graphs() {
379        let entity = ctx.entity(&graph.entity).ok_or_else(|| {
380            MutationExecutorError::Bind(format!("missing entity: {}", graph.entity))
381        })?;
382        if initial_graph_exists_postgres(executor, dialect, entity, graph).await? {
383            if let Some(query) = compile_initial_graph_update(dialect, entity, graph)? {
384                executor.execute(&query).await?;
385            }
386        } else {
387            let query = compile_initial_graph_insert(dialect, entity, graph)?;
388            executor.execute(&query).await?;
389        }
390    }
391    Ok(())
392}
393
394async fn initial_graph_exists_postgres(
395    executor: &PgMutationExecutor,
396    dialect: &PostgresDialect,
397    entity: &EntityDescriptor,
398    graph: &GraphNode,
399) -> Result<bool, MutationExecutorError> {
400    let Some(id) = graph.values.get("id") else {
401        return Ok(false);
402    };
403    let query = dialect.compile_select(
404        entity,
405        &SelectQuery::new(&graph.entity)
406            .project("id")
407            .filter(Expr::eq("id", id.clone()))
408            .limit(1),
409    )?;
410    Ok(!executor.fetch_all(&query).await?.is_empty())
411}
412
413fn compile_initial_graph_insert(
414    dialect: &impl SqlDialect,
415    entity: &EntityDescriptor,
416    graph: &GraphNode,
417) -> Result<CompiledQuery, MutationExecutorError> {
418    let mut command = InsertCommand::new(&graph.entity);
419    for (field, value) in &graph.values {
420        command = command.value(field.clone(), value.clone());
421    }
422    dialect.compile_insert(entity, &command).map_err(Into::into)
423}
424
425fn compile_initial_graph_update(
426    dialect: &impl SqlDialect,
427    entity: &EntityDescriptor,
428    graph: &crate::GraphNode,
429) -> Result<Option<CompiledQuery>, MutationExecutorError> {
430    let Some(id) = graph.values.get("id") else {
431        return Ok(None);
432    };
433    let mut command = UpdateCommand::new(&graph.entity, id.clone());
434    for (field, value) in &graph.values {
435        if field == "id" {
436            continue;
437        }
438        command = command.value(field.clone(), value.clone());
439    }
440    match dialect.compile_update(entity, &command) {
441        Ok(query) => Ok(Some(query)),
442        Err(SqlCompileError::EmptyMutation(_)) => Ok(None),
443        Err(err) => Err(err.into()),
444    }
445}
446
447pub trait PostgresSchemaExt {
448    fn ensure_postgres_schema(
449        &self,
450    ) -> Pin<Box<dyn Future<Output = Result<(), MutationExecutorError>> + '_>>;
451}
452
453pub async fn ensure_postgres_schema_for(ctx: &UserContext) -> Result<(), MutationExecutorError> {
454    let dialect = ctx.get_resource::<PostgresDialect>().ok_or_else(|| {
455        MutationExecutorError::Bind("missing typed resource: PostgresDialect".to_owned())
456    })?;
457    let executor = ctx.get_resource::<PgMutationExecutor>().ok_or_else(|| {
458        MutationExecutorError::Bind("missing typed resource: PgMutationExecutor".to_owned())
459    })?;
460
461    let entities = ctx.all_entities();
462
463    executor.ensure_schema(dialect, &entities).await?;
464    ensure_initial_graphs_postgres(executor, dialect, ctx).await
465}
466
467impl PostgresSchemaExt for UserContext {
468    fn ensure_postgres_schema(
469        &self,
470    ) -> Pin<Box<dyn Future<Output = Result<(), MutationExecutorError>> + '_>> {
471        Box::pin(ensure_postgres_schema_for(self))
472    }
473}
474
475#[derive(Debug, Default, Clone, Copy)]
476pub struct PostgresSchemaProvider;
477
478impl SchemaProvider for PostgresSchemaProvider {
479    fn ensure_schema<'a>(
480        &'a self,
481        ctx: &'a UserContext,
482    ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>> {
483        Box::pin(async move {
484            ensure_postgres_schema_for(ctx)
485                .await
486                .map_err(|err| RuntimeError::Schema(err.to_string()))
487        })
488    }
489}
490
491pub trait PostgresProviderExt {
492    fn use_postgres_provider(&mut self, executor: PgMutationExecutor) -> &mut Self;
493}
494
495impl PostgresProviderExt for UserContext {
496    fn use_postgres_provider(&mut self, executor: PgMutationExecutor) -> &mut Self {
497        self.insert_resource(PostgresDialect);
498        self.insert_resource(executor);
499        self.set_schema_provider(PostgresSchemaProvider);
500        self
501    }
502}
503
504#[derive(Clone)]
505pub struct PgIdSpaceGenerator {
506    pool: Pool,
507    table_name: String,
508}
509
510impl PgIdSpaceGenerator {
511    pub fn new(pool: Pool) -> Self {
512        Self {
513            pool,
514            table_name: DEFAULT_ID_SPACE_TABLE.to_owned(),
515        }
516    }
517
518    pub fn from_executor(executor: PgMutationExecutor) -> Self {
519        Self::new(executor.pool())
520    }
521
522    pub fn with_table_name(mut self, table_name: impl Into<String>) -> Self {
523        self.table_name = table_name.into();
524        self
525    }
526
527    pub async fn ensure_table(&self) -> Result<(), MutationExecutorError> {
528        PgMutationExecutor::new(self.pool.clone())
529            .ensure_id_space_table(&self.table_name)
530            .await
531    }
532
533    pub async fn next_id(&self, entity: &str) -> Result<u64, MutationExecutorError> {
534        self.ensure_table().await?;
535        let update_sql = format!(
536            "UPDATE {} SET current_level = current_level + 1 WHERE type_name = $1 RETURNING current_level",
537            quote_ident(&self.table_name)
538        );
539        let client = self.pool.get().await.map_err(|e| MutationExecutorError::Pool(e.to_string()))?;
540        let row = client.query_opt(&update_sql, &[&entity]).await?;
541        
542        let id = match row {
543            Some(r) => {
544                let level: i64 = r.try_get(0)?;
545                level
546            },
547            None => {
548                let insert_sql = format!(
549                    "INSERT INTO {} (type_name, current_level) VALUES ($1, 1) RETURNING current_level",
550                    quote_ident(&self.table_name)
551                );
552                let insert_res = client.query_one(&insert_sql, &[&entity]).await;
553                match insert_res {
554                    Ok(r) => {
555                        let level: i64 = r.try_get(0)?;
556                        level
557                    },
558                    Err(_) => {
559                        let row = client.query_one(&update_sql, &[&entity]).await?;
560                        let level: i64 = row.try_get(0)?;
561                        level
562                    }
563                }
564            }
565        };
566
567        u64::try_from(id).map_err(|_| {
568            MutationExecutorError::Bind(format!("generated id {id} cannot be represented as u64"))
569        })
570    }
571}
572
573impl InternalIdGenerator for PgIdSpaceGenerator {
574    fn generate_id(&self, entity: &str) -> Result<u64, RuntimeError> {
575        let generator = self.clone();
576        let entity = entity.to_owned();
577        block_on_id_generation(async move { generator.next_id(&entity).await })
578    }
579}
580
581fn block_on_id_generation<F>(future: F) -> Result<u64, RuntimeError>
582where
583    F: Future<Output = Result<u64, MutationExecutorError>> + Send + 'static,
584{
585    let result = if tokio::runtime::Handle::try_current().is_ok() {
586        let handle = tokio::runtime::Handle::current();
587        tokio::task::block_in_place(|| handle.block_on(future))
588    } else {
589        tokio::runtime::Builder::new_current_thread()
590            .enable_all()
591            .build()
592            .map_err(|err| RuntimeError::IdGeneration(err.to_string()))?
593            .block_on(future)
594    };
595    result.map_err(|err| RuntimeError::IdGeneration(err.to_string()))
596}
597
598fn quote_ident(ident: &str) -> String {
599    quote_identifier_if_needed(ident, '"')
600}
601
602/// Strip wrapping identifier quotes from a SQL identifier so that bare column
603/// names returned by `information_schema.columns` can be compared with
604/// potentially-quoted `PropertyDescriptor::column_name` values.
605fn strip_identifier_quotes(ident: &str) -> &str {
606    let bytes = ident.as_bytes();
607    if bytes.len() >= 2 {
608        let (first, last) = (bytes[0], bytes[bytes.len() - 1]);
609        if (first == b'"' && last == b'"')
610            || (first == b'`' && last == b'`')
611            || (first == b'[' && last == b']')
612        {
613            return &ident[1..ident.len() - 1];
614        }
615    }
616    ident
617}
618
619fn try_parse_datetime_from_str(s: &str) -> Option<chrono::DateTime<chrono::Utc>> {
620    if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(s) {
621        return Some(dt.with_timezone(&chrono::Utc));
622    }
623    if let Ok(ndt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
624        return Some(chrono::DateTime::from_naive_utc_and_offset(ndt, chrono::Utc));
625    }
626    if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
627        let ndt = nd.and_hms_opt(0, 0, 0)?;
628        return Some(chrono::DateTime::from_naive_utc_and_offset(ndt, chrono::Utc));
629    }
630    None
631}
632
633struct PgArgs {
634    values: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>>,
635}
636impl PgArgs {
637    fn add<T: tokio_postgres::types::ToSql + Sync + Send + 'static>(&mut self, v: T) {
638        self.values.push(Box::new(v));
639    }
640    fn as_refs(&self) -> Vec<&(dyn tokio_postgres::types::ToSql + Sync)> {
641        self.values.iter().map(|b| b.as_ref() as _).collect()
642    }
643}
644
645fn bind_pg(args: &mut PgArgs, value: &Value) -> Result<(), MutationExecutorError> {
646    match value {
647        Value::Null => {
648            args.add(Option::<i32>::None);
649        }
650        Value::Bool(v) => args.add(*v),
651        Value::I64(v) => args.add(*v),
652        Value::U64(v) => {
653            let v = i64::try_from(*v).map_err(|_| {
654                MutationExecutorError::Bind(format!("u64 value {v} exceeds i64 range"))
655            })?;
656            args.add(v);
657        }
658        Value::F64(v) => args.add(*v),
659        Value::Decimal(v) => args.add(*v),
660        Value::Text(v) => {
661            if let Some(dt) = try_parse_datetime_from_str(v) {
662                args.add(dt);
663            } else {
664                args.add(v.clone());
665            }
666        }
667        Value::Json(v) => {
668            let j_val: serde_json::Value = serde_json::to_value(v).map_err(|e| MutationExecutorError::Bind(e.to_string()))?;
669            args.add(j_val);
670        }
671        Value::Date(v) => args.add(*v),
672        Value::Timestamp(v) => args.add(*v),
673        Value::Object(_) => return Err(MutationExecutorError::UnsupportedValue("object")),
674        Value::List(values) => bind_pg_list(args, values)?,
675    }
676    Ok(())
677}
678
679fn bind_pg_list(args: &mut PgArgs, values: &[Value]) -> Result<(), MutationExecutorError> {
680    let Some(first) = values.first() else {
681        return Err(MutationExecutorError::UnsupportedValue("empty list"));
682    };
683    match first {
684        Value::Bool(_) => {
685            let values = values
686                .iter()
687                .map(|value| match value {
688                    Value::Bool(value) => Ok(*value),
689                    _ => Err(MutationExecutorError::UnsupportedValue("mixed bool list")),
690                })
691                .collect::<Result<Vec<_>, _>>()?;
692            args.add(values);
693        }
694        Value::I64(_) => {
695            let values = values
696                .iter()
697                .map(|value| match value {
698                    Value::I64(value) => Ok(*value),
699                    _ => Err(MutationExecutorError::UnsupportedValue("mixed i64 list")),
700                })
701                .collect::<Result<Vec<_>, _>>()?;
702            args.add(values);
703        }
704        Value::U64(_) => {
705            let values = values
706                .iter()
707                .map(|value| match value {
708                    Value::U64(value) => i64::try_from(*value).map_err(|_| {
709                        MutationExecutorError::Bind(format!("u64 value {value} exceeds i64 range"))
710                    }),
711                    _ => Err(MutationExecutorError::UnsupportedValue("mixed u64 list")),
712                })
713                .collect::<Result<Vec<_>, _>>()?;
714            args.add(values);
715        }
716        Value::F64(_) => {
717            let values = values
718                .iter()
719                .map(|value| match value {
720                    Value::F64(value) => Ok(*value),
721                    _ => Err(MutationExecutorError::UnsupportedValue("mixed f64 list")),
722                })
723                .collect::<Result<Vec<_>, _>>()?;
724            args.add(values);
725        }
726        Value::Decimal(_) => {
727            let values = values
728                .iter()
729                .map(|value| match value {
730                    Value::Decimal(value) => Ok(*value),
731                    _ => Err(MutationExecutorError::UnsupportedValue(
732                        "mixed decimal list",
733                    )),
734                })
735                .collect::<Result<Vec<_>, _>>()?;
736            args.add(values);
737        }
738        Value::Text(_) => {
739            let values = values
740                .iter()
741                .map(|value| match value {
742                    Value::Text(value) => Ok(value.clone()),
743                    _ => Err(MutationExecutorError::UnsupportedValue("mixed text list")),
744                })
745                .collect::<Result<Vec<_>, _>>()?;
746            args.add(values);
747        }
748        Value::Date(_) => {
749            let values = values
750                .iter()
751                .map(|value| match value {
752                    Value::Date(value) => Ok(*value),
753                    _ => Err(MutationExecutorError::UnsupportedValue("mixed date list")),
754                })
755                .collect::<Result<Vec<_>, _>>()?;
756            args.add(values);
757        }
758        Value::Timestamp(_) => {
759            let values = values
760                .iter()
761                .map(|value| match value {
762                    Value::Timestamp(value) => Ok(*value),
763                    _ => Err(MutationExecutorError::UnsupportedValue(
764                        "mixed timestamp list",
765                    )),
766                })
767                .collect::<Result<Vec<_>, _>>()?;
768            args.add(values);
769        }
770        Value::Null => return Err(MutationExecutorError::UnsupportedValue("null list")),
771        Value::Json(_) => return Err(MutationExecutorError::UnsupportedValue("json list")),
772        Value::Object(_) => return Err(MutationExecutorError::UnsupportedValue("object list")),
773        Value::List(_) => return Err(MutationExecutorError::UnsupportedValue("nested list")),
774    }
775    Ok(())
776}
777
778fn decode_pg_row(row: &tokio_postgres::Row) -> Result<Record, MutationExecutorError> {
779    let mut record = BTreeMap::new();
780    for (index, column) in row.columns().iter().enumerate() {
781        let name = column.name().to_owned();
782        let type_name = column.type_().name().to_ascii_uppercase();
783        
784        let value = match type_name.as_str() {
785            "BOOL" | "BOOLEAN" => {
786                let v: Option<bool> = row.try_get(index)?;
787                match v {
788                    Some(v) => Value::Bool(v),
789                    None => Value::Null,
790                }
791            }
792            "INT2" => {
793                let v: Option<i16> = row.try_get(index)?;
794                match v {
795                    Some(v) => Value::I64(v as i64),
796                    None => Value::Null,
797                }
798            }
799            "INT4" => {
800                let v: Option<i32> = row.try_get(index)?;
801                match v {
802                    Some(v) => Value::I64(v as i64),
803                    None => Value::Null,
804                }
805            }
806            "INT8" => {
807                let v: Option<i64> = row.try_get(index)?;
808                match v {
809                    Some(v) => Value::I64(v),
810                    None => Value::Null,
811                }
812            }
813            "FLOAT4" => {
814                let v: Option<f32> = row.try_get(index)?;
815                match v {
816                    Some(v) => Value::F64(v as f64),
817                    None => Value::Null,
818                }
819            }
820            "FLOAT8" => {
821                let v: Option<f64> = row.try_get(index)?;
822                match v {
823                    Some(v) => Value::F64(v),
824                    None => Value::Null,
825                }
826            }
827            "NUMERIC" => {
828                let v: Option<Decimal> = row.try_get(index)?;
829                match v {
830                    Some(v) => Value::Decimal(v),
831                    None => Value::Null,
832                }
833            }
834            "JSON" | "JSONB" => {
835                let v: Option<serde_json::Value> = row.try_get(index)?;
836                match v {
837                    Some(j) => Value::Json(j.into()),
838                    None => Value::Null,
839                }
840            }
841            "DATE" => {
842                let v: Option<NaiveDate> = row.try_get(index)?;
843                match v {
844                    Some(v) => Value::Date(v),
845                    None => Value::Null,
846                }
847            }
848            "TIMESTAMP" | "TIMESTAMPTZ" => {
849                let v: Option<DateTime<Utc>> = row.try_get(index)?;
850                match v {
851                    Some(v) => Value::Timestamp(v),
852                    None => Value::Null,
853                }
854            }
855            "TEXT" | "VARCHAR" | "BPCHAR" | "NAME" | "UUID" => {
856                let v: Option<String> = row.try_get(index)?;
857                match v {
858                    Some(v) => Value::Text(v),
859                    None => Value::Null,
860                }
861            }
862            other => {
863                return Err(MutationExecutorError::UnsupportedColumnType(
864                    other.to_owned(),
865                ));
866            }
867        };
868        record.insert(name, value);
869    }
870    Ok(record)
871}
872
873#[cfg(test)]
874mod tests {
875    use super::*;
876    use teaql_core::{DeleteCommand, RecoverCommand};
877
878    fn entity() -> EntityDescriptor {
879        EntityDescriptor::new("Order")
880            .table_name("orders")
881            .property(
882                PropertyDescriptor::new("id", DataType::U64)
883                    .column_name("id")
884                    .id()
885                    .not_null(),
886            )
887            .property(
888                PropertyDescriptor::new("version", DataType::I64)
889                    .column_name("version")
890                    .version()
891                    .not_null(),
892            )
893            .property(PropertyDescriptor::new("name", DataType::Text).column_name("name"))
894    }
895
896    #[test]
897    fn postgres_dialect_compiles_mutations_with_numbered_placeholders() {
898        let insert = PostgresDialect
899            .compile_insert(
900                &entity(),
901                &InsertCommand::new("Order")
902                    .value("id", 1_u64)
903                    .value("name", "A"),
904            )
905            .unwrap();
906        assert_eq!(
907            insert.sql,
908            "INSERT INTO orders (id, name) VALUES ($1, $2)"
909        );
910
911        let update = PostgresDialect
912            .compile_update(
913                &entity(),
914                &UpdateCommand::new("Order", 1_u64)
915                    .expected_version(3)
916                    .value("name", "B"),
917            )
918            .unwrap();
919        assert_eq!(
920            update.sql,
921            "UPDATE orders SET name = $1, version = $2 WHERE id = $3 AND version = $4"
922        );
923
924        let delete = PostgresDialect
925            .compile_delete(
926                &entity(),
927                &DeleteCommand::new("Order", 1_u64).expected_version(3),
928            )
929            .unwrap();
930        let recover = PostgresDialect
931            .compile_recover(&entity(), &RecoverCommand::new("Order", 1_u64, -4))
932            .unwrap();
933        assert_eq!(
934            delete.sql,
935            "UPDATE orders SET version = $1 WHERE id = $2 AND version = $3"
936        );
937        assert_eq!(
938            recover.sql,
939            "UPDATE orders SET version = $1 WHERE id = $2 AND version = $3"
940        );
941    }
942
943    #[test]
944    fn postgres_dialect_compiles_schema_and_large_in_array_binds() {
945        let create = PostgresDialect.compile_create_table(&entity()).unwrap();
946        assert_eq!(
947            create,
948            "CREATE TABLE IF NOT EXISTS orders (id BIGINT PRIMARY KEY NOT NULL, version BIGINT NOT NULL, name TEXT)"
949        );
950        assert!(
951            PostgresDialect
952                .schema_setup_sqls()
953                .iter()
954                .any(|sql| sql.contains("CREATE OR REPLACE FUNCTION soundex"))
955        );
956
957        let query = PostgresDialect
958            .compile_select(
959                &entity(),
960                &SelectQuery::new("Order")
961                    .filter(Expr::in_large(
962                        "id",
963                        vec![Value::from(1_u64), Value::from(2_u64)],
964                    ))
965                    .order_asc("id"),
966            )
967            .unwrap();
968        assert_eq!(
969            query.sql,
970            "SELECT id, version, name FROM orders WHERE (id = ANY($1)) ORDER BY id ASC"
971        );
972        assert_eq!(
973            query.params,
974            vec![Value::List(vec![Value::U64(1), Value::U64(2)])]
975        );
976    }
977}