use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use std::path::Path;
use crate::api::partial;
use crate::types::{CountMethod, DataType, RowCountEstimate, SchemaResult};
#[pyclass(name = "SchemaResult")]
pub struct PySchemaResult {
inner: SchemaResult,
}
impl PySchemaResult {
pub fn from_inner(inner: SchemaResult) -> Self {
Self { inner }
}
}
#[pymethods]
impl PySchemaResult {
#[getter]
fn columns(&self) -> Vec<std::collections::HashMap<String, String>> {
self.inner
.columns
.iter()
.map(|c| {
let mut m = std::collections::HashMap::new();
m.insert("name".to_string(), c.name.clone());
m.insert(
"data_type".to_string(),
match c.data_type {
DataType::Integer => "integer",
DataType::Float => "float",
DataType::String => "string",
DataType::Date => "date",
DataType::Boolean => "boolean",
}
.to_string(),
);
m
})
.collect()
}
#[getter]
fn rows_sampled(&self) -> usize {
self.inner.rows_sampled
}
#[getter]
fn inference_time_ms(&self) -> u128 {
self.inner.inference_time_ms
}
#[getter]
fn schema_stable(&self) -> bool {
self.inner.schema_stable
}
#[getter]
fn num_columns(&self) -> usize {
self.inner.columns.len()
}
#[getter]
fn column_names(&self) -> Vec<String> {
self.inner.columns.iter().map(|c| c.name.clone()).collect()
}
fn __repr__(&self) -> String {
format!(
"SchemaResult(columns={}, rows_sampled={}, stable={}, time={}ms)",
self.inner.columns.len(),
self.inner.rows_sampled,
self.inner.schema_stable,
self.inner.inference_time_ms,
)
}
}
#[pyclass(name = "RowCountEstimate")]
pub struct PyRowCountEstimate {
inner: RowCountEstimate,
}
impl PyRowCountEstimate {
pub fn from_inner(inner: RowCountEstimate) -> Self {
Self { inner }
}
}
#[pymethods]
impl PyRowCountEstimate {
#[getter]
fn count(&self) -> u64 {
self.inner.count
}
#[getter]
fn exact(&self) -> bool {
self.inner.exact
}
#[getter]
fn method(&self) -> &str {
match self.inner.method {
CountMethod::ParquetMetadata => "parquet_metadata",
CountMethod::FullScan => "full_scan",
CountMethod::Sampling => "sampling",
CountMethod::StreamFullScan => "stream_full_scan",
}
}
#[getter]
fn count_time_ms(&self) -> u128 {
self.inner.count_time_ms
}
fn __repr__(&self) -> String {
format!(
"RowCountEstimate(count={}, exact={}, method='{}', time={}ms)",
self.inner.count,
self.inner.exact,
self.method(),
self.inner.count_time_ms,
)
}
}
#[pyfunction]
pub fn infer_schema(path: &str) -> PyResult<PySchemaResult> {
let result = partial::infer_schema(Path::new(path))
.map_err(|e| PyRuntimeError::new_err(format!("Schema inference failed: {}", e)))?;
Ok(PySchemaResult { inner: result })
}
#[pyfunction]
pub fn quick_row_count(path: &str) -> PyResult<PyRowCountEstimate> {
let result = partial::quick_row_count(Path::new(path))
.map_err(|e| PyRuntimeError::new_err(format!("Row count failed: {}", e)))?;
Ok(PyRowCountEstimate { inner: result })
}