1use rand::prelude::*;
12use std::collections::HashSet;
13
14use super::python_enum::PythonEnumerator;
15use super::GeneratedCode;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Feature {
20 IntLiterals,
22 FloatLiterals,
24 StringLiterals,
26 BoolLiterals,
28 NoneLiteral,
30 Variables,
32 Assignments,
34 ArithmeticOps,
36 LogicalOps,
38 UnaryOps,
40 Comparisons,
42 IfStatements,
44 WhileLoops,
46 ForLoops,
48 Functions,
50 FunctionCalls,
52 Returns,
54 Lists,
56 ControlFlow,
58}
59
60impl Feature {
61 #[must_use]
63 pub fn all() -> Vec<Self> {
64 vec![
65 Self::IntLiterals,
66 Self::FloatLiterals,
67 Self::StringLiterals,
68 Self::BoolLiterals,
69 Self::NoneLiteral,
70 Self::Variables,
71 Self::Assignments,
72 Self::ArithmeticOps,
73 Self::LogicalOps,
74 Self::UnaryOps,
75 Self::Comparisons,
76 Self::IfStatements,
77 Self::WhileLoops,
78 Self::ForLoops,
79 Self::Functions,
80 Self::FunctionCalls,
81 Self::Returns,
82 Self::Lists,
83 Self::ControlFlow,
84 ]
85 }
86
87 #[must_use]
90 pub fn core() -> Vec<Self> {
91 vec![Self::IntLiterals, Self::Variables, Self::Assignments]
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct SwarmConfig {
98 pub enabled_features: HashSet<Feature>,
100 pub seed: u64,
102 pub batch_id: usize,
104}
105
106impl SwarmConfig {
107 #[must_use]
109 pub fn random(seed: u64, features_per_batch: usize, batch_id: usize) -> Self {
110 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(batch_id as u64));
111 let all_features = Feature::all();
112
113 let mut enabled: HashSet<Feature> = Feature::core().into_iter().collect();
115
116 let optional_features: Vec<Feature> = all_features
118 .into_iter()
119 .filter(|f| !enabled.contains(f))
120 .collect();
121
122 let to_select = features_per_batch.saturating_sub(enabled.len());
124 let selected: Vec<&Feature> = optional_features
125 .choose_multiple(&mut rng, to_select)
126 .collect();
127
128 for feature in selected {
129 enabled.insert(*feature);
130 }
131
132 Self {
133 enabled_features: enabled,
134 seed,
135 batch_id,
136 }
137 }
138
139 #[must_use]
141 pub fn is_enabled(&self, feature: Feature) -> bool {
142 self.enabled_features.contains(&feature)
143 }
144
145 #[must_use]
147 pub fn feature_count(&self) -> usize {
148 self.enabled_features.len()
149 }
150}
151
152#[derive(Debug)]
157pub struct SwarmGenerator {
158 max_depth: usize,
160 seed: u64,
162 features_per_batch: usize,
164 current_batch: usize,
166 stats: SwarmStats,
168}
169
170#[derive(Debug, Clone, Default)]
172pub struct SwarmStats {
173 pub batches_generated: usize,
175 pub programs_generated: usize,
177 pub feature_coverage: HashSet<Feature>,
179 pub programs_per_feature: Vec<(Feature, usize)>,
181}
182
183impl SwarmStats {
184 #[must_use]
186 pub fn coverage_percentage(&self) -> f64 {
187 let total = Feature::all().len();
188 if total == 0 {
189 return 0.0;
190 }
191 (self.feature_coverage.len() as f64 / total as f64) * 100.0
192 }
193}
194
195impl SwarmGenerator {
196 #[must_use]
198 pub fn new(max_depth: usize, features_per_batch: usize) -> Self {
199 Self {
200 max_depth,
201 seed: 42,
202 features_per_batch,
203 current_batch: 0,
204 stats: SwarmStats::default(),
205 }
206 }
207
208 #[must_use]
210 pub fn with_seed(mut self, seed: u64) -> Self {
211 self.seed = seed;
212 self
213 }
214
215 pub fn generate_batch(&mut self, batch_size: usize) -> Vec<GeneratedCode> {
217 let config = SwarmConfig::random(self.seed, self.features_per_batch, self.current_batch);
218 self.current_batch += 1;
219 self.stats.batches_generated += 1;
220
221 for feature in &config.enabled_features {
223 self.stats.feature_coverage.insert(*feature);
224 }
225
226 let programs = self.generate_with_config(&config, batch_size);
228 self.stats.programs_generated += programs.len();
229
230 programs
231 }
232
233 fn generate_with_config(&self, config: &SwarmConfig, count: usize) -> Vec<GeneratedCode> {
235 let enumerator = PythonEnumerator::new(self.max_depth);
236 let all_programs = enumerator.enumerate_programs();
237
238 let filtered: Vec<GeneratedCode> = all_programs
240 .into_iter()
241 .filter(|prog| self.matches_config(prog, config))
242 .take(count)
243 .map(|mut prog| {
244 prog.features
246 .push(format!("swarm_batch_{}", config.batch_id));
247 prog.features
248 .push(format!("swarm_features_{}", config.feature_count()));
249 prog
250 })
251 .collect();
252
253 filtered
254 }
255
256 fn matches_config(&self, prog: &GeneratedCode, config: &SwarmConfig) -> bool {
258 let used_features = self.extract_features(&prog.code);
260
261 for feature in &used_features {
263 if !config.is_enabled(*feature) {
264 return false;
265 }
266 }
267
268 true
269 }
270
271 fn extract_features(&self, code: &str) -> HashSet<Feature> {
273 let mut features = HashSet::new();
274
275 if code.chars().any(|c| c.is_ascii_digit()) {
277 features.insert(Feature::IntLiterals);
278 }
279 if code.contains('.') && code.chars().any(|c| c.is_ascii_digit()) {
280 if code
282 .split_whitespace()
283 .any(|s| s.parse::<f64>().is_ok() && s.contains('.'))
284 {
285 features.insert(Feature::FloatLiterals);
286 }
287 }
288 if code.contains('"') || code.contains('\'') {
289 features.insert(Feature::StringLiterals);
290 }
291 if code.contains("True") || code.contains("False") {
292 features.insert(Feature::BoolLiterals);
293 }
294 if code.contains("None") {
295 features.insert(Feature::NoneLiteral);
296 }
297
298 for op in ['+', '-', '*', '/', '%'] {
300 if code.contains(op) {
301 features.insert(Feature::ArithmeticOps);
302 break;
303 }
304 }
305 if code.contains("**") || code.contains("//") {
306 features.insert(Feature::ArithmeticOps);
307 }
308 if code.contains(" and ") || code.contains(" or ") {
309 features.insert(Feature::LogicalOps);
310 }
311 if code.contains("not ") {
312 features.insert(Feature::UnaryOps);
313 }
314
315 for op in ["==", "!=", "<=", ">=", " < ", " > "] {
317 if code.contains(op) {
318 features.insert(Feature::Comparisons);
319 break;
320 }
321 }
322
323 if code.contains("if ") {
325 features.insert(Feature::IfStatements);
326 }
327 if code.contains("while ") {
328 features.insert(Feature::WhileLoops);
329 }
330 if code.contains("for ") {
331 features.insert(Feature::ForLoops);
332 }
333 if code.contains("def ") {
334 features.insert(Feature::Functions);
335 }
336 if code.contains("return") {
337 features.insert(Feature::Returns);
338 }
339 if code.contains("break") || code.contains("continue") || code.contains("pass") {
340 features.insert(Feature::ControlFlow);
341 }
342
343 if code.contains('[') && code.contains(']') {
345 features.insert(Feature::Lists);
346 }
347
348 if code.contains("print(") || code.contains("len(") || code.contains("range(") {
350 features.insert(Feature::FunctionCalls);
351 }
352
353 if code.contains(" = ") {
355 features.insert(Feature::Assignments);
356 features.insert(Feature::Variables);
357 }
358
359 features
360 }
361
362 pub fn generate(&mut self, total_count: usize, batch_size: usize) -> Vec<GeneratedCode> {
364 let mut all_programs = Vec::with_capacity(total_count);
365 let num_batches = (total_count + batch_size - 1) / batch_size;
366
367 for _ in 0..num_batches {
368 let remaining = total_count - all_programs.len();
369 let this_batch_size = remaining.min(batch_size);
370 let batch = self.generate_batch(this_batch_size);
371 all_programs.extend(batch);
372
373 if all_programs.len() >= total_count {
374 break;
375 }
376 }
377
378 all_programs.truncate(total_count);
379 all_programs
380 }
381
382 #[must_use]
384 pub fn stats(&self) -> &SwarmStats {
385 &self.stats
386 }
387
388 pub fn reset(&mut self) {
390 self.current_batch = 0;
391 self.stats = SwarmStats::default();
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_feature_all() {
401 let features = Feature::all();
402 assert!(features.len() >= 15, "Should have many features");
403 }
404
405 #[test]
406 fn test_feature_core() {
407 let core = Feature::core();
408 assert!(core.contains(&Feature::IntLiterals));
409 assert!(core.contains(&Feature::Variables));
410 assert!(core.contains(&Feature::Assignments));
411 }
412
413 #[test]
414 fn test_swarm_config_random() {
415 let config = SwarmConfig::random(42, 8, 0);
416 assert!(config.feature_count() >= 3, "Should have core features");
417 assert!(config.is_enabled(Feature::IntLiterals));
418 }
419
420 #[test]
421 fn test_swarm_config_different_batches() {
422 let config1 = SwarmConfig::random(42, 8, 0);
423 let config2 = SwarmConfig::random(42, 8, 1);
424 assert_ne!(config1.enabled_features, config2.enabled_features);
426 }
427
428 #[test]
429 fn test_swarm_generator_new() {
430 let gen = SwarmGenerator::new(3, 8);
431 assert_eq!(gen.max_depth, 3);
432 assert_eq!(gen.features_per_batch, 8);
433 }
434
435 #[test]
436 fn test_swarm_generator_with_seed() {
437 let gen = SwarmGenerator::new(3, 8).with_seed(123);
438 assert_eq!(gen.seed, 123);
439 }
440
441 #[test]
442 fn test_swarm_generator_generate_batch() {
443 let mut gen = SwarmGenerator::new(2, 5).with_seed(42);
444 let programs = gen.generate_batch(10);
445 assert!(!programs.is_empty(), "Should generate some programs");
446
447 for prog in &programs {
449 assert!(
450 prog.features.iter().any(|f| f.starts_with("swarm_")),
451 "Should have swarm metadata"
452 );
453 }
454 }
455
456 #[test]
457 fn test_swarm_generator_stats() {
458 let mut gen = SwarmGenerator::new(2, 5).with_seed(42);
459 gen.generate_batch(10);
460
461 let stats = gen.stats();
462 assert_eq!(stats.batches_generated, 1);
463 assert!(stats.programs_generated > 0);
464 assert!(!stats.feature_coverage.is_empty());
465 }
466
467 #[test]
468 fn test_swarm_generator_multiple_batches() {
469 let mut gen = SwarmGenerator::new(2, 6).with_seed(42);
470
471 gen.generate_batch(5);
472 gen.generate_batch(5);
473 gen.generate_batch(5);
474
475 let stats = gen.stats();
476 assert_eq!(stats.batches_generated, 3);
477 assert!(
479 stats.coverage_percentage() > 20.0,
480 "Should have decent coverage"
481 );
482 }
483
484 #[test]
485 fn test_swarm_generator_generate() {
486 let mut gen = SwarmGenerator::new(2, 6).with_seed(42);
487 let programs = gen.generate(20, 5);
488
489 assert!(!programs.is_empty());
491 let stats = gen.stats();
492 assert!(stats.batches_generated >= 1);
493 }
494
495 #[test]
496 fn test_swarm_generator_reset() {
497 let mut gen = SwarmGenerator::new(2, 5).with_seed(42);
498 gen.generate_batch(10);
499
500 assert!(gen.stats().batches_generated > 0);
501
502 gen.reset();
503 assert_eq!(gen.stats().batches_generated, 0);
504 assert_eq!(gen.stats().programs_generated, 0);
505 }
506
507 #[test]
508 fn test_swarm_stats_coverage_percentage() {
509 let mut stats = SwarmStats::default();
510 assert!((stats.coverage_percentage() - 0.0).abs() < 0.001);
511
512 stats.feature_coverage.insert(Feature::IntLiterals);
513 stats.feature_coverage.insert(Feature::Assignments);
514 assert!(stats.coverage_percentage() > 0.0);
515 }
516
517 #[test]
518 fn test_swarm_stats_debug() {
519 let stats = SwarmStats::default();
520 let debug = format!("{:?}", stats);
521 assert!(debug.contains("SwarmStats"));
522 }
523
524 #[test]
525 fn test_swarm_config_debug() {
526 let config = SwarmConfig::random(42, 5, 0);
527 let debug = format!("{:?}", config);
528 assert!(debug.contains("SwarmConfig"));
529 }
530
531 #[test]
532 fn test_extract_features_arithmetic() {
533 let gen = SwarmGenerator::new(2, 5);
534 let features = gen.extract_features("x = 1 + 2");
535 assert!(features.contains(&Feature::ArithmeticOps));
536 assert!(features.contains(&Feature::IntLiterals));
537 assert!(features.contains(&Feature::Assignments));
538 }
539
540 #[test]
541 fn test_extract_features_control_flow() {
542 let gen = SwarmGenerator::new(2, 5);
543 let features = gen.extract_features("if x > 0:\n pass");
544 assert!(features.contains(&Feature::IfStatements));
545 assert!(features.contains(&Feature::Comparisons));
546 assert!(features.contains(&Feature::ControlFlow));
547 }
548
549 #[test]
550 fn test_extract_features_loops() {
551 let gen = SwarmGenerator::new(2, 5);
552
553 let features = gen.extract_features("while x > 0:\n x = x - 1");
554 assert!(features.contains(&Feature::WhileLoops));
555
556 let features = gen.extract_features("for i in range(10):\n pass");
557 assert!(features.contains(&Feature::ForLoops));
558 assert!(features.contains(&Feature::FunctionCalls));
559 }
560
561 #[test]
562 fn test_extract_features_functions() {
563 let gen = SwarmGenerator::new(2, 5);
564 let features = gen.extract_features("def foo():\n return 1");
565 assert!(features.contains(&Feature::Functions));
566 assert!(features.contains(&Feature::Returns));
567 }
568
569 #[test]
570 fn test_extract_features_logical() {
571 let gen = SwarmGenerator::new(2, 5);
572 let features = gen.extract_features("x = True and False");
573 assert!(features.contains(&Feature::LogicalOps));
574 assert!(features.contains(&Feature::BoolLiterals));
575 }
576
577 #[test]
578 fn test_extract_features_lists() {
579 let gen = SwarmGenerator::new(2, 5);
580 let features = gen.extract_features("x = [1, 2, 3]");
581 assert!(features.contains(&Feature::Lists));
582 }
583
584 #[test]
585 fn test_extract_features_none() {
586 let gen = SwarmGenerator::new(2, 5);
587 let features = gen.extract_features("x = None");
588 assert!(features.contains(&Feature::NoneLiteral));
589 }
590}