Skip to main content

yscv_model/dataset/
jsonl.rs

1use serde::Deserialize;
2use std::path::Path;
3
4use crate::ModelError;
5
6use super::helpers::{
7    adapter_sample_len, build_supervised_dataset_from_flat_values, load_dataset_text_file,
8    validate_adapter_sample_shape, validate_finite_values,
9};
10
11/// Configuration for parsing/loading supervised JSONL datasets.
12#[derive(Debug, Clone, PartialEq)]
13pub struct SupervisedJsonlConfig {
14    input_shape: Vec<usize>,
15    target_shape: Vec<usize>,
16}
17
18impl SupervisedJsonlConfig {
19    pub fn new(input_shape: Vec<usize>, target_shape: Vec<usize>) -> Result<Self, ModelError> {
20        validate_adapter_sample_shape("input_shape", &input_shape)?;
21        validate_adapter_sample_shape("target_shape", &target_shape)?;
22        Ok(Self {
23            input_shape,
24            target_shape,
25        })
26    }
27
28    pub fn input_shape(&self) -> &[usize] {
29        &self.input_shape
30    }
31
32    pub fn target_shape(&self) -> &[usize] {
33        &self.target_shape
34    }
35}
36
37#[derive(Debug, Deserialize)]
38struct JsonlSupervisedRecord {
39    #[serde(alias = "inputs")]
40    #[serde(alias = "features")]
41    input: JsonlNumericField,
42    #[serde(alias = "targets")]
43    #[serde(alias = "label")]
44    #[serde(alias = "labels")]
45    target: JsonlNumericField,
46}
47
48#[derive(Debug, Deserialize)]
49#[serde(untagged)]
50enum JsonlNumericField {
51    Scalar(f32),
52    Vector(Vec<f32>),
53}
54
55fn jsonl_field_to_row(field: JsonlNumericField) -> Vec<f32> {
56    match field {
57        JsonlNumericField::Vector(values) => values,
58        JsonlNumericField::Scalar(value) => vec![value],
59    }
60}
61
62/// Parses supervised training samples from JSONL text into a `SupervisedDataset`.
63///
64/// Each non-empty line must be a JSON object with:
65/// - `input` (or alias `inputs`/`features`): flat sample values matching `config.input_shape`
66/// - `target` (or alias `targets`/`label`/`labels`): flat sample values matching `config.target_shape`
67pub fn parse_supervised_dataset_jsonl(
68    content: &str,
69    config: &SupervisedJsonlConfig,
70) -> Result<super::types::SupervisedDataset, ModelError> {
71    let input_row_len = adapter_sample_len("input_shape", config.input_shape())?;
72    let target_row_len = adapter_sample_len("target_shape", config.target_shape())?;
73
74    let mut input_values = Vec::new();
75    let mut target_values = Vec::new();
76    let mut sample_count = 0usize;
77
78    for (line_idx, raw_line) in content.lines().enumerate() {
79        let line_number = line_idx + 1;
80        let line = raw_line.trim();
81        if line.is_empty() || line.starts_with('#') {
82            continue;
83        }
84
85        let record: JsonlSupervisedRecord =
86            serde_json::from_str(line).map_err(|error| ModelError::DatasetJsonlParse {
87                line: line_number,
88                message: error.to_string(),
89            })?;
90
91        let input_row = jsonl_field_to_row(record.input);
92        if input_row.len() != input_row_len {
93            return Err(ModelError::InvalidDatasetRecordLength {
94                line: line_number,
95                field: "input",
96                expected: input_row_len,
97                got: input_row.len(),
98            });
99        }
100        let target_row = jsonl_field_to_row(record.target);
101        if target_row.len() != target_row_len {
102            return Err(ModelError::InvalidDatasetRecordLength {
103                line: line_number,
104                field: "target",
105                expected: target_row_len,
106                got: target_row.len(),
107            });
108        }
109
110        validate_finite_values(line_number, "input", &input_row)?;
111        validate_finite_values(line_number, "target", &target_row)?;
112
113        input_values.extend_from_slice(&input_row);
114        target_values.extend_from_slice(&target_row);
115        sample_count =
116            sample_count
117                .checked_add(1)
118                .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
119                    field: "sample_count",
120                    shape: vec![sample_count],
121                    message: "sample count overflow".to_string(),
122                })?;
123    }
124
125    if sample_count == 0 {
126        return Err(ModelError::EmptyDataset);
127    }
128
129    build_supervised_dataset_from_flat_values(
130        config.input_shape(),
131        config.target_shape(),
132        sample_count,
133        input_values,
134        target_values,
135    )
136}
137
138/// Loads supervised training samples from a JSONL file.
139pub fn load_supervised_dataset_jsonl_file<P: AsRef<Path>>(
140    path: P,
141    config: &SupervisedJsonlConfig,
142) -> Result<super::types::SupervisedDataset, ModelError> {
143    let content = load_dataset_text_file(path)?;
144    parse_supervised_dataset_jsonl(&content, config)
145}