import numpy as np
from pyts.classification import BOSSVS, SAXVSM, KNeighborsClassifier, TimeSeriesForest
def generate():
fixtures = []
rng = np.random.RandomState(42)
X_train = rng.randn(20, 30)
y_train = np.array(["A"] * 10 + ["B"] * 10)
X_test = rng.randn(6, 30)
y_test = np.array(["A"] * 3 + ["B"] * 3)
knn = KNeighborsClassifier(n_neighbors=3, metric="euclidean")
knn.fit(X_train, y_train)
predictions = knn.predict(X_test)
score = knn.score(X_test, y_test)
fixtures.append({
"test_name": "knn_euclidean",
"params": {"n_neighbors": 3, "metric": "euclidean"},
"input": {
"X": X_train.tolist(),
"y": y_train.tolist(),
"X_test": X_test.tolist(),
},
"expected": {
"predictions": predictions.tolist(),
"score": float(score),
},
"tolerance": 1e-8,
})
knn = KNeighborsClassifier(n_neighbors=1, metric="dtw")
knn.fit(X_train, y_train)
predictions = knn.predict(X_test)
score = knn.score(X_test, y_test)
fixtures.append({
"test_name": "knn_dtw",
"params": {"n_neighbors": 1, "metric": "dtw"},
"input": {
"X": X_train.tolist(),
"y": y_train.tolist(),
"X_test": X_test.tolist(),
},
"expected": {
"predictions": predictions.tolist(),
"score": float(score),
},
"tolerance": 1e-8,
})
bossvs = BOSSVS(window_size=8, word_size=4, n_bins=4)
bossvs.fit(X_train, y_train)
predictions = bossvs.predict(X_test)
score = bossvs.score(X_test, y_test)
fixtures.append({
"test_name": "bossvs_basic",
"params": {"window_size": 8, "word_size": 4, "n_bins": 4},
"input": {
"X": X_train.tolist(),
"y": y_train.tolist(),
"X_test": X_test.tolist(),
},
"expected": {
"predictions": predictions.tolist(),
"score": float(score),
},
"tolerance": 1e-6,
})
saxvsm = SAXVSM(window_size=8, word_size=4, n_bins=4)
saxvsm.fit(X_train, y_train)
predictions = saxvsm.predict(X_test)
score = saxvsm.score(X_test, y_test)
fixtures.append({
"test_name": "saxvsm_basic",
"params": {"window_size": 8, "word_size": 4, "n_bins": 4},
"input": {
"X": X_train.tolist(),
"y": y_train.tolist(),
"X_test": X_test.tolist(),
},
"expected": {
"predictions": predictions.tolist(),
"score": float(score),
},
"tolerance": 1e-6,
})
return fixtures