let wasmExports = null;
async function initWasm(wasmUrl) {
try {
const response = await fetch(wasmUrl);
const buffer = await response.arrayBuffer();
const { instance } = await WebAssembly.instantiate(buffer, {
env: {
},
});
wasmExports = instance.exports;
self.postMessage({ type: "ready" });
} catch (err) {
self.postMessage({ type: "error", id: -1, message: String(err) });
}
}
function localMatmul(a, b, n) {
const c = new Float32Array(n * n);
for (let i = 0; i < n; i++) {
for (let k = 0; k < n; k++) {
const aik = a[i * n + k];
for (let j = 0; j < n; j++) {
c[i * n + j] += aik * b[k * n + j];
}
}
}
return c;
}
function localStats(data) {
let sum = 0;
let min = Infinity;
let max = -Infinity;
for (let i = 0; i < data.length; i++) {
sum += data[i];
if (data[i] < min) min = data[i];
if (data[i] > max) max = data[i];
}
const mean = sum / data.length;
let variance = 0;
for (let i = 0; i < data.length; i++) {
variance += (data[i] - mean) ** 2;
}
const std = Math.sqrt(variance / data.length);
return { mean, std, min, max };
}
function localDft(re) {
const n = re.length;
const outRe = new Float32Array(n);
const outIm = new Float32Array(n);
for (let k = 0; k < n; k++) {
for (let t = 0; t < n; t++) {
const angle = (-2 * Math.PI * k * t) / n;
outRe[k] += re[t] * Math.cos(angle);
outIm[k] += re[t] * Math.sin(angle);
}
}
return { re: outRe, im: outIm };
}
self.addEventListener("message", (event) => {
const msg = event.data;
switch (msg.type) {
case "init":
initWasm(msg.wasmUrl);
break;
case "matmul": {
try {
const result = localMatmul(msg.a, msg.b, msg.n);
self.postMessage({ type: "result", id: msg.id, result }, [
result.buffer,
]);
} catch (err) {
self.postMessage({ type: "error", id: msg.id, message: String(err) });
}
break;
}
case "fft": {
try {
const { re, im } = localDft(msg.data);
self.postMessage({ type: "result", id: msg.id, result: { re, im } }, [
re.buffer,
im.buffer,
]);
} catch (err) {
self.postMessage({ type: "error", id: msg.id, message: String(err) });
}
break;
}
case "stats": {
try {
const result = localStats(msg.data);
self.postMessage({ type: "result", id: msg.id, result });
} catch (err) {
self.postMessage({ type: "error", id: msg.id, message: String(err) });
}
break;
}
default:
self.postMessage({
type: "error",
id: msg.id ?? -1,
message: `Unknown message type: ${msg.type}`,
});
}
});
function createScirs2Worker(workerUrl, wasmUrl) {
const worker = new Worker(workerUrl);
const pending = new Map();
let nextId = 0;
const ready = new Promise((resolve, reject) => {
worker.addEventListener(
"message",
(event) => {
if (event.data.type === "ready") resolve();
if (event.data.type === "error" && event.data.id === -1) reject(new Error(event.data.message));
},
{ once: true }
);
});
worker.addEventListener("message", (event) => {
const { type, id, result, message } = event.data;
if (id === -1 || id === undefined) return;
const handlers = pending.get(id);
if (!handlers) return;
pending.delete(id);
if (type === "result") handlers.resolve(result);
else handlers.reject(new Error(message));
});
function dispatch(type, payload, transferable) {
return new Promise((resolve, reject) => {
const id = nextId++;
pending.set(id, { resolve, reject });
worker.postMessage({ type, id, ...payload }, transferable ?? []);
});
}
worker.postMessage({ type: "init", wasmUrl });
return {
ready,
matmul: (a, b, n) =>
dispatch("matmul", { a, b, n }, [a.buffer, b.buffer]),
fft: (data) => dispatch("fft", { data }, [data.buffer]),
stats: (data) => dispatch("stats", { data }),
terminate: () => worker.terminate(),
};
}
if (typeof module !== "undefined") {
module.exports = { createScirs2Worker };
}