#![allow(dead_code)]
mod pid_file;
use clap::Parser;
use pid_file::PidFile;
use sql_fun_server_api::{
AnalyzeQueryArgs, AnalyzeQueryResponse, Collection, DescribeTableArgs, DescribeTableResponse,
InitializeSchemaArgs, ListObjectArgs, ObjectSummary, ReadCollectionArgs, ReleaseCollectionArgs,
SchemaContext, SqlFunServerApi, SqlFunServerApiError,
};
use std::{net::SocketAddr, sync::LazyLock};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt},
net::TcpStream,
};
#[derive(clap::Parser, Clone)]
pub struct Args {
#[arg(short, long, env = "USER")]
user: String,
}
pub async fn service() -> Result<(), ServerError> {
let args = Args::parse();
let listner = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let local_addr = listner.local_addr()?;
let port = local_addr.port();
let Some(_pid_file) = PidFile::try_register_primary_server(&args, port)? else {
return Ok(());
};
loop {
eprintln!("Server started at {local_addr}");
let (accepted_socker, remote_addr) = listner.accept().await?;
tokio::task::spawn(async move { serve_client(accepted_socker, remote_addr).await });
}
}
pub struct Api {}
static API_IMPL: LazyLock<Api> = LazyLock::new(|| Api {});
impl Api {
#[must_use]
pub fn api_impl() -> &'static Self {
&API_IMPL
}
}
#[async_trait::async_trait]
impl sql_fun_server_api::SqlFunServerApi for Api {
async fn initialize_schema_context(
&self,
_args: InitializeSchemaArgs,
) -> Result<SchemaContext, SqlFunServerApiError> {
todo!()
}
async fn read_collection(
&self,
_args: ReadCollectionArgs,
) -> Result<serde_json::Value, SqlFunServerApiError> {
todo!()
}
async fn release_collection(
&self,
_args: ReleaseCollectionArgs,
) -> Result<(), SqlFunServerApiError> {
todo!()
}
async fn analyze_query(
&self,
_args: AnalyzeQueryArgs,
) -> Result<AnalyzeQueryResponse, SqlFunServerApiError> {
todo!()
}
async fn list_objects(
&self,
_args: ListObjectArgs,
) -> Result<Collection<ObjectSummary>, SqlFunServerApiError> {
todo!()
}
async fn describe_table(
&self,
_args: DescribeTableArgs,
) -> Result<DescribeTableResponse, SqlFunServerApiError> {
todo!()
}
}
pub async fn process_request(
request_message: &JsonRpcRequest,
) -> Result<serde_json::Value, JsonRpcError> {
let method = request_message.method.as_str();
match method {
"initialize_schema_context" => {
let Some(ref params) = request_message.params else {
Err(JsonRpcError::parameter_required(method))?
};
let initialize_arg: InitializeSchemaArgs = serde_json::from_value(params.clone())?;
let result = Api::api_impl()
.initialize_schema_context(initialize_arg)
.await?;
let result_value = serde_json::to_value(result)?;
Ok(result_value)
}
_ => Err(JsonRpcError::unknown_method(method)),
}
}
#[derive(serde::Deserialize)]
pub struct JsonRpcRequest {
#[serde(rename = "jsonrpc")]
json_rpc_version: String,
id: serde_json::Value,
method: String,
params: Option<serde_json::Value>,
}
#[derive(serde::Serialize)]
struct JsonRpcResponse {
#[serde(rename = "jsonrpc")]
json_rpc_version: String,
id: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<JsonRpcError>,
}
impl JsonRpcResponse {
pub fn success(request: &JsonRpcRequest, value: serde_json::Value) -> Self {
Self {
json_rpc_version: String::from("2.0"),
id: Some(request.id.clone()),
result: Some(value),
error: None,
}
}
pub fn error(request: &JsonRpcRequest, error: &JsonRpcError) -> Self {
Self {
json_rpc_version: String::from("2.0"),
id: Some(request.id.clone()),
result: None,
error: Some(error.clone()),
}
}
}
#[derive(serde::Serialize, Clone)]
pub struct JsonRpcError {
code: i32,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<serde_json::Value>,
}
impl From<serde_json::Error> for JsonRpcError {
fn from(_value: serde_json::Error) -> Self {
Self::json_deserialize_error()
}
}
impl From<SqlFunServerApiError> for JsonRpcError {
fn from(value: SqlFunServerApiError) -> Self {
Self::server_api_execution_error(value)
}
}
impl JsonRpcError {
const CODE_PARAMETER_REQUIRED: i32 = 0;
const CODE_UNKNOWN_METHOD: i32 = 0;
const CODE_JSON_DESERIALIZE_ERROR: i32 = 0;
const CODE_INTERNAL_SERVER_ERROR: i32 = 0;
#[must_use]
pub fn parameter_required(method: &str) -> Self {
Self {
code: Self::CODE_PARAMETER_REQUIRED,
message: format!("required params for method : {method}"),
data: None,
}
}
#[must_use]
pub fn unknown_method(method: &str) -> Self {
Self {
code: Self::CODE_UNKNOWN_METHOD,
message: format!("unknown method : {method}"),
data: None,
}
}
#[must_use]
pub fn json_deserialize_error() -> Self {
Self {
code: Self::CODE_JSON_DESERIALIZE_ERROR,
message: "unexpected params".to_string(),
data: None,
}
}
#[must_use]
pub fn server_api_execution_error(_err: SqlFunServerApiError) -> Self {
Self {
code: Self::CODE_INTERNAL_SERVER_ERROR,
message: "internal server error".to_string(),
data: None,
}
}
}
async fn serve_client(stream: TcpStream, addr: SocketAddr) -> Result<(), ServerError> {
let (read_half, mut write_half) = stream.into_split();
let mut buf_read = tokio::io::BufReader::new(read_half);
loop {
let mut line = String::new();
buf_read.read_line(&mut line).await?;
tracing::info!("{addr}: recieved {line}");
let request_message: JsonRpcRequest = serde_json::from_str(&line)?;
let response = process_request(&request_message).await;
let response = match response {
Ok(result) => JsonRpcResponse::success(&request_message, result),
Err(error) => JsonRpcResponse::error(&request_message, &error),
};
let mut response_json = serde_json::to_string(&response)?;
response_json.push_str("\n");
tracing::info!("{addr}: sending {response_json}");
write_half.write_all(response_json.as_bytes()).await?;
}
}
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error("logic bug {0}")]
Bug(String),
}
struct ReadableBuffer {
buffer_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
return_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
current_buffer: Option<Vec<u8>>,
inner_bufpos: usize,
}
impl ReadableBuffer {
fn fill_buff(&mut self) {
if self.current_buffer.is_none() {
self.inner_bufpos = 0;
self.current_buffer = self.buffer_rx.blocking_recv();
}
}
fn return_buff(&mut self) -> std::io::Result<()> {
let Some(bufflen) = self.current_buffer.as_ref().map(Vec::len) else {
return Ok(());
};
if bufflen == self.inner_bufpos {
let mut buffer = self.current_buffer.take().unwrap();
buffer.clear();
if self.return_tx.blocking_send(buffer).is_err() {
return Err(std::io::Error::other("sender closed"));
}
self.inner_bufpos = 0;
}
Ok(())
}
}
impl std::io::Read for ReadableBuffer {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.fill_buff();
let Some(current_buff) = &self.current_buffer else {
return Ok(0);
};
let source_slice = ¤t_buff[self.inner_bufpos..];
let len = buf.len().min(source_slice.len());
buf[..len].copy_from_slice(&source_slice[..len]);
self.inner_bufpos += len;
self.return_buff()?;
Ok(len)
}
}
struct PipeBuffer {
buffer_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
return_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
}
impl PipeBuffer {
pub async fn pipe_read_until<A: AsyncBufReadExt + std::marker::Unpin>(
&mut self,
source: &mut A,
byte: u8,
) -> Result<usize, std::io::Error> {
let Some(mut buff) = self.return_rx.recv().await else {
return Err(std::io::Error::other("receiver closed"))?;
};
buff.clear();
let size = source.read_until(byte, &mut buff).await?;
if size != 0 {
self.buffer_tx
.send(buff)
.await
.map_err(|_| std::io::Error::other("No receiver"))?;
}
Ok(size)
}
}