Skip to main content

hs_predict_wasm/
lib.rs

1//! WebAssembly bindings for `hs-predict`.
2//!
3//! Exposes three APIs:
4//!
5//! - **`classify_smiles(smiles)`** — SMILES → functional groups + HS heading hint
6//! - **`classify_product(product_json)`** — full rule-based pipeline (Priorities 1–3)
7//! - **`WasmSession`** — Akinator-style interactive session
8//!
9//! Build with:
10//! ```bash
11//! wasm-pack build --target web --release
12//! ```
13//!
14//! # JavaScript usage
15//! ```js
16//! import init, { classify_smiles, WasmSession } from './pkg/hs_predict_wasm.js';
17//! await init();
18//!
19//! // 1. SMILES classification
20//! const r = classify_smiles('CC(O)=O');
21//! // → { organic_class: "organic", functional_groups: ["CarboxylicAcid"],
22//! //     heading_hint: { chapter: 29, heading: 2915, confidence: 0.6, rationale: "..." } }
23//!
24//! // 2. Rule-based pipeline
25//! const pred = classify_product(JSON.stringify({
26//!   identifier: { cas: "1310-73-2" },
27//!   physical_form: "Solid"
28//! }));
29//! // → { hs_code: "281511", confidence: 1.0, ... }
30//!
31//! // 3. Interactive session
32//! const session = new WasmSession();
33//! const q1 = session.start();
34//! // → { step: "Identifier", prompt: "Please enter a CAS number...", type: "text" }
35//! const r1 = session.answer(JSON.stringify({ Text: "1310-73-2" }));
36//! // → { type: "NeedMoreInfo", next_question: { step: "IsMixture", ... } }
37//! ```
38
39use wasm_bindgen::prelude::*;
40
41use hs_predict::pipeline::HsPipeline;
42use hs_predict::session::{Answer, ClassificationSession};
43use hs_predict::types::ProductDescription;
44
45// ─────────────────────────────────────────────────────────────────────────────
46// Internal helpers
47// ─────────────────────────────────────────────────────────────────────────────
48
49fn to_js<T: serde::Serialize>(val: &T) -> JsValue {
50    serde_wasm_bindgen::to_value(val).unwrap_or(JsValue::NULL)
51}
52
53fn err_js(e: impl std::fmt::Display) -> JsValue {
54    JsValue::from_str(&e.to_string())
55}
56
57// ─────────────────────────────────────────────────────────────────────────────
58// API 1 — SMILES classification
59// ─────────────────────────────────────────────────────────────────────────────
60
61/// Analyse a SMILES string and return functional-group + HS heading hint.
62///
63/// Returns a `SmilesClassification` JS object, or `null` if the SMILES string
64/// is empty or cannot be parsed.
65///
66/// # JS return shape
67/// ```json
68/// {
69///   "organic_class": "organic",
70///   "functional_groups": ["CarboxylicAcid"],
71///   "heading_hint": {
72///     "chapter": 29,
73///     "heading": 2915,
74///     "rationale": "Carboxylic acid → heading 29.15",
75///     "confidence": 0.60
76///   }
77/// }
78/// ```
79#[wasm_bindgen]
80pub fn classify_smiles(smiles: &str) -> JsValue {
81    match hs_predict::smiles::classify_smiles(smiles) {
82        Some(r) => to_js(&r),
83        None => JsValue::NULL,
84    }
85}
86
87// ─────────────────────────────────────────────────────────────────────────────
88// API 2 — Full rule-based pipeline
89// ─────────────────────────────────────────────────────────────────────────────
90
91/// Classify a chemical product using the full rule-based pipeline (Priorities 1–3).
92///
93/// `product_json` must be a JSON-serialised `ProductDescription`:
94/// ```json
95/// {
96///   "identifier": { "cas": "1310-73-2" },
97///   "physical_form": "Solid",
98///   "purity_pct": null,
99///   "purity_type": null,
100///   "mixture_components": null,
101///   "intended_use": null,
102///   "additional_context": null
103/// }
104/// ```
105///
106/// Returns a `HsPrediction` JS object on success, or throws a JS error string.
107#[wasm_bindgen]
108pub fn classify_product(product_json: &str) -> Result<JsValue, JsValue> {
109    let product: ProductDescription =
110        serde_json::from_str(product_json).map_err(|e| err_js(e))?;
111    let pipeline = HsPipeline::new();
112    pipeline
113        .classify(&product)
114        .map(|pred| to_js(&pred))
115        .map_err(|e| err_js(e))
116}
117
118// ─────────────────────────────────────────────────────────────────────────────
119// API 3 — Interactive Akinator session
120// ─────────────────────────────────────────────────────────────────────────────
121
122/// Interactive Akinator-style HS classification session.
123///
124/// # Typical flow
125/// ```js
126/// const session = new WasmSession();       // English
127/// // const session = WasmSession.new_ja(); // Japanese
128///
129/// const q1 = session.start();
130/// // q1 = { step: "Identifier", prompt: "...", type: "text", choices: null, ... }
131///
132/// const r1 = session.answer(JSON.stringify({ Text: "64-19-7" }));
133/// // r1 = { type: "NeedMoreInfo", next_question: { step: "IsMixture", ... } }
134///
135/// // … repeat until r.type === "Ready" …
136///
137/// const prediction = session.classify();
138/// // { hs_code: "291511", confidence: 0.95, ... }
139/// ```
140#[wasm_bindgen]
141pub struct WasmSession {
142    session: ClassificationSession,
143    pipeline: HsPipeline,
144}
145
146#[wasm_bindgen]
147impl WasmSession {
148    /// Create a new session with English prompts.
149    #[wasm_bindgen(constructor)]
150    pub fn new() -> WasmSession {
151        WasmSession {
152            session: ClassificationSession::new(),
153            pipeline: HsPipeline::new(),
154        }
155    }
156
157    /// Create a new session with Japanese prompts.
158    pub fn new_ja() -> WasmSession {
159        WasmSession {
160            session: ClassificationSession::new_ja(),
161            pipeline: HsPipeline::new(),
162        }
163    }
164
165    /// Start the session and return the first `Question` as a JS object.
166    ///
167    /// # JS return shape
168    /// ```json
169    /// {
170    ///   "step": "Identifier",
171    ///   "prompt": "Please enter a CAS number, IUPAC name, SMILES, or InChIKey",
172    ///   "type": "text",
173    ///   "choices": null,
174    ///   "number_range": null
175    /// }
176    /// ```
177    pub fn start(&mut self) -> JsValue {
178        let q = self.session.start();
179        to_js(&q)
180    }
181
182    /// Submit an answer and advance the session.
183    ///
184    /// `answer_json` must be the JSON representation of an `Answer` variant:
185    /// - `JSON.stringify({ Text: "1310-73-2" })` — free-text answer
186    /// - `JSON.stringify({ YesNo: true })` — yes/no answer
187    /// - `JSON.stringify({ Choice: 0 })` — single-choice index
188    /// - `JSON.stringify({ MultiChoice: [0, 2] })` — multi-choice indices
189    /// - `JSON.stringify({ Number: 30.5 })` — numeric answer
190    ///
191    /// # JS return shape
192    /// ```json
193    /// { "type": "NeedMoreInfo", "next_question": { "step": "...", ... } }
194    /// { "type": "Ready" }
195    /// { "type": "RequiresLlm" }
196    /// ```
197    pub fn answer(&mut self, answer_json: &str) -> Result<JsValue, JsValue> {
198        let answer: Answer = serde_json::from_str(answer_json).map_err(|e| err_js(e))?;
199        self.session
200            .answer(answer)
201            .map(|result| to_js(&result))
202            .map_err(|e| err_js(e))
203    }
204
205    /// Classify the product once the session is `Ready`.
206    ///
207    /// Call this after `answer()` returns `{ type: "Ready" }`.
208    ///
209    /// Returns a `HsPrediction` JS object on success, or throws a JS error string.
210    pub fn classify(&self) -> Result<JsValue, JsValue> {
211        let product = self.session.to_product_description();
212        self.pipeline
213            .classify(&product)
214            .map(|pred| to_js(&pred))
215            .map_err(|e| err_js(e))
216    }
217}
218
219// ─────────────────────────────────────────────────────────────────────────────
220// WASM tests (run with: wasm-pack test --node)
221// ─────────────────────────────────────────────────────────────────────────────
222//
223// These tests run in a JS/WASM environment only.
224// Core classification logic is already covered by hs-predict's own test suite.
225
226#[cfg(all(test, target_arch = "wasm32"))]
227mod tests {
228    use super::*;
229    use wasm_bindgen_test::*;
230
231    wasm_bindgen_test_configure!(run_in_browser);
232
233    #[wasm_bindgen_test]
234    fn classify_smiles_acetic_acid_returns_non_null() {
235        let r = classify_smiles("CC(O)=O");
236        assert!(!r.is_null(), "acetic acid SMILES should return a classification");
237    }
238
239    #[wasm_bindgen_test]
240    fn classify_smiles_empty_returns_null() {
241        let r = classify_smiles("");
242        assert!(r.is_null(), "empty SMILES should return null");
243    }
244
245    #[wasm_bindgen_test]
246    fn classify_product_naoh_solid_ok() {
247        let json = r#"{
248            "identifier": { "cas": "1310-73-2", "smiles": null, "iupac_name": null,
249                            "inchi": null, "inchi_key": null, "cid": null },
250            "physical_form": "Solid",
251            "purity_pct": null,
252            "purity_type": null,
253            "mixture_components": null,
254            "intended_use": null,
255            "additional_context": null
256        }"#;
257        let r = classify_product(json);
258        assert!(r.is_ok(), "NaOH solid should classify successfully");
259    }
260
261    #[wasm_bindgen_test]
262    fn classify_product_invalid_json_returns_err() {
263        let r = classify_product("not json at all");
264        assert!(r.is_err());
265    }
266
267    #[wasm_bindgen_test]
268    fn wasm_session_start_returns_question() {
269        let mut s = WasmSession::new();
270        let q = s.start();
271        assert!(!q.is_null(), "start() should return a question");
272    }
273
274    #[wasm_bindgen_test]
275    fn wasm_session_answer_valid_cas() {
276        let mut s = WasmSession::new();
277        s.start();
278        let r = s.answer(r#"{"Text":"1310-73-2"}"#);
279        assert!(r.is_ok(), "valid CAS should be accepted");
280    }
281
282    #[wasm_bindgen_test]
283    fn wasm_session_answer_invalid_json_returns_err() {
284        let mut s = WasmSession::new();
285        s.start();
286        let r = s.answer("not json");
287        assert!(r.is_err());
288    }
289}