{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e34afb69-8f87-4073-9560-cb80c67e6ae5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import sklearn\n",
"from sklearn import linear_model\n",
"import polars\n",
"from sklearn.model_selection import train_test_split\n",
"from scipy import sparse"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "b02b8789-e927-4df5-96e1-d26706ff5fa2",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(49998, 4098)\n",
"(599976, 4096)\n",
"(599976,)\n",
"(449982, 4096) (149994, 4096)\n"
]
}
],
"source": [
"\n",
"df_it = polars.read_csv_batched(\"train.csv\", has_header=True)\n",
"ys = []\n",
"Xs = []\n",
"idxs = []\n",
"for i in range(12):\n",
" df = df_it.next_batches(1)[0]\n",
" df[0, 0:10]\n",
" print(df.shape)\n",
" y = df[:, 1].to_numpy()\n",
" ys.append(y)\n",
" X = sparse.csr_matrix(np.float32(df[:, 2:].to_numpy()))\n",
" X = sklearn.preprocessing.normalize(X)\n",
" Xs.append(X)\n",
" idx = df[:, 0]\n",
" idxs.append(idx)\n",
"\n",
"X = sparse.vstack(Xs)\n",
"y = np.hstack(ys)\n",
"\n",
"del Xs\n",
"del ys\n",
"\n",
"print(X.shape)\n",
"print(y.shape)\n",
"(X_train, X_test, y_train, y_test) = train_test_split(X, y, shuffle=False)\n",
"del X\n",
"del y\n",
"n_train = X_train.shape[0]\n",
"print(X_train.shape, X_test.shape)\n",
"\n",
"\n",
"# print(df.shape)\n",
"# print(df[:, 2:].mean())\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "4e6d2504-69d1-4135-844f-1d80a1d96e04",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"160, 171, 173, 187, 192, 196, 199, 200, 201, 202, 205, 214, 220, 223, 224, 225, 226, 227, 228, 231, 232, 233, 234, 235, 236, 237, 238, 239, 242, 243, 244, 245, 246, 249, 250, 251, 252, 333, 339, 8201, 8211, 8212, 8217, 8220, 8221, 8222, 8239\n"
]
}
],
"source": [
"from collections import Counter\n",
"counter = Counter()\n",
"\n",
"c = 0\n",
"for line in open(\"dataset/archive/sentences.prepared.csv\"):\n",
" (rid, lang, sentence) = line.strip().split(\"\\t\", 2)\n",
" if lang not in {\"fra\",\"eng\", \"ita\", \"deu\", \"esp\", \"por\"}:\n",
" continue\n",
" c += 1\n",
" if c > 100_000:\n",
" break\n",
" for chr in sentence:\n",
" if ord(chr) > 128:\n",
" counter[chr] += 1\n",
"letters = sorted(ord(letter) for (letter, count) in counter.most_common(100) if count >= 10)\n",
"print(\", \".join(map(str, letters)))\n",
"#print(letters)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "5f7408bc-da6b-4608-84e0-af67b203a748",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
" This problem is unconstrained.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"RUNNING THE L-BFGS-B CODE\n",
"\n",
" * * *\n",
"\n",
"Machine precision = 2.220D-16\n",
" N = 65552 M = 10\n",
"\n",
"At X0 0 variables are exactly at the bounds\n",
"\n",
"At iterate 0 f= 1.24762D+06 |proj g|= 2.45139D+04\n",
"\n",
"At iterate 50 f= 1.21037D+04 |proj g|= 3.55325D+02\n",
"\n",
"At iterate 100 f= 4.94206D+03 |proj g|= 1.13277D+02\n",
"\n",
"At iterate 150 f= 3.82850D+03 |proj g|= 1.43956D+01\n",
"\n",
"At iterate 200 f= 3.43115D+03 |proj g|= 1.99731D+01\n",
"\n",
"At iterate 250 f= 3.29184D+03 |proj g|= 1.39921D+01\n",
"\n",
"At iterate 300 f= 3.23884D+03 |proj g|= 4.21846D+00\n",
"\n",
"At iterate 350 f= 3.21968D+03 |proj g|= 1.86934D+00\n",
"\n",
"At iterate 400 f= 3.21289D+03 |proj g|= 7.79969D-01\n",
"\n",
"At iterate 450 f= 3.21017D+03 |proj g|= 1.19910D+00\n",
"\n",
"At iterate 500 f= 3.20925D+03 |proj g|= 7.50076D-01\n",
"\n",
" * * *\n",
"\n",
"Tit = total number of iterations\n",
"Tnf = total number of function evaluations\n",
"Tnint = total number of segments explored during Cauchy searches\n",
"Skip = number of BFGS updates skipped\n",
"Nact = number of active bounds at final generalized Cauchy point\n",
"Projg = norm of the final projected gradient\n",
"F = final function value\n",
"\n",
" * * *\n",
"\n",
" N Tit Tnf Tnint Skip Nact Projg F\n",
"65552 500 549 1 0 0 7.501D-01 3.209D+03\n",
" F = 3209.2533390579683 \n",
"\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT \n",
"\n",
"\n",
"\n",
"\n",
"\n",
"----------\n",
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fulmicoton/miniconda3/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n",
"[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 8.0min finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9997377672884693\n",
"0.9961465125271678\n"
]
}
],
"source": [
"for C in [64]:\n",
" model = sklearn.linear_model.LogisticRegression(max_iter=500, penalty='l2', multi_class='multinomial', C=C, verbose=1, class_weight='balanced') #, l1_ratio=0.1,) # penalty='elasticnet', solver='saga') \n",
" model.fit(X_train, y_train)\n",
" print(\"\\n\\n\\n\\n\\n----------\")\n",
" print(C)\n",
" print((model.predict(X_train) == y_train).mean())\n",
" print((model.predict(X_test) == y_test).mean())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "d47b1c8a-2d26-41aa-a537-91c86ecf77c6",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n",
" 'spa' 'swe' 'tur' 'vie']\n",
"[[13558 13 4 2 26 1 0 3]\n",
" [ 20 36656 23 10 27 10 0 20]\n",
" [ 2 9 11541 9 2 5 0 11]\n",
" [ 1 10 25 18634 4 28 0 41]\n",
" [ 14 9 2 4 3663 1 0 0]\n",
" [ 3 3 16 28 0 9036 0 104]\n",
" [ 0 0 0 0 0 0 21220 0]\n",
" [ 1 4 10 29 1 68 0 8532]]\n"
]
},
{
"data": {
"text/plain": [
"array([[ 209, 0, 0],\n",
" [ 0, 5127, 3],\n",
" [ 0, 1, 1704]])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import time\n",
"time.sleep(100)\n",
"print(\"start\")from sklearn import metrics\n",
"print(model.classes_)\n",
"print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n",
"\n",
"sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "6ba8d76f-4dde-45d9-b07d-b12b4355b0d6",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n",
" 'spa' 'swe' 'tur' 'vie']\n",
"[[13565 7 2 2 22 3 0 3]\n",
" [ 13 36698 8 12 20 7 0 9]\n",
" [ 1 6 11555 7 1 4 0 5]\n",
" [ 3 6 20 18666 4 22 0 29]\n",
" [ 11 12 2 4 3665 5 0 1]\n",
" [ 4 3 6 22 1 9064 0 89]\n",
" [ 0 0 0 0 0 0 21220 0]\n",
" [ 2 3 9 34 0 58 0 8539]]\n"
]
},
{
"data": {
"text/plain": [
"array([[ 209, 0, 0],\n",
" [ 0, 5128, 3],\n",
" [ 0, 2, 1703]])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"from sklearn import metrics\n",
"print(model.classes_)\n",
"print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n",
"\n",
"sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0efb14c6-50db-4bd6-bb33-cd2dcee69493",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n",
" 'spa' 'swe' 'tur' 'vie']\n",
"[[13562 11 2 3 21 2 0 3]\n",
" [ 14 36684 11 13 21 5 0 15]\n",
" [ 1 6 11551 8 2 4 0 7]\n",
" [ 3 6 21 18662 5 22 0 30]\n",
" [ 11 12 2 4 3666 5 0 0]\n",
" [ 4 3 8 23 1 9060 0 90]\n",
" [ 0 0 0 0 0 0 21220 0]\n",
" [ 2 3 8 35 0 57 0 8539]]\n"
]
},
{
"data": {
"text/plain": [
"array([[ 209, 0, 0],\n",
" [ 0, 5128, 3],\n",
" [ 0, 1, 1704]])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn import metrics\n",
"print(model.classes_)\n",
"print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n",
"\n",
"sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c4252d1b-1ed4-4ffb-8f6a-0196b32622f9",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n",
" 'spa' 'swe' 'tur' 'vie']\n",
"[[ 6866 8 2 6 10 2 0 7]\n",
" [ 4 18245 22 11 30 3 0 22]\n",
" [ 0 2 5838 9 3 7 0 12]\n",
" [ 1 6 14 9446 0 27 0 40]\n",
" [ 10 7 1 3 1770 0 0 4]\n",
" [ 1 2 10 31 4 4517 0 58]\n",
" [ 0 0 0 0 0 0 10485 0]\n",
" [ 2 4 10 34 1 43 0 4133]]\n"
]
},
{
"data": {
"text/plain": [
"array([[ 101, 0, 0],\n",
" [ 0, 2561, 5],\n",
" [ 0, 0, 812]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn import metrics\n",
"print(model.classes_)\n",
"print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n",
"\n",
"sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "476077a6-2110-447a-8197-f0848777a508",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(565,)\n",
"(75000,)\n",
"ita por 8484775\n",
"ita spa 3553712\n",
"ita por 2614727\n",
"ita spa 3213159\n",
"tur nld 4129221\n",
"por spa 6992976\n",
"por spa 7888818\n",
"deu fra 843713\n",
"eng swe 1164430\n",
"por spa 972364\n"
]
}
],
"source": [
"y_predict = model.predict(X_test)\n",
"print(np.where((y_predict == y_test) == False)[0].shape)\n",
"print(y_predict.shape)\n",
"i = 0 \n",
"for row in list(np.where((y_predict == y_test) == False))[0]:\n",
" i += 1\n",
" print(y_test[row], y_predict[row], idx[int(n_train + row)])\n",
" if i == 10:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "ac15d1d4-c839-4e96-89cc-724af639d231",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(16, 4096)\n",
"-0.0047384733\n",
"[-4.73847311e-03 -2.28266937e-03 -5.36394365e-01 5.88218350e-01\n",
" -1.03133002e-01 -5.35878639e-04 -2.13580000e+00 -2.86555783e-03\n",
" -1.19892844e-03 -3.64568806e-01 -3.39155160e-01 -5.44463552e-04\n",
" 8.38150022e-01 1.53047103e+00 5.64755284e-01 -3.03773764e-02]\n"
]
}
],
"source": [
"(LANG, DIM) = model.coef_.shape\n",
"print(model.coef_.shape)\n",
"coef = np.float32(model.coef_)\n",
"\n",
"print(coef[0,0])\n",
"print(model.coef_[:,0])\n",
"\n",
"f = open(\"src/weights.rs\", \"w\")\n",
"\n",
"f.write(\"#[derive(Clone, Copy, Debug, Eq, PartialEq)]\\n\")\n",
"f.write(\"pub enum Lang {\\n\")\n",
"for lang in model.classes_:\n",
" f.write(\"\\t%s,\\n\" % lang.capitalize(),)\n",
"f.write(\"}\\n\\n\")\n",
"\n",
"f.write(\"\"\"\n",
"impl Lang {\n",
" pub fn three_letter_code(self)-> &'static str {\n",
" match self {\n",
"\"\"\")\n",
"for lang in model.classes_:\n",
" f.write(\"\\t\\t\\tLang::%s => \\\"%s\\\",\\n\" % (lang.capitalize(), lang))\n",
"f.write(\"\\t\\t}\\t}\\n}\\n\\n\\n\")\n",
"\n",
"\n",
"f.write(\"pub const LANGUAGES: [Lang; %d] = [\\n\\t\" % LANG)\n",
"for lang in model.classes_:\n",
" f.write(\"Lang::%s, \" % lang.capitalize())\n",
"f.write(\"];\\n\\n\")\n",
"\n",
"f.write(\"pub const WEIGHTS: [f32; %d] = [\\n\" % (LANG * DIM))\n",
"for i in range(DIM):\n",
" f.write(\"\\t\")\n",
" for val in coef[:, i]:\n",
" f.write(\"%f, \" % val)\n",
" f.write(\"\\n\")\n",
"f.write(\"];\\n\\n\")\n",
"\n",
"\n",
"f.write(\"pub const INTERCEPTS: [f32; %d] = [\\n\\t\" % LANG)\n",
"for val in model.intercept_:\n",
" f.write(\"%f, \" % val)\n",
"f.write(\"];\\n\\n\")\n",
"\n",
"\n",
"f.flush()\n",
"f.close()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}