import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
csv = 'lr_loss.csv'
def parse_csv():
with open(csv, 'r') as f:
lines = f.readlines()
lines = lines[1:]
lr = []
epochs = []
loss = []
for line in lines:
line = line.strip()
line = line.split(',')
lr.append(float(line[0]))
epochs.append(int(line[1]))
loss.append(float(line[2]))
return (lr, loss, epochs)
lr, loss, epochs = parse_csv()
plt.plot(lr, loss)
plt.yscale('log')
plt.xlabel('lr')
plt.ylabel('loss')
plt.title('lr-loss')
plt.plot(epochs, loss)
plt.yscale('log')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('epochs-loss')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_trisurf(np.array(lr), np.array(epochs), np.array(loss), linewidth=0, antialiased=False)
ax.set_xlabel('lr')
ax.set_ylabel('epochs')
ax.set_zlabel('loss')
min_loss = min(loss)
min_loss_index = loss.index(min_loss)
min_lr = lr[min_loss_index]
min_epoch = epochs[min_loss_index]
ax.scatter(min_lr, min_epoch, min_loss, color='r')
plt.title('lr-epoch-loss')
plt.show()