Skip to main content

proof_engine/graph/
level_gen.rs

1use glam::Vec2;
2use std::collections::{HashMap, HashSet, VecDeque};
3use super::graph_core::{Graph, GraphKind, NodeId, EdgeId};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum RoomType {
7    Start,
8    End,
9    Normal,
10    Treasure,
11    Boss,
12    Secret,
13}
14
15#[derive(Debug, Clone)]
16pub struct RoomNode {
17    pub room_type: RoomType,
18    pub position: Vec2,
19    pub size: Vec2,
20    pub connections: Vec<NodeId>,
21}
22
23#[derive(Debug, Clone)]
24pub struct LevelGraph {
25    pub graph: Graph<RoomNode, f32>,
26    pub rooms: HashMap<NodeId, RoomNode>,
27}
28
29impl LevelGraph {
30    pub fn new() -> Self {
31        Self {
32            graph: Graph::new(GraphKind::Undirected),
33            rooms: HashMap::new(),
34        }
35    }
36
37    pub fn room_count(&self) -> usize {
38        self.graph.node_count()
39    }
40
41    pub fn corridor_count(&self) -> usize {
42        self.graph.edge_count()
43    }
44
45    pub fn get_room(&self, id: NodeId) -> Option<&RoomNode> {
46        self.rooms.get(&id)
47    }
48
49    pub fn room_ids(&self) -> Vec<NodeId> {
50        self.graph.node_ids()
51    }
52
53    /// Check if all rooms are connected (graph is connected).
54    pub fn is_connected(&self) -> bool {
55        let ids = self.graph.node_ids();
56        if ids.is_empty() { return true; }
57        let visited: Vec<NodeId> = self.graph.bfs(ids[0]).collect();
58        visited.len() == ids.len()
59    }
60}
61
62fn pseudo_random(seed: u64, i: u64) -> f64 {
63    let mut x = seed.wrapping_mul(6364136223846793005).wrapping_add(i.wrapping_mul(1442695040888963407));
64    x ^= x >> 33;
65    x = x.wrapping_mul(0xff51afd7ed558ccd);
66    x ^= x >> 33;
67    (x as f64) / (u64::MAX as f64)
68}
69
70/// Generate a dungeon level graph.
71/// `room_count`: number of rooms
72/// `connectivity`: 0.0 = tree (minimum edges), 1.0 = many extra edges
73pub fn generate_dungeon(room_count: usize, connectivity: f32) -> LevelGraph {
74    generate_dungeon_seeded(room_count, connectivity, 12345)
75}
76
77fn generate_dungeon_seeded(room_count: usize, connectivity: f32, seed: u64) -> LevelGraph {
78    let mut level = LevelGraph::new();
79    if room_count == 0 { return level; }
80
81    let connectivity = connectivity.clamp(0.0, 1.0);
82
83    // Create room nodes with random positions
84    let spread = (room_count as f32).sqrt() * 100.0;
85    let mut node_ids = Vec::new();
86    for i in 0..room_count {
87        let x = (pseudo_random(seed, i as u64 * 2) as f32 - 0.5) * spread;
88        let y = (pseudo_random(seed, i as u64 * 2 + 1) as f32 - 0.5) * spread;
89        let pos = Vec2::new(x, y);
90
91        let room_type = if i == 0 {
92            RoomType::Start
93        } else if i == room_count - 1 {
94            RoomType::End
95        } else if pseudo_random(seed + 100, i as u64) < 0.1 {
96            RoomType::Treasure
97        } else if pseudo_random(seed + 200, i as u64) < 0.05 {
98            RoomType::Boss
99        } else if pseudo_random(seed + 300, i as u64) < 0.05 {
100            RoomType::Secret
101        } else {
102            RoomType::Normal
103        };
104
105        let size = Vec2::new(
106            40.0 + pseudo_random(seed + 400, i as u64) as f32 * 60.0,
107            40.0 + pseudo_random(seed + 500, i as u64) as f32 * 60.0,
108        );
109
110        let room = RoomNode {
111            room_type,
112            position: pos,
113            size,
114            connections: Vec::new(),
115        };
116        let nid = level.graph.add_node_with_pos(room.clone(), pos);
117        level.rooms.insert(nid, room);
118        node_ids.push(nid);
119    }
120
121    if room_count <= 1 { return level; }
122
123    // Build minimum spanning tree using Prim's algorithm for connectivity
124    let mut in_tree: HashSet<NodeId> = HashSet::new();
125    in_tree.insert(node_ids[0]);
126    let mut edges_added = Vec::new();
127
128    while in_tree.len() < room_count {
129        let mut best_dist = f32::INFINITY;
130        let mut best_pair = (node_ids[0], node_ids[1]);
131
132        for &a in &in_tree {
133            for &b in &node_ids {
134                if in_tree.contains(&b) { continue; }
135                let pa = level.graph.node_position(a);
136                let pb = level.graph.node_position(b);
137                let dist = (pa - pb).length();
138                if dist < best_dist {
139                    best_dist = dist;
140                    best_pair = (a, b);
141                }
142            }
143        }
144
145        let (a, b) = best_pair;
146        in_tree.insert(b);
147        let eid = level.graph.add_edge_weighted(a, b, best_dist, best_dist);
148        edges_added.push((a, b));
149        level.rooms.get_mut(&a).unwrap().connections.push(b);
150        level.rooms.get_mut(&b).unwrap().connections.push(a);
151    }
152
153    // Add extra edges based on connectivity
154    let max_extra = (room_count as f32 * connectivity * 1.5) as usize;
155    let mut seed_counter = seed + 10000;
156    let mut extra_added = 0;
157    for i in 0..room_count {
158        if extra_added >= max_extra { break; }
159        for j in (i + 1)..room_count {
160            if extra_added >= max_extra { break; }
161            let a = node_ids[i];
162            let b = node_ids[j];
163            if level.graph.find_edge(a, b).is_some() { continue; }
164
165            let pa = level.graph.node_position(a);
166            let pb = level.graph.node_position(b);
167            let dist = (pa - pb).length();
168            let threshold = spread * 0.3;
169
170            if dist < threshold && pseudo_random(seed_counter, (i * room_count + j) as u64) < connectivity as f64 * 0.5 {
171                seed_counter += 1;
172                level.graph.add_edge_weighted(a, b, dist, dist);
173                level.rooms.get_mut(&a).unwrap().connections.push(b);
174                level.rooms.get_mut(&b).unwrap().connections.push(a);
175                extra_added += 1;
176            }
177        }
178    }
179
180    // Force-directed layout refinement then snap to grid
181    apply_force_layout(&mut level, 50);
182    snap_to_grid(&mut level, 50.0);
183
184    level
185}
186
187fn apply_force_layout(level: &mut LevelGraph, iterations: usize) {
188    let node_ids = level.graph.node_ids();
189    let n = node_ids.len();
190    if n <= 1 { return; }
191
192    let k = 120.0f32; // optimal distance
193    let mut temperature = 200.0f32;
194
195    for _ in 0..iterations {
196        let mut displacements: HashMap<NodeId, Vec2> = HashMap::new();
197        for &nid in &node_ids {
198            displacements.insert(nid, Vec2::ZERO);
199        }
200
201        // Repulsive forces
202        for i in 0..n {
203            for j in (i + 1)..n {
204                let ni = node_ids[i];
205                let nj = node_ids[j];
206                let pi = level.graph.node_position(ni);
207                let pj = level.graph.node_position(nj);
208                let delta = pi - pj;
209                let dist = delta.length().max(1.0);
210                let force = k * k / dist;
211                let d = delta / dist * force;
212                *displacements.get_mut(&ni).unwrap() += d;
213                *displacements.get_mut(&nj).unwrap() -= d;
214            }
215        }
216
217        // Attractive forces along edges
218        for edge in level.graph.edges() {
219            let pi = level.graph.node_position(edge.from);
220            let pj = level.graph.node_position(edge.to);
221            let delta = pi - pj;
222            let dist = delta.length().max(1.0);
223            let force = dist * dist / k;
224            let d = delta / dist * force;
225            *displacements.get_mut(&edge.from).unwrap() -= d;
226            *displacements.get_mut(&edge.to).unwrap() += d;
227        }
228
229        for &nid in &node_ids {
230            let disp = displacements[&nid];
231            let len = disp.length().max(0.01);
232            let clamped = disp / len * len.min(temperature);
233            let pos = level.graph.node_position(nid) + clamped;
234            level.graph.set_node_position(nid, pos);
235        }
236
237        temperature *= 0.95;
238    }
239
240    // Update room positions
241    for &nid in &node_ids {
242        let pos = level.graph.node_position(nid);
243        if let Some(room) = level.rooms.get_mut(&nid) {
244            room.position = pos;
245        }
246    }
247}
248
249fn snap_to_grid(level: &mut LevelGraph, grid_size: f32) {
250    for nid in level.graph.node_ids() {
251        let pos = level.graph.node_position(nid);
252        let snapped = Vec2::new(
253            (pos.x / grid_size).round() * grid_size,
254            (pos.y / grid_size).round() * grid_size,
255        );
256        level.graph.set_node_position(nid, snapped);
257        if let Some(room) = level.rooms.get_mut(&nid) {
258            room.position = snapped;
259        }
260    }
261}
262
263/// Generate a corridor path between two positions.
264/// Uses L-shaped corridors (horizontal then vertical) or straight if aligned.
265pub fn corridor_path(from: Vec2, to: Vec2) -> Vec<Vec2> {
266    let dx = (to.x - from.x).abs();
267    let dy = (to.y - from.y).abs();
268
269    if dx < 1.0 || dy < 1.0 {
270        // Straight corridor
271        vec![from, to]
272    } else {
273        // L-shaped: go horizontal first, then vertical
274        let midpoint = Vec2::new(to.x, from.y);
275        vec![from, midpoint, to]
276    }
277}
278
279/// Alternative corridor: Z-shaped (horizontal, vertical, horizontal).
280pub fn corridor_path_z(from: Vec2, to: Vec2) -> Vec<Vec2> {
281    let mid_y = (from.y + to.y) / 2.0;
282    vec![
283        from,
284        Vec2::new(from.x, mid_y),
285        Vec2::new(to.x, mid_y),
286        to,
287    ]
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_generate_dungeon_basic() {
296        let level = generate_dungeon(10, 0.3);
297        assert_eq!(level.room_count(), 10);
298        assert!(level.corridor_count() >= 9); // at least spanning tree
299        assert!(level.is_connected());
300    }
301
302    #[test]
303    fn test_generate_dungeon_minimal() {
304        let level = generate_dungeon(2, 0.0);
305        assert_eq!(level.room_count(), 2);
306        assert_eq!(level.corridor_count(), 1);
307        assert!(level.is_connected());
308    }
309
310    #[test]
311    fn test_generate_dungeon_empty() {
312        let level = generate_dungeon(0, 0.5);
313        assert_eq!(level.room_count(), 0);
314    }
315
316    #[test]
317    fn test_generate_dungeon_single() {
318        let level = generate_dungeon(1, 0.5);
319        assert_eq!(level.room_count(), 1);
320        assert_eq!(level.corridor_count(), 0);
321    }
322
323    #[test]
324    fn test_connectivity_increases_edges() {
325        let low = generate_dungeon_seeded(15, 0.0, 999);
326        let high = generate_dungeon_seeded(15, 1.0, 999);
327        assert!(high.corridor_count() >= low.corridor_count());
328    }
329
330    #[test]
331    fn test_room_types() {
332        let level = generate_dungeon(20, 0.3);
333        let rooms: Vec<&RoomNode> = level.rooms.values().collect();
334        let start_count = rooms.iter().filter(|r| r.room_type == RoomType::Start).count();
335        let end_count = rooms.iter().filter(|r| r.room_type == RoomType::End).count();
336        assert_eq!(start_count, 1);
337        assert_eq!(end_count, 1);
338    }
339
340    #[test]
341    fn test_corridor_path_straight() {
342        let path = corridor_path(Vec2::new(0.0, 5.0), Vec2::new(10.0, 5.0));
343        assert_eq!(path.len(), 2);
344    }
345
346    #[test]
347    fn test_corridor_path_l_shaped() {
348        let path = corridor_path(Vec2::new(0.0, 0.0), Vec2::new(10.0, 10.0));
349        assert_eq!(path.len(), 3);
350        assert_eq!(path[1], Vec2::new(10.0, 0.0)); // horizontal then vertical
351    }
352
353    #[test]
354    fn test_corridor_path_z() {
355        let path = corridor_path_z(Vec2::new(0.0, 0.0), Vec2::new(10.0, 10.0));
356        assert_eq!(path.len(), 4);
357    }
358
359    #[test]
360    fn test_grid_snapping() {
361        let level = generate_dungeon(5, 0.0);
362        for nid in level.room_ids() {
363            let pos = level.graph.node_position(nid);
364            assert_eq!(pos.x % 50.0, 0.0, "X not snapped: {}", pos.x);
365            assert_eq!(pos.y % 50.0, 0.0, "Y not snapped: {}", pos.y);
366        }
367    }
368
369    #[test]
370    fn test_rooms_have_sizes() {
371        let level = generate_dungeon(5, 0.3);
372        for room in level.rooms.values() {
373            assert!(room.size.x >= 40.0);
374            assert!(room.size.y >= 40.0);
375        }
376    }
377}