#![allow(missing_docs)]
use crate::error::{IoError, Result};
use serde_json::Value;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
pub trait DataSource: Send {
fn next_batch(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>>;
fn reset(&mut self) -> Result<()> {
Err(IoError::Other("reset not supported by this source".to_string()))
}
fn name(&self) -> &str {
"unknown"
}
fn estimated_len(&self) -> Option<usize> {
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileSourceFormat {
Csv,
JsonLines,
Json,
Auto,
}
pub struct FileSource {
path: PathBuf,
format: FileSourceFormat,
reader: Option<FileSourceReader>,
exhausted: bool,
csv_headers: Option<Vec<String>>,
json_records: Option<std::collections::VecDeque<Value>>,
}
enum FileSourceReader {
Lines(BufReader<File>),
}
impl FileSource {
pub fn new(path: impl AsRef<Path>) -> Self {
Self::with_format(path, FileSourceFormat::Auto)
}
pub fn with_format(path: impl AsRef<Path>, format: FileSourceFormat) -> Self {
Self {
path: path.as_ref().to_path_buf(),
format,
reader: None,
exhausted: false,
csv_headers: None,
json_records: None,
}
}
fn resolve_format(&self) -> FileSourceFormat {
if self.format != FileSourceFormat::Auto {
return self.format;
}
match self
.path
.extension()
.and_then(|e| e.to_str())
.map(|e| e.to_ascii_lowercase())
.as_deref()
{
Some("csv") => FileSourceFormat::Csv,
Some("jsonl") | Some("ndjson") => FileSourceFormat::JsonLines,
Some("json") => FileSourceFormat::Json,
_ => FileSourceFormat::JsonLines,
}
}
fn open(&mut self) -> Result<()> {
let fmt = self.resolve_format();
match fmt {
FileSourceFormat::Json => {
let file = File::open(&self.path).map_err(IoError::Io)?;
let value: Value = serde_json::from_reader(BufReader::new(file))
.map_err(|e| IoError::ParseError(e.to_string()))?;
let records: Vec<Value> = match value {
Value::Array(arr) => arr,
other => vec![other],
};
self.json_records = Some(records.into());
}
_ => {
let file = File::open(&self.path).map_err(IoError::Io)?;
let reader = BufReader::new(file);
self.reader = Some(FileSourceReader::Lines(reader));
}
}
Ok(())
}
fn next_batch_csv(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>> {
let reader = match &mut self.reader {
Some(FileSourceReader::Lines(r)) => r,
None => return Ok(None),
};
if self.csv_headers.is_none() {
let mut line = String::new();
if reader.read_line(&mut line).map_err(IoError::Io)? == 0 {
return Ok(None);
}
let headers: Vec<String> = line.trim_end_matches(['\n', '\r'])
.split(',')
.map(|h| h.trim().trim_matches('"').to_string())
.collect();
self.csv_headers = Some(headers);
}
let headers = self.csv_headers.clone().unwrap_or_default();
let mut records = Vec::with_capacity(batch_size);
let mut line = String::new();
while records.len() < batch_size {
line.clear();
let n = reader.read_line(&mut line).map_err(IoError::Io)?;
if n == 0 {
break;
}
let row = line.trim_end_matches(['\n', '\r']);
if row.is_empty() {
continue;
}
let fields: Vec<&str> = row.split(',').collect();
let obj: serde_json::Map<String, Value> = headers
.iter()
.enumerate()
.map(|(i, h)| {
let raw = fields.get(i).copied().unwrap_or("").trim();
let val = raw.trim_matches('"');
let json_val: Value = if let Ok(n) = val.parse::<i64>() {
Value::Number(n.into())
} else if let Ok(f) = val.parse::<f64>() {
Value::Number(
serde_json::Number::from_f64(f)
.unwrap_or(serde_json::Number::from(0)),
)
} else if val == "true" {
Value::Bool(true)
} else if val == "false" {
Value::Bool(false)
} else {
Value::String(val.to_string())
};
(h.clone(), json_val)
})
.collect();
records.push(Value::Object(obj));
}
if records.is_empty() {
Ok(None)
} else {
Ok(Some(records))
}
}
fn next_batch_jsonl(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>> {
let reader = match &mut self.reader {
Some(FileSourceReader::Lines(r)) => r,
None => return Ok(None),
};
let mut records = Vec::with_capacity(batch_size);
let mut line = String::new();
while records.len() < batch_size {
line.clear();
let n = reader.read_line(&mut line).map_err(IoError::Io)?;
if n == 0 {
break;
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let val: Value = serde_json::from_str(trimmed)
.map_err(|e| IoError::ParseError(format!("JSONL parse: {e}")))?;
records.push(val);
}
if records.is_empty() {
Ok(None)
} else {
Ok(Some(records))
}
}
}
impl DataSource for FileSource {
fn next_batch(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>> {
if self.exhausted {
return Ok(None);
}
if self.reader.is_none() && self.json_records.is_none() {
self.open()?;
}
let result = if let Some(queue) = &mut self.json_records {
if queue.is_empty() {
Ok(None)
} else {
let batch: Vec<Value> = queue.drain(..batch_size.min(queue.len())).collect();
Ok(Some(batch))
}
} else {
match self.resolve_format() {
FileSourceFormat::Csv => self.next_batch_csv(batch_size),
_ => self.next_batch_jsonl(batch_size),
}
};
if matches!(result, Ok(None)) {
self.exhausted = true;
}
result
}
fn reset(&mut self) -> Result<()> {
self.reader = None;
self.json_records = None;
self.exhausted = false;
self.csv_headers = None;
Ok(())
}
fn name(&self) -> &str {
"file_source"
}
}
pub struct DatabaseSource {
connection_string: String,
query: String,
rows: Option<std::collections::VecDeque<Value>>,
exhausted: bool,
}
impl DatabaseSource {
pub fn new(
connection_string: impl Into<String>,
query: impl Into<String>,
) -> Self {
Self {
connection_string: connection_string.into(),
query: query.into(),
rows: None,
exhausted: false,
}
}
#[cfg(feature = "sqlite")]
fn load_rows(&mut self) -> Result<()> {
use rusqlite::{Connection, types::ValueRef};
let conn = Connection::open(&self.connection_string)
.map_err(|e| IoError::DatabaseError(e.to_string()))?;
let mut stmt = conn
.prepare(&self.query)
.map_err(|e| IoError::DatabaseError(e.to_string()))?;
let column_names: Vec<String> = stmt
.column_names()
.into_iter()
.map(|s| s.to_string())
.collect();
let mut rows_vec = Vec::new();
let mut rows = stmt
.query([])
.map_err(|e| IoError::DatabaseError(e.to_string()))?;
while let Some(row) = rows
.next()
.map_err(|e| IoError::DatabaseError(e.to_string()))?
{
let mut obj = serde_json::Map::new();
for (i, col) in column_names.iter().enumerate() {
let val = match row.get_ref(i).map_err(|e| IoError::DatabaseError(e.to_string()))? {
ValueRef::Null => Value::Null,
ValueRef::Integer(n) => Value::Number(n.into()),
ValueRef::Real(f) => Value::Number(
serde_json::Number::from_f64(f).unwrap_or(serde_json::Number::from(0)),
),
ValueRef::Text(t) => Value::String(
std::str::from_utf8(t)
.unwrap_or("")
.to_string(),
),
ValueRef::Blob(b) => {
Value::String(format!("<blob {} bytes>", b.len()))
}
};
obj.insert(col.clone(), val);
}
rows_vec.push(Value::Object(obj));
}
self.rows = Some(rows_vec.into());
Ok(())
}
#[cfg(not(feature = "sqlite"))]
fn load_rows(&mut self) -> Result<()> {
Err(IoError::Other(
"DatabaseSource requires the `sqlite` feature".to_string(),
))
}
}
impl DataSource for DatabaseSource {
fn next_batch(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>> {
if self.exhausted {
return Ok(None);
}
if self.rows.is_none() {
self.load_rows()?;
}
let queue = self.rows.as_mut().expect("rows must be populated after load_rows");
if queue.is_empty() {
self.exhausted = true;
return Ok(None);
}
let batch: Vec<Value> = queue.drain(..batch_size.min(queue.len())).collect();
Ok(Some(batch))
}
fn reset(&mut self) -> Result<()> {
self.rows = None;
self.exhausted = false;
Ok(())
}
fn name(&self) -> &str {
"database_source"
}
}
pub struct StreamSource {
receiver: crossbeam_channel::Receiver<Value>,
timeout: std::time::Duration,
label: String,
}
pub type StreamSender = crossbeam_channel::Sender<Value>;
impl StreamSource {
pub fn new_unbounded() -> (Self, StreamSender) {
let (tx, rx) = crossbeam_channel::unbounded();
let source = Self {
receiver: rx,
timeout: std::time::Duration::from_millis(50),
label: "stream_source".to_string(),
};
(source, tx)
}
pub fn new_bounded(capacity: usize) -> (Self, StreamSender) {
let (tx, rx) = crossbeam_channel::bounded(capacity);
let source = Self {
receiver: rx,
timeout: std::time::Duration::from_millis(50),
label: "stream_source_bounded".to_string(),
};
(source, tx)
}
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.label = name.into();
self
}
}
impl DataSource for StreamSource {
fn next_batch(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>> {
use crossbeam_channel::RecvTimeoutError;
let mut batch = Vec::with_capacity(batch_size);
match self.receiver.recv_timeout(self.timeout) {
Ok(msg) => batch.push(msg),
Err(RecvTimeoutError::Disconnected) => return Ok(None),
Err(RecvTimeoutError::Timeout) => return Ok(Some(batch)), }
while batch.len() < batch_size {
match self.receiver.try_recv() {
Ok(msg) => batch.push(msg),
Err(_) => break,
}
}
Ok(Some(batch))
}
fn name(&self) -> &str {
&self.label
}
}
pub struct GeneratorSource<F>
where
F: FnMut() -> Option<Value> + Send,
{
generator: F,
label: String,
exhausted: bool,
}
impl<F> GeneratorSource<F>
where
F: FnMut() -> Option<Value> + Send,
{
pub fn new(generator: F) -> Self {
Self {
generator,
label: "generator_source".to_string(),
exhausted: false,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.label = name.into();
self
}
}
impl<F> DataSource for GeneratorSource<F>
where
F: FnMut() -> Option<Value> + Send,
{
fn next_batch(&mut self, batch_size: usize) -> Result<Option<Vec<Value>>> {
if self.exhausted {
return Ok(None);
}
let mut batch = Vec::with_capacity(batch_size);
while batch.len() < batch_size {
match (self.generator)() {
Some(v) => batch.push(v),
None => {
self.exhausted = true;
break;
}
}
}
if batch.is_empty() {
Ok(None)
} else {
Ok(Some(batch))
}
}
fn reset(&mut self) -> Result<()> {
Err(IoError::Other(
"GeneratorSource does not support reset".to_string(),
))
}
fn name(&self) -> &str {
&self.label
}
}
pub fn drain_source(
source: &mut dyn DataSource,
batch_size: usize,
) -> Result<Vec<Value>> {
let mut all = Vec::new();
loop {
match source.next_batch(batch_size)? {
Some(batch) => all.extend(batch),
None => break,
}
}
Ok(all)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::io::Write;
fn temp_path(suffix: &str) -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!("scirs2_io_test_{suffix}_{}", std::process::id()));
p
}
#[test]
fn test_generator_source() {
let mut counter = 0u64;
let mut src = GeneratorSource::new(move || {
if counter < 5 {
let v = json!(counter);
counter += 1;
Some(v)
} else {
None
}
});
let batch = src.next_batch(3).unwrap().unwrap();
assert_eq!(batch.len(), 3);
let rest = drain_source(&mut src, 10).unwrap();
assert_eq!(rest.len(), 2);
assert!(src.next_batch(10).unwrap().is_none());
}
#[test]
fn test_stream_source_basic() {
let (mut src, tx) = StreamSource::new_unbounded();
tx.send(json!({"id": 1})).unwrap();
tx.send(json!({"id": 2})).unwrap();
drop(tx);
let all = drain_source(&mut src, 10).unwrap();
let total: usize = all.len()
+ drain_source(&mut src, 10)
.unwrap_or_default()
.len();
assert!(total >= 2 || all.len() >= 2);
}
#[test]
fn test_file_source_jsonl() {
let path = temp_path("jsonl");
{
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, r#"{{"a":1,"b":"x"}}"#).unwrap();
writeln!(f, r#"{{"a":2,"b":"y"}}"#).unwrap();
writeln!(f, r#"{{"a":3,"b":"z"}}"#).unwrap();
}
let mut src = FileSource::with_format(&path, FileSourceFormat::JsonLines);
let all = drain_source(&mut src, 10).unwrap();
assert_eq!(all.len(), 3);
assert_eq!(all[0]["a"], json!(1));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_file_source_csv() {
let path = temp_path("csv");
{
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "name,score").unwrap();
writeln!(f, "alice,90").unwrap();
writeln!(f, "bob,85").unwrap();
}
let mut src = FileSource::with_format(&path, FileSourceFormat::Csv);
let all = drain_source(&mut src, 10).unwrap();
assert_eq!(all.len(), 2);
assert_eq!(all[0]["name"], json!("alice"));
assert_eq!(all[0]["score"], json!(90));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_file_source_json_array() {
let path = temp_path("json");
{
let mut f = std::fs::File::create(&path).unwrap();
write!(f, r#"[{{"x":1}},{{"x":2}},{{"x":3}}]"#).unwrap();
}
let mut src = FileSource::with_format(&path, FileSourceFormat::Json);
let b1 = src.next_batch(2).unwrap().unwrap();
assert_eq!(b1.len(), 2);
let b2 = src.next_batch(2).unwrap().unwrap();
assert_eq!(b2.len(), 1);
assert!(src.next_batch(2).unwrap().is_none());
std::fs::remove_file(&path).ok();
}
#[test]
fn test_file_source_batching() {
let path = temp_path("jsonl2");
{
let mut f = std::fs::File::create(&path).unwrap();
for i in 0..10 {
writeln!(f, r#"{{"i":{i}}}"#).unwrap();
}
}
let mut src = FileSource::with_format(&path, FileSourceFormat::JsonLines);
let b1 = src.next_batch(4).unwrap().unwrap();
assert_eq!(b1.len(), 4);
let b2 = src.next_batch(4).unwrap().unwrap();
assert_eq!(b2.len(), 4);
let b3 = src.next_batch(4).unwrap().unwrap();
assert_eq!(b3.len(), 2);
assert!(src.next_batch(4).unwrap().is_none());
std::fs::remove_file(&path).ok();
}
}