Skip to main content

sql_fun_server/
lib.rs

1#![allow(dead_code)]
2
3mod pid_file;
4
5use clap::Parser;
6use pid_file::PidFile;
7use sql_fun_server_api::{
8    AnalyzeQueryArgs, AnalyzeQueryResponse, Collection, DescribeTableArgs, DescribeTableResponse,
9    InitializeSchemaArgs, ListObjectArgs, ObjectSummary, ReadCollectionArgs, ReleaseCollectionArgs,
10    SchemaContext, SqlFunServerApi, SqlFunServerApiError,
11};
12use std::{net::SocketAddr, sync::LazyLock};
13use tokio::{
14    io::{AsyncBufReadExt, AsyncWriteExt},
15    net::TcpStream,
16};
17
18#[derive(clap::Parser, Clone)]
19pub struct Args {
20    #[arg(short, long, env = "USER")]
21    user: String,
22}
23
24/// Run the server event loop until interrupted.
25///
26/// # Errors
27///
28/// Returns [`ServerError`] when binding or accepting sockets fails.
29pub async fn service() -> Result<(), ServerError> {
30    let args = Args::parse();
31
32    let listner = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
33
34    let local_addr = listner.local_addr()?;
35    let port = local_addr.port();
36    let Some(_pid_file) = PidFile::try_register_primary_server(&args, port)? else {
37        return Ok(());
38    };
39
40    loop {
41        eprintln!("Server started at {local_addr}");
42
43        let (accepted_socker, remote_addr) = listner.accept().await?;
44
45        tokio::task::spawn(async move { serve_client(accepted_socker, remote_addr).await });
46    }
47}
48
49pub struct Api {}
50static API_IMPL: LazyLock<Api> = LazyLock::new(|| Api {});
51
52impl Api {
53    #[must_use]
54    pub fn api_impl() -> &'static Self {
55        &API_IMPL
56    }
57}
58
59#[async_trait::async_trait]
60impl sql_fun_server_api::SqlFunServerApi for Api {
61    /// Initialize schema-aware query analysis context
62    async fn initialize_schema_context(
63        &self,
64        _args: InitializeSchemaArgs,
65    ) -> Result<SchemaContext, SqlFunServerApiError> {
66        todo!()
67    }
68
69    /// Read a collection
70    ///
71    /// When server returns multiple items, server keeps scrollable collection,
72    /// [`read_collection`] reads items in collection.
73    ///
74    /// # Return value
75    ///
76    /// returns [Collection<T>], T depends collection item type
77    ///
78    async fn read_collection(
79        &self,
80        _args: ReadCollectionArgs,
81    ) -> Result<serde_json::Value, SqlFunServerApiError> {
82        todo!()
83    }
84
85    /// Release a collection
86    ///
87    /// When server returns collection, server keeps scrollable collection,
88    /// [`release_collection`] release a server side collection.
89    ///
90    async fn release_collection(
91        &self,
92        _args: ReleaseCollectionArgs,
93    ) -> Result<(), SqlFunServerApiError> {
94        todo!()
95    }
96
97    /// analyze query
98    async fn analyze_query(
99        &self,
100        _args: AnalyzeQueryArgs,
101    ) -> Result<AnalyzeQueryResponse, SqlFunServerApiError> {
102        todo!()
103    }
104
105    /// list schema objects
106    async fn list_objects(
107        &self,
108        _args: ListObjectArgs,
109    ) -> Result<Collection<ObjectSummary>, SqlFunServerApiError> {
110        todo!()
111    }
112
113    /// describe a table
114    async fn describe_table(
115        &self,
116        _args: DescribeTableArgs,
117    ) -> Result<DescribeTableResponse, SqlFunServerApiError> {
118        todo!()
119    }
120}
121
122/// process a request
123///
124/// # Errors
125///
126/// [`JsonRpcError`]
127pub async fn process_request(
128    request_message: &JsonRpcRequest,
129) -> Result<serde_json::Value, JsonRpcError> {
130    let method = request_message.method.as_str();
131
132    match method {
133        "initialize_schema_context" => {
134            let Some(ref params) = request_message.params else {
135                Err(JsonRpcError::parameter_required(method))?
136            };
137            let initialize_arg: InitializeSchemaArgs = serde_json::from_value(params.clone())?;
138            let result = Api::api_impl()
139                .initialize_schema_context(initialize_arg)
140                .await?;
141            let result_value = serde_json::to_value(result)?;
142            Ok(result_value)
143        }
144
145        _ => Err(JsonRpcError::unknown_method(method)),
146    }
147}
148
149#[derive(serde::Deserialize)]
150pub struct JsonRpcRequest {
151    #[serde(rename = "jsonrpc")]
152    json_rpc_version: String,
153    id: serde_json::Value,
154    method: String,
155    params: Option<serde_json::Value>,
156}
157
158#[derive(serde::Serialize)]
159struct JsonRpcResponse {
160    #[serde(rename = "jsonrpc")]
161    json_rpc_version: String,
162    id: Option<serde_json::Value>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    result: Option<serde_json::Value>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    error: Option<JsonRpcError>,
167}
168
169impl JsonRpcResponse {
170    pub fn success(request: &JsonRpcRequest, value: serde_json::Value) -> Self {
171        Self {
172            json_rpc_version: String::from("2.0"),
173            id: Some(request.id.clone()),
174            result: Some(value),
175            error: None,
176        }
177    }
178
179    pub fn error(request: &JsonRpcRequest, error: &JsonRpcError) -> Self {
180        Self {
181            json_rpc_version: String::from("2.0"),
182            id: Some(request.id.clone()),
183            result: None,
184            error: Some(error.clone()),
185        }
186    }
187}
188
189#[derive(serde::Serialize, Clone)]
190pub struct JsonRpcError {
191    code: i32,
192    message: String,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    data: Option<serde_json::Value>,
195}
196
197impl From<serde_json::Error> for JsonRpcError {
198    fn from(_value: serde_json::Error) -> Self {
199        Self::json_deserialize_error()
200    }
201}
202
203impl From<SqlFunServerApiError> for JsonRpcError {
204    fn from(value: SqlFunServerApiError) -> Self {
205        Self::server_api_execution_error(value)
206    }
207}
208
209impl JsonRpcError {
210    /// error code for parameter missing
211    // TODO: Fixme value
212    const CODE_PARAMETER_REQUIRED: i32 = 0;
213
214    /// error code for unknown method
215    // TODO: Fixme value
216    const CODE_UNKNOWN_METHOD: i32 = 0;
217
218    /// error code for json deserialize error
219    // TODO: Fixme value
220    const CODE_JSON_DESERIALIZE_ERROR: i32 = 0;
221
222    /// error code for internal server error
223    // TODO: Fixme value
224    const CODE_INTERNAL_SERVER_ERROR: i32 = 0;
225
226    #[must_use]
227    pub fn parameter_required(method: &str) -> Self {
228        Self {
229            code: Self::CODE_PARAMETER_REQUIRED,
230            message: format!("required params for method : {method}"),
231            data: None,
232        }
233    }
234
235    #[must_use]
236    pub fn unknown_method(method: &str) -> Self {
237        Self {
238            code: Self::CODE_UNKNOWN_METHOD,
239            message: format!("unknown method : {method}"),
240            data: None,
241        }
242    }
243
244    #[must_use]
245    pub fn json_deserialize_error() -> Self {
246        Self {
247            code: Self::CODE_JSON_DESERIALIZE_ERROR,
248            message: "unexpected params".to_string(),
249            data: None,
250        }
251    }
252
253    #[must_use]
254    pub fn server_api_execution_error(_err: SqlFunServerApiError) -> Self {
255        Self {
256            code: Self::CODE_INTERNAL_SERVER_ERROR,
257            message: "internal server error".to_string(),
258            data: None,
259        }
260    }
261}
262
263async fn serve_client(stream: TcpStream, addr: SocketAddr) -> Result<(), ServerError> {
264    let (read_half, mut write_half) = stream.into_split();
265    let mut buf_read = tokio::io::BufReader::new(read_half);
266
267    loop {
268        let mut line = String::new();
269        buf_read.read_line(&mut line).await?;
270        tracing::info!("{addr}: recieved {line}");
271        let request_message: JsonRpcRequest = serde_json::from_str(&line)?;
272        let response = process_request(&request_message).await;
273        let response = match response {
274            Ok(result) => JsonRpcResponse::success(&request_message, result),
275            Err(error) => JsonRpcResponse::error(&request_message, &error),
276        };
277        let mut response_json = serde_json::to_string(&response)?;
278        response_json.push_str("\n");
279
280        tracing::info!("{addr}: sending {response_json}");
281        write_half.write_all(response_json.as_bytes()).await?;
282    }
283}
284
285#[derive(thiserror::Error, Debug)]
286pub enum ServerError {
287    #[error(transparent)]
288    Io(#[from] std::io::Error),
289
290    #[error(transparent)]
291    Serde(#[from] serde_json::Error),
292
293    #[error("logic bug {0}")]
294    Bug(String),
295}
296
297struct ReadableBuffer {
298    buffer_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
299    return_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
300    current_buffer: Option<Vec<u8>>,
301    inner_bufpos: usize,
302}
303
304impl ReadableBuffer {
305    fn fill_buff(&mut self) {
306        if self.current_buffer.is_none() {
307            self.inner_bufpos = 0;
308            self.current_buffer = self.buffer_rx.blocking_recv();
309        }
310    }
311
312    fn return_buff(&mut self) -> std::io::Result<()> {
313        let Some(bufflen) = self.current_buffer.as_ref().map(Vec::len) else {
314            return Ok(());
315        };
316        if bufflen == self.inner_bufpos {
317            let mut buffer = self.current_buffer.take().unwrap();
318            buffer.clear();
319            if self.return_tx.blocking_send(buffer).is_err() {
320                return Err(std::io::Error::other("sender closed"));
321            }
322            self.inner_bufpos = 0;
323        }
324        Ok(())
325    }
326}
327
328impl std::io::Read for ReadableBuffer {
329    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
330        self.fill_buff();
331
332        let Some(current_buff) = &self.current_buffer else {
333            return Ok(0);
334        };
335        let source_slice = &current_buff[self.inner_bufpos..];
336        let len = buf.len().min(source_slice.len());
337
338        buf[..len].copy_from_slice(&source_slice[..len]);
339        self.inner_bufpos += len;
340        self.return_buff()?;
341        Ok(len)
342    }
343}
344
345struct PipeBuffer {
346    buffer_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
347    return_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
348}
349
350impl PipeBuffer {
351    pub async fn pipe_read_until<A: AsyncBufReadExt + std::marker::Unpin>(
352        &mut self,
353        source: &mut A,
354        byte: u8,
355    ) -> Result<usize, std::io::Error> {
356        let Some(mut buff) = self.return_rx.recv().await else {
357            return Err(std::io::Error::other("receiver closed"))?;
358        };
359        buff.clear();
360        let size = source.read_until(byte, &mut buff).await?;
361        if size != 0 {
362            self.buffer_tx
363                .send(buff)
364                .await
365                .map_err(|_| std::io::Error::other("No receiver"))?;
366        }
367        Ok(size)
368    }
369}