use crate::api::{
DataStatistics, IndexRecommendation, QueryCost, QuerySimilarity, QueryStats, QueryStatsSummary,
};
use crate::AppState;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Response},
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictCostRequest {
pub sql: String,
pub data_statistics: DataStatistics,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictCostResponse {
pub cost: QueryCost,
pub timestamp: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RecommendStrategyRequest {
pub sql: String,
pub data_statistics: DataStatistics,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RecommendStrategyResponse {
pub strategy: crate::api::ExecutionStrategy,
pub estimated_cost: QueryCost,
pub timestamp: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FindSimilarRequest {
pub sql: String,
#[serde(default = "default_similarity_threshold")]
pub threshold: f64,
}
fn default_similarity_threshold() -> f64 {
0.7
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FindSimilarResponse {
pub similar_queries: Vec<QuerySimilarity>,
pub count: usize,
pub timestamp: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct StatisticsResponse {
pub statistics: Vec<QueryStats>,
pub count: usize,
pub timestamp: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SummaryResponse {
#[serde(flatten)]
pub summary: QueryStatsSummary,
pub timestamp: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: String,
pub details: Option<String>,
}
impl IntoResponse for ErrorResponse {
fn into_response(self) -> Response {
let status = StatusCode::BAD_REQUEST;
(status, Json(self)).into_response()
}
}
pub async fn get_statistics(
State(state): State<AppState>,
) -> Result<Json<StatisticsResponse>, ErrorResponse> {
let statistics = state.query_intelligence.get_statistics().await;
let count = statistics.len();
Ok(Json(StatisticsResponse {
statistics,
count,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
pub async fn get_summary(
State(state): State<AppState>,
) -> Result<Json<SummaryResponse>, ErrorResponse> {
let summary = state.query_intelligence.get_summary().await;
Ok(Json(SummaryResponse {
summary,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
pub async fn predict_cost(
State(state): State<AppState>,
Json(request): Json<PredictCostRequest>,
) -> Result<Json<PredictCostResponse>, ErrorResponse> {
let parsed_query = crate::api::select::parse_sql(&request.sql).map_err(
|e: crate::api::select::SelectError| ErrorResponse {
error: "Failed to parse SQL query".to_string(),
details: Some(e.to_string()),
},
)?;
let cost = state
.query_intelligence
.predict_cost(&parsed_query, &request.data_statistics)
.await;
Ok(Json(PredictCostResponse {
cost,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
pub async fn recommend_strategy(
State(state): State<AppState>,
Json(request): Json<RecommendStrategyRequest>,
) -> Result<Json<RecommendStrategyResponse>, ErrorResponse> {
let parsed_query = crate::api::select::parse_sql(&request.sql).map_err(
|e: crate::api::select::SelectError| ErrorResponse {
error: "Failed to parse SQL query".to_string(),
details: Some(e.to_string()),
},
)?;
let strategy = state
.query_intelligence
.get_execution_strategy(&parsed_query, &request.data_statistics)
.await;
let cost = state
.query_intelligence
.predict_cost(&parsed_query, &request.data_statistics)
.await;
Ok(Json(RecommendStrategyResponse {
strategy,
estimated_cost: cost,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
pub async fn find_similar(
State(state): State<AppState>,
Json(request): Json<FindSimilarRequest>,
) -> Result<Json<FindSimilarResponse>, ErrorResponse> {
let parsed_query = crate::api::select::parse_sql(&request.sql).map_err(
|e: crate::api::select::SelectError| ErrorResponse {
error: "Failed to parse SQL query".to_string(),
details: Some(e.to_string()),
},
)?;
let similar_queries = state
.query_intelligence
.find_similar_queries(&parsed_query, request.threshold)
.await;
let count = similar_queries.len();
Ok(Json(FindSimilarResponse {
similar_queries,
count,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IndexRecommendationsResponse {
pub recommendations: Vec<IndexRecommendation>,
pub count: usize,
pub timestamp: String,
}
pub async fn get_index_recommendations(
State(state): State<AppState>,
) -> Result<Json<IndexRecommendationsResponse>, ErrorResponse> {
let recommendations = state.query_intelligence.get_index_recommendations().await;
let count = recommendations.len();
Ok(Json(IndexRecommendationsResponse {
recommendations,
count,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ComplexityDistributionResponse {
pub distribution: std::collections::HashMap<String, usize>,
pub timestamp: String,
}
pub async fn get_complexity_distribution(
State(state): State<AppState>,
) -> Result<Json<ComplexityDistributionResponse>, ErrorResponse> {
let distribution = state.query_intelligence.get_complexity_distribution().await;
Ok(Json(ComplexityDistributionResponse {
distribution,
timestamp: chrono::Utc::now().to_rfc3339(),
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_similarity_threshold() {
assert_eq!(default_similarity_threshold(), 0.7);
}
#[test]
fn test_predict_cost_request_serialization() {
let req = PredictCostRequest {
sql: "SELECT * FROM s3object".to_string(),
data_statistics: DataStatistics::default(),
};
if let Ok(json) = serde_json::to_string(&req) {
assert!(json.contains("SELECT"));
}
}
#[test]
fn test_find_similar_request_default_threshold() {
let json = r#"{"sql": "SELECT * FROM s3object"}"#;
if let Ok(req) = serde_json::from_str::<FindSimilarRequest>(json) {
assert_eq!(req.threshold, 0.7);
}
}
#[test]
fn test_index_recommendations_response_serialization() {
let response = IndexRecommendationsResponse {
recommendations: vec![IndexRecommendation {
column_name: "user_id".to_string(),
index_type: crate::api::IndexType::BTree,
reason: crate::api::IndexReason::FilterColumn,
impact_score: 0.85,
estimated_speedup: 5.0,
query_count: 100,
avg_selectivity: 0.02,
}],
count: 1,
timestamp: "2024-01-01T00:00:00Z".to_string(),
};
if let Ok(json) = serde_json::to_string(&response) {
assert!(json.contains("user_id"));
assert!(json.contains("FilterColumn"));
assert!(json.contains("0.85"));
}
}
}