tblis 0.2.4

TBLIS wrapper in Rust
Documentation
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "163e1754-2bda-47b1-b40f-d07e03450788",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pytblis\n",
    "import torch\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1fa45174-7885-44c4-8fb6-f894a80dbee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "nao = 96\n",
    "e = np.cos(np.arange(nao**4) + 0.2).reshape(nao, nao, nao, nao)\n",
    "e_torch = torch.asarray(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "68c78b52-a386-4ef0-a17e-0521fc49bb32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fp(arr):\n",
    "    return np.cos(np.arange(arr.size)) @ arr.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "042162b6-5986-4e00-82ff-7a0d701dc411",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fp_torch(arr):\n",
    "    return torch.cos(torch.arange(np.prod(list(arr.size())), dtype=torch.double)) @ arr.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "980fca0e-72a0-4241-957c-8c55dbf2e186",
   "metadata": {},
   "outputs": [],
   "source": [
    "subscripts_list = [\n",
    "    \"abxy, xycd -> abcd\",  # naive gemm case, 2 * n^6\n",
    "    \"axyz, xyzb -> ab\",    # naive gemm case, 2 * n^5\n",
    "    \"axyz, bxyz -> ab\",    # naive syrk case,     n^5\n",
    "    \"axyz, ybzx -> ab\",    # comp  gemm case, 2 * n^5\n",
    "    \"axby, yacx -> abc\",   # batch gemm case, 2 * n^5\n",
    "    \"xpay, aybx -> ab\",    # complicate case, 2 * n^4\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "554f3ebc-6f31-4388-b4d5-bcf226ce4f8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NumPy einsum\n",
      "Subscripts: abxy, xycd -> abcd\n",
      "elapsed time:     2.132740 sec (avg of  5 repeats)\n",
      "fingerprint : -19471467.265266474336\n",
      "Subscripts: axyz, xyzb -> ab\n",
      "elapsed time:     0.063124 sec (avg of 20 repeats)\n",
      "fingerprint :      48.288443230390\n",
      "Subscripts: axyz, bxyz -> ab\n",
      "elapsed time:     0.293240 sec (avg of 20 repeats)\n",
      "fingerprint : -217920.505845849111\n",
      "Subscripts: axyz, ybzx -> ab\n",
      "elapsed time:     0.207656 sec (avg of 20 repeats)\n",
      "fingerprint :       2.131216642236\n",
      "Subscripts: axby, yacx -> abc\n",
      "elapsed time:    29.650491 sec (avg of  1 repeats)\n",
      "fingerprint :    -134.741201125226\n",
      "Subscripts: xpay, aybx -> ab\n",
      "elapsed time:    33.931122 sec (avg of  1 repeats)\n",
      "fingerprint :       4.640285999007\n"
     ]
    }
   ],
   "source": [
    "print(\"NumPy einsum\")\n",
    "repeat_list = [5, 20, 20, 20, 1, 1]\n",
    "for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
    "    print(f\"Subscripts: {subscripts}\")\n",
    "    t = time.time()\n",
    "    for _ in range(nrepeat):\n",
    "        v = np.einsum(subscripts, e, e, optimize=True)\n",
    "    print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
    "    print(f\"fingerprint : {fp(v):20.12f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b408c155-71c5-4b8d-9a8b-40911b2d1209",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTBLIS einsum\n",
      "Subscripts: abxy, xycd -> abcd\n",
      "elapsed time:     1.958193 sec (avg of  5 repeats)\n",
      "fingerprint : -19471467.265266474336\n",
      "Subscripts: axyz, xyzb -> ab\n",
      "elapsed time:     0.142870 sec (avg of 20 repeats)\n",
      "fingerprint :      48.288443230387\n",
      "Subscripts: axyz, bxyz -> ab\n",
      "elapsed time:     0.116316 sec (avg of 20 repeats)\n",
      "fingerprint : -217920.505846078857\n",
      "Subscripts: axyz, ybzx -> ab\n",
      "elapsed time:     0.142035 sec (avg of 20 repeats)\n",
      "fingerprint :       2.131216642223\n",
      "Subscripts: axby, yacx -> abc\n",
      "elapsed time:    29.598574 sec (avg of  1 repeats)\n",
      "fingerprint :    -134.741201125226\n",
      "Subscripts: xpay, aybx -> ab\n",
      "elapsed time:    33.830630 sec (avg of  1 repeats)\n",
      "fingerprint :       4.640285999007\n"
     ]
    }
   ],
   "source": [
    "print(\"PyTBLIS einsum\")\n",
    "repeat_list = [5, 20, 20, 20, 1, 1]\n",
    "for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
    "    print(f\"Subscripts: {subscripts}\")\n",
    "    t = time.time()\n",
    "    for _ in range(nrepeat):\n",
    "        v = pytblis.einsum(subscripts, e, e, optimize=\"greedy\")\n",
    "    print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
    "    print(f\"fingerprint : {fp(v):20.12f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e27aeb8f-5f3f-48eb-aabf-e9c9ba9eaa50",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch einsum\n",
      "Subscripts: abxy, xycd -> abcd\n",
      "elapsed time:     1.981685 sec (avg of  5 repeats)\n",
      "fingerprint : -19471467.265266474336\n",
      "Subscripts: axyz, xyzb -> ab\n",
      "elapsed time:     0.063420 sec (avg of 20 repeats)\n",
      "fingerprint :      48.288443230390\n",
      "Subscripts: axyz, bxyz -> ab\n",
      "elapsed time:     0.037417 sec (avg of 20 repeats)\n",
      "fingerprint : -217920.505845895852\n",
      "Subscripts: axyz, ybzx -> ab\n",
      "elapsed time:     0.211103 sec (avg of 20 repeats)\n",
      "fingerprint :       2.131216642236\n",
      "Subscripts: axby, yacx -> abc\n",
      "elapsed time:     0.179182 sec (avg of 20 repeats)\n",
      "fingerprint :    -134.741201125241\n",
      "Subscripts: xpay, aybx -> ab\n",
      "elapsed time:     0.106856 sec (avg of 20 repeats)\n",
      "fingerprint :       4.640285999005\n"
     ]
    }
   ],
   "source": [
    "print(\"PyTorch einsum\")\n",
    "repeat_list = [5, 20, 20, 20, 20, 20]\n",
    "for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
    "    print(f\"Subscripts: {subscripts}\")\n",
    "    t = time.time()\n",
    "    for _ in range(nrepeat):\n",
    "        v = torch.einsum(subscripts, e_torch, e_torch)\n",
    "    print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
    "    print(f\"fingerprint : {fp_torch(v):20.12f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}