use std::fs::File;
use std::path::PathBuf;
use anyhow::{anyhow, Context, Result};
use clap::{Parser, Subcommand};
use comfy_table::{Cell, ContentArrangement, Table};
use polars::prelude::*;
use rand::Rng;
// JSON-specific imports
use polars::prelude::JsonReader;
use polars::prelude::JsonWriter;
use polars::prelude::JsonFormat;
{{#if (eq visualization "yes")}}
// Visualization imports
use plotters::prelude::*;
use plotters::style::Color;
{{else}}
// Import only what's needed for non-visualization mode
use plotters::style::Color;
{{/if}}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Analyze a data file
Analyze {
/// Path to the data file
#[arg(short, long)]
file: PathBuf,
/// Format of the data file (auto-detected if not specified)
#[arg(short = 'm', long, default_value = "")]
format: String,
/// Show statistics
#[arg(short, long)]
statistics: bool,
/// Group by column
#[arg(short, long)]
group_by: Option<String>,
/// Aggregate column
#[arg(short, long)]
agg_column: Option<String>,
/// Aggregation function (sum, mean, min, max, count)
#[arg(short = 'u', long, default_value = "sum")]
agg_func: String,
/// Filter expression (e.g., "age > 30")
#[arg(short = 'e', long)]
filter: Option<String>,
/// Limit number of rows
#[arg(short, long)]
limit: Option<usize>,
},
/// Generate sample data
Generate {
/// Number of rows to generate
#[arg(short, long, default_value_t = 100)]
rows: usize,
/// Output file path
#[arg(short, long)]
output: PathBuf,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match &cli.command {
Commands::Analyze {
file,
format,
statistics,
group_by,
agg_column,
agg_func,
filter,
limit
} => {
// Determine the format from the file extension if not specified
let format_str = if !format.is_empty() {
format.to_lowercase()
} else {
match file.extension() {
Some(ext) => ext.to_string_lossy().to_string(),
None => {
return Err(anyhow!("Could not determine file format. Please specify with --format"));
}
}
};
// Load the data
let df = load_data(file, &format_str)?;
// Apply filter if specified
let df = if let Some(filter_expr) = filter {
apply_filter(&df, filter_expr)?
} else {
df
};
// Apply limit if specified
let df = if let Some(limit_val) = limit {
df.head(Some(*limit_val))
} else {
df
};
// Print the data
print_dataframe(&df)?;
// Show statistics if requested
if *statistics {
print_statistics(&df)?;
}
// Group by if requested
if let (Some(group_col), Some(agg_col)) = (group_by, agg_column) {
group_and_aggregate(&df, group_col, agg_col, agg_func)?;
}
{{#if (eq visualization "yes")}}
// Create a simple visualization
if *statistics && df.height() > 0 {
if let Some(numeric_col) = find_numeric_column(&df) {
println!("\nCreating visualization for column: {}", numeric_col);
create_histogram(&df, &numeric_col)?;
}
}
{{/if}}
},
Commands::Generate { rows, output } => {
generate_sample_data(*rows, output)?;
}
}
Ok(())
}
fn load_data(path: &PathBuf, format: &str) -> Result<DataFrame> {
match format {
"json" => {
let file = File::open(path)
.with_context(|| format!("Failed to open JSON file: {}", path.display()))?;
let df = JsonReader::new(file)
.finish()
.with_context(|| format!("Failed to parse JSON file: {}", path.display()))?;
Ok(df)
},
_ => Err(anyhow!("Unsupported format: {}. This template only supports JSON files.", format)),
}
}
fn apply_filter(df: &DataFrame, filter_expr: &str) -> Result<DataFrame> {
// Simple parsing of filter expressions like "column > value"
let parts: Vec<&str> = filter_expr.split_whitespace().collect();
if parts.len() != 3 {
return Err(anyhow!("Invalid filter expression. Expected format: 'column operator value'"));
}
let column = parts[0];
let operator = parts[1];
let value = parts[2];
// Create the filter expression
let expr = match operator {
">" => col(column).gt(lit(value.parse::<f64>().unwrap_or(0.0))),
">=" => col(column).gt_eq(lit(value.parse::<f64>().unwrap_or(0.0))),
"<" => col(column).lt(lit(value.parse::<f64>().unwrap_or(0.0))),
"<=" => col(column).lt_eq(lit(value.parse::<f64>().unwrap_or(0.0))),
"==" | "=" => col(column).eq(lit(value.parse::<f64>().unwrap_or(0.0))),
"!=" => col(column).neq(lit(value.parse::<f64>().unwrap_or(0.0))),
_ => return Err(anyhow!("Unsupported operator: {}", operator)),
};
// Apply the filter
let filtered = df.clone().lazy().filter(expr).collect()?;
Ok(filtered)
}
fn print_dataframe(df: &DataFrame) -> Result<()> {
println!("\nDataFrame Shape: {} rows × {} columns", df.height(), df.width());
let mut table = Table::new();
table.set_content_arrangement(ContentArrangement::Dynamic);
// Add header row
let header_cells: Vec<Cell> = df.get_column_names().iter()
.map(|name| Cell::new(name.to_string()))
.collect();
table.set_header(header_cells);
// Add data rows (limit to 20 rows for display)
let max_rows = std::cmp::min(df.height(), 20);
for i in 0..max_rows {
let row_cells: Vec<Cell> = df.get_column_names().iter()
.map(|name| {
let col = df.column(name).unwrap();
let val = format!("{}", col.get(i).unwrap());
Cell::new(val)
})
.collect();
table.add_row(row_cells);
}
if df.height() > max_rows {
println!("{}", table);
println!("... and {} more rows", df.height() - max_rows);
} else {
println!("{}", table);
}
Ok(())
}
fn print_statistics(df: &DataFrame) -> Result<()> {
println!("\nStatistics:");
let mut table = Table::new();
table.set_content_arrangement(ContentArrangement::Dynamic);
// Add header row
let mut header_cells = vec![Cell::new("Statistic")];
// Find numeric columns
let numeric_cols: Vec<String> = df.get_column_names().iter()
.filter(|&name| {
let col = df.column(name).unwrap();
matches!(col.dtype(), DataType::Int32 | DataType::Int64 | DataType::Float32 | DataType::Float64)
})
.map(|s| s.to_string())
.collect();
for col in &numeric_cols {
header_cells.push(Cell::new(col.clone()));
}
table.set_header(header_cells);
// Add statistics rows
let stats = ["Count", "Mean", "Std", "Min", "25%", "50%", "75%", "Max"];
for stat in stats {
let mut row = vec![Cell::new(stat)];
for col_name in &numeric_cols {
let col = df.column(col_name).unwrap();
let value = match stat {
"Count" => format!("{}", col.len()),
"Mean" => {
// Use the appropriate method based on data type
let mean = match col.dtype() {
DataType::Int32 => col.i32().unwrap().mean(),
DataType::Int64 => col.i64().unwrap().mean(),
DataType::Float32 => col.f32().unwrap().mean(),
DataType::Float64 => col.f64().unwrap().mean(),
_ => None,
};
format!("{:.2}", mean.unwrap_or(f64::NAN))
},
"Std" => {
// Use the appropriate method based on data type
let std = match col.dtype() {
DataType::Int32 => col.i32().unwrap().std(0),
DataType::Int64 => col.i64().unwrap().std(0),
DataType::Float32 => col.f32().unwrap().std(0),
DataType::Float64 => col.f64().unwrap().std(0),
_ => None,
};
format!("{:.2}", std.unwrap_or(f64::NAN))
},
"Min" => {
let min = match col.dtype() {
DataType::Int32 => col.i32().unwrap().min().map(|v| v as f64),
DataType::Int64 => col.i64().unwrap().min().map(|v| v as f64),
DataType::Float32 => col.f32().unwrap().min().map(|v| v as f64),
DataType::Float64 => col.f64().unwrap().min(),
_ => None,
};
format!("{:.2}", min.unwrap_or(f64::NAN))
},
"25%" => {
// For percentiles, we'll use a simple approach instead of quantile
let mut values = Vec::new();
match col.dtype() {
DataType::Int32 => {
let arr = col.i32().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Int64 => {
let arr = col.i64().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Float32 => {
let arr = col.f32().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Float64 => {
let arr = col.f64().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v);
}
}
},
_ => {}
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = (values.len() as f64 * 0.25) as usize;
let q25 = if values.is_empty() { 0.0 } else { values[idx.min(values.len() - 1)] };
format!("{:.2}", q25)
},
"50%" => {
// For percentiles, we'll use a simple approach instead of quantile
let mut values = Vec::new();
match col.dtype() {
DataType::Int32 => {
let arr = col.i32().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Int64 => {
let arr = col.i64().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Float32 => {
let arr = col.f32().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Float64 => {
let arr = col.f64().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v);
}
}
},
_ => {}
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = (values.len() as f64 * 0.50) as usize;
let q50 = if values.is_empty() { 0.0 } else { values[idx.min(values.len() - 1)] };
format!("{:.2}", q50)
},
"75%" => {
// For percentiles, we'll use a simple approach instead of quantile
let mut values = Vec::new();
match col.dtype() {
DataType::Int32 => {
let arr = col.i32().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Int64 => {
let arr = col.i64().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Float32 => {
let arr = col.f32().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v as f64);
}
}
},
DataType::Float64 => {
let arr = col.f64().unwrap();
for i in 0..arr.len() {
if let Some(v) = arr.get(i) {
values.push(v);
}
}
},
_ => {}
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = (values.len() as f64 * 0.75) as usize;
let q75 = if values.is_empty() { 0.0 } else { values[idx.min(values.len() - 1)] };
format!("{:.2}", q75)
},
"Max" => {
let max = match col.dtype() {
DataType::Int32 => col.i32().unwrap().max().map(|v| v as f64),
DataType::Int64 => col.i64().unwrap().max().map(|v| v as f64),
DataType::Float32 => col.f32().unwrap().max().map(|v| v as f64),
DataType::Float64 => col.f64().unwrap().max(),
_ => None,
};
format!("{:.2}", max.unwrap_or(f64::NAN))
},
_ => "".to_string(),
};
row.push(Cell::new(value));
}
table.add_row(row);
}
println!("{}", table);
Ok(())
}
fn group_and_aggregate(df: &DataFrame, group_col: &str, agg_col: &str, agg_func: &str) -> Result<()> {
println!("\nGrouping by '{}' and aggregating '{}' with function '{}':", group_col, agg_col, agg_func);
// Check if columns exist
let col_names = df.get_column_names();
let group_col_exists = col_names.iter().any(|name| name.as_str() == group_col);
let agg_col_exists = col_names.iter().any(|name| name.as_str() == agg_col);
if !group_col_exists {
return Err(anyhow!("Column '{}' not found in DataFrame", group_col));
}
if !agg_col_exists {
return Err(anyhow!("Column '{}' not found in DataFrame", agg_col));
}
// Create the aggregation expression
let agg_expr = match agg_func.to_lowercase().as_str() {
"sum" => col(agg_col).sum(),
"mean" => col(agg_col).mean(),
"min" => col(agg_col).min(),
"max" => col(agg_col).max(),
"count" => col(agg_col).count(),
_ => return Err(anyhow!("Unsupported aggregation function: {}", agg_func)),
};
// Perform the groupby operation
let result = df.clone()
.lazy()
.group_by([col(group_col)])
.agg([agg_expr.alias(&format!("{}_{}", agg_func, agg_col))])
.collect()?;
// Print the result
print_dataframe(&result)?;
Ok(())
}
{{#if (eq visualization "yes")}}
fn find_numeric_column(df: &DataFrame) -> Option<String> {
for name in df.get_column_names() {
let col = df.column(name).unwrap();
if matches!(col.dtype(), DataType::Int32 | DataType::Int64 | DataType::Float32 | DataType::Float64) {
return Some(name.to_string());
}
}
None
}
fn create_histogram(df: &DataFrame, column: &str) -> Result<()> {
// Get the numeric column
let series = df.column(column)?;
// Create a temporary file for the plot
let filename = "histogram.png";
// Create the plot
let root = BitMapBackend::new(filename, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
// Get min and max values
let min_val = match series.dtype() {
DataType::Int32 => series.i32().unwrap().min().map(|v| v as f64).unwrap_or(0.0),
DataType::Int64 => series.i64().unwrap().min().map(|v| v as f64).unwrap_or(0.0),
DataType::Float32 => series.f32().unwrap().min().map(|v| v as f64).unwrap_or(0.0),
DataType::Float64 => series.f64().unwrap().min().unwrap_or(0.0),
_ => 0.0,
};
let max_val = match series.dtype() {
DataType::Int32 => series.i32().unwrap().max().map(|v| v as f64).unwrap_or(100.0),
DataType::Int64 => series.i64().unwrap().max().map(|v| v as f64).unwrap_or(100.0),
DataType::Float32 => series.f32().unwrap().max().map(|v| v as f64).unwrap_or(100.0),
DataType::Float64 => series.f64().unwrap().max().unwrap_or(100.0),
_ => 100.0,
};
// Create the chart
let mut chart = ChartBuilder::on(&root)
.caption(format!("Histogram of {}", column), ("sans-serif", 30))
.margin(10)
.x_label_area_size(30)
.y_label_area_size(30)
.build_cartesian_2d(min_val..max_val, 0f64..30f64)?;
chart.configure_mesh().draw()?;
// Calculate histogram bins
let bin_width = (max_val - min_val) / 20.0;
let mut bins = vec![0; 20];
for i in 0..series.len() {
let val = series.get(i).unwrap();
let f = match val {
AnyValue::Float64(f) => f,
AnyValue::Float32(f) => f as f64,
AnyValue::Int64(i) => i as f64,
AnyValue::Int32(i) => i as f64,
_ => continue,
};
let bin_idx = ((f - min_val) / bin_width).floor() as usize;
if bin_idx < bins.len() {
bins[bin_idx] += 1;
}
}
// Draw the histogram
chart.draw_series(
bins.iter().enumerate().map(|(i, &count)| {
let x0 = min_val + bin_width * i as f64;
let x1 = min_val + bin_width * (i + 1) as f64;
Rectangle::new(
[(x0, 0.0), (x1, count as f64)],
BLUE.filled(),
)
}),
)?;
// Save the plot
root.present()?;
println!("Histogram saved to {}", filename);
Ok(())
}
{{/if}}
fn generate_sample_data(rows: usize, output: &PathBuf) -> Result<()> {
println!("Generating {} rows of sample data to {}", rows, output.display());
// Create directory if it doesn't exist
if let Some(parent) = output.parent() {
std::fs::create_dir_all(parent)?;
}
// Generate random data
let mut rng = rand::thread_rng();
// Create columns
let mut id_vec = Vec::with_capacity(rows);
let mut name_vec = Vec::with_capacity(rows);
let mut age_vec = Vec::with_capacity(rows);
let mut salary_vec = Vec::with_capacity(rows);
let mut department_vec = Vec::with_capacity(rows);
let departments = ["Engineering", "Marketing", "Sales", "HR", "Finance"];
for i in 0..rows {
id_vec.push(i as i32);
// Generate a random name
let names = ["Alice", "Bob", "Charlie", "David", "Emma", "Frank", "Grace", "Hannah"];
let surnames = ["Smith", "Johnson", "Williams", "Jones", "Brown", "Davis", "Miller", "Wilson"];
let name = format!("{} {}",
names[rng.gen_range(0..names.len())],
surnames[rng.gen_range(0..surnames.len())]
);
name_vec.push(name);
// Generate a random age between 20 and 65
age_vec.push(rng.gen_range(20..65));
// Generate a random salary between 30000 and 150000
salary_vec.push(rng.gen_range(30000..150000));
// Assign a random department
department_vec.push(departments[rng.gen_range(0..departments.len())].to_string());
}
// Create a DataFrame
let df = DataFrame::new(vec![
Series::new("id".into(), id_vec).into(),
Series::new("name".into(), name_vec).into(),
Series::new("age".into(), age_vec).into(),
Series::new("salary".into(), salary_vec).into(),
Series::new("department".into(), department_vec).into(),
])?;
// Write to JSON
let mut file = File::create(output)?;
let mut df_mut = df.clone();
JsonWriter::new(&mut file)
.with_json_format(JsonFormat::Json)
.finish(&mut df_mut)?;
println!("Sample data generated successfully!");
println!("\nPreview of generated data:");
print_dataframe(&df)?;
Ok(())
}