use bytes::BytesMut;
use databend_driver::{Client, Connection};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use bytes::Buf;
use clap::{command, Parser};
#[derive(Debug, Clone, Parser, PartialEq)]
#[command(name = "ttc")]
struct Config {
#[clap(short = 'P', default_value = "9902", env = "TTC_PORT", long)]
port: u16,
#[clap(
long,
env = "DATABEND_DSN",
hide_env_values = true,
default_value = "databend://default:@127.0.0.1:8000"
)]
databend_dsn: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct Response {
values: Vec<Vec<Option<String>>>,
error: Option<String>,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = Config::parse();
{
println!(
"Start to check databend dsn: {} is valid",
config.databend_dsn
);
let client = Client::new(config.databend_dsn.clone());
let conn = client.get_conn().await.unwrap();
println!("Databend version: {}", conn.version().await.unwrap());
}
let l = format!("127.0.0.1:{}", config.port);
let listener = TcpListener::bind(&l).await?;
println!("Rust TTC Server running on {l}");
println!("Ready to accept connections");
loop {
let (socket, _) = listener.accept().await?;
let config = config.clone();
tokio::spawn(async move {
if let Err(e) = process(socket, &config).await {
eprintln!("Error processing connection: {:?}", e);
}
});
}
}
async fn process(mut socket: TcpStream, config: &Config) -> Result<(), Box<dyn std::error::Error>> {
let mut buf = BytesMut::with_capacity(1024);
let client = Client::new(config.databend_dsn.clone());
let mut conn = client.get_conn().await?;
loop {
let n = socket.read_buf(&mut buf).await?;
if n == 0 {
return Ok(());
}
while let Some((frame, size)) = decode_frame(&buf) {
execute_command(&frame, &mut socket, &mut conn).await?;
buf.advance(size);
}
}
}
fn decode_frame(buf: &BytesMut) -> Option<(Vec<u8>, usize)> {
if buf.len() < 4 {
return None;
}
let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if buf.len() < 4 + len {
return None;
}
let message = buf[4..4 + len].to_vec();
Some((message, 4 + len))
}
async fn execute_command(
command: &[u8],
socket: &mut TcpStream,
conn: &mut Connection,
) -> Result<(), Box<dyn std::error::Error>> {
let command_str = String::from_utf8_lossy(command);
let results = conn.query_raw_all(&command_str).await;
let mut response = Response {
values: vec![],
error: None,
};
match results {
Ok(results) => {
response.values = results.into_iter().map(|row| row.raw_row).collect();
}
Err(err) => response.error = Some(err.to_string()),
}
let response = serde_json::to_vec(&response).unwrap();
let len = response.len() as u32;
let len_bytes = len.to_be_bytes();
let mut buffer = Vec::with_capacity(4 + response.len());
buffer.extend_from_slice(&len_bytes);
buffer.extend_from_slice(&response);
socket.write_all(&buffer).await?;
Ok(())
}