datagen_rs_progress_plugin/
lib.rs1#[cfg(feature = "plugin")]
2use datagen_rs::declare_plugin;
3use datagen_rs::generate::current_schema::CurrentSchemaRef;
4use datagen_rs::generate::generated_schema::GeneratedSchema;
5use datagen_rs::generate::generated_schema::IntoRandom;
6use datagen_rs::generate::schema_mapper::MapSchema;
7use datagen_rs::plugins::plugin::Plugin;
8#[cfg(feature = "plugin")]
9use datagen_rs::plugins::plugin::PluginConstructor;
10use datagen_rs::schema::any::Any;
11use datagen_rs::schema::any_of::AnyOf;
12use datagen_rs::schema::any_value::AnyValue;
13use datagen_rs::schema::array::{Array, ArrayLength};
14use datagen_rs::schema::object::Object;
15#[cfg(not(feature = "plugin"))]
16use datagen_rs::schema::schema_definition::Schema;
17use datagen_rs::util::traits::generate::TransformTrait;
18use datagen_rs::util::types::Result;
19use rand::prelude::SliceRandom;
20use rand::Rng;
21use serde_json::Value;
22use std::collections::{BTreeMap, HashMap, VecDeque};
23use std::fmt::Debug;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::{Arc, Mutex};
26
27#[derive(PartialEq, Eq, PartialOrd, Ord)]
28struct RandomArrayLength {
29 min: u32,
30 max: u32,
31}
32
33impl RandomArrayLength {
34 fn new(min: u32, max: u32) -> Self {
35 Self { min, max }
36 }
37}
38
39pub struct ProgressPlugin<F: Fn(usize, usize)> {
40 total_elements: AtomicUsize,
41 progress: AtomicUsize,
42 arrays: Mutex<BTreeMap<RandomArrayLength, VecDeque<u32>>>,
43 callback: F,
44}
45
46impl<F: Fn(usize, usize)> Debug for ProgressPlugin<F> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("ProgressPlugin")
49 .field("total_elements", &self.total_elements)
50 .field("progress", &self.progress)
51 .finish()
52 }
53}
54
55#[cfg(not(feature = "plugin"))]
56pub struct PluginWithSchemaResult {
57 pub schema: Schema,
58 pub plugins: HashMap<String, Box<dyn Plugin>>,
59}
60
61impl<F: Fn(usize, usize)> ProgressPlugin<F> {
62 #[cfg(not(feature = "plugin"))]
63 pub fn with_schema(mut schema: Schema, callback: F) -> Result<PluginWithSchemaResult>
64 where
65 F: Fn(usize, usize) + 'static,
66 {
67 let progress: Box<dyn Plugin> = Box::new(ProgressPlugin::new(callback));
68
69 schema.value = AnyValue::Any(Any::Plugin(datagen_rs::schema::plugin::Plugin {
70 plugin_name: "progress".into(),
71 args: Some(serde_json::to_value(schema.value).map_err(|e| e.to_string())?),
72 transform: None,
73 }));
74
75 Ok(PluginWithSchemaResult {
76 schema,
77 plugins: vec![("progress".into(), progress)].into_iter().collect(),
78 })
79 }
80
81 #[cfg(not(feature = "plugin"))]
82 pub fn new(callback: F) -> Self {
83 Self {
84 total_elements: AtomicUsize::new(0),
85 progress: AtomicUsize::new(0),
86 arrays: Mutex::new(BTreeMap::new()),
87 callback,
88 }
89 }
90
91 fn increase_count(&self) {
92 let total = self.total_elements.load(Ordering::SeqCst);
93 let current = self.progress.fetch_add(1, Ordering::SeqCst) + 1;
94 (self.callback)(current, total);
95 }
96
97 fn convert_any_value(
98 &self,
99 schema: CurrentSchemaRef,
100 val: AnyValue,
101 ) -> Result<Arc<GeneratedSchema>> {
102 match val {
103 AnyValue::Any(any) => match any {
104 Any::Array(array) => self.convert_array(schema, *array),
105 Any::Object(object) => self.convert_object(schema, *object),
106 Any::AnyOf(any_of) => self.convert_any_of(schema, any_of),
107 rest => rest.into_random(schema),
108 },
109 rest => rest.into_random(schema),
110 }
111 }
112
113 fn get_array_length(&self, len: &ArrayLength) -> Result<u32> {
114 match len {
115 ArrayLength::Random { min, max } => self
116 .arrays
117 .lock()
118 .unwrap()
119 .get_mut(&RandomArrayLength::new(*min, *max))
120 .ok_or("Array length not found".to_string())?
121 .pop_front()
122 .ok_or("Array length not found".into()),
123 ArrayLength::Constant { value } => Ok(*value),
124 }
125 }
126
127 fn convert_array(
128 &self,
129 schema: CurrentSchemaRef,
130 array: Array,
131 ) -> Result<Arc<GeneratedSchema>> {
132 let len = self.get_array_length(&array.length)?;
133 schema.map_array(
134 len as _,
135 array.items,
136 array.transform,
137 true,
138 |cur, value| {
139 let res = self.convert_any_value(cur.clone(), value)?;
140 self.increase_count();
141 Ok(res)
142 },
143 )
144 }
145
146 fn convert_object(
147 &self,
148 schema: CurrentSchemaRef,
149 object: Object,
150 ) -> Result<Arc<GeneratedSchema>> {
151 schema.map_index_map(object.properties, object.transform, true, |cur, value| {
152 let res = self.convert_any_value(cur.clone(), value)?;
153 self.increase_count();
154 Ok(res)
155 })
156 }
157
158 fn convert_any_of(
159 &self,
160 schema: CurrentSchemaRef,
161 mut any_of: AnyOf,
162 ) -> Result<Arc<GeneratedSchema>> {
163 any_of.values.shuffle(&mut rand::thread_rng());
164 let mut num = any_of.num.unwrap_or(1);
165 match num.cmp(&0) {
166 core::cmp::Ordering::Equal => num = any_of.values.len() as i64,
167 core::cmp::Ordering::Less => {
168 num = rand::thread_rng().gen_range(0..any_of.values.len() as i64)
169 }
170 _ => {}
171 }
172
173 let values = any_of
174 .values
175 .drain(0..num as usize)
176 .map(|value| value.into_random(schema.clone()))
177 .collect::<Result<Vec<_>>>()?;
178
179 let mut res = if values.is_empty() {
180 Arc::new(GeneratedSchema::None)
181 } else if values.len() == 1 {
182 values[0].clone()
183 } else {
184 Arc::new(GeneratedSchema::Array(values))
185 };
186
187 if let Some(transform) = any_of.transform {
188 res = transform.transform(schema.clone(), res)?;
189 }
190
191 Ok(schema.finalize(res))
192 }
193
194 pub fn map_any(&self, val: &mut AnyValue) -> usize {
195 if let AnyValue::Any(any) = val {
196 match any {
197 Any::Array(array) => self.map_array(array.as_mut()),
198 Any::Object(object) => {
199 let mut len = 1;
200 for (_, value) in &mut object.properties {
201 len += self.map_any(value);
202 }
203
204 len
205 }
206 Any::AnyOf(any_of) => {
207 any_of.values.shuffle(&mut rand::thread_rng());
208 let mut num = any_of.num.unwrap_or(1);
209 match num.cmp(&0) {
210 core::cmp::Ordering::Equal => num = -1,
211 core::cmp::Ordering::Less => {
212 num = rand::thread_rng().gen_range(0..any_of.values.len() as i64)
213 }
214 _ => {}
215 }
216
217 if num >= 0 {
218 any_of.values.drain(num as usize..);
219 }
220
221 let mut len = 1;
222 for val in &mut any_of.values {
223 len += self.map_any(val);
224 }
225 len
226 }
227 _ => 1,
228 }
229 } else {
230 1
231 }
232 }
233
234 fn add_array_len(&self, len: &ArrayLength) -> u32 {
235 match len {
236 ArrayLength::Random { min, max } => {
237 let mut arrays = self.arrays.lock().unwrap();
238
239 let entry = arrays
240 .entry(RandomArrayLength::new(*min, *max))
241 .or_insert_with(VecDeque::new);
242 let mut rng = rand::thread_rng();
243 let res = rng.gen_range(*min..=*max);
244 entry.push_back(res);
245
246 res
247 }
248 ArrayLength::Constant { value } => *value,
249 }
250 }
251
252 fn map_array(&self, val: &mut Array) -> usize {
253 let len = self.add_array_len(&val.length);
254
255 let mut res = 1;
256 for _ in 0..len {
257 res += self.map_any(&mut val.items);
258 }
259
260 res
261 }
262}
263
264impl<F: Fn(usize, usize)> Plugin for ProgressPlugin<F> {
265 fn name(&self) -> String {
266 "progress".into()
267 }
268
269 fn generate(&self, schema: CurrentSchemaRef, args: Value) -> Result<Arc<GeneratedSchema>> {
270 let mut val: AnyValue = serde_json::from_value(args)?;
271
272 self.total_elements
273 .store(self.map_any(&mut val), Ordering::SeqCst);
274 self.convert_any_value(schema, val)
275 }
276}
277
278#[cfg(feature = "plugin")]
279impl PluginConstructor for ProgressPlugin<fn(usize, usize)> {
280 fn new(_args: Value) -> Result<Self> {
281 Ok(Self {
282 total_elements: AtomicUsize::new(0),
283 progress: AtomicUsize::new(0),
284 arrays: Mutex::new(BTreeMap::new()),
285 callback: |current, total| {
286 println!("{current} / {total}");
287 },
288 })
289 }
290}
291
292#[cfg(feature = "plugin")]
293declare_plugin!(ProgressPlugin<fn(usize, usize)>);