Skip to main content

kermit_ds/
relation.rs

1//! This module defines the `Relation` trait and file reading extensions.
2use {
3    arrow::array::AsArray,
4    kermit_iters::JoinIterable,
5    parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
6    std::{fmt, fs::File, path::Path},
7};
8
9/// Error type for relation file operations (CSV and Parquet).
10#[derive(Debug)]
11pub enum RelationError {
12    /// A CSV library error.
13    Csv(csv::Error),
14    /// A filesystem I/O error.
15    Io(std::io::Error),
16    /// A Parquet library error.
17    Parquet(parquet::errors::ParquetError),
18    /// An Arrow conversion error.
19    Arrow(arrow::error::ArrowError),
20    /// A data value that could not be converted (e.g. non-integer in a CSV).
21    InvalidData(String),
22}
23
24impl fmt::Display for RelationError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            | RelationError::Csv(e) => write!(f, "CSV error: {e}"),
28            | RelationError::Io(e) => write!(f, "I/O error: {e}"),
29            | RelationError::Parquet(e) => write!(f, "Parquet error: {e}"),
30            | RelationError::Arrow(e) => write!(f, "Arrow error: {e}"),
31            | RelationError::InvalidData(msg) => write!(f, "Invalid data: {msg}"),
32        }
33    }
34}
35
36impl std::error::Error for RelationError {
37    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
38        match self {
39            | RelationError::Csv(e) => Some(e),
40            | RelationError::Io(e) => Some(e),
41            | RelationError::Parquet(e) => Some(e),
42            | RelationError::Arrow(e) => Some(e),
43            | RelationError::InvalidData(_) => None,
44        }
45    }
46}
47
48impl From<csv::Error> for RelationError {
49    fn from(e: csv::Error) -> Self { RelationError::Csv(e) }
50}
51
52impl From<std::io::Error> for RelationError {
53    fn from(e: std::io::Error) -> Self { RelationError::Io(e) }
54}
55
56impl From<parquet::errors::ParquetError> for RelationError {
57    fn from(e: parquet::errors::ParquetError) -> Self { RelationError::Parquet(e) }
58}
59
60impl From<arrow::error::ArrowError> for RelationError {
61    fn from(e: arrow::error::ArrowError) -> Self { RelationError::Arrow(e) }
62}
63
64/// Whether a relation's attributes are identified by name or by position.
65pub enum ModelType {
66    /// Attributes are accessed by column index only.
67    Positional,
68    /// Attributes have explicit string names.
69    Named,
70}
71
72/// Metadata for a relation: its name, attribute names, and arity.
73///
74/// A header can be "positional" (no attribute names, only arity) or "named"
75/// (with explicit column names). A "nameless" header has an empty `name` field,
76/// used for intermediate/projected relations.
77#[derive(Clone, Debug)]
78pub struct RelationHeader {
79    name: String,
80    attrs: Vec<String>,
81    arity: usize,
82}
83
84impl RelationHeader {
85    /// Creates a new `RelationHeader` with the specified name, attributes, and
86    /// arity.
87    pub fn new(name: impl Into<String>, attrs: Vec<String>) -> Self {
88        let arity = attrs.len();
89        RelationHeader {
90            name: name.into(),
91            attrs,
92            arity,
93        }
94    }
95
96    /// Creates a nameless header with the given attribute names. Arity is
97    /// inferred from the length of `attrs`.
98    pub fn new_nameless(attrs: Vec<String>) -> Self {
99        let arity = attrs.len();
100        RelationHeader {
101            name: String::new(),
102            attrs,
103            arity,
104        }
105    }
106
107    /// Creates a named header with positional (unnamed) attributes.
108    pub fn new_positional(name: impl Into<String>, arity: usize) -> Self {
109        RelationHeader {
110            name: name.into(),
111            attrs: vec![],
112            arity,
113        }
114    }
115
116    /// Creates a nameless header with positional attributes of the given arity.
117    pub fn new_nameless_positional(arity: usize) -> Self {
118        RelationHeader {
119            name: String::new(),
120            attrs: vec![],
121            arity,
122        }
123    }
124
125    pub fn is_nameless(&self) -> bool { self.name.is_empty() }
126
127    pub fn name(&self) -> &str { &self.name }
128
129    pub fn attrs(&self) -> &[String] { &self.attrs }
130
131    pub fn arity(&self) -> usize { self.arity }
132
133    pub fn model_type(&self) -> ModelType {
134        if self.attrs.is_empty() {
135            ModelType::Positional
136        } else {
137            ModelType::Named
138        }
139    }
140}
141
142impl From<usize> for RelationHeader {
143    fn from(value: usize) -> RelationHeader { RelationHeader::new_nameless_positional(value) }
144}
145
146/// A relation that can produce a new relation containing only the specified
147/// columns.
148pub trait Projectable {
149    /// Returns a new relation containing only the columns at the given indices.
150    fn project(&self, columns: Vec<usize>) -> Self;
151}
152
153/// The `Relation` trait defines a relational data structure that can store and
154/// retrieve tuples of `usize` keys, and participate in join operations.
155pub trait Relation: JoinIterable + Projectable {
156    /// Returns the header (name, attributes, arity) of this relation.
157    fn header(&self) -> &RelationHeader;
158
159    /// Creates a new relation with the specified arity.
160    fn new(header: RelationHeader) -> Self;
161
162    /// Creates a new relation with the specified arity and given tuples.
163    fn from_tuples(header: RelationHeader, tuples: Vec<Vec<usize>>) -> Self;
164
165    /// Inserts a tuple into the relation, returning `true` if successful and
166    /// `false` if otherwise.
167    fn insert(&mut self, tuple: Vec<usize>) -> bool;
168
169    /// Inserts multiple tuples into the relation, returning `true` if
170    /// successful and `false` if otherwise.
171    fn insert_all(&mut self, tuples: Vec<Vec<usize>>) -> bool;
172}
173
174/// Extension trait for `Relation` to add file reading capabilities.
175pub trait RelationFileExt: Relation {
176    /// Creates a new relation from a Parquet file with header.
177    ///
178    /// This method extracts column names from the Parquet schema and the
179    /// relation name from the filename.
180    fn from_parquet<P: AsRef<Path>>(filepath: P) -> Result<Self, RelationError>
181    where
182        Self: Sized;
183
184    /// Creates a new relation from a CSV file.
185    ///
186    /// # Note
187    /// * Each line represents a tuple, and each value in the line should be
188    ///   parsable into `Relation::KT`.
189    fn from_csv<P: AsRef<Path>>(filepath: P) -> Result<Self, RelationError>
190    where
191        Self: Sized;
192}
193
194/// Blanket implementation of `RelationFileExt` for any type that
195/// implements `Relation`.
196impl<R> RelationFileExt for R
197where
198    R: Relation,
199{
200    fn from_csv<P: AsRef<Path>>(filepath: P) -> Result<Self, RelationError> {
201        let path = filepath.as_ref();
202        let file = File::open(path)?;
203
204        let mut rdr = csv::ReaderBuilder::new()
205            .has_headers(true)
206            .delimiter(b',')
207            .double_quote(false)
208            .escape(Some(b'\\'))
209            .flexible(false)
210            .comment(Some(b'#'))
211            .from_reader(file);
212
213        // Extract column names from CSV header
214        let attrs: Vec<String> = rdr.headers()?.iter().map(|s| s.to_string()).collect();
215
216        // Extract relation name from filename (without extension)
217        let relation_name = path
218            .file_stem()
219            .and_then(|s| s.to_str())
220            .unwrap_or("")
221            .to_string();
222
223        // Create header from the CSV header with the extracted name
224        let header = RelationHeader::new(relation_name, attrs);
225
226        let mut tuples = Vec::new();
227        for (row_idx, result) in rdr.records().enumerate() {
228            let record = result?;
229            let mut tuple: Vec<usize> = Vec::with_capacity(record.len());
230            for (col_idx, field) in record.iter().enumerate() {
231                let value = field.parse::<usize>().map_err(|_| {
232                    RelationError::InvalidData(format!(
233                        "row {row_idx}, column {col_idx}: cannot parse {:?} as usize",
234                        field,
235                    ))
236                })?;
237                tuple.push(value);
238            }
239            tuples.push(tuple);
240        }
241        Ok(R::from_tuples(header, tuples))
242    }
243
244    fn from_parquet<P: AsRef<Path>>(filepath: P) -> Result<Self, RelationError> {
245        let path = filepath.as_ref();
246        let file = File::open(path)?;
247
248        let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
249
250        // Extract schema to get column names
251        let schema = builder.schema();
252        let attrs: Vec<String> = schema
253            .fields()
254            .iter()
255            .map(|field| field.name().clone())
256            .collect();
257
258        // Extract relation name from filename (without extension)
259        let relation_name = path
260            .file_stem()
261            .and_then(|s| s.to_str())
262            .unwrap_or("")
263            .to_string();
264
265        // Create header from the parquet schema with the extracted name
266        let header = RelationHeader::new(relation_name, attrs);
267
268        // Build the reader
269        let reader = builder.build()?;
270
271        // Collect all tuples first for efficient construction
272        let mut tuples = Vec::new();
273
274        // Read all record batches and collect tuples
275        for batch_result in reader {
276            let batch = batch_result?;
277
278            let num_rows = batch.num_rows();
279            let num_cols = batch.num_columns();
280
281            // Convert columnar data to row format (tuples)
282            for row_idx in 0..num_rows {
283                let mut tuple: Vec<usize> = Vec::with_capacity(num_cols);
284
285                for col_idx in 0..num_cols {
286                    let column = batch.column(col_idx);
287                    let int_array = column.as_primitive::<arrow::datatypes::Int64Type>();
288
289                    if let Ok(value) = usize::try_from(int_array.value(row_idx)) {
290                        tuple.push(value);
291                    } else {
292                        return Err(RelationError::InvalidData(
293                            "failed to convert Parquet value to usize".into(),
294                        ));
295                    }
296                }
297
298                tuples.push(tuple);
299            }
300        }
301
302        // Use from_tuples for efficient construction (sorts before insertion)
303        Ok(R::from_tuples(header, tuples))
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    // ── RelationError Display ──────────────────────────────────────────
312
313    #[test]
314    fn relation_error_display_csv() {
315        let csv_err = csv::Error::from(std::io::Error::new(
316            std::io::ErrorKind::NotFound,
317            "file not found",
318        ));
319        let err = RelationError::from(csv_err);
320        let msg = err.to_string();
321        assert!(msg.starts_with("CSV error:"), "got: {msg}");
322    }
323
324    #[test]
325    fn relation_error_display_io() {
326        let err = RelationError::from(std::io::Error::new(std::io::ErrorKind::NotFound, "gone"));
327        assert!(err.to_string().starts_with("I/O error:"));
328    }
329
330    #[test]
331    fn relation_error_display_invalid_data() {
332        let err = RelationError::InvalidData("bad value".into());
333        assert_eq!(err.to_string(), "Invalid data: bad value");
334    }
335
336    #[test]
337    fn relation_error_source_delegates() {
338        use std::error::Error;
339
340        let io_err = std::io::Error::other("inner");
341        let err = RelationError::Io(io_err);
342        assert!(err.source().is_some());
343
344        let err = RelationError::InvalidData("no source".into());
345        assert!(err.source().is_none());
346    }
347
348    // ── from_csv error on invalid data ─────────────────────────────────
349
350    #[test]
351    fn from_csv_rejects_non_integer_values() {
352        use crate::ds::TreeTrie;
353
354        let dir = std::env::temp_dir();
355        let path = dir.join("test_csv_bad_value.csv");
356        std::fs::write(&path, "a,b\n1,2\n3,hello\n").unwrap();
357
358        let result = TreeTrie::from_csv(&path);
359        assert!(result.is_err(), "expected error for non-integer CSV value");
360
361        let err = result.unwrap_err();
362        let msg = err.to_string();
363        assert!(
364            msg.contains("hello"),
365            "error should mention the bad value, got: {msg}"
366        );
367        assert!(
368            msg.contains("row 1"),
369            "error should mention the row, got: {msg}"
370        );
371        assert!(
372            msg.contains("column 1"),
373            "error should mention the column, got: {msg}"
374        );
375
376        std::fs::remove_file(path).ok();
377    }
378
379    #[test]
380    fn from_csv_missing_file_returns_error() {
381        use crate::ds::TreeTrie;
382
383        let result = TreeTrie::from_csv("/tmp/nonexistent_kermit_test_file.csv");
384        assert!(result.is_err());
385        assert!(
386            matches!(result.unwrap_err(), RelationError::Io(_)),
387            "expected Io variant for missing file"
388        );
389    }
390
391    // ── from_parquet error paths ───────────────────────────────────────
392
393    #[test]
394    fn from_parquet_missing_file_returns_error() {
395        use crate::ds::TreeTrie;
396
397        let result = TreeTrie::from_parquet("/tmp/nonexistent_kermit_test_file.parquet");
398        assert!(result.is_err());
399        assert!(
400            matches!(result.unwrap_err(), RelationError::Io(_)),
401            "expected Io variant for missing file"
402        );
403    }
404
405    #[test]
406    fn from_parquet_invalid_file_returns_error() {
407        use crate::ds::TreeTrie;
408
409        let dir = std::env::temp_dir();
410        let path = dir.join("test_bad_parquet.parquet");
411        std::fs::write(&path, b"this is not a parquet file").unwrap();
412
413        let result = TreeTrie::from_parquet(&path);
414        assert!(result.is_err());
415        assert!(
416            matches!(result.unwrap_err(), RelationError::Parquet(_)),
417            "expected Parquet variant for corrupt file"
418        );
419
420        std::fs::remove_file(path).ok();
421    }
422}