tblis 0.2.5

TBLIS wrapper in Rust
Documentation
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "163e1754-2bda-47b1-b40f-d07e03450788",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pytblis\n",
    "import torch\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1fa45174-7885-44c4-8fb6-f894a80dbee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "nfull = 128\n",
    "nao = 96\n",
    "e_full = np.cos(np.arange(nfull**4) + 0.2).reshape(nfull, nfull, nfull, nfull)\n",
    "e = e_full[:nao, :nao, :nao, :nao]\n",
    "e_torch = torch.asarray(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "68c78b52-a386-4ef0-a17e-0521fc49bb32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fp(arr):\n",
    "    return np.cos(np.arange(arr.size)) @ arr.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "042162b6-5986-4e00-82ff-7a0d701dc411",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fp_torch(arr):\n",
    "    return torch.cos(torch.arange(np.prod(list(arr.size())), dtype=torch.double)) @ arr.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "980fca0e-72a0-4241-957c-8c55dbf2e186",
   "metadata": {},
   "outputs": [],
   "source": [
    "subscripts_list = [\n",
    "    \"abxy, xycd -> abcd\",  # naive gemm case, 2 * n^6\n",
    "    \"axyz, xyzb -> ab\",    # naive gemm case, 2 * n^5\n",
    "    \"axyz, bxyz -> ab\",    # naive syrk case,     n^5\n",
    "    \"axyz, ybzx -> ab\",    # comp  gemm case, 2 * n^5\n",
    "    \"axby, yacx -> abc\",   # batch gemm case, 2 * n^5\n",
    "    \"xpay, aybx -> ab\",    # complicate case, 2 * n^4\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "554f3ebc-6f31-4388-b4d5-bcf226ce4f8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NumPy einsum\n",
      "Subscripts: abxy, xycd -> abcd\n",
      "elapsed time:     7.306426 sec (avg of  5 repeats)\n",
      "fingerprint :     -20.188290390819\n",
      "Subscripts: axyz, xyzb -> ab\n",
      "elapsed time:     0.776790 sec (avg of 20 repeats)\n",
      "fingerprint :      20.343405116707\n",
      "Subscripts: axyz, bxyz -> ab\n",
      "elapsed time:     0.455486 sec (avg of 20 repeats)\n",
      "fingerprint : -200211.721311474335\n",
      "Subscripts: axyz, ybzx -> ab\n",
      "elapsed time:     0.725011 sec (avg of 20 repeats)\n",
      "fingerprint :       0.274707781823\n",
      "Subscripts: axby, yacx -> abc\n",
      "elapsed time:    27.076382 sec (avg of  1 repeats)\n",
      "fingerprint :       0.466623082298\n",
      "Subscripts: xpay, aybx -> ab\n",
      "elapsed time:   248.522274 sec (avg of  1 repeats)\n",
      "fingerprint :       0.134542958876\n"
     ]
    }
   ],
   "source": [
    "print(\"NumPy einsum\")\n",
    "repeat_list = [5, 20, 20, 20, 1, 1]\n",
    "for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
    "    print(f\"Subscripts: {subscripts}\")\n",
    "    t = time.time()\n",
    "    for _ in range(nrepeat):\n",
    "        v = np.einsum(subscripts, e, e, optimize=True)\n",
    "    print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
    "    print(f\"fingerprint : {fp(v):20.12f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b408c155-71c5-4b8d-9a8b-40911b2d1209",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTBLIS einsum\n",
      "Subscripts: abxy, xycd -> abcd\n",
      "elapsed time:     1.951329 sec (avg of  5 repeats)\n",
      "fingerprint :     -20.188290390824\n",
      "Subscripts: axyz, xyzb -> ab\n",
      "elapsed time:     0.141577 sec (avg of 20 repeats)\n",
      "fingerprint :      20.343405116705\n",
      "Subscripts: axyz, bxyz -> ab\n",
      "elapsed time:     0.114764 sec (avg of 20 repeats)\n",
      "fingerprint : -200211.721311338129\n",
      "Subscripts: axyz, ybzx -> ab\n",
      "elapsed time:     0.139257 sec (avg of 20 repeats)\n",
      "fingerprint :       0.274707781825\n",
      "Subscripts: axby, yacx -> abc\n",
      "elapsed time:    27.224720 sec (avg of  1 repeats)\n",
      "fingerprint :       0.466623082298\n",
      "Subscripts: xpay, aybx -> ab\n",
      "elapsed time:   249.175256 sec (avg of  1 repeats)\n",
      "fingerprint :       0.134542958876\n"
     ]
    }
   ],
   "source": [
    "print(\"PyTBLIS einsum\")\n",
    "repeat_list = [5, 20, 20, 20, 1, 1]\n",
    "for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
    "    print(f\"Subscripts: {subscripts}\")\n",
    "    t = time.time()\n",
    "    for _ in range(nrepeat):\n",
    "        v = pytblis.einsum(subscripts, e, e, optimize=\"greedy\")\n",
    "    print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
    "    print(f\"fingerprint : {fp(v):20.12f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e27aeb8f-5f3f-48eb-aabf-e9c9ba9eaa50",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch einsum\n",
      "Subscripts: abxy, xycd -> abcd\n",
      "elapsed time:     2.105942 sec (avg of  5 repeats)\n",
      "fingerprint :     -20.188290390819\n",
      "Subscripts: axyz, xyzb -> ab\n",
      "elapsed time:     0.204239 sec (avg of 20 repeats)\n",
      "fingerprint :      20.343405116707\n",
      "Subscripts: axyz, bxyz -> ab\n",
      "elapsed time:     0.211401 sec (avg of 20 repeats)\n",
      "fingerprint : -200211.721311473870\n",
      "Subscripts: axyz, ybzx -> ab\n",
      "elapsed time:     0.406712 sec (avg of 20 repeats)\n",
      "fingerprint :       0.274707781823\n",
      "Subscripts: axby, yacx -> abc\n",
      "elapsed time:     0.263642 sec (avg of 20 repeats)\n",
      "fingerprint :       0.466623082310\n",
      "Subscripts: xpay, aybx -> ab\n",
      "elapsed time:     0.147302 sec (avg of 20 repeats)\n",
      "fingerprint :       0.134542958873\n"
     ]
    }
   ],
   "source": [
    "print(\"PyTorch einsum\")\n",
    "repeat_list = [5, 20, 20, 20, 20, 20]\n",
    "for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
    "    print(f\"Subscripts: {subscripts}\")\n",
    "    t = time.time()\n",
    "    for _ in range(nrepeat):\n",
    "        v = torch.einsum(subscripts, e_torch, e_torch)\n",
    "    print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
    "    print(f\"fingerprint : {fp_torch(v):20.12f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}