1use std::collections::{HashMap, HashSet};
7use std::path::Path;
8use std::sync::Arc;
9
10use arrow::array::{
11 Array, ArrayRef, Float64Builder, Int64Array, Int64Builder, StringArray, StringBuilder,
12 UInt64Builder,
13};
14use arrow::datatypes::{DataType, Field, Schema};
15use arrow::record_batch::RecordBatch;
16
17use llkv_column_map::store::ROW_ID_COLUMN_NAME;
18use llkv_result::{Error, Result as LlkvResult};
19use llkv_storage::pager::Pager;
20use simd_r_drive_entry_handle::EntryHandle;
21
22use llkv_table::{ColMeta, Table, types::FieldId};
23
24use crate::inference::normalize_numeric_like;
25use crate::{CsvReadOptions, CsvReader};
26
27fn convert_row_id(array: &ArrayRef) -> LlkvResult<ArrayRef> {
31 match array.data_type() {
32 DataType::UInt64 => Ok(Arc::clone(array)),
33 DataType::Int64 => {
34 let int_array = array
35 .as_any()
36 .downcast_ref::<Int64Array>()
37 .ok_or_else(|| Error::InvalidArgumentError("row_id column is not Int64".into()))?;
38
39 if int_array.null_count() > 0 {
40 return Err(Error::InvalidArgumentError(
41 "row_id column cannot contain nulls".into(),
42 ));
43 }
44
45 let mut builder = UInt64Builder::with_capacity(int_array.len());
46 for i in 0..int_array.len() {
47 let value = int_array.value(i);
48 if value < 0 {
49 return Err(Error::InvalidArgumentError(
50 "row_id column must contain non-negative values".into(),
51 ));
52 }
53 builder.append_value(value as u64);
54 }
55
56 Ok(Arc::new(builder.finish()) as ArrayRef)
57 }
58 other => Err(Error::InvalidArgumentError(format!(
59 "row_id column must be Int64 or UInt64, got {other:?}"
60 ))),
61 }
62}
63
64fn ensure_supported_type(data_type: &DataType, column: &str) -> LlkvResult<()> {
65 llkv_column_map::ensure_supported_arrow_type(data_type).map_err(|err| match err {
66 Error::InvalidArgumentError(msg) => {
67 Error::InvalidArgumentError(format!("column '{column}': {msg}"))
68 }
69 other => other,
70 })
71}
72
73fn existing_column_mapping<P>(table: &Table<P>) -> HashMap<String, FieldId>
74where
75 P: Pager<Blob = EntryHandle> + Send + Sync,
76{
77 let logical_fields = table.store().user_field_ids_for_table(table.table_id());
78 if logical_fields.is_empty() {
79 return HashMap::new();
80 }
81
82 let mut field_ids: Vec<FieldId> = Vec::new();
83 for lfid in logical_fields {
84 let fid = lfid.field_id();
85 if fid != 0 {
86 field_ids.push(fid);
87 }
88 }
89
90 if field_ids.is_empty() {
91 return HashMap::new();
92 }
93
94 let metas = table.catalog().get_cols_meta(table.table_id(), &field_ids);
95 let mut mapping = HashMap::with_capacity(metas.len());
96 for (fid, meta_opt) in field_ids.into_iter().zip(metas.into_iter()) {
97 if let Some(meta) = meta_opt
98 && let Some(name) = meta.name
99 {
100 mapping.insert(name, fid);
101 }
102 }
103 mapping
104}
105
106fn infer_field_mapping<'a, P>(
107 table: &Table<P>,
108 schema: &'a Schema,
109 provided: Option<&'a HashMap<String, FieldId>>,
110) -> LlkvResult<HashMap<String, FieldId>>
111where
112 P: Pager<Blob = EntryHandle> + Send + Sync,
113{
114 let mut mapping = HashMap::new();
115 let mut existing = existing_column_mapping(table);
116 let mut used_ids: HashSet<FieldId> = HashSet::default();
120 let mut next_field_id: FieldId = existing.values().copied().max().unwrap_or(0);
121
122 for field in schema.fields() {
123 if field.name() == ROW_ID_COLUMN_NAME {
124 continue;
125 }
126
127 ensure_supported_type(field.data_type(), field.name())?;
128
129 let mut chosen: Option<FieldId> = None;
130 let mut should_register_meta = false;
131
132 if let Some(manual) = provided
133 && let Some(&fid) = manual.get(field.name())
134 {
135 if let Some(&existing_fid) = existing.get(field.name()) {
136 if existing_fid != fid {
137 return Err(Error::InvalidArgumentError(format!(
138 "column '{}' mapped to field_id {} but existing schema expects {}",
139 field.name(),
140 fid,
141 existing_fid
142 )));
143 }
144 } else {
145 should_register_meta = true;
146 }
147 chosen = Some(fid);
148 }
149
150 if chosen.is_none()
151 && let Some(&fid) = existing.get(field.name())
152 {
153 chosen = Some(fid);
154 }
155
156 if chosen.is_none() {
157 next_field_id = next_field_id
158 .checked_add(1)
159 .ok_or_else(|| Error::Internal("field_id overflow when inferring schema".into()))?;
160 let fid = next_field_id;
161 should_register_meta = true;
162 chosen = Some(fid);
163 }
164
165 let fid = chosen.unwrap();
166 if should_register_meta {
167 let meta = ColMeta {
168 col_id: fid,
169 name: Some(field.name().to_string()),
170 flags: 0,
171 default: None,
172 };
173 table.catalog().put_col_meta(table.table_id(), &meta);
174 existing.insert(field.name().to_string(), fid);
175 }
176 if fid == 0 {
177 return Err(Error::InvalidArgumentError(format!(
178 "column '{}' cannot map to reserved field_id 0",
179 field.name()
180 )));
181 }
182 if !used_ids.insert(fid) {
183 return Err(Error::InvalidArgumentError(format!(
184 "field_id {} assigned to multiple columns during schema inference",
185 fid
186 )));
187 }
188
189 mapping.insert(field.name().to_string(), fid);
190 }
191
192 Ok(mapping)
193}
194
195fn build_schema_with_metadata(
196 schema: &Schema,
197 field_mapping: &HashMap<String, FieldId>,
198) -> LlkvResult<(Arc<Schema>, usize)> {
199 let row_id_index = schema
200 .fields()
201 .iter()
202 .position(|f| f.name() == ROW_ID_COLUMN_NAME)
203 .ok_or_else(|| {
204 Error::InvalidArgumentError(format!(
205 "CSV schema must include a '{ROW_ID_COLUMN_NAME}' column"
206 ))
207 })?;
208
209 let mut fields_with_metadata = Vec::with_capacity(schema.fields().len());
210 for (idx, field) in schema.fields().iter().enumerate() {
211 if idx == row_id_index {
212 fields_with_metadata.push(Field::new(
213 ROW_ID_COLUMN_NAME,
214 DataType::UInt64,
215 field.is_nullable(),
216 ));
217 continue;
218 }
219
220 ensure_supported_type(field.data_type(), field.name())?;
221
222 let field_id = field_mapping.get(field.name()).ok_or_else(|| {
223 Error::InvalidArgumentError(format!(
224 "no field_id mapping provided for column '{}'",
225 field.name()
226 ))
227 })?;
228
229 let mut metadata = std::collections::HashMap::new();
230 metadata.insert(
231 llkv_table::constants::FIELD_ID_META_KEY.to_string(),
232 field_id.to_string(),
233 );
234
235 fields_with_metadata.push(
236 Field::new(field.name(), field.data_type().clone(), field.is_nullable())
237 .with_metadata(metadata),
238 );
239 }
240
241 Ok((Arc::new(Schema::new(fields_with_metadata)), row_id_index))
242}
243
244fn append_csv_into_table_internal<P, C>(
245 table: &Table<P>,
246 csv_path: C,
247 csv_options: &CsvReadOptions,
248 field_mapping_override: Option<&HashMap<String, FieldId>>,
249) -> LlkvResult<()>
250where
251 P: Pager<Blob = EntryHandle> + Send + Sync,
252 C: AsRef<Path>,
253{
254 let csv_path_ref = csv_path.as_ref();
255 let reader_builder = CsvReader::with_options(csv_options.clone());
256 let session = reader_builder
257 .open(csv_path_ref)
258 .map_err(|err| Error::Internal(format!("failed to open CSV: {err}")))?;
259 let target_schema = session.schema();
260 let type_overrides = session.type_overrides().to_vec();
261
262 let inferred_mapping =
263 infer_field_mapping(table, target_schema.as_ref(), field_mapping_override)?;
264 let (schema_with_metadata, row_id_index) =
265 build_schema_with_metadata(&target_schema, &inferred_mapping)?;
266
267 for batch_result in session {
268 let batch = batch_result
269 .map_err(|err| Error::Internal(format!("failed to read CSV batch: {err}")))?;
270
271 if batch.num_rows() == 0 {
272 continue;
273 }
274
275 let mut columns: Vec<ArrayRef> = batch.columns().to_vec();
279 for col in columns.iter_mut() {
280 if matches!(col.data_type(), DataType::LargeUtf8) {
281 let casted =
282 arrow::compute::cast(col.as_ref(), &DataType::Utf8).map_err(|err| {
283 Error::Internal(format!("failed to cast LargeUtf8 column to Utf8: {err}"))
284 })?;
285 *col = casted;
286 }
287 }
288 if let Some(token) = &csv_options.null_token {
289 let token_lower = token.to_lowercase();
290 for col in columns.iter_mut() {
291 if col.data_type() == &DataType::Utf8 {
292 let sarr = col
295 .as_any()
296 .downcast_ref::<StringArray>()
297 .expect("expected StringArray");
298 let mut builder = StringBuilder::with_capacity(sarr.len(), sarr.len() * 8);
299 for idx in 0..sarr.len() {
300 if sarr.is_null(idx) {
301 builder.append_null();
302 continue;
303 }
304 let v = sarr.value(idx);
305 if v.trim().to_lowercase() == token_lower {
306 builder.append_null();
307 } else {
308 builder.append_value(v);
309 }
310 }
311 *col = Arc::new(builder.finish());
312 }
313 }
314 }
315
316 for (idx, target_type_opt) in type_overrides.iter().enumerate() {
317 if idx == row_id_index {
318 continue;
319 }
320 let Some(target_type) = target_type_opt else {
321 continue;
322 };
323
324 if columns[idx].data_type() == target_type {
325 continue;
326 }
327
328 match (columns[idx].data_type(), target_type) {
329 (DataType::Utf8, DataType::Float64) => {
330 let sarr = columns[idx]
331 .as_any()
332 .downcast_ref::<StringArray>()
333 .ok_or_else(|| {
334 Error::Internal(format!(
335 "expected StringArray for column '{}' during Float64 conversion",
336 target_schema.field(idx).name()
337 ))
338 })?;
339
340 let mut builder = Float64Builder::with_capacity(sarr.len());
341 for row_idx in 0..sarr.len() {
342 if sarr.is_null(row_idx) {
343 builder.append_null();
344 continue;
345 }
346 let v = sarr.value(row_idx);
347 if let Some((cleaned, _)) = normalize_numeric_like(v) {
348 match cleaned.parse::<f64>() {
349 Ok(parsed) => builder.append_value(parsed),
350 Err(_) => {
351 return Err(Error::InvalidArgumentError(format!(
352 "failed to parse '{}' as Float64 in column '{}'",
353 v,
354 target_schema.field(idx).name()
355 )));
356 }
357 }
358 } else {
359 builder.append_null();
360 }
361 }
362 columns[idx] = Arc::new(builder.finish());
363 }
364 (DataType::Utf8, DataType::Int64) => {
365 let sarr = columns[idx]
366 .as_any()
367 .downcast_ref::<StringArray>()
368 .ok_or_else(|| {
369 Error::Internal(format!(
370 "expected StringArray for column '{}' during Int64 conversion",
371 target_schema.field(idx).name()
372 ))
373 })?;
374
375 let mut builder = Int64Builder::with_capacity(sarr.len());
376 for row_idx in 0..sarr.len() {
377 if sarr.is_null(row_idx) {
378 builder.append_null();
379 continue;
380 }
381 let v = sarr.value(row_idx);
382 if let Some((cleaned, has_decimal)) = normalize_numeric_like(v) {
383 if has_decimal {
384 return Err(Error::InvalidArgumentError(format!(
385 "value '{}' in column '{}' contains decimals but column inferred as Int64",
386 v,
387 target_schema.field(idx).name()
388 )));
389 }
390 match cleaned.parse::<i64>() {
391 Ok(parsed) => builder.append_value(parsed),
392 Err(_) => {
393 return Err(Error::InvalidArgumentError(format!(
394 "failed to parse '{}' as Int64 in column '{}'",
395 v,
396 target_schema.field(idx).name()
397 )));
398 }
399 }
400 } else {
401 builder.append_null();
402 }
403 }
404 columns[idx] = Arc::new(builder.finish());
405 }
406 _ => {
407 let casted = arrow::compute::cast(columns[idx].as_ref(), target_type).map_err(
409 |err| {
410 Error::Internal(format!(
411 "failed to cast column '{}' to {:?}: {err}",
412 target_schema.field(idx).name(),
413 target_type
414 ))
415 },
416 )?;
417 columns[idx] = casted;
418 }
419 }
420 }
421
422 let row_id_array = convert_row_id(&columns[row_id_index])?;
423 columns[row_id_index] = row_id_array;
424
425 let new_batch = RecordBatch::try_new(Arc::clone(&schema_with_metadata), columns)?;
426 table.append(&new_batch)?;
427 }
428
429 for (col_name, fid) in inferred_mapping.iter() {
435 let metas = table.get_cols_meta(&[*fid]);
436 let need_put = match metas.first() {
437 Some(Some(meta)) => meta.name.is_none(),
438 _ => true,
439 };
440 if need_put {
441 let meta = ColMeta {
442 col_id: *fid,
443 name: Some(col_name.clone()),
444 flags: 0,
445 default: None,
446 };
447 table.catalog().put_col_meta(table.table_id(), &meta);
448 }
449 }
450
451 Ok(())
452}
453
454pub fn append_csv_into_table<P, C>(
456 table: &Table<P>,
457 csv_path: C,
458 csv_options: &CsvReadOptions,
459) -> LlkvResult<()>
460where
461 P: Pager<Blob = EntryHandle> + Send + Sync,
462 C: AsRef<Path>,
463{
464 append_csv_into_table_internal(table, csv_path, csv_options, None)
465}
466
467pub fn append_csv_into_table_with_mapping<P, C>(
469 table: &Table<P>,
470 csv_path: C,
471 field_mapping: &HashMap<String, FieldId>,
472 csv_options: &CsvReadOptions,
473) -> LlkvResult<()>
474where
475 P: Pager<Blob = EntryHandle> + Send + Sync,
476 C: AsRef<Path>,
477{
478 append_csv_into_table_internal(table, csv_path, csv_options, Some(field_mapping))
479}