use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::process::Command;
use std::str::FromStr;
use crate::arena::errors::ExportError;
use crate::core::Card;
use super::{CFRState, NodeData};
const DEFAULT_FONT: &str = "Arial";
const COLOR_ROOT: &str = "lightblue"; const COLOR_CHANCE: &str = "lightgreen"; const COLOR_PLAYER: &str = "coral"; const COLOR_TERMINAL: &str = "lightgrey";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExportFormat {
Dot,
Png,
Svg,
All,
}
impl FromStr for ExportFormat {
type Err = ExportError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"dot" => Ok(ExportFormat::Dot),
"png" => Ok(ExportFormat::Png),
"svg" => Ok(ExportFormat::Svg),
"all" => Ok(ExportFormat::All),
_ => Err(ExportError::InvalidExportFormat(s.to_string())),
}
}
}
pub fn generate_dot(state: &CFRState) -> String {
let mut output = String::new();
output.push_str("digraph CFRTree {\n");
output.push_str(" // Graph styling\n");
output.push_str(" graph [rankdir=TB, splines=polyline, nodesep=1.0, ranksep=1.2, concentrate=true, compound=true];\n");
output.push_str(&format!(
" node [shape=box, style=\"rounded,filled\", fontname=\"{DEFAULT_FONT}\", margin=0.2];\n"
));
output.push_str(&format!(" edge [fontname=\"{DEFAULT_FONT}\", penwidth=1.0, labelangle=25, labeldistance=1.8, labelfloat=true];\n"));
output.push_str(" // Add legend\n");
output.push_str(" subgraph cluster_legend {\n");
output.push_str(" graph [rank=sink];\n");
output.push_str(" label=\"Legend\";\n");
output.push_str(" style=rounded;\n");
output.push_str(" color=gray;\n");
output.push_str(" margin=16;\n");
output.push_str(" node [shape=plaintext, style=\"\"];\n");
output.push_str(" legend [label=<\n");
output.push_str(" <table border=\"0\" cellborder=\"0\" cellspacing=\"2\">\n");
output.push_str(" <tr><td align=\"left\"><b>Node Types:</b></td></tr>\n");
output.push_str(
" <tr><td align=\"left\">• Root (⬢): Light Blue - Starting state</td></tr>\n",
);
output.push_str(
" <tr><td align=\"left\">• Player (□): Coral - Decision points</td></tr>\n",
);
output.push_str(
" <tr><td align=\"left\">• Chance (○): Light Green - Card deals</td></tr>\n",
);
output.push_str(
" <tr><td align=\"left\">• Terminal (⬡): Light Grey - Final states</td></tr>\n",
);
output.push_str(" <tr><td><br/></td></tr>\n");
output.push_str(" <tr><td align=\"left\"><b>Edge Properties:</b></td></tr>\n");
output.push_str(" <tr><td align=\"left\">• Labels: Action/Card</td></tr>\n");
output.push_str(" </table>\n");
output.push_str(" >];\n");
output.push_str(" }\n\n");
output.push_str(" // Node grouping\n");
output.push_str(" {rank=source; node_0;}\n");
let arena = state.arena();
for node in arena.iter() {
if let (Some(parent_idx), Some(parent_child_idx)) =
(node.get_parent(), node.get_parent_child_idx())
&& parent_idx != node.idx as usize
&& arena.get(parent_idx).get_child(parent_child_idx) != Some(node.idx as usize)
{
continue;
}
let data = node.read_data();
let (color, shape, style) = match &*data {
NodeData::Root => (COLOR_ROOT, "doubleoctagon", "filled"),
NodeData::Chance => (COLOR_CHANCE, "ellipse", "filled"),
NodeData::Player(_) => (COLOR_PLAYER, "box", "rounded,filled"),
NodeData::Terminal(_) => (COLOR_TERMINAL, "hexagon", "filled"),
};
let label = match &*data {
NodeData::Root => format!("Root Node\\nIndex: {}", node.idx),
NodeData::Chance => format!("Chance Node\\nIndex: {}", node.idx),
NodeData::Player(player_data) => {
let player_seat = player_data.player_idx;
format!("Player {} Node\\nIndex: {}", player_seat, node.idx)
}
NodeData::Terminal(td) => format!(
"Terminal Node\\nIndex: {}\\nUtility: {:.2}",
node.idx, td.total_utility
),
};
let tooltip = match &*data {
NodeData::Terminal(td) => {
format!("Utility: {:.2}", td.total_utility)
}
_ => format!("Index: {}", node.idx),
};
output.push_str(&format!(
" node_{} [label=\"{}\", shape={}, style=\"{}\", fillcolor=\"{}\", tooltip=\"{}\"];\n",
node.idx, label, shape, style, color, tooltip
));
if let NodeData::Player(_) = &*data {
output.push_str(&format!(
" {{rank=same; node_{};}} // Group player nodes\n",
node.idx
));
}
let is_chance = data.is_chance();
let is_player = data.is_player();
drop(data);
for (child_idx, child_node_idx) in node.iter_children() {
let edge_label = if is_chance {
Card::from(child_idx as u8).to_string()
} else if is_player {
if child_idx == 0 {
"Fold".to_string()
} else if child_idx == 1 {
"Check/Call".to_string()
} else {
format!("Bet/Raise {}", child_idx - 1)
}
} else {
format!("{child_idx}")
};
output.push_str(&format!(
" node_{} -> node_{} [label=\"{}\", weight=1]\n",
node.idx, child_node_idx, edge_label
));
}
}
output.push_str("}\n");
output
}
pub fn export_to_dot(state: &CFRState, output_path: &Path) -> Result<(), ExportError> {
let dot_content = generate_dot(state);
let mut file = File::create(output_path)?;
Ok(file.write_all(dot_content.as_bytes())?)
}
fn convert_with_graphviz(
dot_path: &Path,
output_path: &Path,
format: &str,
cleanup_dot: bool,
) -> Result<(), ExportError> {
let status = Command::new("dot")
.arg(format!("-T{format}"))
.arg(dot_path)
.arg("-o")
.arg(output_path)
.status()?;
if !status.success() {
return Err(ExportError::FailedToRunDot(status));
}
if cleanup_dot {
std::fs::remove_file(dot_path)?;
}
Ok(())
}
pub fn export_to_png(
state: &CFRState,
output_path: &Path,
cleanup_dot: bool,
) -> Result<(), ExportError> {
let dot_path = output_path.with_extension("dot");
export_to_dot(state, &dot_path)?;
convert_with_graphviz(&dot_path, output_path, "png", cleanup_dot)
}
pub fn export_to_svg(
state: &CFRState,
output_path: &Path,
cleanup_dot: bool,
) -> Result<(), ExportError> {
let dot_path = output_path.with_extension("dot");
export_to_dot(state, &dot_path)?;
convert_with_graphviz(&dot_path, output_path, "svg", cleanup_dot)
}
pub fn export_cfr_state(
state: &CFRState,
output_path: &Path,
format: ExportFormat,
) -> Result<(), ExportError> {
match format {
ExportFormat::Dot => export_to_dot(state, output_path),
ExportFormat::Png => export_to_png(state, output_path, true),
ExportFormat::Svg => export_to_svg(state, output_path, true),
ExportFormat::All => {
let dot_path = output_path.with_extension("dot");
let png_path = output_path.with_extension("png");
let svg_path = output_path.with_extension("svg");
export_to_dot(state, &dot_path)?;
convert_with_graphviz(&dot_path, &png_path, "png", false)?;
convert_with_graphviz(&dot_path, &svg_path, "svg", false)?;
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
use crate::arena::cfr::{CFRState, NodeData, PlayerData, TerminalData};
use std::fs;
fn create_test_cfr_state() -> CFRState {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = CFRState::new(game_state);
let player0_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
});
let player0_idx = cfr_state.add(0, 0, player0_node);
let terminal_fold = NodeData::Terminal(TerminalData::new(-10.0));
let _fold_idx = cfr_state.add(player0_idx, 0, terminal_fold);
let player0_call = cfr_state.add(player0_idx, 1, NodeData::Chance);
let player0_raise = cfr_state.add(player0_idx, 2, NodeData::Chance);
for i in 0..3 {
let player1_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
});
let player1_idx = cfr_state.add(player0_call, i, player1_node);
let p1_fold_terminal = NodeData::Terminal(TerminalData::new(15.0));
cfr_state.add(player1_idx, 0, p1_fold_terminal);
let p1_call_terminal = NodeData::Terminal(TerminalData::new(5.0));
cfr_state.add(player1_idx, 1, p1_call_terminal);
}
let player1_vs_raise = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
});
let player1_vs_raise_idx = cfr_state.add(player0_raise, 0, player1_vs_raise);
let p1_fold_vs_raise = NodeData::Terminal(TerminalData::new(20.0));
cfr_state.add(player1_vs_raise_idx, 0, p1_fold_vs_raise);
let chance_after_call_vs_raise = cfr_state.add(player1_vs_raise_idx, 1, NodeData::Chance);
let final_terminal = NodeData::Terminal(TerminalData::new(30.0));
cfr_state.add(chance_after_call_vs_raise, 0, final_terminal);
cfr_state
}
#[test]
fn test_export_to_dot_creates_file() {
let cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("test_export.dot");
let result = export_to_dot(&cfr_state, &output_path);
assert!(
result.is_ok(),
"Failed to export to DOT: {:?}",
result.err()
);
assert!(output_path.exists(), "DOT file was not created");
temp_dir.close().unwrap();
}
#[test]
fn test_different_node_types_displayed_correctly() {
let cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("test_node_types.dot");
export_to_dot(&cfr_state, &output_path).unwrap();
let content = fs::read_to_string(&output_path).unwrap();
assert!(
content.contains("Root Node"),
"Root node not properly labeled"
);
assert!(
content.contains("lightblue"),
"Root node not properly colored"
);
assert!(
content.contains("Player 0") || content.contains("Player 1"),
"Player node not properly labeled"
);
assert!(
content.contains("coral"),
"Player node not properly colored"
);
assert!(
content.contains("Chance Node"),
"Chance node not properly labeled"
);
assert!(
content.contains("lightgreen"),
"Chance node not properly colored"
);
assert!(
content.contains("Terminal Node"),
"Terminal node not properly labeled"
);
assert!(
content.contains("Utility"),
"Terminal node utility not displayed"
);
assert!(
content.contains("lightgrey"),
"Terminal node not properly colored"
);
assert!(content.contains("Fold"), "Fold action not properly labeled");
assert!(
content.contains("Check/Call"),
"Call action not properly labeled"
);
assert!(
content.contains("Bet/Raise"),
"Raise action not properly labeled"
);
assert!(
content.contains("label="),
"Edge labels not properly displayed"
);
temp_dir.close().unwrap();
}
#[test]
fn test_export_creates_different_formats() {
if std::process::Command::new("dot")
.arg("-V")
.status()
.is_err()
{
println!("Skipping test_export_creates_different_formats - Graphviz not installed");
return;
}
let cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let dot_path = temp_dir.path().join("test.dot");
let dot_result = export_to_dot(&cfr_state, &dot_path);
assert!(
dot_result.is_ok(),
"DOT export failed: {:?}",
dot_result.err()
);
assert!(dot_path.exists(), "DOT file was not created");
let png_path = temp_dir.path().join("test.png");
let png_result = export_to_png(&cfr_state, &png_path, true);
assert!(
png_result.is_ok(),
"PNG export failed: {:?}",
png_result.err()
);
assert!(png_path.exists(), "PNG file was not created");
let svg_path = temp_dir.path().join("test.svg");
let svg_result = export_to_svg(&cfr_state, &svg_path, true);
assert!(
svg_result.is_ok(),
"SVG export failed: {:?}",
svg_result.err()
);
assert!(svg_path.exists(), "SVG file was not created");
let all_base_path = temp_dir.path().join("test_all");
let all_result = export_cfr_state(&cfr_state, &all_base_path, ExportFormat::All);
assert!(
all_result.is_ok(),
"All formats export failed: {:?}",
all_result.err()
);
let all_dot_path = all_base_path.with_extension("dot");
let all_png_path = all_base_path.with_extension("png");
let all_svg_path = all_base_path.with_extension("svg");
assert!(
all_dot_path.exists(),
"DOT file not created in 'all' format at {all_dot_path:?}"
);
assert!(
all_png_path.exists(),
"PNG file not created in 'all' format at {all_png_path:?}"
);
assert!(
all_svg_path.exists(),
"SVG file not created in 'all' format at {all_svg_path:?}"
);
if !all_dot_path.exists() || !all_png_path.exists() || !all_svg_path.exists() {
println!("Directory contents:");
if let Ok(entries) = std::fs::read_dir(temp_dir.path()) {
for entry in entries.flatten() {
println!(" {:?}", entry.path());
}
}
}
temp_dir.close().unwrap();
}
#[test]
fn test_invalid_format_returns_error() {
let _cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let _invalid_path = temp_dir.path().join("invalid_format");
let result = ExportFormat::from_str("invalid_format");
assert!(result.is_err(), "Should error on invalid format string");
if let Err(e) = result {
assert!(
e.to_string().contains("Invalid export format"),
"Error message should mention invalid format"
);
}
temp_dir.close().unwrap();
}
#[test]
fn test_player_seat_labeling() {
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("player_seats.dot");
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = CFRState::new(game_state);
let player0_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
});
let player0_idx = cfr_state.add(0, 0, player0_node.clone());
let player1_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
});
let _player1_idx = cfr_state.add(player0_idx, 1, player1_node);
export_to_dot(&cfr_state, &output_path).unwrap();
let dot_content = fs::read_to_string(&output_path).unwrap();
assert!(dot_content.contains("Player 0 Node"));
assert!(dot_content.contains("Player 1 Node"));
}
#[test]
fn test_generate_dot_output() {
let cfr_state = create_test_cfr_state();
let dot_content = generate_dot(&cfr_state);
println!("Generated DOT content:\n{dot_content}");
assert!(
dot_content.starts_with("digraph CFRTree {"),
"Missing graph header"
);
assert!(dot_content.ends_with("}\n"), "Missing graph closing");
assert!(
dot_content.contains(&format!("fontname=\"{DEFAULT_FONT}\"")),
"Missing font settings"
);
assert!(
dot_content.contains("fillcolor=\"lightblue\""),
"Missing root node style"
);
assert!(
dot_content.contains("fillcolor=\"lightgreen\""),
"Missing chance node style"
);
assert!(
dot_content.contains("fillcolor=\"coral\""),
"Missing player node style"
);
assert!(
dot_content.contains("fillcolor=\"lightgrey\""),
"Missing terminal node style"
);
assert!(dot_content.contains("Root Node"), "Missing root node label");
assert!(
dot_content.contains("Player 0 Node"),
"Missing player 0 label"
);
assert!(
dot_content.contains("Player 1 Node"),
"Missing player 1 label"
);
assert!(
dot_content.contains("Terminal Node"),
"Missing terminal node label"
);
assert!(dot_content.contains("Utility:"), "Missing utility value");
assert!(
dot_content.contains("label=\"Fold\""),
"Missing fold action label"
);
assert!(
dot_content.contains("label=\"Check/Call\""),
"Missing call action label"
);
assert!(
dot_content.contains("label=\"Bet/Raise 1\""),
"Missing raise action label"
);
}
#[test]
fn test_dot_generation_matches_file_output() {
let cfr_state = create_test_cfr_state();
let dot_content = generate_dot(&cfr_state);
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("test_match.dot");
export_to_dot(&cfr_state, &output_path).unwrap();
let file_content = fs::read_to_string(&output_path).unwrap();
assert_eq!(
dot_content, file_content,
"Generated DOT content should match file output"
);
temp_dir.close().unwrap();
}
}