use std::error::Error;
use crate::lib::errors::execute_error::ExecuteError;
use crate::lib::executor::predule::Executor;
use crate::lib::logger::predule::Logger;
use crate::lib::pgwire::predule::Connection;
use crate::lib::server::channel::ChannelResponse;
use crate::lib::server::predule::{ChannelRequest, ServerOption, SharedState};
use futures::future::join_all;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use super::client::ClientInfo;
pub struct Server {
pub option: ServerOption,
}
impl Server {
pub fn new(option: ServerOption) -> Self {
Self { option }
}
pub async fn run(&self) -> Result<(), Box<dyn Error>> {
let (request_sender, mut request_receiver) = mpsc::channel::<ChannelRequest>(1000);
let background_task = tokio::spawn(async move {
while let Some(request) = request_receiver.recv().await {
tokio::spawn(async move {
let executor = Executor::new();
let result = executor.process_query(request.statement).await;
match result {
Ok(result) => {
if let Err(_response) = request
.response_sender
.send(ChannelResponse { result: Ok(result) })
{
Logger::error("channel send failed");
}
}
Err(error) => {
let error = error.to_string();
if let Err(_response) = request.response_sender.send(ChannelResponse {
result: Err(ExecuteError::boxed(ExecuteError::boxed(error))),
}) {
Logger::error("channel send failed");
}
}
}
});
}
});
let listener =
TcpListener::bind((self.option.host.to_owned(), self.option.port as u16)).await?;
let connection_task = tokio::spawn(async move {
loop {
let accepted = listener.accept().await;
let (stream, address) = match accepted {
Ok((stream, address)) => (stream, address),
Err(error) => {
Logger::error(format!("socket error {:?}", error));
continue;
}
};
let client_info = ClientInfo {
ip: address.ip(),
connection_id: uuid::Uuid::new_v4().to_string(),
database: "None".into(),
};
let shared_state = SharedState {
sender: request_sender.clone(),
client_info,
};
tokio::spawn(async move {
let mut conn = Connection::new(shared_state);
if let Err(error) = conn.run(stream).await {
Logger::error(format!("connection error {:?}", error));
}
});
}
});
Logger::info(format!(
"Server is running on {}:{}",
self.option.host, self.option.port
));
join_all(vec![connection_task, background_task]).await;
Ok(())
}
}