use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::process::Command;
use std::str::FromStr;
use crate::arena::errors::ExportError;
use crate::core::Card;
use super::{CFRState, NodeData};
const DEFAULT_FONT: &str = "Arial";
const COLOR_ROOT: &str = "lightblue"; const COLOR_CHANCE: &str = "lightgreen"; const COLOR_PLAYER: &str = "coral"; const COLOR_TERMINAL: &str = "lightgrey";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExportFormat {
Dot,
Png,
Svg,
All,
}
impl FromStr for ExportFormat {
type Err = ExportError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"dot" => Ok(ExportFormat::Dot),
"png" => Ok(ExportFormat::Png),
"svg" => Ok(ExportFormat::Svg),
"all" => Ok(ExportFormat::All),
_ => Err(ExportError::InvalidExportFormat(s.to_string())),
}
}
}
pub fn generate_dot(state: &CFRState) -> String {
let mut output = String::new();
output.push_str("digraph CFRTree {\n");
output.push_str(" // Graph styling\n");
output.push_str(" graph [rankdir=TB, splines=polyline, nodesep=1.0, ranksep=1.2, concentrate=true, compound=true];\n");
output.push_str(&format!(
" node [shape=box, style=\"rounded,filled\", fontname=\"{}\", margin=0.2];\n",
DEFAULT_FONT
));
output.push_str(&format!(" edge [fontname=\"{}\", penwidth=1.0, labelangle=25, labeldistance=1.8, labelfloat=true];\n", DEFAULT_FONT));
output.push_str(" // Add legend\n");
output.push_str(" subgraph cluster_legend {\n");
output.push_str(" graph [rank=sink];\n");
output.push_str(" label=\"Legend\";\n");
output.push_str(" style=rounded;\n");
output.push_str(" color=gray;\n");
output.push_str(" margin=16;\n");
output.push_str(" node [shape=plaintext, style=\"\"];\n");
output.push_str(" legend [label=<\n");
output.push_str(" <table border=\"0\" cellborder=\"0\" cellspacing=\"2\">\n");
output.push_str(" <tr><td align=\"left\"><b>Node Types:</b></td></tr>\n");
output.push_str(
" <tr><td align=\"left\">• Root (⬢): Light Blue - Starting state</td></tr>\n",
);
output.push_str(
" <tr><td align=\"left\">• Player (□): Coral - Decision points</td></tr>\n",
);
output.push_str(
" <tr><td align=\"left\">• Chance (○): Light Green - Card deals</td></tr>\n",
);
output.push_str(
" <tr><td align=\"left\">• Terminal (⬡): Light Grey - Final states</td></tr>\n",
);
output.push_str(" <tr><td><br/></td></tr>\n");
output.push_str(" <tr><td align=\"left\"><b>Edge Properties:</b></td></tr>\n");
output.push_str(" <tr><td align=\"left\">• Thickness: Usage frequency</td></tr>\n");
output.push_str(" <tr><td align=\"left\">• Labels: Action/Card</td></tr>\n");
output.push_str(" <tr><td align=\"left\">• Percent: Visit frequency</td></tr>\n");
output.push_str(" </table>\n");
output.push_str(" >];\n");
output.push_str(" }\n\n");
output.push_str(" // Node grouping\n");
output.push_str(" {rank=source; node_0;}\n");
let inner_state = state.internal_state();
let nodes = &inner_state.borrow().nodes;
for node in nodes {
let (color, shape, style) = match &node.data {
NodeData::Root => (COLOR_ROOT, "doubleoctagon", "filled"),
NodeData::Chance => (COLOR_CHANCE, "ellipse", "filled"),
NodeData::Player(_) => (COLOR_PLAYER, "box", "rounded,filled"),
NodeData::Terminal(_) => (COLOR_TERMINAL, "hexagon", "filled"),
};
let total_visits: u32 = (0..52).map(|i| node.get_count(i)).sum();
let label = match &node.data {
NodeData::Root => format!(
"Root Node\\nIndex: {}\\nTotal Visits: {}",
node.idx, total_visits
),
NodeData::Chance => format!(
"Chance Node\\nIndex: {}\\nTotal Visits: {}",
node.idx, total_visits
),
NodeData::Player(player_data) => {
let player_seat = player_data.player_idx;
format!(
"Player {} Node\\nIndex: {}\\nTotal Visits: {}",
player_seat, node.idx, total_visits
)
}
NodeData::Terminal(td) => format!(
"Terminal Node\\nIndex: {}\\nUtility: {:.2}\\nVisits: {}",
node.idx, td.total_utility, total_visits
),
};
let tooltip = match &node.data {
NodeData::Terminal(td) => format!(
"Average Utility: {:.2}",
if total_visits > 0 {
td.total_utility / total_visits as f32
} else {
0.0
}
),
_ => {
let (most_common_idx, most_common_count) = (0..52)
.map(|i| (i, node.get_count(i)))
.max_by_key(|&(_, count)| count)
.unwrap_or((0, 0));
format!(
"Most Common Action: {}\\nAction Frequency: {:.1}%",
most_common_idx,
if total_visits > 0 {
(most_common_count as f32 / total_visits as f32) * 100.0
} else {
0.0
}
)
}
};
output.push_str(&format!(
" node_{} [label=\"{}\", shape={}, style=\"{}\", fillcolor=\"{}\", tooltip=\"{}\"];\n",
node.idx, label, shape, style, color, tooltip
));
let total_count: u32 = (0..52).map(|i| node.get_count(i)).sum();
if let NodeData::Player(_) = node.data {
output.push_str(&format!(
" {{rank=same; node_{};}} // Group player nodes\n",
node.idx
));
}
for (child_idx, child_node_idx) in node.iter_children() {
let edge_label = match &node.data {
NodeData::Chance => Card::from(child_idx as u8).to_string(),
NodeData::Player(_) => {
if child_idx == 0 {
"Fold".to_string()
} else if child_idx == 1 {
"Check/Call".to_string()
} else {
format!("Bet/Raise {}", child_idx - 1)
}
}
_ => format!("{}", child_idx),
};
let count = node.get_count(child_idx);
let edge_style = if total_count > 0 {
let percentage = (count as f32 / total_count as f32) * 100.0;
let penwidth = 1.0 + (percentage / 10.0).min(8.0);
let color = format!(
"#{:02X}{:02X}FF",
(155.0 + percentage).min(255.0) as u8,
(155.0 + percentage).min(255.0) as u8
);
format!(
" [label=\"{}\", penwidth={}, color=\"{}\", tooltip=\"Frequency: {:.1}%\", xlabel=\"{:.0}%\", weight={}]",
edge_label,
penwidth,
color,
percentage,
percentage,
if percentage > 0.0 {
percentage as u32
} else {
1
}
)
} else {
format!(" [label=\"{}\", weight=1]", edge_label)
};
output.push_str(&format!(
" node_{} -> node_{}{}\n",
node.idx, child_node_idx, edge_style
));
}
}
output.push_str("}\n");
output
}
pub fn export_to_dot(state: &CFRState, output_path: &Path) -> Result<(), ExportError> {
let dot_content = generate_dot(state);
let mut file = File::create(output_path)?;
Ok(file.write_all(dot_content.as_bytes())?)
}
fn convert_with_graphviz(
dot_path: &Path,
output_path: &Path,
format: &str,
cleanup_dot: bool,
) -> Result<(), ExportError> {
let status = Command::new("dot")
.arg(format!("-T{}", format))
.arg(dot_path)
.arg("-o")
.arg(output_path)
.status()?;
if !status.success() {
return Err(ExportError::FailedToRunDot(status));
}
if cleanup_dot {
std::fs::remove_file(dot_path)?;
}
Ok(())
}
pub fn export_to_png(
state: &CFRState,
output_path: &Path,
cleanup_dot: bool,
) -> Result<(), ExportError> {
let dot_path = output_path.with_extension("dot");
export_to_dot(state, &dot_path)?;
convert_with_graphviz(&dot_path, output_path, "png", cleanup_dot)
}
pub fn export_to_svg(
state: &CFRState,
output_path: &Path,
cleanup_dot: bool,
) -> Result<(), ExportError> {
let dot_path = output_path.with_extension("dot");
export_to_dot(state, &dot_path)?;
convert_with_graphviz(&dot_path, output_path, "svg", cleanup_dot)
}
pub fn export_cfr_state(
state: &CFRState,
output_path: &Path,
format: ExportFormat,
) -> Result<(), ExportError> {
match format {
ExportFormat::Dot => export_to_dot(state, output_path),
ExportFormat::Png => export_to_png(state, output_path, true),
ExportFormat::Svg => export_to_svg(state, output_path, true),
ExportFormat::All => {
let dot_path = output_path.with_extension("dot");
let png_path = output_path.with_extension("png");
let svg_path = output_path.with_extension("svg");
export_to_dot(state, &dot_path)?;
convert_with_graphviz(&dot_path, &png_path, "png", false)?;
convert_with_graphviz(&dot_path, &svg_path, "svg", false)?;
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameState;
use crate::arena::cfr::{CFRState, NodeData, PlayerData, TerminalData};
use std::fs;
fn create_test_cfr_state() -> CFRState {
let game_state = GameState::new_starting(vec![100.0; 2], 10.0, 5.0, 0.0, 0);
let mut cfr_state = CFRState::new(game_state);
let player0_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
});
let player0_idx = cfr_state.add(0, 0, player0_node);
let terminal_fold = NodeData::Terminal(TerminalData::new(-10.0));
let _fold_idx = cfr_state.add(player0_idx, 0, terminal_fold);
let player0_call = cfr_state.add(player0_idx, 1, NodeData::Chance);
let player0_raise = cfr_state.add(player0_idx, 2, NodeData::Chance);
for i in 0..3 {
let player1_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
});
let player1_idx = cfr_state.add(player0_call, i, player1_node);
let p1_fold_terminal = NodeData::Terminal(TerminalData::new(15.0));
cfr_state.add(player1_idx, 0, p1_fold_terminal);
let p1_call_terminal = NodeData::Terminal(TerminalData::new(5.0));
cfr_state.add(player1_idx, 1, p1_call_terminal);
}
let player1_vs_raise = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
});
let player1_vs_raise_idx = cfr_state.add(player0_raise, 0, player1_vs_raise);
let p1_fold_vs_raise = NodeData::Terminal(TerminalData::new(20.0));
cfr_state.add(player1_vs_raise_idx, 0, p1_fold_vs_raise);
let chance_after_call_vs_raise = cfr_state.add(player1_vs_raise_idx, 1, NodeData::Chance);
let final_terminal = NodeData::Terminal(TerminalData::new(30.0));
cfr_state.add(chance_after_call_vs_raise, 0, final_terminal);
if let Some(mut node) = cfr_state.get_mut(player0_idx) {
node.increment_count(1); node.increment_count(2); node.increment_count(2);
}
cfr_state
}
#[test]
fn test_export_to_dot_creates_file() {
let cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("test_export.dot");
let result = export_to_dot(&cfr_state, &output_path);
assert!(
result.is_ok(),
"Failed to export to DOT: {:?}",
result.err()
);
assert!(output_path.exists(), "DOT file was not created");
temp_dir.close().unwrap();
}
#[test]
fn test_different_node_types_displayed_correctly() {
let cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("test_node_types.dot");
export_to_dot(&cfr_state, &output_path).unwrap();
let content = fs::read_to_string(&output_path).unwrap();
assert!(
content.contains("Root Node"),
"Root node not properly labeled"
);
assert!(
content.contains("lightblue"),
"Root node not properly colored"
);
assert!(
content.contains("Player 0") || content.contains("Player 1"),
"Player node not properly labeled"
);
assert!(
content.contains("coral"),
"Player node not properly colored"
);
assert!(
content.contains("Chance Node"),
"Chance node not properly labeled"
);
assert!(
content.contains("lightgreen"),
"Chance node not properly colored"
);
assert!(
content.contains("Terminal Node"),
"Terminal node not properly labeled"
);
assert!(
content.contains("Utility"),
"Terminal node utility not displayed"
);
assert!(
content.contains("lightgrey"),
"Terminal node not properly colored"
);
assert!(content.contains("Fold"), "Fold action not properly labeled");
assert!(
content.contains("Check/Call"),
"Call action not properly labeled"
);
assert!(
content.contains("Bet/Raise"),
"Raise action not properly labeled"
);
assert!(
content.contains('%'),
"Action percentages not properly displayed"
);
temp_dir.close().unwrap();
}
#[test]
fn test_export_creates_different_formats() {
if std::process::Command::new("dot")
.arg("-V")
.status()
.is_err()
{
println!("Skipping test_export_creates_different_formats - Graphviz not installed");
return;
}
let cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let dot_path = temp_dir.path().join("test.dot");
let dot_result = export_to_dot(&cfr_state, &dot_path);
assert!(
dot_result.is_ok(),
"DOT export failed: {:?}",
dot_result.err()
);
assert!(dot_path.exists(), "DOT file was not created");
let png_path = temp_dir.path().join("test.png");
let png_result = export_to_png(&cfr_state, &png_path, true);
assert!(
png_result.is_ok(),
"PNG export failed: {:?}",
png_result.err()
);
assert!(png_path.exists(), "PNG file was not created");
let svg_path = temp_dir.path().join("test.svg");
let svg_result = export_to_svg(&cfr_state, &svg_path, true);
assert!(
svg_result.is_ok(),
"SVG export failed: {:?}",
svg_result.err()
);
assert!(svg_path.exists(), "SVG file was not created");
let all_base_path = temp_dir.path().join("test_all");
let all_result = export_cfr_state(&cfr_state, &all_base_path, ExportFormat::All);
assert!(
all_result.is_ok(),
"All formats export failed: {:?}",
all_result.err()
);
let all_dot_path = all_base_path.with_extension("dot");
let all_png_path = all_base_path.with_extension("png");
let all_svg_path = all_base_path.with_extension("svg");
assert!(
all_dot_path.exists(),
"DOT file not created in 'all' format at {:?}",
all_dot_path
);
assert!(
all_png_path.exists(),
"PNG file not created in 'all' format at {:?}",
all_png_path
);
assert!(
all_svg_path.exists(),
"SVG file not created in 'all' format at {:?}",
all_svg_path
);
if !all_dot_path.exists() || !all_png_path.exists() || !all_svg_path.exists() {
println!("Directory contents:");
if let Ok(entries) = std::fs::read_dir(temp_dir.path()) {
for entry in entries.flatten() {
println!(" {:?}", entry.path());
}
}
}
temp_dir.close().unwrap();
}
#[test]
fn test_invalid_format_returns_error() {
let _cfr_state = create_test_cfr_state();
let temp_dir = tempfile::tempdir().unwrap();
let _invalid_path = temp_dir.path().join("invalid_format");
let result = ExportFormat::from_str("invalid_format");
assert!(result.is_err(), "Should error on invalid format string");
if let Err(e) = result {
assert!(
e.to_string().contains("Invalid export format"),
"Error message should mention invalid format"
);
}
temp_dir.close().unwrap();
}
#[test]
fn test_player_seat_labeling() {
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("player_seats.dot");
let game_state = GameState::new_starting(vec![100.0; 2], 10.0, 5.0, 0.0, 0);
let mut cfr_state = CFRState::new(game_state);
let player0_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
});
let player0_idx = cfr_state.add(0, 0, player0_node.clone());
let player1_node = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
});
let _player1_idx = cfr_state.add(player0_idx, 1, player1_node);
export_to_dot(&cfr_state, &output_path).unwrap();
let dot_content = fs::read_to_string(&output_path).unwrap();
assert!(dot_content.contains("Player 0 Node"));
assert!(dot_content.contains("Player 1 Node"));
}
#[test]
fn test_generate_dot_output() {
let cfr_state = create_test_cfr_state();
let dot_content = generate_dot(&cfr_state);
println!("Generated DOT content:\n{}", dot_content);
assert!(
dot_content.starts_with("digraph CFRTree {"),
"Missing graph header"
);
assert!(dot_content.ends_with("}\n"), "Missing graph closing");
assert!(
dot_content.contains(&format!("fontname=\"{}\"", DEFAULT_FONT)),
"Missing font settings"
);
assert!(
dot_content.contains("fillcolor=\"lightblue\""),
"Missing root node style"
);
assert!(
dot_content.contains("fillcolor=\"lightgreen\""),
"Missing chance node style"
);
assert!(
dot_content.contains("fillcolor=\"coral\""),
"Missing player node style"
);
assert!(
dot_content.contains("fillcolor=\"lightgrey\""),
"Missing terminal node style"
);
assert!(dot_content.contains("Root Node"), "Missing root node label");
assert!(
dot_content.contains("Player 0 Node"),
"Missing player 0 label"
);
assert!(
dot_content.contains("Player 1 Node"),
"Missing player 1 label"
);
assert!(
dot_content.contains("Terminal Node"),
"Missing terminal node label"
);
assert!(dot_content.contains("Utility:"), "Missing utility value");
assert!(
dot_content.contains("label=\"Fold\""),
"Missing fold action label"
);
assert!(
dot_content.contains("label=\"Check/Call\""),
"Missing call action label"
);
assert!(
dot_content.contains("xlabel=\"33%\""),
"Missing call percentage"
);
assert!(
dot_content.contains("label=\"Bet/Raise 1\""),
"Missing raise action label"
);
assert!(
dot_content.contains("xlabel=\"67%\""),
"Missing raise percentage"
);
assert!(
dot_content.contains("penwidth=4.333"),
"Edge thickness for call (33.3%) not correct"
);
assert!(
dot_content.contains("penwidth=7.666"),
"Edge thickness for raise (66.7%) not correct"
);
}
#[test]
fn test_dot_generation_matches_file_output() {
let cfr_state = create_test_cfr_state();
let dot_content = generate_dot(&cfr_state);
let temp_dir = tempfile::tempdir().unwrap();
let output_path = temp_dir.path().join("test_match.dot");
export_to_dot(&cfr_state, &output_path).unwrap();
let file_content = fs::read_to_string(&output_path).unwrap();
assert_eq!(
dot_content, file_content,
"Generated DOT content should match file output"
);
temp_dir.close().unwrap();
}
}