1use std::collections::HashMap;
4use serde_json::Value;
5use crate::error::{DataForgeError, Result};
6use super::{CustomGenerator, CustomGeneratorRegistry, ParamType};
7
8pub struct UserAPI {
10 registry: CustomGeneratorRegistry,
11 templates: HashMap<String, Template>,
12}
13
14impl UserAPI {
15 pub fn new() -> Self {
17 Self {
18 registry: CustomGeneratorRegistry::new(),
19 templates: HashMap::new(),
20 }
21 }
22
23 pub fn register_generator<G>(&mut self, generator: G) -> Result<()>
25 where
26 G: CustomGenerator + 'static,
27 {
28 self.registry.register(generator)
29 }
30
31 pub fn generate(&self, generator_name: &str, params: Option<&HashMap<String, Value>>) -> Result<Value> {
33 self.registry.generate(generator_name, params)
34 }
35
36 pub fn generate_batch(&self, generator_name: &str, count: usize, params: Option<&HashMap<String, Value>>) -> Result<Vec<Value>> {
38 let mut results = Vec::with_capacity(count);
39
40 for _ in 0..count {
41 let value = self.registry.generate(generator_name, params)?;
42 results.push(value);
43 }
44
45 Ok(results)
46 }
47
48 pub fn create_template(&mut self, name: String, template: Template) -> Result<()> {
50 if self.templates.contains_key(&name) {
51 return Err(DataForgeError::validation(&format!("Template '{}' already exists", name)));
52 }
53
54 template.validate(&self.registry)?;
56
57 self.templates.insert(name, template);
58 Ok(())
59 }
60
61 pub fn generate_from_template(&self, template_name: &str, params: Option<&HashMap<String, Value>>) -> Result<Value> {
63 let template = self.templates.get(template_name)
64 .ok_or_else(|| DataForgeError::validation(&format!("Template '{}' not found", template_name)))?;
65
66 template.generate(&self.registry, params)
67 }
68
69 pub fn generate_batch_from_template(&self, template_name: &str, count: usize, params: Option<&HashMap<String, Value>>) -> Result<Vec<Value>> {
71 let mut results = Vec::with_capacity(count);
72
73 for _ in 0..count {
74 let value = self.generate_from_template(template_name, params)?;
75 results.push(value);
76 }
77
78 Ok(results)
79 }
80
81 pub fn list_generators(&self) -> Vec<super::GeneratorInfo> {
83 self.registry.list_generators()
84 }
85
86 pub fn list_templates(&self) -> Vec<String> {
88 self.templates.keys().cloned().collect()
89 }
90
91 pub fn get_template(&self, name: &str) -> Option<&Template> {
93 self.templates.get(name)
94 }
95
96 pub fn remove_template(&mut self, name: &str) -> bool {
98 self.templates.remove(name).is_some()
99 }
100
101 pub fn validate_generator_params(&self, generator_name: &str, params: &HashMap<String, Value>) -> Result<()> {
103 let generator = self.registry.get(generator_name)
104 .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
105
106 generator.validate_params(params)
107 }
108
109 pub fn get_generator_param_schema(&self, generator_name: &str) -> Result<HashMap<String, ParamType>> {
111 let generator = self.registry.get(generator_name)
112 .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
113
114 Ok(generator.param_schema())
115 }
116}
117
118impl Default for UserAPI {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct Template {
127 pub name: String,
128 pub description: String,
129 pub fields: HashMap<String, FieldTemplate>,
130}
131
132#[derive(Debug, Clone)]
134pub struct FieldTemplate {
135 pub generator: String,
136 pub params: Option<HashMap<String, Value>>,
137 pub nullable: bool,
138 pub null_probability: f64, }
140
141impl Template {
142 pub fn new(name: String, description: String) -> Self {
144 Self {
145 name,
146 description,
147 fields: HashMap::new(),
148 }
149 }
150
151 pub fn add_field(&mut self, field_name: String, field_template: FieldTemplate) {
153 self.fields.insert(field_name, field_template);
154 }
155
156 pub fn validate(&self, registry: &CustomGeneratorRegistry) -> Result<()> {
158 for (field_name, field_template) in &self.fields {
159 if registry.get(&field_template.generator).is_none() {
161 return Err(DataForgeError::validation(&format!(
162 "Generator '{}' for field '{}' not found",
163 field_template.generator, field_name
164 )));
165 }
166
167 if let Some(params) = &field_template.params {
169 let generator = registry.get(&field_template.generator).unwrap();
170 generator.validate_params(params)?;
171 }
172
173 if field_template.null_probability < 0.0 || field_template.null_probability > 1.0 {
175 return Err(DataForgeError::validation(&format!(
176 "Invalid null probability {} for field '{}'",
177 field_template.null_probability, field_name
178 )));
179 }
180 }
181
182 Ok(())
183 }
184
185 pub fn generate(&self, registry: &CustomGeneratorRegistry, global_params: Option<&HashMap<String, Value>>) -> Result<Value> {
187 use rand::Rng;
188 let mut rng = rand::thread_rng();
189 let mut result = serde_json::Map::new();
190
191 for (field_name, field_template) in &self.fields {
192 if field_template.nullable && rng.gen::<f64>() < field_template.null_probability {
194 result.insert(field_name.clone(), Value::Null);
195 continue;
196 }
197
198 let params = match (&field_template.params, global_params) {
200 (Some(field_params), Some(global_params)) => {
201 let mut merged = global_params.clone();
202 merged.extend(field_params.clone());
203 Some(merged)
204 },
205 (Some(field_params), None) => Some(field_params.clone()),
206 (None, Some(global_params)) => Some(global_params.clone()),
207 (None, None) => None,
208 };
209
210 let value = registry.generate(&field_template.generator, params.as_ref())?;
212 result.insert(field_name.clone(), value);
213 }
214
215 Ok(Value::Object(result))
216 }
217}
218
219pub struct TemplateBuilder {
221 template: Template,
222}
223
224impl TemplateBuilder {
225 pub fn new(name: String, description: String) -> Self {
227 Self {
228 template: Template::new(name, description),
229 }
230 }
231
232 pub fn field(mut self, name: &str, generator: &str) -> Self {
234 let field_template = FieldTemplate {
235 generator: generator.to_string(),
236 params: None,
237 nullable: false,
238 null_probability: 0.0,
239 };
240 self.template.add_field(name.to_string(), field_template);
241 self
242 }
243
244 pub fn field_with_params(mut self, name: &str, generator: &str, params: HashMap<String, Value>) -> Self {
246 let field_template = FieldTemplate {
247 generator: generator.to_string(),
248 params: Some(params),
249 nullable: false,
250 null_probability: 0.0,
251 };
252 self.template.add_field(name.to_string(), field_template);
253 self
254 }
255
256 pub fn nullable_field(mut self, name: &str, generator: &str, null_probability: f64) -> Self {
258 let field_template = FieldTemplate {
259 generator: generator.to_string(),
260 params: None,
261 nullable: true,
262 null_probability,
263 };
264 self.template.add_field(name.to_string(), field_template);
265 self
266 }
267
268 pub fn build(self) -> Template {
270 self.template
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::customization::RandomStringGenerator;
278
279 #[test]
280 fn test_user_api() {
281 let mut api = UserAPI::new();
282 let generator = RandomStringGenerator::new();
283
284 assert!(api.register_generator(generator).is_ok());
285
286 let result = api.generate("random_string", None);
287 assert!(result.is_ok());
288
289 let batch = api.generate_batch("random_string", 5, None);
290 assert!(batch.is_ok());
291 assert_eq!(batch.unwrap().len(), 5);
292 }
293
294 #[test]
295 fn test_template_creation() {
296 let mut api = UserAPI::new();
297 let generator = RandomStringGenerator::new();
298 api.register_generator(generator).unwrap();
299
300 let template = TemplateBuilder::new(
301 "user_template".to_string(),
302 "User data template".to_string()
303 )
304 .field("name", "random_string")
305 .field("email", "random_string")
306 .build();
307
308 assert!(api.create_template("user".to_string(), template).is_ok());
309
310 let result = api.generate_from_template("user", None);
311 assert!(result.is_ok());
312
313 if let Ok(Value::Object(obj)) = result {
314 assert!(obj.contains_key("name"));
315 assert!(obj.contains_key("email"));
316 }
317 }
318
319 #[test]
320 fn test_template_with_params() {
321 let mut api = UserAPI::new();
322 let generator = RandomStringGenerator::new();
323 api.register_generator(generator).unwrap();
324
325 let mut params = HashMap::new();
326 params.insert("length".to_string(), Value::Number(serde_json::Number::from(10)));
327
328 let template = TemplateBuilder::new(
329 "fixed_length_template".to_string(),
330 "Fixed length string template".to_string()
331 )
332 .field_with_params("code", "random_string", params)
333 .build();
334
335 assert!(api.create_template("fixed_length".to_string(), template).is_ok());
336
337 let result = api.generate_from_template("fixed_length", None);
338 assert!(result.is_ok());
339
340 if let Ok(Value::Object(obj)) = result {
341 if let Some(Value::String(code)) = obj.get("code") {
342 assert_eq!(code.len(), 10);
343 }
344 }
345 }
346
347 #[test]
348 fn test_nullable_field() {
349 let mut api = UserAPI::new();
350 let generator = RandomStringGenerator::new();
351 api.register_generator(generator).unwrap();
352
353 let template = TemplateBuilder::new(
354 "nullable_template".to_string(),
355 "Template with nullable field".to_string()
356 )
357 .field("required_field", "random_string")
358 .nullable_field("optional_field", "random_string", 0.5)
359 .build();
360
361 assert!(api.create_template("nullable".to_string(), template).is_ok());
362
363 let mut null_count = 0;
365 let total_tests = 100;
366
367 for _ in 0..total_tests {
368 let result = api.generate_from_template("nullable", None).unwrap();
369 if let Value::Object(obj) = result {
370 if let Some(Value::Null) = obj.get("optional_field") {
371 null_count += 1;
372 }
373 }
374 }
375
376 assert!(null_count > 20 && null_count < 80);
378 }
379}