Skip to main content

proof_engine/ml/
procgen.rs

1//! Procedural generation via ML models with non-ML fallbacks.
2
3use glam::Vec2;
4use super::tensor::Tensor;
5use super::model::{Model, Sequential};
6
7/// Enemy formation: list of (position, enemy_type_id).
8pub type Formation = Vec<(Vec2, u32)>;
9
10/// Generated room layout.
11#[derive(Debug, Clone)]
12pub struct RoomLayout {
13    /// 2-D grid: 0=empty, 1=wall, 2=door, 3=obstacle, 4=spawn, 5=treasure
14    pub grid: Vec<Vec<u8>>,
15    pub width: usize,
16    pub height: usize,
17}
18
19impl RoomLayout {
20    pub fn get(&self, x: usize, y: usize) -> u8 {
21        if x < self.width && y < self.height { self.grid[y][x] } else { 1 }
22    }
23}
24
25/// Stats for a generated item.
26#[derive(Debug, Clone)]
27pub struct ItemStats {
28    pub name: String,
29    pub damage: f32,
30    pub defense: f32,
31    pub speed: f32,
32    pub magic: f32,
33    pub rarity_score: f32,
34}
35
36/// Room type hint for generation.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum RoomType {
39    Combat,
40    Puzzle,
41    Treasure,
42    Boss,
43    Corridor,
44}
45
46// ── Formation Generator ─────────────────────────────────────────────────
47
48pub struct FormationGenerator {
49    pub model: Option<Model>,
50}
51
52impl FormationGenerator {
53    pub fn new() -> Self {
54        Self { model: None }
55    }
56
57    pub fn with_model(model: Model) -> Self {
58        Self { model: Some(model) }
59    }
60
61    /// Generate an enemy formation. `difficulty` in [0, 1], `enemy_types` is the
62    /// number of distinct enemy types available.
63    pub fn generate_formation(&self, difficulty: f32, enemy_types: u32) -> Formation {
64        if let Some(ref model) = self.model {
65            return self.generate_from_model(model, difficulty, enemy_types);
66        }
67        self.generate_fallback(difficulty, enemy_types)
68    }
69
70    fn generate_from_model(&self, model: &Model, difficulty: f32, enemy_types: u32) -> Formation {
71        let input = Tensor::from_vec(vec![difficulty, enemy_types as f32, 0.5, 0.5], vec![1, 4]);
72        let output = model.forward(&input);
73        // Interpret output as pairs of (x, y, type) triples
74        let mut formation = Vec::new();
75        let data = &output.data;
76        let mut i = 0;
77        while i + 2 < data.len() {
78            let x = data[i].abs() * 20.0;
79            let y = data[i + 1].abs() * 20.0;
80            let t = (data[i + 2].abs() * enemy_types as f32) as u32 % enemy_types.max(1);
81            formation.push((Vec2::new(x, y), t));
82            i += 3;
83        }
84        if formation.is_empty() {
85            return self.generate_fallback(difficulty, enemy_types);
86        }
87        formation
88    }
89
90    fn generate_fallback(&self, difficulty: f32, enemy_types: u32) -> Formation {
91        let count = (3.0 + difficulty * 10.0) as usize;
92        let mut formation = Vec::with_capacity(count);
93        let mut rng = (difficulty.to_bits() as u64).wrapping_add(1);
94
95        for i in 0..count {
96            rng ^= rng << 13;
97            rng ^= rng >> 7;
98            rng ^= rng << 17;
99            let x = (rng as u32 as f32 / u32::MAX as f32) * 16.0 + 2.0;
100            rng ^= rng << 13;
101            rng ^= rng >> 7;
102            rng ^= rng << 17;
103            let y = (rng as u32 as f32 / u32::MAX as f32) * 16.0 + 2.0;
104            let enemy_type = if enemy_types > 0 { (i as u32) % enemy_types } else { 0 };
105            formation.push((Vec2::new(x, y), enemy_type));
106        }
107        formation
108    }
109}
110
111// ── Room Layout Generator ───────────────────────────────────────────────
112
113pub struct RoomLayoutGenerator {
114    pub model: Option<Model>,
115}
116
117impl RoomLayoutGenerator {
118    pub fn new() -> Self {
119        Self { model: None }
120    }
121
122    pub fn with_model(model: Model) -> Self {
123        Self { model: Some(model) }
124    }
125
126    pub fn generate_room(&self, room_type: RoomType, width: usize, height: usize) -> RoomLayout {
127        if let Some(ref model) = self.model {
128            return self.generate_from_model(model, room_type, width, height);
129        }
130        self.generate_fallback(room_type, width, height)
131    }
132
133    fn generate_from_model(&self, model: &Model, room_type: RoomType, width: usize, height: usize) -> RoomLayout {
134        let rt = match room_type {
135            RoomType::Combat => 0.0,
136            RoomType::Puzzle => 0.25,
137            RoomType::Treasure => 0.5,
138            RoomType::Boss => 0.75,
139            RoomType::Corridor => 1.0,
140        };
141        let input = Tensor::from_vec(vec![rt, width as f32, height as f32, 0.0], vec![1, 4]);
142        let output = model.forward(&input);
143        // Try to interpret output as grid values
144        let mut grid = vec![vec![0u8; width]; height];
145        let mut idx = 0;
146        for y in 0..height {
147            for x in 0..width {
148                if idx < output.data.len() {
149                    let v = (output.data[idx].abs() * 6.0) as u8;
150                    grid[y][x] = v.min(5);
151                } else {
152                    grid[y][x] = 0;
153                }
154                idx += 1;
155            }
156        }
157        // Ensure walls on borders
158        for x in 0..width {
159            grid[0][x] = 1;
160            grid[height - 1][x] = 1;
161        }
162        for y in 0..height {
163            grid[y][0] = 1;
164            grid[y][width - 1] = 1;
165        }
166        RoomLayout { grid, width, height }
167    }
168
169    fn generate_fallback(&self, room_type: RoomType, width: usize, height: usize) -> RoomLayout {
170        let mut grid = vec![vec![0u8; width]; height];
171
172        // Border walls
173        for x in 0..width {
174            grid[0][x] = 1;
175            grid[height - 1][x] = 1;
176        }
177        for y in 0..height {
178            grid[y][0] = 1;
179            grid[y][width - 1] = 1;
180        }
181
182        // Door on left and right walls
183        if height > 2 {
184            grid[height / 2][0] = 2;
185            grid[height / 2][width - 1] = 2;
186        }
187
188        let mut rng = (width * height + room_type as usize) as u64;
189        let mut next_rng = || -> u64 {
190            rng ^= rng << 13;
191            rng ^= rng >> 7;
192            rng ^= rng << 17;
193            rng
194        };
195
196        match room_type {
197            RoomType::Combat => {
198                // Spawn points in center area
199                for _ in 0..3 {
200                    let x = 2 + (next_rng() as usize % (width.saturating_sub(4)).max(1));
201                    let y = 2 + (next_rng() as usize % (height.saturating_sub(4)).max(1));
202                    grid[y][x] = 4; // spawn
203                }
204                // A few obstacles
205                for _ in 0..2 {
206                    let x = 2 + (next_rng() as usize % (width.saturating_sub(4)).max(1));
207                    let y = 2 + (next_rng() as usize % (height.saturating_sub(4)).max(1));
208                    grid[y][x] = 3; // obstacle
209                }
210            }
211            RoomType::Puzzle => {
212                // Grid of obstacles
213                for y in (2..height - 2).step_by(2) {
214                    for x in (2..width - 2).step_by(2) {
215                        grid[y][x] = 3;
216                    }
217                }
218            }
219            RoomType::Treasure => {
220                // Treasure in center
221                let cx = width / 2;
222                let cy = height / 2;
223                grid[cy][cx] = 5;
224                // Surround with obstacles
225                if cx > 0 { grid[cy][cx - 1] = 3; }
226                if cx + 1 < width { grid[cy][cx + 1] = 3; }
227                if cy > 0 { grid[cy - 1][cx] = 3; }
228                if cy + 1 < height { grid[cy + 1][cx] = 3; }
229            }
230            RoomType::Boss => {
231                // Large open area with spawn in center
232                let cx = width / 2;
233                let cy = height / 2;
234                grid[cy][cx] = 4;
235                // Pillars in corners
236                if width > 4 && height > 4 {
237                    grid[2][2] = 3;
238                    grid[2][width - 3] = 3;
239                    grid[height - 3][2] = 3;
240                    grid[height - 3][width - 3] = 3;
241                }
242            }
243            RoomType::Corridor => {
244                // Fill top and bottom halves with walls, leave center row open
245                for y in 2..height / 2 {
246                    for x in 1..width - 1 {
247                        grid[y][x] = 1;
248                    }
249                }
250                for y in height / 2 + 2..height - 1 {
251                    for x in 1..width - 1 {
252                        grid[y][x] = 1;
253                    }
254                }
255            }
256        }
257
258        RoomLayout { grid, width, height }
259    }
260}
261
262// ── Item Stat Generator ─────────────────────────────────────────────────
263
264pub struct ItemStatGenerator {
265    pub model: Option<Model>,
266}
267
268impl ItemStatGenerator {
269    pub fn new() -> Self {
270        Self { model: None }
271    }
272
273    /// Generate item stats from a rarity level [0,1] and a type hint string.
274    pub fn generate_item(&self, rarity: f32, type_hint: &str) -> ItemStats {
275        let type_val = match type_hint {
276            "sword" => 0.0,
277            "shield" => 0.2,
278            "staff" => 0.4,
279            "bow" => 0.6,
280            "armor" => 0.8,
281            _ => 0.5,
282        };
283
284        if let Some(ref model) = self.model {
285            let input = Tensor::from_vec(vec![rarity, type_val, 0.5, 0.5], vec![1, 4]);
286            let out = model.forward(&input);
287            let d = &out.data;
288            return ItemStats {
289                name: format!("{}_r{}", type_hint, (rarity * 100.0) as u32),
290                damage: d.first().copied().unwrap_or(0.0).abs() * 100.0 * rarity,
291                defense: d.get(1).copied().unwrap_or(0.0).abs() * 80.0 * rarity,
292                speed: d.get(2).copied().unwrap_or(0.5).abs() * 50.0,
293                magic: d.get(3).copied().unwrap_or(0.0).abs() * 60.0 * rarity,
294                rarity_score: rarity,
295            };
296        }
297
298        // Fallback: heuristic generation
299        let base_mult = 1.0 + rarity * 4.0;
300        let (dmg, def, spd, mag) = match type_hint {
301            "sword" => (20.0 * base_mult, 5.0, 10.0 * base_mult, 0.0),
302            "shield" => (0.0, 30.0 * base_mult, 5.0, 5.0),
303            "staff" => (5.0, 5.0, 8.0, 25.0 * base_mult),
304            "bow" => (15.0 * base_mult, 0.0, 15.0 * base_mult, 0.0),
305            "armor" => (0.0, 25.0 * base_mult, 3.0, 10.0),
306            _ => (10.0 * base_mult, 10.0 * base_mult, 10.0, 10.0),
307        };
308
309        let prefix = match rarity {
310            r if r > 0.9 => "Legendary",
311            r if r > 0.7 => "Epic",
312            r if r > 0.4 => "Rare",
313            r if r > 0.2 => "Uncommon",
314            _ => "Common",
315        };
316
317        ItemStats {
318            name: format!("{} {}", prefix, type_hint),
319            damage: dmg,
320            defense: def,
321            speed: spd,
322            magic: mag,
323            rarity_score: rarity,
324        }
325    }
326}
327
328// ── Name Generator ──────────────────────────────────────────────────────
329
330/// Character-level RNN name generator. Falls back to Markov chain heuristic.
331pub struct NameGenerator {
332    pub model: Option<Model>,
333    /// Markov chain: for each pair of chars, probability of next char.
334    pub bigram_table: Vec<Vec<(char, f32)>>,
335}
336
337impl NameGenerator {
338    pub fn new() -> Self {
339        // Build a simple bigram table from common fantasy name patterns
340        let seed_names = [
341            "aerin", "balor", "celia", "darin", "elena", "feron", "galia",
342            "heron", "irene", "jorah", "kiera", "loric", "miren", "norin",
343            "orion", "perin", "quill", "rohan", "siren", "torin", "uriel",
344            "valen", "wren", "xyla", "yoren", "zara",
345        ];
346
347        let mut bigrams: std::collections::HashMap<char, std::collections::HashMap<char, u32>> =
348            std::collections::HashMap::new();
349
350        for name in &seed_names {
351            let chars: Vec<char> = name.chars().collect();
352            for w in chars.windows(2) {
353                *bigrams.entry(w[0]).or_default().entry(w[1]).or_insert(0) += 1;
354            }
355        }
356
357        let mut bigram_table = Vec::new();
358        let mut all_chars: Vec<char> = bigrams.keys().cloned().collect();
359        all_chars.sort();
360        for &ch in &all_chars {
361            if let Some(nexts) = bigrams.get(&ch) {
362                let total: u32 = nexts.values().sum();
363                let probs: Vec<(char, f32)> = nexts.iter()
364                    .map(|(&c, &count)| (c, count as f32 / total as f32))
365                    .collect();
366                bigram_table.push(probs);
367            } else {
368                bigram_table.push(vec![]);
369            }
370        }
371
372        Self { model: None, bigram_table }
373    }
374
375    /// Generate a name given a seed and max length.
376    pub fn generate_name(&self, seed: u64, max_len: usize) -> String {
377        let vowels = ['a', 'e', 'i', 'o', 'u'];
378        let consonants = ['b', 'c', 'd', 'f', 'g', 'h', 'k', 'l', 'm', 'n', 'p', 'r', 's', 't', 'v', 'w', 'z'];
379
380        let mut rng = seed.wrapping_add(1);
381        let mut next = || -> u64 {
382            rng ^= rng << 13;
383            rng ^= rng >> 7;
384            rng ^= rng << 17;
385            rng
386        };
387
388        let max_len = max_len.max(3).min(20);
389        let name_len = 3 + (next() as usize % (max_len - 2));
390
391        let mut name = String::with_capacity(name_len);
392        let mut use_vowel = next() % 2 == 0;
393
394        for i in 0..name_len {
395            let ch = if use_vowel {
396                vowels[next() as usize % vowels.len()]
397            } else {
398                consonants[next() as usize % consonants.len()]
399            };
400            if i == 0 {
401                name.push(ch.to_uppercase().next().unwrap());
402            } else {
403                name.push(ch);
404            }
405            // Alternate consonant/vowel with occasional doubles
406            if next() % 5 != 0 {
407                use_vowel = !use_vowel;
408            }
409        }
410
411        name
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_formation_fallback() {
421        let gen = FormationGenerator::new();
422        let formation = gen.generate_formation(0.5, 3);
423        assert!(!formation.is_empty());
424        for &(pos, etype) in &formation {
425            assert!(pos.x >= 0.0 && pos.y >= 0.0);
426            assert!(etype < 3);
427        }
428    }
429
430    #[test]
431    fn test_formation_difficulty_scales() {
432        let gen = FormationGenerator::new();
433        let easy = gen.generate_formation(0.0, 2);
434        let hard = gen.generate_formation(1.0, 2);
435        assert!(hard.len() > easy.len());
436    }
437
438    #[test]
439    fn test_room_layout_combat() {
440        let gen = RoomLayoutGenerator::new();
441        let room = gen.generate_room(RoomType::Combat, 10, 10);
442        assert_eq!(room.width, 10);
443        assert_eq!(room.height, 10);
444        // Borders should be walls (1) or doors (2)
445        for x in 0..10 {
446            assert!(room.get(x, 0) == 1 || room.get(x, 0) == 2);
447        }
448        // Should contain at least one spawn point
449        let has_spawn = room.grid.iter().flatten().any(|&v| v == 4);
450        assert!(has_spawn);
451    }
452
453    #[test]
454    fn test_room_layout_treasure() {
455        let gen = RoomLayoutGenerator::new();
456        let room = gen.generate_room(RoomType::Treasure, 8, 8);
457        let has_treasure = room.grid.iter().flatten().any(|&v| v == 5);
458        assert!(has_treasure);
459    }
460
461    #[test]
462    fn test_room_layout_all_types() {
463        let gen = RoomLayoutGenerator::new();
464        for rt in &[RoomType::Combat, RoomType::Puzzle, RoomType::Treasure, RoomType::Boss, RoomType::Corridor] {
465            let room = gen.generate_room(*rt, 12, 12);
466            assert_eq!(room.width, 12);
467            assert_eq!(room.height, 12);
468        }
469    }
470
471    #[test]
472    fn test_item_stat_generator() {
473        let gen = ItemStatGenerator::new();
474        let sword = gen.generate_item(0.5, "sword");
475        assert!(sword.damage > 0.0);
476        assert!(sword.name.contains("sword") || sword.name.contains("Sword"));
477
478        let shield = gen.generate_item(0.8, "shield");
479        assert!(shield.defense > 0.0);
480    }
481
482    #[test]
483    fn test_item_rarity_scaling() {
484        let gen = ItemStatGenerator::new();
485        let common = gen.generate_item(0.1, "sword");
486        let legendary = gen.generate_item(0.95, "sword");
487        assert!(legendary.damage > common.damage);
488        assert!(legendary.name.contains("Legendary"));
489        assert!(common.name.contains("Common"));
490    }
491
492    #[test]
493    fn test_name_generator() {
494        let gen = NameGenerator::new();
495        let name1 = gen.generate_name(42, 8);
496        let name2 = gen.generate_name(99, 8);
497        assert!(name1.len() >= 3);
498        assert!(name2.len() >= 3);
499        assert_ne!(name1, name2);
500        // First character should be uppercase
501        assert!(name1.chars().next().unwrap().is_uppercase());
502    }
503
504    #[test]
505    fn test_name_generator_different_seeds() {
506        let gen = NameGenerator::new();
507        let names: Vec<String> = (0..10).map(|s| gen.generate_name(s, 6)).collect();
508        // At least some should be unique
509        let unique: std::collections::HashSet<_> = names.iter().collect();
510        assert!(unique.len() > 3);
511    }
512}