use axum::{
extract::{Query, State},
Json,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::api::models::{ApiError, ApiResponse};
use crate::api::server::AppState;
#[derive(Debug, Deserialize)]
pub struct InferSchemaRequest {
pub samples: Vec<serde_json::Value>,
pub table_name: Option<String>,
pub options: Option<InferenceOptions>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct InferenceOptions {
#[serde(default = "default_true")]
pub detect_nullable: bool,
#[serde(default = "default_true")]
pub detect_unique: bool,
#[serde(default = "default_true")]
pub detect_primary_key: bool,
#[serde(default)]
pub detect_foreign_keys: bool,
#[serde(default = "default_true")]
pub suggest_indexes: bool,
#[serde(default = "default_true")]
pub prefer_narrow_types: bool,
#[serde(default = "default_max_varchar")]
pub max_varchar_length: usize,
#[serde(default = "default_true")]
pub detect_vectors: bool,
#[serde(default = "default_true")]
pub detect_json: bool,
}
fn default_true() -> bool {
true
}
fn default_max_varchar() -> usize {
255
}
#[derive(Debug, Serialize)]
pub struct InferredSchema {
pub table_name: String,
pub columns: Vec<InferredColumn>,
pub primary_key: Option<Vec<String>>,
pub indexes: Vec<SuggestedIndex>,
pub constraints: Vec<InferredConstraint>,
pub ddl: String,
pub confidence: f32,
pub warnings: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct InferredColumn {
pub name: String,
pub sql_type: String,
pub nullable: bool,
pub unique: bool,
pub default: Option<String>,
pub confidence: f32,
pub alternatives: Vec<String>,
pub detected_pattern: Option<String>,
pub statistics: Option<ColumnStatistics>,
}
#[derive(Debug, Serialize)]
pub struct ColumnStatistics {
pub null_count: usize,
pub distinct_count: usize,
pub min: Option<serde_json::Value>,
pub max: Option<serde_json::Value>,
pub avg_length: Option<f32>,
pub max_length: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SuggestedIndex {
pub name: String,
pub columns: Vec<String>,
pub index_type: String,
pub reason: String,
}
#[derive(Debug, Serialize)]
pub struct InferredConstraint {
pub constraint_type: String,
pub columns: Vec<String>,
pub expression: Option<String>,
pub references: Option<ForeignKeyRef>,
}
#[derive(Debug, Serialize)]
pub struct ForeignKeyRef {
pub table: String,
pub columns: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct BatchInferRequest {
pub tables: Vec<TableSamples>,
#[serde(default)]
pub detect_relationships: bool,
pub options: Option<InferenceOptions>,
}
#[derive(Debug, Deserialize)]
pub struct TableSamples {
pub name: String,
pub samples: Vec<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct BatchInferResponse {
pub schemas: Vec<InferredSchema>,
pub relationships: Vec<DetectedRelationship>,
pub combined_ddl: String,
}
#[derive(Debug, Serialize)]
pub struct DetectedRelationship {
pub from_table: String,
pub from_column: String,
pub to_table: String,
pub to_column: String,
pub relationship_type: String,
pub confidence: f32,
}
#[derive(Debug, Deserialize)]
pub struct InferFromFileRequest {
pub format: String,
pub content: String,
pub csv_options: Option<CsvOptions>,
pub table_name: Option<String>,
pub options: Option<InferenceOptions>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct CsvOptions {
#[serde(default = "default_comma")]
pub delimiter: char,
#[serde(default = "default_true")]
pub has_header: bool,
#[serde(default = "default_quote")]
pub quote: char,
pub skip_rows: Option<usize>,
}
fn default_comma() -> char {
','
}
fn default_quote() -> char {
'"'
}
#[derive(Debug, Deserialize)]
pub struct OptimizeSchemaRequest {
pub current_ddl: String,
pub sample_queries: Option<Vec<String>>,
pub goals: Option<Vec<String>>,
pub statistics: Option<HashMap<String, TableStatistics>>,
}
#[derive(Debug, Deserialize)]
pub struct TableStatistics {
pub row_count: Option<usize>,
pub avg_row_size: Option<usize>,
pub column_cardinality: Option<HashMap<String, usize>>,
}
#[derive(Debug, Serialize)]
pub struct OptimizationResponse {
pub optimized_ddl: String,
pub changes: Vec<SchemaChange>,
pub migration_sql: String,
pub impact: OptimizationImpact,
}
#[derive(Debug, Serialize)]
pub struct SchemaChange {
pub change_type: String,
pub description: String,
pub affected: Vec<String>,
pub reason: String,
pub risk: String,
}
#[derive(Debug, Serialize)]
pub struct OptimizationImpact {
pub query_improvement: Option<String>,
pub storage_change: Option<String>,
pub risks: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct CompareSchemaRequest {
pub source: String,
pub target: String,
#[serde(default = "default_true")]
pub generate_migration: bool,
}
#[derive(Debug, Serialize)]
pub struct SchemaComparisonResponse {
pub differences: Vec<SchemaDifference>,
pub forward_migration: Option<String>,
pub backward_migration: Option<String>,
pub is_compatible: bool,
}
#[derive(Debug, Serialize)]
pub struct SchemaDifference {
pub diff_type: String,
pub object: String,
pub source_state: Option<String>,
pub target_state: Option<String>,
pub breaking: bool,
}
pub async fn infer_schema(
State(_state): State<AppState>,
Json(req): Json<InferSchemaRequest>,
) -> Result<Json<ApiResponse<InferredSchema>>, ApiError> {
if req.samples.is_empty() {
return Err(ApiError::bad_request("At least one sample is required"));
}
let options = req.options.unwrap_or(InferenceOptions {
detect_nullable: true,
detect_unique: true,
detect_primary_key: true,
detect_foreign_keys: false,
suggest_indexes: true,
prefer_narrow_types: true,
max_varchar_length: 255,
detect_vectors: true,
detect_json: true,
});
let table_name = req.table_name.unwrap_or_else(|| "inferred_table".to_string());
let mut columns = Vec::new();
let mut column_types: std::collections::HashMap<String, Vec<String>> = std::collections::HashMap::new();
let mut nullable_columns = std::collections::HashSet::new();
for sample in &req.samples {
if let serde_json::Value::Object(obj) = sample {
for (key, value) in obj {
let col_types = column_types.entry(key.clone()).or_insert_with(Vec::new);
let inferred_type = match value {
serde_json::Value::Null => {
nullable_columns.insert(key.clone());
"NULL".to_string()
}
serde_json::Value::Bool(_) => "BOOLEAN".to_string(),
serde_json::Value::Number(n) => {
if n.is_i64() {
if options.prefer_narrow_types { "INTEGER" } else { "BIGINT" }.to_string()
} else {
"NUMERIC".to_string()
}
}
serde_json::Value::String(s) => {
if s.len() > options.max_varchar_length {
"TEXT".to_string()
} else {
format!("VARCHAR({})", std::cmp::min(s.len() * 2, options.max_varchar_length))
}
}
serde_json::Value::Array(arr) => {
if options.detect_vectors && arr.iter().all(|v| matches!(v, serde_json::Value::Number(_))) {
format!("VECTOR({})", arr.len())
} else {
"JSON".to_string()
}
}
serde_json::Value::Object(_) => {
if options.detect_json {
"JSONB".to_string()
} else {
"JSON".to_string()
}
}
};
if inferred_type != "NULL" {
col_types.push(inferred_type);
}
}
}
}
for (name, types) in column_types {
let sql_type = if types.is_empty() {
"TEXT".to_string()
} else {
types.into_iter().next().unwrap_or_else(|| "TEXT".to_string())
};
let is_nullable = nullable_columns.contains(&name);
columns.push(InferredColumn {
name: name.clone(),
sql_type,
nullable: is_nullable,
unique: options.detect_unique && req.samples.iter()
.filter_map(|s| {
if let serde_json::Value::Object(obj) = s {
obj.get(&name)
} else {
None
}
})
.count() == req.samples.len(), default: None,
confidence: if is_nullable { 0.8 } else { 0.95 },
alternatives: vec![],
detected_pattern: None,
statistics: None,
});
}
let column_defs: Vec<String> = columns.iter()
.map(|c| format!("{} {} NOT NULL", c.name, c.sql_type))
.collect();
let ddl = format!(
"CREATE TABLE {} (\n {},\n PRIMARY KEY (id)\n);",
table_name,
column_defs.join(",\n ")
);
let schema = InferredSchema {
table_name,
columns,
primary_key: Some(vec!["id".to_string()]),
indexes: vec![],
constraints: vec![],
ddl,
confidence: 0.85,
warnings: vec!["Add 'id' column explicitly if needed for primary key".to_string()],
};
Ok(Json(ApiResponse::success(schema)))
}
pub async fn batch_infer_schema(
State(_state): State<AppState>,
Json(req): Json<BatchInferRequest>,
) -> Result<Json<ApiResponse<BatchInferResponse>>, ApiError> {
if req.tables.is_empty() {
return Err(ApiError::bad_request("At least one table is required"));
}
let options = req.options.clone().unwrap_or(InferenceOptions {
detect_nullable: true,
detect_unique: true,
detect_primary_key: true,
detect_foreign_keys: false,
suggest_indexes: true,
prefer_narrow_types: true,
max_varchar_length: 255,
detect_vectors: true,
detect_json: true,
});
let mut schemas = Vec::new();
let mut all_ddl = Vec::new();
for table in &req.tables {
let mut columns = Vec::new();
let mut column_types: std::collections::HashMap<String, Vec<String>> = std::collections::HashMap::new();
let mut nullable_columns = std::collections::HashSet::new();
for sample in &table.samples {
if let serde_json::Value::Object(obj) = sample {
for (key, value) in obj {
let col_types = column_types.entry(key.clone()).or_insert_with(Vec::new);
let inferred_type = match value {
serde_json::Value::Null => {
nullable_columns.insert(key.clone());
"NULL".to_string()
}
serde_json::Value::Bool(_) => "BOOLEAN".to_string(),
serde_json::Value::Number(n) => {
if n.is_i64() {
if options.prefer_narrow_types { "INTEGER" } else { "BIGINT" }.to_string()
} else {
"NUMERIC".to_string()
}
}
serde_json::Value::String(s) => {
if s.len() > options.max_varchar_length {
"TEXT".to_string()
} else {
format!("VARCHAR({})", std::cmp::min(s.len() * 2, options.max_varchar_length))
}
}
serde_json::Value::Array(arr) => {
if options.detect_vectors && arr.iter().all(|v| matches!(v, serde_json::Value::Number(_))) {
format!("VECTOR({})", arr.len())
} else {
"JSON".to_string()
}
}
serde_json::Value::Object(_) => {
if options.detect_json { "JSONB".to_string() } else { "JSON".to_string() }
}
};
if inferred_type != "NULL" {
col_types.push(inferred_type);
}
}
}
}
for (name, types) in column_types {
let sql_type = types.into_iter().next().unwrap_or_else(|| "TEXT".to_string());
let is_nullable = nullable_columns.contains(&name);
columns.push(InferredColumn {
name,
sql_type,
nullable: is_nullable,
unique: false,
default: None,
confidence: 0.85,
alternatives: vec![],
detected_pattern: None,
statistics: None,
});
}
let column_defs: Vec<String> = columns.iter()
.map(|c| format!("{} {}", c.name, c.sql_type))
.collect();
let ddl = format!(
"CREATE TABLE {} (\n {}\n);",
table.name,
column_defs.join(",\n ")
);
all_ddl.push(ddl.clone());
schemas.push(InferredSchema {
table_name: table.name.clone(),
columns,
primary_key: None,
indexes: vec![],
constraints: vec![],
ddl,
confidence: 0.85,
warnings: vec![],
});
}
let relationships = if req.detect_relationships {
vec![]
} else {
vec![]
};
let response = BatchInferResponse {
schemas,
relationships,
combined_ddl: format!("{}\n", all_ddl.join("\n\n")),
};
Ok(Json(ApiResponse::success(response)))
}
pub async fn infer_from_file(
State(_state): State<AppState>,
Json(req): Json<InferFromFileRequest>,
) -> Result<Json<ApiResponse<InferredSchema>>, ApiError> {
let table_name = req.table_name.unwrap_or_else(|| "imported_table".to_string());
let schema = InferredSchema {
table_name: table_name.clone(),
columns: vec![],
primary_key: None,
indexes: vec![],
constraints: vec![],
ddl: format!("CREATE TABLE {} (id BIGINT PRIMARY KEY);", table_name),
confidence: 0.5,
warnings: vec![format!("File inference from {} is not yet implemented", req.format)],
};
Ok(Json(ApiResponse::success(schema)))
}
pub async fn optimize_schema(
State(_state): State<AppState>,
Json(req): Json<OptimizeSchemaRequest>,
) -> Result<Json<ApiResponse<OptimizationResponse>>, ApiError> {
let _goals = req.goals.unwrap_or_default();
let _stats = req.statistics.unwrap_or_default();
let response = OptimizationResponse {
optimized_ddl: req.current_ddl,
changes: vec![],
migration_sql: "-- No optimizations recommended".to_string(),
impact: OptimizationImpact {
query_improvement: Some("0%".to_string()),
storage_change: Some("0 bytes".to_string()),
risks: vec!["Schema optimization not yet implemented".to_string()],
},
};
Ok(Json(ApiResponse::success(response)))
}
pub async fn compare_schemas(
State(_state): State<AppState>,
Json(req): Json<CompareSchemaRequest>,
) -> Result<Json<ApiResponse<SchemaComparisonResponse>>, ApiError> {
let response = SchemaComparisonResponse {
differences: vec![],
forward_migration: if req.generate_migration {
Some("-- No changes detected".to_string())
} else {
None
},
backward_migration: if req.generate_migration {
Some("-- No changes detected".to_string())
} else {
None
},
is_compatible: true,
};
Ok(Json(ApiResponse::success(response)))
}
#[derive(Debug, Deserialize)]
pub struct NaturalLanguageSchemaRequest {
pub description: String,
#[serde(default = "default_sql")]
pub format: String,
#[serde(default)]
pub include_samples: bool,
}
fn default_sql() -> String {
"sql".to_string()
}
pub async fn generate_from_description(
State(_state): State<AppState>,
Json(req): Json<NaturalLanguageSchemaRequest>,
) -> Result<Json<ApiResponse<NaturalLanguageSchemaResponse>>, ApiError> {
let response = NaturalLanguageSchemaResponse {
schema: "-- Natural language schema generation not yet implemented".to_string(),
explanation: format!("Received description: {}", req.description),
samples: None,
suggestions: vec!["Use the schema inference endpoints with sample data instead".to_string()],
};
Ok(Json(ApiResponse::success(response)))
}
#[derive(Debug, Serialize)]
pub struct NaturalLanguageSchemaResponse {
pub schema: String,
pub explanation: String,
pub samples: Option<Vec<serde_json::Value>>,
pub suggestions: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct ValidateSchemaRequest {
pub ddl: String,
pub rules: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
pub struct SchemaValidationResponse {
pub valid: bool,
pub errors: Vec<ValidationError>,
pub warnings: Vec<ValidationWarning>,
pub suggestions: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct ValidationError {
pub code: String,
pub message: String,
pub location: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ValidationWarning {
pub code: String,
pub message: String,
pub location: Option<String>,
}
pub async fn validate_schema(
State(_state): State<AppState>,
Json(_req): Json<ValidateSchemaRequest>,
) -> Result<Json<ApiResponse<SchemaValidationResponse>>, ApiError> {
let response = SchemaValidationResponse {
valid: true,
errors: vec![],
warnings: vec![],
suggestions: vec!["Schema validation not yet fully implemented".to_string()],
};
Ok(Json(ApiResponse::success(response)))
}
pub async fn list_templates(
State(_state): State<AppState>,
Query(_params): Query<HashMap<String, String>>,
) -> Result<Json<ApiResponse<Vec<SchemaTemplate>>>, ApiError> {
let templates: Vec<SchemaTemplate> = vec![];
Ok(Json(ApiResponse::success(templates)))
}
#[derive(Debug, Serialize)]
pub struct SchemaTemplate {
pub id: String,
pub name: String,
pub description: String,
pub category: String,
pub ddl: String,
pub parameters: Vec<TemplateParameter>,
}
#[derive(Debug, Serialize)]
pub struct TemplateParameter {
pub name: String,
pub description: String,
pub param_type: String,
pub default: Option<String>,
pub required: bool,
}
#[derive(Debug, Deserialize)]
pub struct InstantiateTemplateRequest {
pub template_id: String,
pub parameters: HashMap<String, serde_json::Value>,
}
pub async fn instantiate_template(
State(_state): State<AppState>,
Json(req): Json<InstantiateTemplateRequest>,
) -> Result<Json<ApiResponse<InferredSchema>>, ApiError> {
let schema = InferredSchema {
table_name: "template_instance".to_string(),
columns: vec![],
primary_key: None,
indexes: vec![],
constraints: vec![],
ddl: "-- Template instantiation not yet implemented".to_string(),
confidence: 0.5,
warnings: vec![format!("Template {} not found", req.template_id)],
};
Ok(Json(ApiResponse::success(schema)))
}