pgwire 0.38.3

Postgresql wire protocol implemented as a library
Documentation
use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
use futures::{Sink, SinkExt, stream};
use pgwire::api::auth::StartupHandler;
use tokio::net::TcpListener;

use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::query::SimpleQueryHandler;
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::NoticeResponse;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use pgwire::tokio::process_socket;

pub struct DummyProcessor;

#[async_trait]
impl NoopStartupHandler for DummyProcessor {
    async fn post_startup<C>(
        &self,
        client: &mut C,
        _message: PgWireFrontendMessage,
    ) -> PgWireResult<()>
    where
        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
        C::Error: Debug,
        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
    {
        println!("Connected: {}", client.socket_addr());
        client
            .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from(
                ErrorInfo::new(
                    "NOTICE".to_owned(),
                    "01000".to_owned(),
                    "Supported queries in this example:\n- BEGIN;\n- ROLLBACK;\n- COMMIT;\n- SELECT 1;"
                        .to_string(),
                ),
            )))
            .await?;
        Ok(())
    }
}

#[async_trait]
impl SimpleQueryHandler for DummyProcessor {
    async fn do_query<C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
    where
        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
        C::Error: Debug,
        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
    {
        let resp = match query {
            "BEGIN;" => Response::TransactionStart(Tag::new("BEGIN")),
            "ROLLBACK;" => Response::TransactionEnd(Tag::new("ROLLBACK")),
            "COMMIT;" => Response::TransactionEnd(Tag::new("COMMIT")),
            "SELECT 1;" => {
                let f1 =
                    FieldInfo::new("SELECT 1".into(), None, None, Type::INT4, FieldFormat::Text);
                let schema = Arc::new(vec![f1]);
                let schema_ref = schema.clone();

                let mut encoder = DataRowEncoder::new(schema_ref.clone());
                let row = {
                    encoder.encode_field(&Some(1))?;

                    Ok(encoder.take_row())
                };
                let data_row_stream = stream::iter(vec![row]);
                Response::Query(QueryResponse::new(schema, data_row_stream))
            }
            _ => Response::Error(Box::new(ErrorInfo::new(
                "FATAL".to_string(),
                "38003".to_string(),
                "Unsupported statement.".to_string(),
            ))),
        };

        Ok(vec![resp])
    }
}

struct DummyProcessorFactory {
    handler: Arc<DummyProcessor>,
}

impl PgWireServerHandlers for DummyProcessorFactory {
    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
        self.handler.clone()
    }

    fn startup_handler(&self) -> Arc<impl StartupHandler> {
        self.handler.clone()
    }
}

#[tokio::main]
pub async fn main() {
    let factory = Arc::new(DummyProcessorFactory {
        handler: Arc::new(DummyProcessor),
    });

    let server_addr = "127.0.0.1:5432";
    let listener = TcpListener::bind(server_addr).await.unwrap();
    println!("Listening to {}", server_addr);
    loop {
        let incoming_socket = listener.accept().await.unwrap();
        let factory_ref = factory.clone();
        tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
    }
}