db-cores 0.1.0

Database core utilities
Documentation

use crate::common::ColumnBaseInfo;
use crate::to_json::PgRowParse;
use base64::{engine::general_purpose, Engine};
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
use rust_decimal::Decimal;
use serde_json::{json, Value as JsonValue};
use sqlx::postgres::{ PgRow, PgTypeInfo};
use sqlx::{Column, Row, TypeInfo};
use crate::utlis::decode_auto;

// ========================================
// 主接口:转换为 JSON
// ========================================
pub fn to_json(results: Vec<PgRow>) -> anyhow::Result<(Vec<JsonValue>, Vec<ColumnBaseInfo>)> {
    if results.is_empty() {
        return Ok((vec![], vec![]));
    }
    let first_row = &results[0];
    let PgRowParse { methods, columns } = determine_parsing_methods(first_row)?;

    let mut data = Vec::with_capacity(results.len());
    for row in results {
        let mut row_data = json!({});
        for col in row.columns() {
            let col_index = col.ordinal();
            let col_name = col.name();
            let value = parse_value(&row, col_index, &methods);
            row_data[col_name] = value;
        }
        data.push(row_data);
    }

    Ok((data, columns))
}

// ========================================
// 类型映射与解析策略
// ========================================

// 为每一列确定解析函数和元信息

pub fn determine_parsing_methods(row: &PgRow) -> anyhow::Result<PgRowParse> {
    let columns = row.columns();
    let mut methods = Vec::with_capacity(columns.len());
    let mut new_columns: Vec<ColumnBaseInfo> = Vec::with_capacity(columns.len());

    for col in columns {
        let col_index = col.ordinal();
        let col_name = col.name();
        let type_info = col.type_info();
        let field_type = detect_pg_type(type_info);

        let method = match field_type {
            "text" => parse_text_value,
            "int" => parse_integer_value,
            "float" | "numeric" => parse_decimal_value,
            "bool" => parse_bool_value,
            "date" => parse_date_value,
            "timestamp" => parse_datetime_value,
            "timestamptz" => parse_utc_value,
            "time" => parse_time_value,
            "jsonb" => parse_json_value,
            "bytea" => parse_bytea_value,
            "uuid" => parse_text_value,
            "array" => parse_array,
            "other" => parse_text_value,
            // 扩展类型
            "geometry" | "geography" | "hstore" => parse_text_value,
            _ => parse_text_value,
        };

        new_columns.push(ColumnBaseInfo {
            name: col_name.to_string(),
            r#type: field_type.to_string(),
            index: col_index as u64,
        });
        methods.push(method);
    }

    Ok(PgRowParse {
        methods,
        columns: new_columns,
    })
}

fn parse_value(
    row: &PgRow,
    col_index: usize,
    parsing_methods: &[fn(&PgRow, usize) -> JsonValue],
) -> JsonValue {
    let method = parsing_methods[col_index];
    method(row, col_index)
}

// ========================================
// PostgreSQL 类型检测
// ========================================

pub fn detect_pg_type(type_info: &PgTypeInfo) -> &'static str {
    let kind: &sqlx::postgres::PgTypeKind = type_info.kind();
    match kind {
        sqlx::postgres::PgTypeKind::Simple => {
            let name = type_info.name().to_lowercase();
            match name.as_str() {
                "int2" | "smallint" | "int4" | "integer" | "int8" | "bigint" => "int",
                "float4" | "real" | "float8" | "double precision" => "float",
                "numeric" | "decimal" => "numeric",
                "bool" | "boolean" => "bool",
                "text" | "varchar" | "char" | "bpchar" | "citext" | "name" => "text",
                "date" => "date",
                "timestamp" | "timestamp without time zone" => "timestamp",
                "timestamptz" | "timestamp with time zone" => "timestamptz",
                "time" | "timetz" | "time without time zone" => "time",
                "jsonb" | "json" => "jsonb",
                "bytea" => "bytea",
                "uuid" => "uuid",
                "interval" | "money" | "inet" | "cidr" | "macaddr" | "xml" => "text",
                // 扩展
                "geometry" | "geography" => "geometry",
                "hstore" => "hstore",
                _ => "text", // 其他未知类型一律当 text 处理
            }
        }
        sqlx::postgres::PgTypeKind::Array(_) => "array",
        _ => "other", // 其它复杂类型(enum, composite, range)都转成 string
    }
}

// ========================================
// 各类型解析实现
// ========================================
fn parse_text_value(row: &PgRow, col_index: usize) -> JsonValue {
    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(col_index) {
        JsonValue::String(v)
    } else {
        JsonValue::Null
    }
}

fn parse_integer_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<i64>, _>(col_index) {
        Ok(Some(i)) => json!(i),
        _ => JsonValue::Null,
    }
}

fn _parse_real_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<f64>, _>(col_index) {
        Ok(Some(f)) => json!(f),
        _ => JsonValue::Null,
    }
}

fn parse_bool_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<bool>, _>(col_index) {
        Ok(Some(b)) => json!(b),
        _ => JsonValue::Null,
    }
}

fn parse_date_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<NaiveDate>, _>(col_index) {
        Ok(Some(d)) => json!(d.format("%Y-%m-%d").to_string()),
        _ => JsonValue::Null,
    }
}

fn parse_datetime_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<NaiveDateTime>, _>(col_index) {
        Ok(Some(dt)) => json!(dt.format("%Y-%m-%d %H:%M:%S").to_string()),
        _ => JsonValue::Null,
    }
}

fn parse_utc_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<DateTime<Utc>>, _>(col_index) {
        Ok(Some(dt)) => {
            // 直接序列化为 ISO 8601 字符串(推荐)
            json!(dt.to_rfc3339())
        }
        Ok(None) => JsonValue::Null,
        Err(_) => JsonValue::Null, // 包括列不存在、类型不匹配等
    }
}

fn parse_time_value(row: &PgRow, col_index: usize) -> JsonValue {
    if let Ok(Some(t)) = row.try_get::<Option<String>, _>(col_index) {
        json!(t)
    } else {
        JsonValue::Null
    }
}

fn parse_json_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<JsonValue>, _>(col_index) {
        Ok(Some(j)) => j,
        _ => JsonValue::Null,
    }
}

fn parse_decimal_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<Decimal>, _>(col_index) {
        Ok(Some(d)) => json!(d.to_string()), // 保持字符串形式避免精度丢失
        _ => JsonValue::Null,
    }
}

fn parse_bytea_value(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<Vec<u8>>, _>(col_index) {
        Ok(Some(b)) => {
            let is_text = blob_is_text(&b);
            if is_text {
                JsonValue::String(decode_auto(&b))
            } else {
                JsonValue::String(general_purpose::STANDARD.encode(b))
            }
        }
        _ => JsonValue::Null,
    }
}

fn parse_array(row: &PgRow, col_index: usize) -> JsonValue {
    match row.try_get::<Option<String>, _>(col_index) {
        Ok(Some(d)) => parse_postgres_array(&d), // 保持字符串形式避免精度丢失
        _ => JsonValue::Null,
    }
}

// ========================================
// 数组解析(PostgreSQL 格式)
// ========================================
fn parse_postgres_array(input: &str) -> JsonValue {
    let s = input.trim();
    // 快速路径:空值
    if s.is_empty() || s.eq_ignore_ascii_case("null") || s == "{}" {
        return JsonValue::Array(Vec::new());
    }
    // // 必须是 { ... }
    if !s.starts_with('{') || !s.ends_with('}') {
        return JsonValue::String(s.to_owned());
    }
    let content = &s[1..s.len() - 1];
    // 🚀 快速路径:如果没有引号,直接 split
    if !content.contains('"') {
        return JsonValue::Array(
            content
                .split(',')
                .map(|item| parse_array_element_fast(item.trim()))
                .collect(),
        );
    }

    // 🐢 慢路径:有引号,需要状态机
    let mut items = Vec::with_capacity(4); // 预估容量
    let mut current = String::with_capacity(16);
    let mut in_quotes = false;
    let mut chars = content.chars().peekable();

    while let Some(c) = chars.next() {
        match c {
            '"' => {
                in_quotes = !in_quotes;
            }
            '\\' if in_quotes && chars.peek() == Some(&'"') => {
                chars.next();
                current.push('"');
            }
            ',' if !in_quotes => {
                items.push(parse_array_element_owned(&current));
                current.clear();
            }
            _ => {
                current.push(c);
            }
        }
    }

    if !current.is_empty() || content.ends_with(',') {
        items.push(parse_array_element_owned(&current));
    }

    JsonValue::Array(items)
}

// 快速解析:已知无引号,直接处理切片
#[inline]
#[allow(dead_code)]
fn parse_array_element_fast(trimmed: &str) -> JsonValue {
    match trimmed {
        "" | "NULL" | "null" => JsonValue::Null,
        s => {
            // 尝试整数
            if let Ok(n) = s.parse::<i64>() {
                return JsonValue::Number(n.into());
            }
            // 尝试浮点
            if let Ok(n) = s.parse::<f64>() {
                if let Some(num) = serde_json::Number::from_f64(n) {
                    return JsonValue::Number(num);
                }
            }
            JsonValue::String(s.to_owned())
        }
    }
}

// 慢速解析:字符串已拥有,需拷贝
#[inline]
#[allow(dead_code)]
fn parse_array_element_owned(s: &str) -> JsonValue {
    let trimmed = s.trim();
    match trimmed {
        "" | "NULL" | "null" => JsonValue::Null,
        _ => {
            if let Ok(n) = trimmed.parse::<i64>() {
                JsonValue::Number(n.into())
            } else if let Ok(n) = trimmed.parse::<f64>() {
                JsonValue::Number(serde_json::Number::from_f64(n).unwrap_or_else(|| 0.into()))
            } else {
                JsonValue::String(trimmed.to_owned())
            }
        }
    }
}

// ========================================
// 工具函数
// ========================================
fn blob_is_text(data: &[u8]) -> bool {
    const NON_TEXT_THRESHOLD: f32 = 0.2;
    const SAMPLE_SIZE: usize = 1024;
    let total_len = data.len();
    if total_len == 0 {
        return false;
    }
    let step = if total_len > SAMPLE_SIZE {
        total_len / SAMPLE_SIZE
    } else {
        1
    };
    let mut non_printables = 0;
    let mut checked_count = 0;
    for &byte in data.iter().step_by(step) {
        checked_count += 1;
        if !byte.is_ascii_graphic() && !byte.is_ascii_whitespace() {
            non_printables += 1;
        }
        let ratio = (non_printables as f32) / (checked_count as f32);
        if ratio >= NON_TEXT_THRESHOLD {
            return false;
        }
    }
    true
}