rustrees 0.2.4

Decision trees in Rust
Documentation
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a22a5ebb-54fe-431f-bc8c-667d36f6f798",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-03-06T20:47:04.654741Z",
     "iopub.status.busy": "2023-03-06T20:47:04.654422Z",
     "iopub.status.idle": "2023-03-06T20:47:04.659468Z",
     "shell.execute_reply": "2023-03-06T20:47:04.657869Z",
     "shell.execute_reply.started": "2023-03-06T20:47:04.654711Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.metrics import r2_score, accuracy_score\n",
    "from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier\n",
    "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
    "import rustrees.decision_tree as rt_dt\n",
    "import rustrees.random_forest as rt_rf\n",
    "import time\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7339d8b2-1b14-445c-8bf3-6ed0c050437c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-03-07T16:36:23.400487Z",
     "iopub.status.busy": "2023-03-07T16:36:23.400172Z",
     "iopub.status.idle": "2023-03-07T16:36:23.407520Z",
     "shell.execute_reply": "2023-03-07T16:36:23.406510Z",
     "shell.execute_reply.started": "2023-03-07T16:36:23.400459Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "datasets = {\n",
    "    \"reg\": [\"diabetes\", \"housing\", \"dgp\"],\n",
    "    \"clf\": [\"breast_cancer\", \"titanic\"]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f8d469ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_dataset(dataset, problem, model, max_depth, n_repeats, n_estimators=None):\n",
    "    df_train = pd.read_csv(f\"../../datasets/{dataset}_train.csv\")\n",
    "    df_test = pd.read_csv(f\"../../datasets/{dataset}_test.csv\")\n",
    "\n",
    "    if problem == \"reg\":\n",
    "        metric_fn = r2_score\n",
    "        metric = \"r2\"\n",
    "        if model == \"dt\":\n",
    "            model_sk = DecisionTreeRegressor(max_depth=max_depth)\n",
    "            model_rt = rt_dt.DecisionTreeRegressor(max_depth=max_depth)\n",
    "        elif model == \"rf\":\n",
    "            model_sk = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)\n",
    "            model_rt = rt_rf.RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)\n",
    "    elif problem == \"clf\":\n",
    "        metric_fn = accuracy_score\n",
    "        metric = \"acc\"\n",
    "        if model == \"dt\":\n",
    "            model_sk = DecisionTreeClassifier(max_depth=max_depth)\n",
    "            model_rt = rt_dt.DecisionTreeClassifier(max_depth=max_depth)\n",
    "        elif model == \"rf\":\n",
    "            model_sk = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)\n",
    "            model_rt = rt_rf.RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)\n",
    "\n",
    "    start_time = time.time()\n",
    "    results_sk = []\n",
    "    for _ in range(n_repeats):\n",
    "        model_sk.fit(df_train.drop(\"target\", axis=1), df_train.target)\n",
    "        results_sk.append(metric_fn(df_test.target, model_sk.predict(df_test.drop(\"target\", axis=1))))\n",
    "    sk_time = (time.time() - start_time)/n_repeats\n",
    "    sk_mean = np.mean(results_sk)\n",
    "    sk_std = np.std(results_sk)\n",
    "    \n",
    "    start_time = time.time()\n",
    "    results_rt = []\n",
    "    for _ in range(n_repeats):\n",
    "        model_rt.fit(df_train.drop(\"target\", axis=1), df_train.target)\n",
    "        results_rt.append(metric_fn(df_test.target, model_rt.predict(df_test.drop(\"target\", axis=1))))\n",
    "    rt_time = (time.time() - start_time)/n_repeats\n",
    "    rt_mean = np.mean(results_rt)\n",
    "    rt_std = np.std(results_rt)\n",
    "        \n",
    "    return (dataset, sk_mean, rt_mean, sk_std, rt_std, sk_time, rt_time, metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8a2ae87c-9213-4c02-bd19-1a844eff5f05",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-03-07T16:36:24.510409Z",
     "iopub.status.busy": "2023-03-07T16:36:24.510122Z",
     "iopub.status.idle": "2023-03-07T16:36:24.610884Z",
     "shell.execute_reply": "2023-03-07T16:36:24.610170Z",
     "shell.execute_reply.started": "2023-03-07T16:36:24.510384Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>sk_mean</th>\n",
       "      <th>rt_mean</th>\n",
       "      <th>sk_std</th>\n",
       "      <th>rt_std</th>\n",
       "      <th>sk_time(s)</th>\n",
       "      <th>rt_time(s)</th>\n",
       "      <th>metric</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>diabetes</td>\n",
       "      <td>0.315319</td>\n",
       "      <td>0.270029</td>\n",
       "      <td>3.251468e-02</td>\n",
       "      <td>1.780794e-02</td>\n",
       "      <td>0.002659</td>\n",
       "      <td>0.003520</td>\n",
       "      <td>r2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>housing</td>\n",
       "      <td>0.599732</td>\n",
       "      <td>0.598390</td>\n",
       "      <td>1.336886e-16</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.042986</td>\n",
       "      <td>0.060472</td>\n",
       "      <td>r2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>dgp</td>\n",
       "      <td>0.993509</td>\n",
       "      <td>0.993510</td>\n",
       "      <td>4.440892e-16</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.056852</td>\n",
       "      <td>0.360891</td>\n",
       "      <td>r2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>breast_cancer</td>\n",
       "      <td>0.928702</td>\n",
       "      <td>0.929018</td>\n",
       "      <td>6.747068e-03</td>\n",
       "      <td>6.746612e-03</td>\n",
       "      <td>0.004165</td>\n",
       "      <td>0.006442</td>\n",
       "      <td>acc</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>titanic</td>\n",
       "      <td>0.786441</td>\n",
       "      <td>0.806780</td>\n",
       "      <td>1.110223e-16</td>\n",
       "      <td>3.330669e-16</td>\n",
       "      <td>0.002300</td>\n",
       "      <td>0.002896</td>\n",
       "      <td>acc</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         dataset   sk_mean   rt_mean        sk_std        rt_std  sk_time(s)  \\\n",
       "0       diabetes  0.315319  0.270029  3.251468e-02  1.780794e-02    0.002659   \n",
       "1        housing  0.599732  0.598390  1.336886e-16  0.000000e+00    0.042986   \n",
       "2            dgp  0.993509  0.993510  4.440892e-16  0.000000e+00    0.056852   \n",
       "3  breast_cancer  0.928702  0.929018  6.747068e-03  6.746612e-03    0.004165   \n",
       "4        titanic  0.786441  0.806780  1.110223e-16  3.330669e-16    0.002300   \n",
       "\n",
       "   rt_time(s) metric  \n",
       "0    0.003520     r2  \n",
       "1    0.060472     r2  \n",
       "2    0.360891     r2  \n",
       "3    0.006442    acc  \n",
       "4    0.002896    acc  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_reg = [evaluate_dataset(d, \"reg\", model=\"dt\", max_depth=5, n_repeats=100) for d in datasets[\"reg\"]]\n",
    "results_clf = [evaluate_dataset(d, \"clf\", model=\"dt\", max_depth=5, n_repeats=100) for d in datasets[\"clf\"]]\n",
    "results = results_reg + results_clf\n",
    "\n",
    "cols = \"dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric\".split()\n",
    "\n",
    "pd.DataFrame(results, columns=cols)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "75c713ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>sk_mean</th>\n",
       "      <th>rt_mean</th>\n",
       "      <th>sk_std</th>\n",
       "      <th>rt_std</th>\n",
       "      <th>sk_time(s)</th>\n",
       "      <th>rt_time(s)</th>\n",
       "      <th>metric</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>diabetes</td>\n",
       "      <td>0.437938</td>\n",
       "      <td>0.432859</td>\n",
       "      <td>0.009338</td>\n",
       "      <td>0.005773</td>\n",
       "      <td>0.114510</td>\n",
       "      <td>0.010676</td>\n",
       "      <td>r2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>housing</td>\n",
       "      <td>0.439645</td>\n",
       "      <td>0.440555</td>\n",
       "      <td>0.000613</td>\n",
       "      <td>0.000857</td>\n",
       "      <td>0.255593</td>\n",
       "      <td>0.401618</td>\n",
       "      <td>r2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>dgp</td>\n",
       "      <td>0.756377</td>\n",
       "      <td>0.756061</td>\n",
       "      <td>0.000342</td>\n",
       "      <td>0.000276</td>\n",
       "      <td>0.322776</td>\n",
       "      <td>2.913919</td>\n",
       "      <td>r2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>breast_cancer</td>\n",
       "      <td>0.946667</td>\n",
       "      <td>0.937193</td>\n",
       "      <td>0.003438</td>\n",
       "      <td>0.003663</td>\n",
       "      <td>0.126519</td>\n",
       "      <td>0.025618</td>\n",
       "      <td>acc</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>titanic</td>\n",
       "      <td>0.763390</td>\n",
       "      <td>0.772881</td>\n",
       "      <td>0.004982</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.140300</td>\n",
       "      <td>0.011944</td>\n",
       "      <td>acc</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         dataset   sk_mean   rt_mean    sk_std    rt_std  sk_time(s)  \\\n",
       "0       diabetes  0.437938  0.432859  0.009338  0.005773    0.114510   \n",
       "1        housing  0.439645  0.440555  0.000613  0.000857    0.255593   \n",
       "2            dgp  0.756377  0.756061  0.000342  0.000276    0.322776   \n",
       "3  breast_cancer  0.946667  0.937193  0.003438  0.003663    0.126519   \n",
       "4        titanic  0.763390  0.772881  0.004982  0.000000    0.140300   \n",
       "\n",
       "   rt_time(s) metric  \n",
       "0    0.010676     r2  \n",
       "1    0.401618     r2  \n",
       "2    2.913919     r2  \n",
       "3    0.025618    acc  \n",
       "4    0.011944    acc  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_reg = [evaluate_dataset(d, \"reg\", model=\"rf\", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets[\"reg\"]]\n",
    "results_clf = [evaluate_dataset(d, \"clf\", model=\"rf\", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets[\"clf\"]]\n",
    "results = results_reg + results_clf\n",
    "\n",
    "cols = \"dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric\".split()\n",
    "\n",
    "pd.DataFrame(results, columns=cols)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b795007c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}