use anyhow::{anyhow, bail, Context, Result};
use clickhouse::{Client, Row};
use serde::Deserialize;
use crate::{
options::Options,
schema::{Column, SqlType, Table},
};
fn make_client(options: &Options) -> Client {
let url = if !options.url.starts_with("http") {
format!("http://{}", options.url)
} else {
options.url.clone()
};
let mut client = Client::default().with_url(url);
if let Some(user) = &options.user {
client = client.with_user(user);
}
if let Some(password) = &options.password {
client = client.with_password(password);
}
client
}
#[derive(Debug, Deserialize, Row)]
struct RawColumn {
name: String,
#[serde(rename = "type")]
type_: String,
comment: String,
}
async fn fetch_raw_columns(client: &Client, options: &Options) -> Result<Vec<RawColumn>> {
Ok(client
.query(
"
SELECT ?fields
FROM system.columns
WHERE database = ?
AND table = ?
",
)
.bind(&options.database)
.bind(&options.table)
.fetch_all::<RawColumn>()
.await?)
}
fn make_table(raw_columns: Vec<RawColumn>, options: &Options) -> Result<Table> {
let mut columns = Vec::new();
for raw_column in raw_columns {
if options.ignore.contains(&raw_column.name) {
continue;
}
let reason = format!("failed to handle the `{}` column", raw_column.name);
let column = make_column(raw_column).context(reason)?;
columns.push(column);
}
Ok(Table { columns })
}
fn make_column(raw: RawColumn) -> Result<Column> {
let type_ = parse_type(&raw.type_)
.with_context(|| format!("failed to parse the `{}` type", raw.type_))?;
Ok(Column {
name: raw.name,
type_,
comment: raw.comment,
})
}
pub fn parse_type(raw: &str) -> Result<SqlType> {
let raw = raw.trim();
let raw = if let Some(args) = extract_inner(raw, "SimpleAggregateFunction") {
let mut tokens = args.split(", ").skip(1);
let type1 = tokens
.next()
.ok_or_else(|| anyhow!("single-arg SimpleAggregateFunction"))?;
if tokens.next().is_some() {
bail!("more than 2 args aren't supported in SimpleAggregateFunction");
}
type1
} else {
raw
};
let raw = extract_inner(raw, "LowCardinality").unwrap_or(raw);
Ok(match raw {
"UInt8" => SqlType::UInt8,
"UInt16" => SqlType::UInt16,
"UInt32" => SqlType::UInt32,
"UInt64" => SqlType::UInt64,
"UInt128" => SqlType::UInt128,
"Int8" => SqlType::Int8,
"Int16" => SqlType::Int16,
"Int32" => SqlType::Int32,
"Int64" => SqlType::Int64,
"Int128" => SqlType::Int128,
"Bool" => SqlType::Bool,
"String" => SqlType::String,
"Float32" => SqlType::Float32,
"Float64" => SqlType::Float64,
"Date" => SqlType::Date,
"DateTime" => SqlType::DateTime(None),
"IPv4" => SqlType::IPv4,
"IPv6" => SqlType::IPv6,
"UUID" => SqlType::UUID,
_ => {
if let Some(inner) = extract_inner(raw, "Nullable") {
SqlType::Nullable(Box::new(parse_type(inner)?))
}
else if let Some(inner) = extract_inner(raw, "DateTime") {
SqlType::DateTime(Some(inner.into()))
}
else if let Some(inner) = extract_inner(raw, "DateTime64") {
let (prec, tz) = inner
.split_once(',')
.map_or((inner, None), |(p, tz)| (p, Some(tz)));
let prec = prec.trim().parse().context("invalid precision")?;
SqlType::DateTime64(prec, tz.map(str::trim).map(Into::into))
}
else if let Some(inner) = extract_inner(raw, "Enum8") {
SqlType::Enum8(parse_kv_list(inner).context("invalid enum")?)
}
else if let Some(inner) = extract_inner(raw, "Enum16") {
SqlType::Enum16(parse_kv_list(inner).context("invalid enum")?)
}
else if let Some(inner) = extract_inner(raw, "Decimal") {
let (prec, scale) = inner.split_once(',').context("invalid decimal")?;
let prec = prec.trim().parse().context("invalid prec")?;
let scale = scale.trim().parse().context("invalid precision")?;
SqlType::Decimal(prec, scale)
}
else if let Some(inner) = extract_inner(raw, "FixedString") {
SqlType::FixedString(inner.parse().context("invalid size")?)
}
else if let Some(inner) = extract_inner(raw, "Array") {
SqlType::Array(Box::new(parse_type(inner)?))
}
else if let Some(inner) = extract_inner(raw, "Tuple") {
SqlType::Tuple(
inner
.split(',')
.map(parse_type)
.collect::<Result<Vec<_>>>()?,
)
}
else if let Some(inner) = extract_inner(raw, "Map") {
let (key, value) = inner.split_once(',').context("invalid map")?;
let key = parse_type(key).context("invalid key")?;
let value = parse_type(value).context("invalid value")?;
SqlType::Map(Box::new(key), Box::new(value))
} else {
bail!("unknown type");
}
}
})
}
fn extract_inner<'a>(raw: &'a str, wrapper: &str) -> Option<&'a str> {
if raw.starts_with(wrapper) && raw[wrapper.len()..].starts_with('(') {
Some(&raw[wrapper.len() + 1..raw.len() - 1])
} else {
None
}
}
fn parse_kv_list(raw: &str) -> Result<Vec<(String, i32)>> {
raw.split(", ")
.map(|pair| {
let (k, v) = pair
.split_once(" = ")
.with_context(|| format!("invalid key-value pair `{}`", pair))?;
let k = k.get(1..k.len() - 1).context("invalid variant key")?;
let v = v.parse().context("invalid variant value")?;
Ok((k.into(), v))
})
.collect()
}
pub async fn mine(options: &Options) -> Result<Table> {
let client = make_client(options);
let raw_columns = fetch_raw_columns(&client, options)
.await
.context("failed to fetch columns")?;
let table = make_table(raw_columns, options).context("failed to make the table")?;
Ok(table)
}