dataforge/customization/
api.rs

1//! 用户API接口模块
2
3use std::collections::HashMap;
4use serde_json::Value;
5use crate::error::{DataForgeError, Result};
6use super::{CustomGenerator, CustomGeneratorRegistry, ParamType};
7
8/// 用户API接口
9pub struct UserAPI {
10    registry: CustomGeneratorRegistry,
11    templates: HashMap<String, Template>,
12}
13
14impl UserAPI {
15    /// 创建新的用户API
16    pub fn new() -> Self {
17        Self {
18            registry: CustomGeneratorRegistry::new(),
19            templates: HashMap::new(),
20        }
21    }
22
23    /// 注册自定义生成器
24    pub fn register_generator<G>(&mut self, generator: G) -> Result<()>
25    where
26        G: CustomGenerator + 'static,
27    {
28        self.registry.register(generator)
29    }
30
31    /// 生成数据
32    pub fn generate(&self, generator_name: &str, params: Option<&HashMap<String, Value>>) -> Result<Value> {
33        self.registry.generate(generator_name, params)
34    }
35
36    /// 批量生成数据
37    pub fn generate_batch(&self, generator_name: &str, count: usize, params: Option<&HashMap<String, Value>>) -> Result<Vec<Value>> {
38        let mut results = Vec::with_capacity(count);
39        
40        for _ in 0..count {
41            let value = self.registry.generate(generator_name, params)?;
42            results.push(value);
43        }
44        
45        Ok(results)
46    }
47
48    /// 创建数据模板
49    pub fn create_template(&mut self, name: String, template: Template) -> Result<()> {
50        if self.templates.contains_key(&name) {
51            return Err(DataForgeError::validation(&format!("Template '{}' already exists", name)));
52        }
53        
54        // 验证模板
55        template.validate(&self.registry)?;
56        
57        self.templates.insert(name, template);
58        Ok(())
59    }
60
61    /// 使用模板生成数据
62    pub fn generate_from_template(&self, template_name: &str, params: Option<&HashMap<String, Value>>) -> Result<Value> {
63        let template = self.templates.get(template_name)
64            .ok_or_else(|| DataForgeError::validation(&format!("Template '{}' not found", template_name)))?;
65        
66        template.generate(&self.registry, params)
67    }
68
69    /// 批量使用模板生成数据
70    pub fn generate_batch_from_template(&self, template_name: &str, count: usize, params: Option<&HashMap<String, Value>>) -> Result<Vec<Value>> {
71        let mut results = Vec::with_capacity(count);
72        
73        for _ in 0..count {
74            let value = self.generate_from_template(template_name, params)?;
75            results.push(value);
76        }
77        
78        Ok(results)
79    }
80
81    /// 列出所有生成器
82    pub fn list_generators(&self) -> Vec<super::GeneratorInfo> {
83        self.registry.list_generators()
84    }
85
86    /// 列出所有模板
87    pub fn list_templates(&self) -> Vec<String> {
88        self.templates.keys().cloned().collect()
89    }
90
91    /// 获取模板信息
92    pub fn get_template(&self, name: &str) -> Option<&Template> {
93        self.templates.get(name)
94    }
95
96    /// 删除模板
97    pub fn remove_template(&mut self, name: &str) -> bool {
98        self.templates.remove(name).is_some()
99    }
100
101    /// 验证生成器参数
102    pub fn validate_generator_params(&self, generator_name: &str, params: &HashMap<String, Value>) -> Result<()> {
103        let generator = self.registry.get(generator_name)
104            .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
105        
106        generator.validate_params(params)
107    }
108
109    /// 获取生成器参数模式
110    pub fn get_generator_param_schema(&self, generator_name: &str) -> Result<HashMap<String, ParamType>> {
111        let generator = self.registry.get(generator_name)
112            .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
113        
114        Ok(generator.param_schema())
115    }
116}
117
118impl Default for UserAPI {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124/// 数据生成模板
125#[derive(Debug, Clone)]
126pub struct Template {
127    pub name: String,
128    pub description: String,
129    pub fields: HashMap<String, FieldTemplate>,
130}
131
132/// 字段模板
133#[derive(Debug, Clone)]
134pub struct FieldTemplate {
135    pub generator: String,
136    pub params: Option<HashMap<String, Value>>,
137    pub nullable: bool,
138    pub null_probability: f64, // 0.0 - 1.0
139}
140
141impl Template {
142    /// 创建新模板
143    pub fn new(name: String, description: String) -> Self {
144        Self {
145            name,
146            description,
147            fields: HashMap::new(),
148        }
149    }
150
151    /// 添加字段
152    pub fn add_field(&mut self, field_name: String, field_template: FieldTemplate) {
153        self.fields.insert(field_name, field_template);
154    }
155
156    /// 验证模板
157    pub fn validate(&self, registry: &CustomGeneratorRegistry) -> Result<()> {
158        for (field_name, field_template) in &self.fields {
159            // 检查生成器是否存在
160            if registry.get(&field_template.generator).is_none() {
161                return Err(DataForgeError::validation(&format!(
162                    "Generator '{}' for field '{}' not found", 
163                    field_template.generator, field_name
164                )));
165            }
166
167            // 验证参数
168            if let Some(params) = &field_template.params {
169                let generator = registry.get(&field_template.generator).unwrap();
170                generator.validate_params(params)?;
171            }
172
173            // 验证null概率
174            if field_template.null_probability < 0.0 || field_template.null_probability > 1.0 {
175                return Err(DataForgeError::validation(&format!(
176                    "Invalid null probability {} for field '{}'", 
177                    field_template.null_probability, field_name
178                )));
179            }
180        }
181
182        Ok(())
183    }
184
185    /// 生成数据
186    pub fn generate(&self, registry: &CustomGeneratorRegistry, global_params: Option<&HashMap<String, Value>>) -> Result<Value> {
187        use rand::Rng;
188        let mut rng = rand::thread_rng();
189        let mut result = serde_json::Map::new();
190
191        for (field_name, field_template) in &self.fields {
192            // 检查是否应该生成null值
193            if field_template.nullable && rng.gen::<f64>() < field_template.null_probability {
194                result.insert(field_name.clone(), Value::Null);
195                continue;
196            }
197
198            // 合并全局参数和字段参数
199            let params = match (&field_template.params, global_params) {
200                (Some(field_params), Some(global_params)) => {
201                    let mut merged = global_params.clone();
202                    merged.extend(field_params.clone());
203                    Some(merged)
204                },
205                (Some(field_params), None) => Some(field_params.clone()),
206                (None, Some(global_params)) => Some(global_params.clone()),
207                (None, None) => None,
208            };
209
210            // 生成字段值
211            let value = registry.generate(&field_template.generator, params.as_ref())?;
212            result.insert(field_name.clone(), value);
213        }
214
215        Ok(Value::Object(result))
216    }
217}
218
219/// 模板构建器
220pub struct TemplateBuilder {
221    template: Template,
222}
223
224impl TemplateBuilder {
225    /// 创建新的模板构建器
226    pub fn new(name: String, description: String) -> Self {
227        Self {
228            template: Template::new(name, description),
229        }
230    }
231
232    /// 添加字段
233    pub fn field(mut self, name: &str, generator: &str) -> Self {
234        let field_template = FieldTemplate {
235            generator: generator.to_string(),
236            params: None,
237            nullable: false,
238            null_probability: 0.0,
239        };
240        self.template.add_field(name.to_string(), field_template);
241        self
242    }
243
244    /// 添加带参数的字段
245    pub fn field_with_params(mut self, name: &str, generator: &str, params: HashMap<String, Value>) -> Self {
246        let field_template = FieldTemplate {
247            generator: generator.to_string(),
248            params: Some(params),
249            nullable: false,
250            null_probability: 0.0,
251        };
252        self.template.add_field(name.to_string(), field_template);
253        self
254    }
255
256    /// 添加可空字段
257    pub fn nullable_field(mut self, name: &str, generator: &str, null_probability: f64) -> Self {
258        let field_template = FieldTemplate {
259            generator: generator.to_string(),
260            params: None,
261            nullable: true,
262            null_probability,
263        };
264        self.template.add_field(name.to_string(), field_template);
265        self
266    }
267
268    /// 构建模板
269    pub fn build(self) -> Template {
270        self.template
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::customization::RandomStringGenerator;
278
279    #[test]
280    fn test_user_api() {
281        let mut api = UserAPI::new();
282        let generator = RandomStringGenerator::new();
283        
284        assert!(api.register_generator(generator).is_ok());
285        
286        let result = api.generate("random_string", None);
287        assert!(result.is_ok());
288        
289        let batch = api.generate_batch("random_string", 5, None);
290        assert!(batch.is_ok());
291        assert_eq!(batch.unwrap().len(), 5);
292    }
293
294    #[test]
295    fn test_template_creation() {
296        let mut api = UserAPI::new();
297        let generator = RandomStringGenerator::new();
298        api.register_generator(generator).unwrap();
299        
300        let template = TemplateBuilder::new(
301            "user_template".to_string(),
302            "User data template".to_string()
303        )
304        .field("name", "random_string")
305        .field("email", "random_string")
306        .build();
307        
308        assert!(api.create_template("user".to_string(), template).is_ok());
309        
310        let result = api.generate_from_template("user", None);
311        assert!(result.is_ok());
312        
313        if let Ok(Value::Object(obj)) = result {
314            assert!(obj.contains_key("name"));
315            assert!(obj.contains_key("email"));
316        }
317    }
318
319    #[test]
320    fn test_template_with_params() {
321        let mut api = UserAPI::new();
322        let generator = RandomStringGenerator::new();
323        api.register_generator(generator).unwrap();
324        
325        let mut params = HashMap::new();
326        params.insert("length".to_string(), Value::Number(serde_json::Number::from(10)));
327        
328        let template = TemplateBuilder::new(
329            "fixed_length_template".to_string(),
330            "Fixed length string template".to_string()
331        )
332        .field_with_params("code", "random_string", params)
333        .build();
334        
335        assert!(api.create_template("fixed_length".to_string(), template).is_ok());
336        
337        let result = api.generate_from_template("fixed_length", None);
338        assert!(result.is_ok());
339        
340        if let Ok(Value::Object(obj)) = result {
341            if let Some(Value::String(code)) = obj.get("code") {
342                assert_eq!(code.len(), 10);
343            }
344        }
345    }
346
347    #[test]
348    fn test_nullable_field() {
349        let mut api = UserAPI::new();
350        let generator = RandomStringGenerator::new();
351        api.register_generator(generator).unwrap();
352        
353        let template = TemplateBuilder::new(
354            "nullable_template".to_string(),
355            "Template with nullable field".to_string()
356        )
357        .field("required_field", "random_string")
358        .nullable_field("optional_field", "random_string", 0.5)
359        .build();
360        
361        assert!(api.create_template("nullable".to_string(), template).is_ok());
362        
363        // 生成多次以测试null概率
364        let mut null_count = 0;
365        let total_tests = 100;
366        
367        for _ in 0..total_tests {
368            let result = api.generate_from_template("nullable", None).unwrap();
369            if let Value::Object(obj) = result {
370                if let Some(Value::Null) = obj.get("optional_field") {
371                    null_count += 1;
372                }
373            }
374        }
375        
376        // 应该有大约50%的null值(允许一些偏差)
377        assert!(null_count > 20 && null_count < 80);
378    }
379}