Skip to main content

ferrule_core/
load.rs

1use ferrule_sql::render_value;
2use ferrule_sql::value::{TypeHint, Value};
3use ferrule_sql::{Backend, Connection, SqlError};
4
5/// Supported load formats.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum LoadFormat {
8    Csv,
9    Json,
10}
11
12impl LoadFormat {
13    pub fn parse(s: &str) -> Option<Self> {
14        match s.to_ascii_lowercase().as_str() {
15            "csv" => Some(Self::Csv),
16            "json" => Some(Self::Json),
17            _ => None,
18        }
19    }
20}
21
22/// Options for a load operation.
23#[derive(Debug, Clone)]
24pub struct LoadOptions {
25    pub format: LoadFormat,
26    pub table: String,
27    pub create_table: bool,
28    pub batch_size: usize,
29}
30
31impl Default for LoadOptions {
32    fn default() -> Self {
33        Self {
34            format: LoadFormat::Csv,
35            table: String::new(),
36            create_table: false,
37            batch_size: 1000,
38        }
39    }
40}
41
42/// Load data from a reader (CSV or JSON) into a table.
43pub fn load_data(
44    conn: &mut dyn Connection,
45    data: &str,
46    backend: Backend,
47    opts: &LoadOptions,
48) -> Result<usize, SqlError> {
49    match opts.format {
50        LoadFormat::Csv => load_csv(conn, data, backend, opts),
51        LoadFormat::Json => load_json(conn, data, backend, opts),
52    }
53}
54
55fn load_csv(
56    conn: &mut dyn Connection,
57    data: &str,
58    backend: Backend,
59    opts: &LoadOptions,
60) -> Result<usize, SqlError> {
61    let mut rdr = csv::Reader::from_reader(data.as_bytes());
62    let headers: Vec<String> = rdr
63        .headers()
64        .map_err(|e| SqlError::QueryFailed(e.to_string()))?
65        .iter()
66        .map(|s| s.to_string())
67        .collect();
68    let quoted_table = quote_identifier(&opts.table);
69    let quoted_cols: Vec<String> = headers.iter().map(|h| quote_identifier(h)).collect();
70    let cols = quoted_cols.join(", ");
71
72    let mut total = 0usize;
73    let mut batch = Vec::new();
74    for result in rdr.records() {
75        let record = result.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
76        let values: Vec<String> = record
77            .iter()
78            .map(|s| render_value(&Value::String(s.to_string()), backend))
79            .collect();
80        batch.push(format!("({})", values.join(", ")));
81        if batch.len() >= opts.batch_size {
82            let sql = format!(
83                "INSERT INTO {quoted_table} ({cols}) VALUES {};",
84                batch.join(", ")
85            );
86            conn.execute(&sql)?;
87            total += batch.len();
88            batch.clear();
89        }
90    }
91    if !batch.is_empty() {
92        let sql = format!(
93            "INSERT INTO {quoted_table} ({cols}) VALUES {};",
94            batch.join(", ")
95        );
96        conn.execute(&sql)?;
97        total += batch.len();
98    }
99    Ok(total)
100}
101
102fn load_json(
103    conn: &mut dyn Connection,
104    data: &str,
105    backend: Backend,
106    opts: &LoadOptions,
107) -> Result<usize, SqlError> {
108    let arr: Vec<serde_json::Value> =
109        serde_json::from_str(data).map_err(|e| SqlError::QueryFailed(e.to_string()))?;
110    if arr.is_empty() {
111        return Ok(0);
112    }
113
114    // Infer columns from first object
115    let first = arr[0]
116        .as_object()
117        .ok_or_else(|| SqlError::QueryFailed("JSON array must contain objects".into()))?;
118    let columns: Vec<String> = first.keys().cloned().collect();
119    let quoted_table = quote_identifier(&opts.table);
120    let quoted_cols: Vec<String> = columns.iter().map(|c| quote_identifier(c)).collect();
121    let cols = quoted_cols.join(", ");
122
123    if opts.create_table {
124        let schema = infer_schema(&arr, backend);
125        let create = build_create_table(&opts.table, &schema, backend);
126        conn.execute(&create)?;
127    }
128
129    let mut total = 0usize;
130    let mut batch = Vec::new();
131    for obj in &arr {
132        if let Some(map) = obj.as_object() {
133            let values: Vec<String> = columns
134                .iter()
135                .map(|c| {
136                    let val = map.get(c).cloned().unwrap_or(serde_json::Value::Null);
137                    render_value(&json_to_value(&val), backend)
138                })
139                .collect();
140            batch.push(format!("({})", values.join(", ")));
141            if batch.len() >= opts.batch_size {
142                let sql = format!(
143                    "INSERT INTO {quoted_table} ({cols}) VALUES {};",
144                    batch.join(", ")
145                );
146                conn.execute(&sql)?;
147                total += batch.len();
148                batch.clear();
149            }
150        }
151    }
152    if !batch.is_empty() {
153        let sql = format!(
154            "INSERT INTO {quoted_table} ({cols}) VALUES {};",
155            batch.join(", ")
156        );
157        conn.execute(&sql)?;
158        total += batch.len();
159    }
160    Ok(total)
161}
162
163fn json_to_value(v: &serde_json::Value) -> Value {
164    match v {
165        serde_json::Value::Null => Value::Null,
166        serde_json::Value::Bool(b) => Value::Bool(*b),
167        serde_json::Value::Number(n) => {
168            if let Some(i) = n.as_i64() {
169                Value::Int64(i)
170            } else if let Some(f) = n.as_f64() {
171                if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
172                    Value::Int64(f as i64)
173                } else {
174                    Value::Float64(f)
175                }
176            } else {
177                Value::String(n.to_string())
178            }
179        }
180        serde_json::Value::String(s) => Value::String(s.clone()),
181        serde_json::Value::Array(a) => Value::Array(a.iter().map(json_to_value).collect()),
182        serde_json::Value::Object(_) => Value::String(v.to_string()),
183    }
184}
185
186/// Infer a schema from a slice of JSON objects.
187pub fn infer_schema(objects: &[serde_json::Value], backend: Backend) -> Vec<(String, TypeHint)> {
188    let mut schema = Vec::new();
189    if objects.is_empty() {
190        return schema;
191    }
192    if let Some(first) = objects[0].as_object() {
193        for (key, val) in first {
194            let hint = infer_json_type(val, backend);
195            schema.push((key.clone(), hint));
196        }
197    }
198    schema
199}
200
201#[cfg_attr(not(feature = "oracle"), allow(unused_variables))]
202fn infer_json_type(val: &serde_json::Value, backend: Backend) -> TypeHint {
203    match val {
204        serde_json::Value::Null => TypeHint::String,
205        serde_json::Value::Bool(_) => {
206            #[cfg(feature = "oracle")]
207            if matches!(backend, Backend::Oracle) {
208                return TypeHint::Int64;
209            }
210            TypeHint::Bool
211        }
212        serde_json::Value::Number(n) => {
213            if let Some(_i) = n.as_i64() {
214                TypeHint::Int64
215            } else {
216                TypeHint::Float64
217            }
218        }
219        serde_json::Value::String(_) => TypeHint::String,
220        serde_json::Value::Array(_) => TypeHint::Array,
221        serde_json::Value::Object(_) => TypeHint::Json,
222    }
223}
224
225fn build_create_table(table: &str, schema: &[(String, TypeHint)], backend: Backend) -> String {
226    let quoted_table = quote_identifier(table);
227    let cols: Vec<String> = schema
228        .iter()
229        .map(|(name, hint)| {
230            let quoted_name = quote_identifier(name);
231            let sql_type = type_hint_to_sql(hint, backend);
232            format!("{} {}", quoted_name, sql_type)
233        })
234        .collect();
235    format!("CREATE TABLE {quoted_table} ({});", cols.join(", "))
236}
237
238#[cfg_attr(not(feature = "oracle"), allow(unused_variables))]
239fn type_hint_to_sql(hint: &TypeHint, backend: Backend) -> &'static str {
240    match hint {
241        TypeHint::Int64 => "INTEGER",
242        TypeHint::Float64 | TypeHint::Decimal => "NUMERIC(18,6)",
243        TypeHint::Bool => {
244            #[cfg(feature = "oracle")]
245            if matches!(backend, Backend::Oracle) {
246                return "NUMBER(1)";
247            }
248            "BOOLEAN"
249        }
250        TypeHint::Json => {
251            #[cfg(feature = "oracle")]
252            if matches!(backend, Backend::Oracle) {
253                return "CLOB";
254            }
255            "TEXT"
256        }
257        TypeHint::String | TypeHint::Null | TypeHint::Uuid => {
258            #[cfg(feature = "oracle")]
259            if matches!(backend, Backend::Oracle) {
260                return "VARCHAR2(4000)";
261            }
262            "TEXT"
263        }
264        _ => {
265            #[cfg(feature = "oracle")]
266            if matches!(backend, Backend::Oracle) {
267                return "VARCHAR2(4000)";
268            }
269            "TEXT"
270        }
271    }
272}
273
274fn quote_identifier(id: &str) -> String {
275    format!("\"{}\"", id.replace('\"', "\"\""))
276}