yscv_model/dataset/
jsonl.rs1use serde::Deserialize;
2use std::path::Path;
3
4use crate::ModelError;
5
6use super::helpers::{
7 adapter_sample_len, build_supervised_dataset_from_flat_values, load_dataset_text_file,
8 validate_adapter_sample_shape, validate_finite_values,
9};
10
11#[derive(Debug, Clone, PartialEq)]
13pub struct SupervisedJsonlConfig {
14 input_shape: Vec<usize>,
15 target_shape: Vec<usize>,
16}
17
18impl SupervisedJsonlConfig {
19 pub fn new(input_shape: Vec<usize>, target_shape: Vec<usize>) -> Result<Self, ModelError> {
20 validate_adapter_sample_shape("input_shape", &input_shape)?;
21 validate_adapter_sample_shape("target_shape", &target_shape)?;
22 Ok(Self {
23 input_shape,
24 target_shape,
25 })
26 }
27
28 pub fn input_shape(&self) -> &[usize] {
29 &self.input_shape
30 }
31
32 pub fn target_shape(&self) -> &[usize] {
33 &self.target_shape
34 }
35}
36
37#[derive(Debug, Deserialize)]
38struct JsonlSupervisedRecord {
39 #[serde(alias = "inputs")]
40 #[serde(alias = "features")]
41 input: JsonlNumericField,
42 #[serde(alias = "targets")]
43 #[serde(alias = "label")]
44 #[serde(alias = "labels")]
45 target: JsonlNumericField,
46}
47
48#[derive(Debug, Deserialize)]
49#[serde(untagged)]
50enum JsonlNumericField {
51 Scalar(f32),
52 Vector(Vec<f32>),
53}
54
55fn jsonl_field_to_row(field: JsonlNumericField) -> Vec<f32> {
56 match field {
57 JsonlNumericField::Vector(values) => values,
58 JsonlNumericField::Scalar(value) => vec![value],
59 }
60}
61
62pub fn parse_supervised_dataset_jsonl(
68 content: &str,
69 config: &SupervisedJsonlConfig,
70) -> Result<super::types::SupervisedDataset, ModelError> {
71 let input_row_len = adapter_sample_len("input_shape", config.input_shape())?;
72 let target_row_len = adapter_sample_len("target_shape", config.target_shape())?;
73
74 let mut input_values = Vec::new();
75 let mut target_values = Vec::new();
76 let mut sample_count = 0usize;
77
78 for (line_idx, raw_line) in content.lines().enumerate() {
79 let line_number = line_idx + 1;
80 let line = raw_line.trim();
81 if line.is_empty() || line.starts_with('#') {
82 continue;
83 }
84
85 let record: JsonlSupervisedRecord =
86 serde_json::from_str(line).map_err(|error| ModelError::DatasetJsonlParse {
87 line: line_number,
88 message: error.to_string(),
89 })?;
90
91 let input_row = jsonl_field_to_row(record.input);
92 if input_row.len() != input_row_len {
93 return Err(ModelError::InvalidDatasetRecordLength {
94 line: line_number,
95 field: "input",
96 expected: input_row_len,
97 got: input_row.len(),
98 });
99 }
100 let target_row = jsonl_field_to_row(record.target);
101 if target_row.len() != target_row_len {
102 return Err(ModelError::InvalidDatasetRecordLength {
103 line: line_number,
104 field: "target",
105 expected: target_row_len,
106 got: target_row.len(),
107 });
108 }
109
110 validate_finite_values(line_number, "input", &input_row)?;
111 validate_finite_values(line_number, "target", &target_row)?;
112
113 input_values.extend_from_slice(&input_row);
114 target_values.extend_from_slice(&target_row);
115 sample_count =
116 sample_count
117 .checked_add(1)
118 .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
119 field: "sample_count",
120 shape: vec![sample_count],
121 message: "sample count overflow".to_string(),
122 })?;
123 }
124
125 if sample_count == 0 {
126 return Err(ModelError::EmptyDataset);
127 }
128
129 build_supervised_dataset_from_flat_values(
130 config.input_shape(),
131 config.target_shape(),
132 sample_count,
133 input_values,
134 target_values,
135 )
136}
137
138pub fn load_supervised_dataset_jsonl_file<P: AsRef<Path>>(
140 path: P,
141 config: &SupervisedJsonlConfig,
142) -> Result<super::types::SupervisedDataset, ModelError> {
143 let content = load_dataset_text_file(path)?;
144 parse_supervised_dataset_jsonl(&content, config)
145}