1use ferrule_sql::render_value;
2use ferrule_sql::value::{TypeHint, Value};
3use ferrule_sql::{Backend, Connection, SqlError};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum LoadFormat {
8 Csv,
9 Json,
10}
11
12impl LoadFormat {
13 pub fn parse(s: &str) -> Option<Self> {
14 match s.to_ascii_lowercase().as_str() {
15 "csv" => Some(Self::Csv),
16 "json" => Some(Self::Json),
17 _ => None,
18 }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct LoadOptions {
25 pub format: LoadFormat,
26 pub table: String,
27 pub create_table: bool,
28 pub batch_size: usize,
29}
30
31impl Default for LoadOptions {
32 fn default() -> Self {
33 Self {
34 format: LoadFormat::Csv,
35 table: String::new(),
36 create_table: false,
37 batch_size: 1000,
38 }
39 }
40}
41
42pub fn load_data(
44 conn: &mut dyn Connection,
45 data: &str,
46 backend: Backend,
47 opts: &LoadOptions,
48) -> Result<usize, SqlError> {
49 match opts.format {
50 LoadFormat::Csv => load_csv(conn, data, backend, opts),
51 LoadFormat::Json => load_json(conn, data, backend, opts),
52 }
53}
54
55fn load_csv(
56 conn: &mut dyn Connection,
57 data: &str,
58 backend: Backend,
59 opts: &LoadOptions,
60) -> Result<usize, SqlError> {
61 let mut rdr = csv::Reader::from_reader(data.as_bytes());
62 let headers: Vec<String> = rdr
63 .headers()
64 .map_err(|e| SqlError::QueryFailed(e.to_string()))?
65 .iter()
66 .map(|s| s.to_string())
67 .collect();
68 let quoted_table = quote_identifier(&opts.table);
69 let quoted_cols: Vec<String> = headers.iter().map(|h| quote_identifier(h)).collect();
70 let cols = quoted_cols.join(", ");
71
72 let mut total = 0usize;
73 let mut batch = Vec::new();
74 for result in rdr.records() {
75 let record = result.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
76 let values: Vec<String> = record
77 .iter()
78 .map(|s| render_value(&Value::String(s.to_string()), backend))
79 .collect();
80 batch.push(format!("({})", values.join(", ")));
81 if batch.len() >= opts.batch_size {
82 let sql = format!(
83 "INSERT INTO {quoted_table} ({cols}) VALUES {};",
84 batch.join(", ")
85 );
86 conn.execute(&sql)?;
87 total += batch.len();
88 batch.clear();
89 }
90 }
91 if !batch.is_empty() {
92 let sql = format!(
93 "INSERT INTO {quoted_table} ({cols}) VALUES {};",
94 batch.join(", ")
95 );
96 conn.execute(&sql)?;
97 total += batch.len();
98 }
99 Ok(total)
100}
101
102fn load_json(
103 conn: &mut dyn Connection,
104 data: &str,
105 backend: Backend,
106 opts: &LoadOptions,
107) -> Result<usize, SqlError> {
108 let arr: Vec<serde_json::Value> =
109 serde_json::from_str(data).map_err(|e| SqlError::QueryFailed(e.to_string()))?;
110 if arr.is_empty() {
111 return Ok(0);
112 }
113
114 let first = arr[0]
116 .as_object()
117 .ok_or_else(|| SqlError::QueryFailed("JSON array must contain objects".into()))?;
118 let columns: Vec<String> = first.keys().cloned().collect();
119 let quoted_table = quote_identifier(&opts.table);
120 let quoted_cols: Vec<String> = columns.iter().map(|c| quote_identifier(c)).collect();
121 let cols = quoted_cols.join(", ");
122
123 if opts.create_table {
124 let schema = infer_schema(&arr, backend);
125 let create = build_create_table(&opts.table, &schema, backend);
126 conn.execute(&create)?;
127 }
128
129 let mut total = 0usize;
130 let mut batch = Vec::new();
131 for obj in &arr {
132 if let Some(map) = obj.as_object() {
133 let values: Vec<String> = columns
134 .iter()
135 .map(|c| {
136 let val = map.get(c).cloned().unwrap_or(serde_json::Value::Null);
137 render_value(&json_to_value(&val), backend)
138 })
139 .collect();
140 batch.push(format!("({})", values.join(", ")));
141 if batch.len() >= opts.batch_size {
142 let sql = format!(
143 "INSERT INTO {quoted_table} ({cols}) VALUES {};",
144 batch.join(", ")
145 );
146 conn.execute(&sql)?;
147 total += batch.len();
148 batch.clear();
149 }
150 }
151 }
152 if !batch.is_empty() {
153 let sql = format!(
154 "INSERT INTO {quoted_table} ({cols}) VALUES {};",
155 batch.join(", ")
156 );
157 conn.execute(&sql)?;
158 total += batch.len();
159 }
160 Ok(total)
161}
162
163fn json_to_value(v: &serde_json::Value) -> Value {
164 match v {
165 serde_json::Value::Null => Value::Null,
166 serde_json::Value::Bool(b) => Value::Bool(*b),
167 serde_json::Value::Number(n) => {
168 if let Some(i) = n.as_i64() {
169 Value::Int64(i)
170 } else if let Some(f) = n.as_f64() {
171 if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
172 Value::Int64(f as i64)
173 } else {
174 Value::Float64(f)
175 }
176 } else {
177 Value::String(n.to_string())
178 }
179 }
180 serde_json::Value::String(s) => Value::String(s.clone()),
181 serde_json::Value::Array(a) => Value::Array(a.iter().map(json_to_value).collect()),
182 serde_json::Value::Object(_) => Value::String(v.to_string()),
183 }
184}
185
186pub fn infer_schema(objects: &[serde_json::Value], backend: Backend) -> Vec<(String, TypeHint)> {
188 let mut schema = Vec::new();
189 if objects.is_empty() {
190 return schema;
191 }
192 if let Some(first) = objects[0].as_object() {
193 for (key, val) in first {
194 let hint = infer_json_type(val, backend);
195 schema.push((key.clone(), hint));
196 }
197 }
198 schema
199}
200
201#[cfg_attr(not(feature = "oracle"), allow(unused_variables))]
202fn infer_json_type(val: &serde_json::Value, backend: Backend) -> TypeHint {
203 match val {
204 serde_json::Value::Null => TypeHint::String,
205 serde_json::Value::Bool(_) => {
206 #[cfg(feature = "oracle")]
207 if matches!(backend, Backend::Oracle) {
208 return TypeHint::Int64;
209 }
210 TypeHint::Bool
211 }
212 serde_json::Value::Number(n) => {
213 if let Some(_i) = n.as_i64() {
214 TypeHint::Int64
215 } else {
216 TypeHint::Float64
217 }
218 }
219 serde_json::Value::String(_) => TypeHint::String,
220 serde_json::Value::Array(_) => TypeHint::Array,
221 serde_json::Value::Object(_) => TypeHint::Json,
222 }
223}
224
225fn build_create_table(table: &str, schema: &[(String, TypeHint)], backend: Backend) -> String {
226 let quoted_table = quote_identifier(table);
227 let cols: Vec<String> = schema
228 .iter()
229 .map(|(name, hint)| {
230 let quoted_name = quote_identifier(name);
231 let sql_type = type_hint_to_sql(hint, backend);
232 format!("{} {}", quoted_name, sql_type)
233 })
234 .collect();
235 format!("CREATE TABLE {quoted_table} ({});", cols.join(", "))
236}
237
238#[cfg_attr(not(feature = "oracle"), allow(unused_variables))]
239fn type_hint_to_sql(hint: &TypeHint, backend: Backend) -> &'static str {
240 match hint {
241 TypeHint::Int64 => "INTEGER",
242 TypeHint::Float64 | TypeHint::Decimal => "NUMERIC(18,6)",
243 TypeHint::Bool => {
244 #[cfg(feature = "oracle")]
245 if matches!(backend, Backend::Oracle) {
246 return "NUMBER(1)";
247 }
248 "BOOLEAN"
249 }
250 TypeHint::Json => {
251 #[cfg(feature = "oracle")]
252 if matches!(backend, Backend::Oracle) {
253 return "CLOB";
254 }
255 "TEXT"
256 }
257 TypeHint::String | TypeHint::Null | TypeHint::Uuid => {
258 #[cfg(feature = "oracle")]
259 if matches!(backend, Backend::Oracle) {
260 return "VARCHAR2(4000)";
261 }
262 "TEXT"
263 }
264 _ => {
265 #[cfg(feature = "oracle")]
266 if matches!(backend, Backend::Oracle) {
267 return "VARCHAR2(4000)";
268 }
269 "TEXT"
270 }
271 }
272}
273
274fn quote_identifier(id: &str) -> String {
275 format!("\"{}\"", id.replace('\"', "\"\""))
276}