use super::raw_data::titanic_raw::load_titanic_raw_data;
use ndarray::{Array1, Array2};
use std::sync::OnceLock;
static TITANIC_DATA: OnceLock<(
Array1<&'static str>,
Array1<&'static str>,
Array2<String>,
Array2<f64>,
)> = OnceLock::new();
fn load_titanic_internal() -> (
Array1<&'static str>,
Array1<&'static str>,
Array2<String>,
Array2<f64>,
) {
let raw_data = load_titanic_raw_data();
let lines: Vec<&str> = raw_data.trim().lines().collect();
if lines.is_empty() {
panic!("No data found");
}
let all_headers: Vec<&str> = lines[0].trim().split(',').collect();
let mut all_rows = Vec::new();
for line in lines.iter().skip(1) {
if line.trim().is_empty() {
continue;
}
let mut cols = Vec::new();
let mut in_quotes = false;
let mut current_col = String::new();
let mut chars = line.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'"' => {
in_quotes = !in_quotes;
}
',' if !in_quotes => {
cols.push(current_col.trim().to_string());
current_col.clear();
}
_ => {
current_col.push(ch);
}
}
}
cols.push(current_col.trim().to_string());
all_rows.push(cols);
}
if all_rows.is_empty() {
panic!("No data rows found");
}
let num_cols = all_headers.len();
let mut is_numeric = vec![true; num_cols];
for row in &all_rows {
for (col_idx, value) in row.iter().enumerate() {
if col_idx >= num_cols {
continue;
}
if value.is_empty() {
continue;
}
if all_headers[col_idx].trim() == "Sex" && (value == "male" || value == "female") {
continue; }
if value.parse::<f64>().is_err() {
is_numeric[col_idx] = false;
}
}
}
let mut string_headers = Vec::new();
let mut numeric_headers = Vec::new();
let mut string_indices = Vec::new();
let mut numeric_indices = Vec::new();
for (i, &header) in all_headers.iter().enumerate() {
if is_numeric[i] {
numeric_headers.push(header.trim());
numeric_indices.push(i);
} else {
string_headers.push(header.trim());
string_indices.push(i);
}
}
let mut string_features = Vec::new();
let mut numeric_features = Vec::new();
let row_count = all_rows.len();
for row in &all_rows {
for &idx in &string_indices {
if idx < row.len() {
string_features.push(row[idx].clone());
} else {
string_features.push(String::new());
}
}
for &idx in &numeric_indices {
if idx < row.len() {
let value = if all_headers[idx].trim() == "Sex" {
match row[idx].as_str() {
"female" => 0.0,
"male" => 1.0,
_ => 0.0, }
} else {
row[idx].parse::<f64>().unwrap_or(0.0)
};
numeric_features.push(value);
} else {
numeric_features.push(0.0);
}
}
}
let string_headers_array = Array1::from_vec(string_headers);
let numeric_headers_array = Array1::from_vec(numeric_headers);
let string_features_array =
Array2::from_shape_vec((row_count, string_indices.len()), string_features).unwrap();
let numeric_features_array =
Array2::from_shape_vec((row_count, numeric_indices.len()), numeric_features).unwrap();
(
string_headers_array,
numeric_headers_array,
string_features_array,
numeric_features_array,
)
}
pub fn load_titanic() -> (
&'static Array1<&'static str>,
&'static Array1<&'static str>,
&'static Array2<String>,
&'static Array2<f64>,
) {
let (string_headers, numeric_headers, string_features, numeric_features) =
TITANIC_DATA.get_or_init(load_titanic_internal);
(
string_headers,
numeric_headers,
string_features,
numeric_features,
)
}
pub fn load_titanic_owned() -> (
Array1<&'static str>,
Array1<&'static str>,
Array2<String>,
Array2<f64>,
) {
let (string_headers, numeric_headers, string_features, numeric_features) = load_titanic();
(
string_headers.clone(),
numeric_headers.clone(),
string_features.clone(),
numeric_features.clone(),
)
}