import { useCallback, useEffect, useRef, useState } from 'react';
import BrowserOnly from '@docusaurus/BrowserOnly';
import styles from './demos.module.css';
interface ChildStat {
mov: string;
visits: number;
avg_reward: number;
prior?: number;
}
interface SearchStats {
total_playouts: number;
total_nodes: number;
best_move?: string;
children: ChildStat[];
}
function UCTvsPUCTDemoInner() {
const { useWasm } = require('../treant/WasmProvider');
const BarChart = require('../treant/BarChart').default;
const ParameterControls = require('../treant/ParameterControls').default;
const PlaybackControls = require('../treant/PlaybackControls').default;
const SideBySide = require('../treant/SideBySide').default;
const StatsPanel = require('../treant/StatsPanel').default;
const { wasm, ready, error } = useWasm();
const uctRef = useRef<any>(null);
const puctRef = useRef<any>(null);
const [uctC, setUctC] = useState(2.0);
const [puctC, setPuctC] = useState(1.5);
const [uctStats, setUctStats] = useState<SearchStats | null>(null);
const [puctStats, setPuctStats] = useState<SearchStats | null>(null);
const createBoth = useCallback(
(uc: number, pc: number) => {
if (!wasm) return;
if (uctRef.current) uctRef.current.free();
if (puctRef.current) puctRef.current.free();
uctRef.current = new wasm.PriorGameUctWasm(uc);
puctRef.current = new wasm.PriorGamePuctWasm(pc);
setUctStats(null);
setPuctStats(null);
},
[wasm],
);
useEffect(() => {
if (ready) {
createBoth(uctC, puctC);
}
return () => {
if (uctRef.current) {
uctRef.current.free();
uctRef.current = null;
}
if (puctRef.current) {
puctRef.current.free();
puctRef.current = null;
}
};
}, [ready]); // eslint-disable-line react-hooks/exhaustive-deps
const refresh = useCallback(() => {
if (uctRef.current) {
setUctStats(uctRef.current.get_stats());
}
if (puctRef.current) {
setPuctStats(puctRef.current.get_stats());
}
}, []);
const handleStep = useCallback(() => {
if (uctRef.current) uctRef.current.playout_n(1);
if (puctRef.current) puctRef.current.playout_n(1);
refresh();
}, [refresh]);
const handleRun = useCallback(
(n: number) => {
if (uctRef.current) uctRef.current.playout_n(n);
if (puctRef.current) puctRef.current.playout_n(n);
refresh();
},
[refresh],
);
const handleReset = useCallback(() => {
createBoth(uctC, puctC);
}, [createBoth, uctC, puctC]);
const handleParamChange = useCallback(
(key: string, value: number) => {
if (key === 'uctC') {
setUctC(value);
createBoth(value, puctC);
} else {
setPuctC(value);
createBoth(uctC, value);
}
},
[createBoth, uctC, puctC],
);
if (error) {
return <div className={styles.error}>Failed to load WASM: {error}</div>;
}
if (!ready) {
return <div className={styles.loading}>Loading...</div>;
}
const toUctBarItems = (s: SearchStats | null) =>
s?.children.map((c) => ({
label: c.mov,
value: c.visits,
secondary: c.avg_reward,
})) ?? [];
const toPuctBarItems = (s: SearchStats | null) =>
s?.children.map((c) => ({
label: c.prior != null ? `${c.mov} (p=${c.prior.toFixed(1)})` : c.mov,
value: c.visits,
secondary: c.avg_reward,
})) ?? [];
const maxVisits = Math.max(
...(uctStats?.children.map((c) => c.visits) ?? [1]),
...(puctStats?.children.map((c) => c.visits) ?? [1]),
);
return (
<div className={styles.demo}>
<div className={styles.section}>
<ParameterControls
params={{
uctC: { label: 'UCT C', value: uctC, min: 0.1, max: 5.0, step: 0.1 },
puctC: { label: 'PUCT C', value: puctC, min: 0.1, max: 5.0, step: 0.1 },
}}
onChange={handleParamChange}
/>
</div>
<div className={styles.section}>
<PlaybackControls
onStep={handleStep}
onRun={handleRun}
onReset={handleReset}
batchSizes={[10, 100, 1000]}
/>
</div>
<div className={styles.section}>
<SideBySide
leftLabel="UCT (no priors)"
rightLabel="PUCT (with priors)"
left={
<>
<BarChart items={toUctBarItems(uctStats)} maxValue={maxVisits} />
{uctStats && (
<StatsPanel
totalPlayouts={uctStats.total_playouts}
totalNodes={uctStats.total_nodes}
bestMove={uctStats.best_move}
children={uctStats.children}
/>
)}
</>
}
right={
<>
<BarChart items={toPuctBarItems(puctStats)} maxValue={maxVisits} />
{puctStats && (
<StatsPanel
totalPlayouts={puctStats.total_playouts}
totalNodes={puctStats.total_nodes}
bestMove={puctStats.best_move}
children={puctStats.children}
/>
)}
</>
}
/>
</div>
</div>
);
}
export default function UCTvsPUCTDemo() {
return (
<BrowserOnly fallback={<div className={styles.loading}>Loading...</div>}>
{() => {
const { WasmProvider } = require('../treant/WasmProvider');
return (
<WasmProvider>
<UCTvsPUCTDemoInner />
</WasmProvider>
);
}}
</BrowserOnly>
);
}