use super::backend::JsNullBackend;
use super::diagnostics::JsInMemoryTraceSink;
use super::graph::JsBuildGraph;
use crate::pipeline::{BuildContext, HashingStage, PipelineStage, ValidationStage};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct JsBuildContext {
backend: JsNullBackend,
trace: JsInMemoryTraceSink,
}
#[wasm_bindgen]
impl JsBuildContext {
#[wasm_bindgen(js_name = withNullBackend)]
pub fn with_null_backend() -> Self {
Self {
backend: JsNullBackend::cpu(),
trace: JsInMemoryTraceSink::new(),
}
}
#[wasm_bindgen(constructor)]
pub fn new(backend: JsNullBackend, trace: JsInMemoryTraceSink) -> Self {
Self { backend, trace }
}
#[wasm_bindgen(js_name = backendName)]
pub fn backend_name(&self) -> String {
self.backend.name()
}
#[wasm_bindgen(js_name = deviceName)]
pub fn device_name(&self) -> String {
self.backend.device()
}
#[wasm_bindgen(getter)]
pub fn trace(&self) -> JsInMemoryTraceSink {
JsInMemoryTraceSink::new() }
}
#[wasm_bindgen]
pub struct JsBuildPipeline {
stages: Vec<String>,
skip_stages: Vec<String>,
}
#[wasm_bindgen]
impl JsBuildPipeline {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
stages: Vec::new(),
skip_stages: Vec::new(),
}
}
#[wasm_bindgen]
pub fn standard() -> Self {
Self {
stages: vec!["validation".to_string(), "hashing".to_string()],
skip_stages: Vec::new(),
}
}
#[wasm_bindgen(js_name = forInference)]
pub fn for_inference() -> Self {
Self {
stages: vec!["hashing".to_string()],
skip_stages: Vec::new(),
}
}
#[wasm_bindgen(js_name = forTraining)]
pub fn for_training() -> Self {
Self {
stages: vec!["validation".to_string(), "hashing".to_string()],
skip_stages: Vec::new(),
}
}
#[wasm_bindgen(js_name = withStage)]
pub fn with_stage(&mut self, stage_name: &str) -> JsBuildPipeline {
let mut stages = self.stages.clone();
stages.push(stage_name.to_string());
Self {
stages,
skip_stages: self.skip_stages.clone(),
}
}
#[wasm_bindgen(js_name = withSkipStage)]
pub fn with_skip_stage(&mut self, stage_name: &str) -> JsBuildPipeline {
let mut skip_stages = self.skip_stages.clone();
skip_stages.push(stage_name.to_string());
Self {
stages: self.stages.clone(),
skip_stages,
}
}
#[wasm_bindgen]
pub fn execute(
&self,
ctx: &JsBuildContext,
graph: JsBuildGraph,
) -> Result<JsBuildGraph, JsError> {
let backend = ctx.backend.inner();
let trace = ctx.trace.inner();
let rust_ctx = BuildContext::new(backend, trace);
let mut rust_graph = graph.into_inner();
for stage_name in &self.stages {
if self.skip_stages.contains(stage_name) {
continue;
}
match stage_name.as_str() {
"validation" => {
let stage = ValidationStage;
stage
.execute(&rust_ctx, &mut rust_graph)
.map_err(|e| JsError::new(&format!("Validation failed: {}", e)))?;
}
"hashing" => {
let stage = HashingStage;
stage
.execute(&rust_ctx, &mut rust_graph)
.map_err(|e| JsError::new(&format!("Hashing failed: {}", e)))?;
}
_ => {
}
}
}
Ok(JsBuildGraph::from_inner(rust_graph))
}
#[wasm_bindgen(getter)]
pub fn stages(&self) -> Vec<String> {
self.stages.clone()
}
}
impl Default for JsBuildPipeline {
fn default() -> Self {
Self::new()
}
}