pgwire 0.38.3

Postgresql wire protocol implemented as a library
Documentation
use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use futures::stream;
use gluesql::prelude::*;
use pgwire::types::format::FormatOptions;
use tokio::net::TcpListener;

use pgwire::api::query::SimpleQueryHandler;
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::tokio::process_socket;

pub struct GluesqlProcessor {
    glue: Arc<Mutex<Glue<MemoryStorage>>>,
}

#[async_trait]
impl SimpleQueryHandler for GluesqlProcessor {
    async fn do_query<C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
    where
        C: ClientInfo + Unpin + Send + Sync,
    {
        println!("{:?}", query);
        let mut glue = self.glue.lock().unwrap();
        futures::executor::block_on(glue.execute(query))
            .map_err(|err| PgWireError::ApiError(Box::new(err)))
            .and_then(|payloads| {
                payloads
                    .iter()
                    .map(|payload| match payload {
                        Payload::Select { labels, rows } => {
                            let fields = labels
                                .iter()
                                .map(|label| {
                                    FieldInfo::new(
                                        label.into(),
                                        None,
                                        None,
                                        Type::UNKNOWN,
                                        FieldFormat::Text,
                                    )
                                })
                                .collect::<Vec<_>>();
                            let fields = Arc::new(fields);
                            let format_options = FormatOptions::default();

                            let mut results = Vec::with_capacity(rows.len());
                            let mut encoder = DataRowEncoder::new(fields.clone());
                            for row in rows {
                                for field in row.iter() {
                                    match field {
                                        Value::Bool(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::BOOL,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::I8(v) => encoder.encode_field_with_type_and_format(
                                            v,
                                            &Type::CHAR,
                                            FieldFormat::Text,
                                            &format_options,
                                        )?,
                                        Value::I16(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::INT2,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::I32(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::INT4,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::I64(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::INT8,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::U8(v) => encoder.encode_field_with_type_and_format(
                                            &(*v as i8),
                                            &Type::CHAR,
                                            FieldFormat::Text,
                                            &format_options,
                                        )?,
                                        Value::F64(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::FLOAT8,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::Str(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::VARCHAR,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::Bytea(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::BYTEA,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::Date(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::DATE,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::Time(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::TIME,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        Value::Timestamp(v) => encoder
                                            .encode_field_with_type_and_format(
                                                v,
                                                &Type::TIMESTAMP,
                                                FieldFormat::Text,
                                                &format_options,
                                            )?,
                                        _ => unimplemented!(),
                                    }
                                }
                                results.push(Ok(encoder.take_row()));
                            }

                            Ok(Response::Query(QueryResponse::new(
                                fields,
                                stream::iter(results),
                            )))
                        }
                        Payload::Insert(rows) => Ok(Response::Execution(
                            Tag::new("INSERT").with_oid(0).with_rows(*rows),
                        )),
                        Payload::Delete(rows) => {
                            Ok(Response::Execution(Tag::new("DELETE").with_rows(*rows)))
                        }
                        Payload::Update(rows) => {
                            Ok(Response::Execution(Tag::new("UPDATE").with_rows(*rows)))
                        }
                        Payload::Create => Ok(Response::Execution(Tag::new("CREATE TABLE"))),
                        Payload::AlterTable => Ok(Response::Execution(Tag::new("ALTER TABLE"))),
                        Payload::DropTable(_) => Ok(Response::Execution(Tag::new("DROP TABLE"))),
                        Payload::CreateIndex => Ok(Response::Execution(Tag::new("CREATE INDEX"))),
                        Payload::DropIndex => Ok(Response::Execution(Tag::new("DROP INDEX"))),
                        _ => {
                            unimplemented!()
                        }
                    })
                    .collect::<Result<Vec<Response>, PgWireError>>()
            })
    }
}

struct GluesqlHandlerFactory {
    processor: Arc<GluesqlProcessor>,
}

impl PgWireServerHandlers for GluesqlHandlerFactory {
    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
        self.processor.clone()
    }
}

#[tokio::main]
pub async fn main() {
    let gluesql = GluesqlProcessor {
        glue: Arc::new(Mutex::new(Glue::new(MemoryStorage::default()))),
    };

    let factory = Arc::new(GluesqlHandlerFactory {
        processor: Arc::new(gluesql),
    });

    let server_addr = "127.0.0.1:5432";
    let listener = TcpListener::bind(server_addr).await.unwrap();
    println!("Listening to {}", server_addr);
    loop {
        let incoming_socket = listener.accept().await.unwrap();
        let factory_ref = factory.clone();

        tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
    }
}