use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::{Stream, stream};
use rusqlite::types::ValueRef;
use rusqlite::{Connection, Rows, Statement, ToSql};
use tokio::net::TcpListener;
use pgwire::api::auth::md5pass::{Md5PasswordAuthStartupHandler, hash_md5_password};
use pgwire::api::auth::{
AuthSource, DefaultServerParameterProvider, LoginInfo, Password, StartupHandler,
};
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{
DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse,
Response, Tag,
};
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use pgwire::tokio::process_socket;
pub struct SqliteBackend {
conn: Arc<Mutex<Connection>>,
query_parser: Arc<NoopQueryParser>,
}
#[derive(Debug)]
struct DummyAuthSource;
#[async_trait]
impl AuthSource for DummyAuthSource {
async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult<Password> {
println!("login info: {:?}", login_info);
let salt = vec![0, 0, 0, 0];
let password = "pencil";
let hash_password =
hash_md5_password(login_info.user().as_ref().unwrap(), password, salt.as_ref());
Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec()))
}
}
#[async_trait]
impl SimpleQueryHandler for SqliteBackend {
async fn do_query<C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + Unpin + Send + Sync,
{
let conn = self.conn.lock().unwrap();
if query.to_uppercase().starts_with("SELECT") {
let mut stmt = conn
.prepare(query)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let header = Arc::new(row_desc_from_stmt(&stmt, &Format::UnifiedText)?);
stmt.query(())
.map(|rows| {
let s = encode_row_data(rows, header.clone());
vec![Response::Query(QueryResponse::new(header, s))]
})
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
conn.execute(query, ())
.map(|affected_rows| {
vec![Response::Execution(Tag::new("OK").with_rows(affected_rows))]
})
.map_err(|e| PgWireError::ApiError(Box::new(e)))
}
}
}
fn name_to_type(name: &str) -> PgWireResult<Type> {
match name.to_uppercase().as_ref() {
"INT" => Ok(Type::INT8),
"VARCHAR" => Ok(Type::VARCHAR),
"TEXT" => Ok(Type::TEXT),
"BINARY" => Ok(Type::BYTEA),
"FLOAT" => Ok(Type::FLOAT8),
_ => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"42846".to_owned(),
format!("Unsupported data type: {name}"),
)))),
}
}
fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
stmt.columns()
.iter()
.enumerate()
.map(|(idx, col)| {
let field_type = col
.decl_type()
.map(name_to_type)
.unwrap_or(Ok(Type::UNKNOWN))?;
Ok(FieldInfo::new(
col.name().to_owned(),
None,
None,
field_type,
format.format_for(idx),
))
})
.collect()
}
fn encode_row_data(
mut rows: Rows,
schema: Arc<Vec<FieldInfo>>,
) -> impl Stream<Item = PgWireResult<DataRow>> + use<> {
let mut results = Vec::new();
let ncols = schema.len();
let mut encoder = DataRowEncoder::new(schema.clone());
while let Ok(Some(row)) = rows.next() {
for idx in 0..ncols {
let data = row.get_ref_unwrap::<usize>(idx);
match data {
ValueRef::Null => encoder.encode_field(&None::<i8>).unwrap(),
ValueRef::Integer(i) => {
encoder.encode_field(&i).unwrap();
}
ValueRef::Real(f) => {
encoder.encode_field(&f).unwrap();
}
ValueRef::Text(t) => {
encoder
.encode_field(&String::from_utf8_lossy(t).as_ref())
.unwrap();
}
ValueRef::Blob(b) => {
encoder.encode_field(&b).unwrap();
}
}
}
results.push(Ok(encoder.take_row()));
}
stream::iter(results)
}
fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
let mut results = Vec::with_capacity(portal.parameter_len());
for i in 0..portal.parameter_len() {
let param_type = portal
.statement
.parameter_types
.get(i)
.unwrap()
.as_ref()
.unwrap_or(&Type::UNKNOWN);
match param_type {
&Type::BOOL => {
let param = portal.parameter::<bool>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::INT2 => {
let param = portal.parameter::<i16>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::INT4 => {
let param = portal.parameter::<i32>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::INT8 => {
let param = portal.parameter::<i64>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::TEXT | &Type::VARCHAR => {
let param = portal.parameter::<String>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::FLOAT4 => {
let param = portal.parameter::<f32>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::FLOAT8 => {
let param = portal.parameter::<f64>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
_ => {
unimplemented!("parameter type not supported")
}
}
}
results
}
#[async_trait]
impl ExtendedQueryHandler for SqliteBackend {
type Statement = String;
type QueryParser = NoopQueryParser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
self.query_parser.clone()
}
async fn do_query<C>(
&self,
_client: &mut C,
portal: &Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
{
let conn = self.conn.lock().unwrap();
let query = &portal.statement.statement;
let mut stmt = conn
.prepare_cached(query)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let params = get_params(portal);
let params_ref = params
.iter()
.map(|f| f.as_ref())
.collect::<Vec<&dyn rusqlite::ToSql>>();
if query.to_uppercase().starts_with("SELECT") {
let header = Arc::new(row_desc_from_stmt(&stmt, &portal.result_column_format)?);
stmt.query::<&[&dyn rusqlite::ToSql]>(params_ref.as_ref())
.map(|rows| {
let s = encode_row_data(rows, header.clone());
Response::Query(QueryResponse::new(header, s))
})
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
stmt.execute::<&[&dyn rusqlite::ToSql]>(params_ref.as_ref())
.map(|affected_rows| Response::Execution(Tag::new("OK").with_rows(affected_rows)))
.map_err(|e| PgWireError::ApiError(Box::new(e)))
}
}
async fn do_describe_statement<C>(
&self,
_client: &mut C,
stmt: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let conn = self.conn.lock().unwrap();
let param_types = stmt
.parameter_types
.iter()
.map(|t| t.clone().unwrap_or(Type::UNKNOWN))
.collect();
let stmt = conn
.prepare_cached(&stmt.statement)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
row_desc_from_stmt(&stmt, &Format::UnifiedBinary)
.map(|fields| DescribeStatementResponse::new(param_types, fields))
}
async fn do_describe_portal<C>(
&self,
_client: &mut C,
portal: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let conn = self.conn.lock().unwrap();
let stmt = conn
.prepare_cached(&portal.statement.statement)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
row_desc_from_stmt(&stmt, &portal.result_column_format).map(DescribePortalResponse::new)
}
}
impl SqliteBackend {
fn new() -> SqliteBackend {
SqliteBackend {
conn: Arc::new(Mutex::new(Connection::open_in_memory().unwrap())),
query_parser: Arc::new(NoopQueryParser::new()),
}
}
}
struct SqliteBackendFactory {
handler: Arc<SqliteBackend>,
}
impl PgWireServerHandlers for SqliteBackendFactory {
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
self.handler.clone()
}
fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
self.handler.clone()
}
fn startup_handler(&self) -> Arc<impl StartupHandler> {
let mut parameters = DefaultServerParameterProvider::default();
parameters.server_version = rusqlite::version().to_owned();
Arc::new(Md5PasswordAuthStartupHandler::new(
Arc::new(DummyAuthSource),
Arc::new(parameters),
))
}
}
#[tokio::main]
pub async fn main() {
let factory = Arc::new(SqliteBackendFactory {
handler: Arc::new(SqliteBackend::new()),
});
let server_addr = "127.0.0.1:5432";
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let factory_ref = factory.clone();
tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
}
}