treant 0.4.0

High-performance, lock-free Monte Carlo Tree Search library for Rust.
Documentation
import { useCallback, useEffect, useRef, useState } from 'react';
import BrowserOnly from '@docusaurus/BrowserOnly';
import styles from './demos.module.css';

interface ChildStat {
  mov: string;
  visits: number;
  avg_reward: number;
}

interface SearchStats {
  total_playouts: number;
  total_nodes: number;
  best_move?: string;
  children: ChildStat[];
}

interface TreeNode {
  visits: number;
  avg_reward: number;
  children: Array<{
    mov: string;
    visits: number;
    avg_reward: number;
    child?: TreeNode;
  }>;
}

function StepThroughDemoInner() {
  const { useWasm } = require('../treant/WasmProvider');
  const TreeVisualization = require('../treant/TreeVisualization').default;
  const StatsPanel = require('../treant/StatsPanel').default;
  const ParameterControls = require('../treant/ParameterControls').default;
  const PlaybackControls = require('../treant/PlaybackControls').default;

  const { wasm, ready, error } = useWasm();
  const gameRef = useRef<any>(null);
  const [c, setC] = useState(2.0);
  const [stats, setStats] = useState<SearchStats | null>(null);
  const [tree, setTree] = useState<TreeNode | null>(null);

  const createGame = useCallback(
    (exploration: number) => {
      if (!wasm) return;
      if (gameRef.current) {
        gameRef.current.free();
      }
      gameRef.current = new wasm.CountingGameWasm(exploration);
      setStats(null);
      setTree(null);
    },
    [wasm],
  );

  useEffect(() => {
    if (ready) {
      createGame(c);
    }
    return () => {
      if (gameRef.current) {
        gameRef.current.free();
        gameRef.current = null;
      }
    };
  }, [ready]); // eslint-disable-line react-hooks/exhaustive-deps

  const refresh = useCallback(() => {
    if (!gameRef.current) return;
    const s = gameRef.current.get_stats();
    const t = gameRef.current.get_tree(5);
    setStats(s);
    setTree(t);
  }, []);

  const handleStep = useCallback(() => {
    if (!gameRef.current) return;
    gameRef.current.playout_n(1);
    refresh();
  }, [refresh]);

  const handleRun = useCallback(
    (n: number) => {
      if (!gameRef.current) return;
      gameRef.current.playout_n(n);
      refresh();
    },
    [refresh],
  );

  const handleReset = useCallback(() => {
    createGame(c);
  }, [createGame, c]);

  const handleParamChange = useCallback(
    (_key: string, value: number) => {
      setC(value);
      createGame(value);
    },
    [createGame],
  );

  if (error) {
    return <div className={styles.error}>Failed to load WASM: {error}</div>;
  }

  if (!ready) {
    return <div className={styles.loading}>Loading...</div>;
  }

  return (
    <div className={styles.demo}>
      <div className={styles.section}>
        <ParameterControls
          params={{
            c: { label: 'C (exploration)', value: c, min: 0.1, max: 5.0, step: 0.1 },
          }}
          onChange={handleParamChange}
        />
      </div>

      <div className={styles.section}>
        <PlaybackControls
          onStep={handleStep}
          onRun={handleRun}
          onReset={handleReset}
          batchSizes={[10, 100, 1000]}
        />
      </div>

      {stats && (
        <div className={styles.section}>
          <StatsPanel
            totalPlayouts={stats.total_playouts}
            totalNodes={stats.total_nodes}
            bestMove={stats.best_move}
            children={stats.children}
          />
        </div>
      )}

      {tree && tree.visits > 0 && (
        <div className={styles.section}>
          <TreeVisualization tree={tree} maxDepth={5} />
        </div>
      )}
    </div>
  );
}

export default function StepThroughDemo() {
  return (
    <BrowserOnly fallback={<div className={styles.loading}>Loading...</div>}>
      {() => {
        const { WasmProvider } = require('../treant/WasmProvider');
        return (
          <WasmProvider>
            <StepThroughDemoInner />
          </WasmProvider>
        );
      }}
    </BrowserOnly>
  );
}