{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "163e1754-2bda-47b1-b40f-d07e03450788",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pytblis\n",
"import torch\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1fa45174-7885-44c4-8fb6-f894a80dbee7",
"metadata": {},
"outputs": [],
"source": [
"nao = 96\n",
"e = np.cos(np.arange(nao**4) + 0.2).reshape(nao, nao, nao, nao)\n",
"e_torch = torch.asarray(e)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "68c78b52-a386-4ef0-a17e-0521fc49bb32",
"metadata": {},
"outputs": [],
"source": [
"def fp(arr):\n",
" return np.cos(np.arange(arr.size)) @ arr.reshape(-1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "042162b6-5986-4e00-82ff-7a0d701dc411",
"metadata": {},
"outputs": [],
"source": [
"def fp_torch(arr):\n",
" return torch.cos(torch.arange(np.prod(list(arr.size())), dtype=torch.double)) @ arr.reshape(-1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "980fca0e-72a0-4241-957c-8c55dbf2e186",
"metadata": {},
"outputs": [],
"source": [
"subscripts_list = [\n",
" \"abxy, xycd -> abcd\", # naive gemm case, 2 * n^6\n",
" \"axyz, xyzb -> ab\", # naive gemm case, 2 * n^5\n",
" \"axyz, bxyz -> ab\", # naive syrk case, n^5\n",
" \"axyz, ybzx -> ab\", # comp gemm case, 2 * n^5\n",
" \"axby, yacx -> abc\", # batch gemm case, 2 * n^5\n",
" \"xpay, aybx -> ab\", # complicate case, 2 * n^4\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "554f3ebc-6f31-4388-b4d5-bcf226ce4f8b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NumPy einsum\n",
"Subscripts: abxy, xycd -> abcd\n",
"elapsed time: 2.132740 sec (avg of 5 repeats)\n",
"fingerprint : -19471467.265266474336\n",
"Subscripts: axyz, xyzb -> ab\n",
"elapsed time: 0.063124 sec (avg of 20 repeats)\n",
"fingerprint : 48.288443230390\n",
"Subscripts: axyz, bxyz -> ab\n",
"elapsed time: 0.293240 sec (avg of 20 repeats)\n",
"fingerprint : -217920.505845849111\n",
"Subscripts: axyz, ybzx -> ab\n",
"elapsed time: 0.207656 sec (avg of 20 repeats)\n",
"fingerprint : 2.131216642236\n",
"Subscripts: axby, yacx -> abc\n",
"elapsed time: 29.650491 sec (avg of 1 repeats)\n",
"fingerprint : -134.741201125226\n",
"Subscripts: xpay, aybx -> ab\n",
"elapsed time: 33.931122 sec (avg of 1 repeats)\n",
"fingerprint : 4.640285999007\n"
]
}
],
"source": [
"print(\"NumPy einsum\")\n",
"repeat_list = [5, 20, 20, 20, 1, 1]\n",
"for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
" print(f\"Subscripts: {subscripts}\")\n",
" t = time.time()\n",
" for _ in range(nrepeat):\n",
" v = np.einsum(subscripts, e, e, optimize=True)\n",
" print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
" print(f\"fingerprint : {fp(v):20.12f}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b408c155-71c5-4b8d-9a8b-40911b2d1209",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTBLIS einsum\n",
"Subscripts: abxy, xycd -> abcd\n",
"elapsed time: 1.958193 sec (avg of 5 repeats)\n",
"fingerprint : -19471467.265266474336\n",
"Subscripts: axyz, xyzb -> ab\n",
"elapsed time: 0.142870 sec (avg of 20 repeats)\n",
"fingerprint : 48.288443230387\n",
"Subscripts: axyz, bxyz -> ab\n",
"elapsed time: 0.116316 sec (avg of 20 repeats)\n",
"fingerprint : -217920.505846078857\n",
"Subscripts: axyz, ybzx -> ab\n",
"elapsed time: 0.142035 sec (avg of 20 repeats)\n",
"fingerprint : 2.131216642223\n",
"Subscripts: axby, yacx -> abc\n",
"elapsed time: 29.598574 sec (avg of 1 repeats)\n",
"fingerprint : -134.741201125226\n",
"Subscripts: xpay, aybx -> ab\n",
"elapsed time: 33.830630 sec (avg of 1 repeats)\n",
"fingerprint : 4.640285999007\n"
]
}
],
"source": [
"print(\"PyTBLIS einsum\")\n",
"repeat_list = [5, 20, 20, 20, 1, 1]\n",
"for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
" print(f\"Subscripts: {subscripts}\")\n",
" t = time.time()\n",
" for _ in range(nrepeat):\n",
" v = pytblis.einsum(subscripts, e, e, optimize=\"greedy\")\n",
" print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
" print(f\"fingerprint : {fp(v):20.12f}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e27aeb8f-5f3f-48eb-aabf-e9c9ba9eaa50",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch einsum\n",
"Subscripts: abxy, xycd -> abcd\n",
"elapsed time: 1.981685 sec (avg of 5 repeats)\n",
"fingerprint : -19471467.265266474336\n",
"Subscripts: axyz, xyzb -> ab\n",
"elapsed time: 0.063420 sec (avg of 20 repeats)\n",
"fingerprint : 48.288443230390\n",
"Subscripts: axyz, bxyz -> ab\n",
"elapsed time: 0.037417 sec (avg of 20 repeats)\n",
"fingerprint : -217920.505845895852\n",
"Subscripts: axyz, ybzx -> ab\n",
"elapsed time: 0.211103 sec (avg of 20 repeats)\n",
"fingerprint : 2.131216642236\n",
"Subscripts: axby, yacx -> abc\n",
"elapsed time: 0.179182 sec (avg of 20 repeats)\n",
"fingerprint : -134.741201125241\n",
"Subscripts: xpay, aybx -> ab\n",
"elapsed time: 0.106856 sec (avg of 20 repeats)\n",
"fingerprint : 4.640285999005\n"
]
}
],
"source": [
"print(\"PyTorch einsum\")\n",
"repeat_list = [5, 20, 20, 20, 20, 20]\n",
"for subscripts, nrepeat in zip(subscripts_list, repeat_list):\n",
" print(f\"Subscripts: {subscripts}\")\n",
" t = time.time()\n",
" for _ in range(nrepeat):\n",
" v = torch.einsum(subscripts, e_torch, e_torch)\n",
" print(f\"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)\")\n",
" print(f\"fingerprint : {fp_torch(v):20.12f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}