lbl/
collections.rs

1use crate::types::*;
2use polars::prelude::*;
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6/// convert address column to hex
7pub fn address_to_hex(data: DataFrame) -> Result<DataFrame, LblError> {
8    let address_column = data.column("address")?;
9    if let DataType::Binary = address_column.dtype() {
10        let address = data.column("address")?.binary()?.hex_encode();
11        let address = address.str()?;
12        let prefix: Series = vec!["0x".to_string(); address.len()].into_iter().collect();
13        let prefix = prefix.with_name("address");
14        let prefix = prefix.str()?;
15        let data = data.clone().with_column(prefix.concat(address))?.clone();
16        Ok(data)
17    } else {
18        Ok(data)
19    }
20}
21
22/// convert address column to binary
23pub fn address_to_binary(data: DataFrame) -> Result<DataFrame, LblError> {
24    let mut data = data.clone();
25    let address_column = data.column("address")?;
26    if let DataType::String = address_column.dtype() {
27        let height = address_column.len();
28        let offset: Series = vec![2; height].iter().collect();
29        let length: Series = vec![42; height].iter().collect();
30        let trimmed_prefix = address_column.str()?.str_slice(&offset, &length)?;
31        data = data.with_column(trimmed_prefix.hex_decode(true)?)?.clone();
32        Ok(data)
33    } else {
34        Ok(data)
35    }
36}
37
38pub(crate) fn get_standard_columns() -> HashMap<String, DataType> {
39    [
40        ("address", DataType::Binary),
41        ("collection", DataType::String),
42        ("name", DataType::String),
43        ("project", DataType::String),
44        ("class", DataType::String),
45        ("network", DataType::String),
46        ("extra_data", DataType::String),
47        ("added_by", DataType::String),
48        ("date_added", DataType::String),
49    ]
50    .into_iter()
51    .map(|(s, dt)| (s.to_string(), dt))
52    .collect()
53}
54
55pub(crate) fn get_standard_column_order() -> Vec<String> {
56    vec![
57        "address".to_string(),
58        "collection".to_string(),
59        "name".to_string(),
60        "project".to_string(),
61        "class".to_string(),
62        "network".to_string(),
63        "extra_data".to_string(),
64        "added_by".to_string(),
65        "date_added".to_string(),
66    ]
67}
68
69/// standardize raw collection data
70pub fn standardize_collection(
71    df: DataFrame,
72    metadata: &CollectionData,
73) -> Result<DataFrame, LblError> {
74    let mut df = df.clone();
75    if let (Err(_), Some(collection)) = (df.column("collection"), metadata.collection.clone()) {
76        let column: Series = vec![collection.clone(); df.height()].into_iter().collect();
77        df = df.with_column(column.with_name("collection"))?.clone();
78    };
79
80    let df_columns = df.get_column_names();
81    let mut df = df.clone();
82    let standard_columns = get_standard_columns();
83
84    // add missing columns
85    for (column, dtype) in standard_columns.iter() {
86        if !df_columns.contains(&column.as_str()) {
87            let series = create_null_column(column.clone(), dtype.clone(), df.height());
88            df = df.with_column(series)?.clone();
89        }
90    }
91
92    // remove extra columns
93    for column in df_columns {
94        if !standard_columns.contains_key(column) {
95            df = df.drop(column)?;
96        }
97    }
98
99    // convert address column to binary
100    let address_column = df.column("address")?;
101    if let DataType::String = address_column.dtype() {
102        let height = address_column.len();
103        let offset: Series = vec![2; height].iter().collect();
104        let length: Series = vec![42; height].iter().collect();
105        let trimmed_prefix = address_column.str()?.str_slice(&offset, &length)?;
106        df = df.with_column(trimmed_prefix.hex_decode(true)?)?.clone();
107    };
108
109    // reorder columns
110    // let columns: Vec<String> = standard_columns.into_keys().collect();
111    let columns = get_standard_column_order();
112    let df = df.select(columns)?;
113
114    Ok(df)
115}
116
117fn create_null_column(name: String, dtype: DataType, len: usize) -> Series {
118    match dtype {
119        DataType::String => Series::new(
120            name.as_str(),
121            std::iter::repeat(None::<&str>)
122                .take(len)
123                .collect::<Vec<Option<&str>>>()
124                .as_slice(),
125        ),
126        DataType::Binary => Series::new(
127            name.as_str(),
128            std::iter::repeat(None::<&[u8]>)
129                .take(len)
130                .collect::<Vec<Option<&[u8]>>>()
131                .as_slice(),
132        ),
133        _ => unimplemented!("Unsupported data type"),
134    }
135}
136
137/// get row counts across paths
138pub fn get_row_counts(paths: Vec<PathBuf>) -> Result<i64, LblError> {
139    let paths_by_extension = crate::filesystem::paths_by_extension(paths);
140    let mut row_count = 0;
141    for (extension, extension_paths) in paths_by_extension.into_iter() {
142        match extension.as_str() {
143            "csv" => {
144                let arc_vec: Arc<Vec<PathBuf>> = Arc::new(extension_paths);
145                let arc_slice: Arc<[PathBuf]> =
146                    Arc::from(arc_vec.as_ref().clone().into_boxed_slice());
147                let lf = LazyCsvReader::new_paths(arc_slice)
148                    .finish()?
149                    .select([col("address")])
150                    .count()
151                    .select([col("address").cast(DataType::Int64)])
152                    .collect()?;
153                row_count += lf.column("address")?.i64()?.get(0).unwrap_or(0);
154            }
155            "parquet" => {
156                let arc_vec: Arc<Vec<PathBuf>> = Arc::new(extension_paths);
157                let arc_slice: Arc<[PathBuf]> =
158                    Arc::from(arc_vec.as_ref().clone().into_boxed_slice());
159                let opts = ScanArgsParquet::default();
160                let lf = LazyFrame::scan_parquet_files(arc_slice, opts)?
161                    .select([col("address")])
162                    .count()
163                    .select([col("address").cast(DataType::Int64)])
164                    .collect()?;
165                row_count += lf.column("address")?.i64()?.get(0).unwrap_or(0);
166            }
167            other => {
168                return Err(LblError::LblError(format!(
169                    "Unknown file extension: {}",
170                    other
171                )))
172            }
173        }
174    }
175    Ok(row_count)
176}
177
178/// write collection file
179pub fn write_file(df: &DataFrame, path: &PathBuf) -> Result<(), LblError> {
180    if let Some(parent) = path.parent() {
181        std::fs::create_dir_all(parent)
182            .map_err(|e| LblError::LblError(format!("could not create directory: {:?}", e)))?;
183    }
184
185    let extension = path
186        .extension()
187        .and_then(std::ffi::OsStr::to_str)
188        .unwrap_or("");
189    let file = std::fs::File::create(path)
190        .map_err(|e| LblError::LblError(format!("could not write file: {:?}", e)))?;
191    let mut df = df.clone();
192
193    match extension {
194        "csv" => {
195            let mut df = address_to_hex(df)?;
196            CsvWriter::new(file).finish(&mut df)?
197        }
198        "parquet" => {
199            let n_row_groups = match df.height() {
200                0 => 1,
201                height if height < 100 => 1,
202                height if height < 6400 => height / 100,
203                _ => 64,
204            };
205            ParquetWriter::new(file)
206                .with_statistics(true)
207                .with_compression(ParquetCompression::Zstd(None))
208                .with_row_group_size(Some(n_row_groups))
209                .finish(&mut df)?;
210        }
211        _ => {
212            return Err(LblError::LblError(
213                "must select either parquet or csv".to_string(),
214            ))
215        }
216    };
217    Ok(())
218}