use ndarray::prelude::*;
use tangram_table::{
	NumberTableColumn, TableColumn, TableColumnView, TableValue, TextTableColumnView,
};
use tangram_text::{Tokenizer, WordEmbeddingModel};
#[derive(Clone, Debug)]
pub struct WordEmbeddingFeatureGroup {
		pub source_column_name: String,
		pub tokenizer: Tokenizer,
		pub model: WordEmbeddingModel,
}
impl WordEmbeddingFeatureGroup {
	pub fn compute_table(
		&self,
		column: tangram_table::TableColumnView,
		progress: &impl Fn(u64),
	) -> Vec<TableColumn> {
		match column {
			TableColumnView::Unknown(_) => unimplemented!(),
			TableColumnView::Number(_) => unimplemented!(),
			TableColumnView::Enum(_) => unimplemented!(),
			TableColumnView::Text(column) => {
				self.compute_table_for_text_column(column, &|| progress(1))
			}
		}
	}
	pub fn compute_array_f32(
		&self,
		features: ArrayViewMut2<f32>,
		column: tangram_table::TableColumnView,
		progress: &impl Fn(),
	) {
		match column {
			TableColumnView::Unknown(_) => unimplemented!(),
			TableColumnView::Number(_) => unimplemented!(),
			TableColumnView::Enum(_) => unimplemented!(),
			TableColumnView::Text(column) => {
				self.compute_array_f32_for_text_column(features, column, progress)
			}
		}
	}
	pub fn compute_array_value(
		&self,
		features: ArrayViewMut2<tangram_table::TableValue>,
		column: tangram_table::TableColumnView,
		progress: &impl Fn(),
	) {
		match column {
			TableColumnView::Unknown(_) => unimplemented!(),
			TableColumnView::Number(_) => unimplemented!(),
			TableColumnView::Enum(_) => unimplemented!(),
			TableColumnView::Text(column) => {
				self.compute_array_value_for_text_column(features, column, progress)
			}
		}
	}
	fn compute_array_f32_for_text_column(
		&self,
		mut features: ArrayViewMut2<f32>,
		column: TextTableColumnView,
		progress: &impl Fn(),
	) {
				features.fill(0.0);
				for (example_index, value) in column.iter().enumerate() {
						let mut count = 0;
			for token in self.tokenizer.tokenize(value) {
								if let Some(embedding) = self.model.get(token.as_ref()) {
					count += 1;
					for (index, value) in embedding.iter().enumerate() {
						*features.get_mut([example_index, index]).unwrap() += value;
					}
				}
			}
						if count > 0 {
				for feature_column_value in features.row_mut(example_index).iter_mut() {
					*feature_column_value /= count as f32;
				}
			}
			progress();
		}
	}
	fn compute_array_value_for_text_column(
		&self,
		mut features: ArrayViewMut2<TableValue>,
		column: TextTableColumnView,
		progress: &impl Fn(),
	) {
				for feature in features.iter_mut() {
			*feature = TableValue::Number(0.0);
		}
				for (example_index, value) in column.iter().enumerate() {
						let mut count = 0;
			for token in self.tokenizer.tokenize(value) {
								if let Some(embedding) = self.model.get(token.as_ref()) {
					count += 1;
					for (index, value) in embedding.iter().enumerate() {
						*features
							.get_mut([example_index, index])
							.unwrap()
							.as_number_mut()
							.unwrap() += value;
					}
				}
			}
			if count > 0 {
								for feature_column_value in features.row_mut(example_index).iter_mut() {
					*feature_column_value.as_number_mut().unwrap() /= count as f32;
				}
			}
			progress();
		}
	}
	fn compute_table_for_text_column(
		&self,
		column: tangram_table::TextTableColumnView,
		progress: &impl Fn(),
	) -> Vec<TableColumn> {
		let mut feature_columns = vec![vec![0.0; column.len()]; self.model.size];
		for (example_index, value) in column.iter().enumerate() {
						let tokenizer = self.tokenizer.tokenize(value);
			let mut count = 0;
			for token in tokenizer {
								if let Some(embedding) = self.model.get(token.as_ref()) {
					count += 1;
					for (index, value) in embedding.iter().enumerate() {
						feature_columns[index][example_index] += value;
					}
				}
			}
			if count > 0 {
								for feature_column in feature_columns.iter_mut() {
					feature_column[example_index] /= count as f32;
				}
			}
			progress();
		}
		feature_columns
			.into_iter()
			.map(|feature_column| TableColumn::Number(NumberTableColumn::new(None, feature_column)))
			.collect()
	}
}