datagen_rs_progress_plugin/
lib.rs

1#[cfg(feature = "plugin")]
2use datagen_rs::declare_plugin;
3use datagen_rs::generate::current_schema::CurrentSchemaRef;
4use datagen_rs::generate::generated_schema::GeneratedSchema;
5use datagen_rs::generate::generated_schema::IntoRandom;
6use datagen_rs::generate::schema_mapper::MapSchema;
7use datagen_rs::plugins::plugin::Plugin;
8#[cfg(feature = "plugin")]
9use datagen_rs::plugins::plugin::PluginConstructor;
10use datagen_rs::schema::any::Any;
11use datagen_rs::schema::any_of::AnyOf;
12use datagen_rs::schema::any_value::AnyValue;
13use datagen_rs::schema::array::{Array, ArrayLength};
14use datagen_rs::schema::object::Object;
15#[cfg(not(feature = "plugin"))]
16use datagen_rs::schema::schema_definition::Schema;
17use datagen_rs::util::traits::generate::TransformTrait;
18use datagen_rs::util::types::Result;
19use rand::prelude::SliceRandom;
20use rand::Rng;
21use serde_json::Value;
22use std::collections::{BTreeMap, HashMap, VecDeque};
23use std::fmt::Debug;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::{Arc, Mutex};
26
27#[derive(PartialEq, Eq, PartialOrd, Ord)]
28struct RandomArrayLength {
29    min: u32,
30    max: u32,
31}
32
33impl RandomArrayLength {
34    fn new(min: u32, max: u32) -> Self {
35        Self { min, max }
36    }
37}
38
39pub struct ProgressPlugin<F: Fn(usize, usize)> {
40    total_elements: AtomicUsize,
41    progress: AtomicUsize,
42    arrays: Mutex<BTreeMap<RandomArrayLength, VecDeque<u32>>>,
43    callback: F,
44}
45
46impl<F: Fn(usize, usize)> Debug for ProgressPlugin<F> {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("ProgressPlugin")
49            .field("total_elements", &self.total_elements)
50            .field("progress", &self.progress)
51            .finish()
52    }
53}
54
55#[cfg(not(feature = "plugin"))]
56pub struct PluginWithSchemaResult {
57    pub schema: Schema,
58    pub plugins: HashMap<String, Box<dyn Plugin>>,
59}
60
61impl<F: Fn(usize, usize)> ProgressPlugin<F> {
62    #[cfg(not(feature = "plugin"))]
63    pub fn with_schema(mut schema: Schema, callback: F) -> Result<PluginWithSchemaResult>
64    where
65        F: Fn(usize, usize) + 'static,
66    {
67        let progress: Box<dyn Plugin> = Box::new(ProgressPlugin::new(callback));
68
69        schema.value = AnyValue::Any(Any::Plugin(datagen_rs::schema::plugin::Plugin {
70            plugin_name: "progress".into(),
71            args: Some(serde_json::to_value(schema.value).map_err(|e| e.to_string())?),
72            transform: None,
73        }));
74
75        Ok(PluginWithSchemaResult {
76            schema,
77            plugins: vec![("progress".into(), progress)].into_iter().collect(),
78        })
79    }
80
81    #[cfg(not(feature = "plugin"))]
82    pub fn new(callback: F) -> Self {
83        Self {
84            total_elements: AtomicUsize::new(0),
85            progress: AtomicUsize::new(0),
86            arrays: Mutex::new(BTreeMap::new()),
87            callback,
88        }
89    }
90
91    fn increase_count(&self) {
92        let total = self.total_elements.load(Ordering::SeqCst);
93        let current = self.progress.fetch_add(1, Ordering::SeqCst) + 1;
94        (self.callback)(current, total);
95    }
96
97    fn convert_any_value(
98        &self,
99        schema: CurrentSchemaRef,
100        val: AnyValue,
101    ) -> Result<Arc<GeneratedSchema>> {
102        match val {
103            AnyValue::Any(any) => match any {
104                Any::Array(array) => self.convert_array(schema, *array),
105                Any::Object(object) => self.convert_object(schema, *object),
106                Any::AnyOf(any_of) => self.convert_any_of(schema, any_of),
107                rest => rest.into_random(schema),
108            },
109            rest => rest.into_random(schema),
110        }
111    }
112
113    fn get_array_length(&self, len: &ArrayLength) -> Result<u32> {
114        match len {
115            ArrayLength::Random { min, max } => self
116                .arrays
117                .lock()
118                .unwrap()
119                .get_mut(&RandomArrayLength::new(*min, *max))
120                .ok_or("Array length not found".to_string())?
121                .pop_front()
122                .ok_or("Array length not found".into()),
123            ArrayLength::Constant { value } => Ok(*value),
124        }
125    }
126
127    fn convert_array(
128        &self,
129        schema: CurrentSchemaRef,
130        array: Array,
131    ) -> Result<Arc<GeneratedSchema>> {
132        let len = self.get_array_length(&array.length)?;
133        schema.map_array(
134            len as _,
135            array.items,
136            array.transform,
137            true,
138            |cur, value| {
139                let res = self.convert_any_value(cur.clone(), value)?;
140                self.increase_count();
141                Ok(res)
142            },
143        )
144    }
145
146    fn convert_object(
147        &self,
148        schema: CurrentSchemaRef,
149        object: Object,
150    ) -> Result<Arc<GeneratedSchema>> {
151        schema.map_index_map(object.properties, object.transform, true, |cur, value| {
152            let res = self.convert_any_value(cur.clone(), value)?;
153            self.increase_count();
154            Ok(res)
155        })
156    }
157
158    fn convert_any_of(
159        &self,
160        schema: CurrentSchemaRef,
161        mut any_of: AnyOf,
162    ) -> Result<Arc<GeneratedSchema>> {
163        any_of.values.shuffle(&mut rand::thread_rng());
164        let mut num = any_of.num.unwrap_or(1);
165        match num.cmp(&0) {
166            core::cmp::Ordering::Equal => num = any_of.values.len() as i64,
167            core::cmp::Ordering::Less => {
168                num = rand::thread_rng().gen_range(0..any_of.values.len() as i64)
169            }
170            _ => {}
171        }
172
173        let values = any_of
174            .values
175            .drain(0..num as usize)
176            .map(|value| value.into_random(schema.clone()))
177            .collect::<Result<Vec<_>>>()?;
178
179        let mut res = if values.is_empty() {
180            Arc::new(GeneratedSchema::None)
181        } else if values.len() == 1 {
182            values[0].clone()
183        } else {
184            Arc::new(GeneratedSchema::Array(values))
185        };
186
187        if let Some(transform) = any_of.transform {
188            res = transform.transform(schema.clone(), res)?;
189        }
190
191        Ok(schema.finalize(res))
192    }
193
194    pub fn map_any(&self, val: &mut AnyValue) -> usize {
195        if let AnyValue::Any(any) = val {
196            match any {
197                Any::Array(array) => self.map_array(array.as_mut()),
198                Any::Object(object) => {
199                    let mut len = 1;
200                    for (_, value) in &mut object.properties {
201                        len += self.map_any(value);
202                    }
203
204                    len
205                }
206                Any::AnyOf(any_of) => {
207                    any_of.values.shuffle(&mut rand::thread_rng());
208                    let mut num = any_of.num.unwrap_or(1);
209                    match num.cmp(&0) {
210                        core::cmp::Ordering::Equal => num = -1,
211                        core::cmp::Ordering::Less => {
212                            num = rand::thread_rng().gen_range(0..any_of.values.len() as i64)
213                        }
214                        _ => {}
215                    }
216
217                    if num >= 0 {
218                        any_of.values.drain(num as usize..);
219                    }
220
221                    let mut len = 1;
222                    for val in &mut any_of.values {
223                        len += self.map_any(val);
224                    }
225                    len
226                }
227                _ => 1,
228            }
229        } else {
230            1
231        }
232    }
233
234    fn add_array_len(&self, len: &ArrayLength) -> u32 {
235        match len {
236            ArrayLength::Random { min, max } => {
237                let mut arrays = self.arrays.lock().unwrap();
238
239                let entry = arrays
240                    .entry(RandomArrayLength::new(*min, *max))
241                    .or_insert_with(VecDeque::new);
242                let mut rng = rand::thread_rng();
243                let res = rng.gen_range(*min..=*max);
244                entry.push_back(res);
245
246                res
247            }
248            ArrayLength::Constant { value } => *value,
249        }
250    }
251
252    fn map_array(&self, val: &mut Array) -> usize {
253        let len = self.add_array_len(&val.length);
254
255        let mut res = 1;
256        for _ in 0..len {
257            res += self.map_any(&mut val.items);
258        }
259
260        res
261    }
262}
263
264impl<F: Fn(usize, usize)> Plugin for ProgressPlugin<F> {
265    fn name(&self) -> String {
266        "progress".into()
267    }
268
269    fn generate(&self, schema: CurrentSchemaRef, args: Value) -> Result<Arc<GeneratedSchema>> {
270        let mut val: AnyValue = serde_json::from_value(args)?;
271
272        self.total_elements
273            .store(self.map_any(&mut val), Ordering::SeqCst);
274        self.convert_any_value(schema, val)
275    }
276}
277
278#[cfg(feature = "plugin")]
279impl PluginConstructor for ProgressPlugin<fn(usize, usize)> {
280    fn new(_args: Value) -> Result<Self> {
281        Ok(Self {
282            total_elements: AtomicUsize::new(0),
283            progress: AtomicUsize::new(0),
284            arrays: Mutex::new(BTreeMap::new()),
285            callback: |current, total| {
286                println!("{current} / {total}");
287            },
288        })
289    }
290}
291
292#[cfg(feature = "plugin")]
293declare_plugin!(ProgressPlugin<fn(usize, usize)>);