Skip to main content

hs_predict/
pipeline.rs

1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. SMILES-based rule engine (v0.3)
7//! 4. LLM fallback via [`LlmClassifier`] trait hook (v0.4, `llm` feature)
8
9use std::collections::HashMap;
10#[cfg(feature = "llm")]
11use std::sync::Arc;
12
13use crate::error::{HsPredictError, Result};
14use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
15use crate::rules::matcher::find_best_rule;
16use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
17
18/// Configuration for the classification pipeline.
19#[derive(Debug, Clone)]
20pub struct PipelineConfig {
21    /// Confidence threshold above which a result is returned directly
22    /// without asking for LLM confirmation.
23    pub confidence_threshold_direct: f32,
24
25    /// Confidence threshold below which LLM is required.
26    /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
27    /// the result is returned with `RecommendedAction::VerifyWithLlm`.
28    pub confidence_threshold_llm_required: f32,
29}
30
31impl Default for PipelineConfig {
32    fn default() -> Self {
33        Self {
34            confidence_threshold_direct: 0.85,
35            confidence_threshold_llm_required: 0.50,
36        }
37    }
38}
39
40/// Main HS code classification pipeline.
41///
42/// # Example — direct (sync)
43/// ```rust,no_run
44/// use hs_predict::pipeline::HsPipeline;
45/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
46///
47/// let pipeline = HsPipeline::new();
48///
49/// let product = ProductDescription {
50///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
51///     physical_form: Some(PhysicalForm::Solid),
52///     purity_pct: None,
53///     purity_type: None,
54///     mixture_components: None,
55///     intended_use: None,
56///     additional_context: None,
57/// };
58///
59/// let prediction = pipeline.classify(&product).unwrap();
60/// assert_eq!(&prediction.hs_code, "281511");
61/// ```
62///
63/// # Example — with PubChem enrichment (async, `pubchem` feature)
64/// ```rust,no_run
65/// # #[cfg(feature = "pubchem")]
66/// # async fn example() -> hs_predict::Result<()> {
67/// use hs_predict::pipeline::HsPipeline;
68/// use hs_predict::pubchem::PubChemClient;
69/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
70///
71/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
72///
73/// let mut product = ProductDescription {
74///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
75///     physical_form: Some(PhysicalForm::Solid),
76///     purity_pct: None,
77///     purity_type: None,
78///     mixture_components: None,
79///     intended_use: None,
80///     additional_context: None,
81/// };
82///
83/// pipeline.enrich(&mut product).await?;   // fills SMILES, InChI, IUPAC name …
84/// let prediction = pipeline.classify(&product)?;
85/// println!("{}", prediction.display());   // "28.15.11"
86/// # Ok(())
87/// # }
88/// ```
89///
90/// # Example — with LLM fallback (async, `llm` feature)
91/// ```rust,no_run
92/// # #[cfg(feature = "llm")]
93/// # async fn example() -> hs_predict::Result<()> {
94/// use hs_predict::pipeline::HsPipeline;
95/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse};
96/// use futures::future::BoxFuture;
97///
98/// struct MyClient;
99/// impl LlmClassifier for MyClient {
100///     fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
101///         Box::pin(async move { todo!() })
102///     }
103/// }
104///
105/// let pipeline = HsPipeline::new().with_llm(MyClient);
106/// use hs_predict::types::{ProductDescription, SubstanceIdentifier};
107/// let product = ProductDescription {
108///     identifier: SubstanceIdentifier::from_cas("12-34-5"),
109///     physical_form: None, purity_pct: None, purity_type: None,
110///     mixture_components: None, intended_use: None, additional_context: None,
111/// };
112/// let prediction = pipeline.classify_with_llm(&product).await?;
113/// println!("{}", prediction.display());
114/// # Ok(())
115/// # }
116/// ```
117#[derive(Default)]
118pub struct HsPipeline {
119    /// User-supplied CAS → HS code overrides. Highest priority.
120    user_mappings: HashMap<String, String>,
121
122    config: PipelineConfig,
123
124    /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
125    #[cfg(feature = "pubchem")]
126    pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
127
128    /// LLM classifier hook (v0.4, `llm` feature).
129    #[cfg(feature = "llm")]
130    llm: Option<Arc<dyn crate::llm::LlmClassifier>>,
131}
132
133impl HsPipeline {
134    /// Create a pipeline with default configuration.
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    /// Add a user-provided CAS → HS code mapping.
140    ///
141    /// These mappings override the embedded rule table with `confidence = 1.0`.
142    pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
143        self.user_mappings.insert(cas.into(), hs_code.into());
144        self
145    }
146
147    /// Override the default pipeline configuration.
148    pub fn with_config(mut self, config: PipelineConfig) -> Self {
149        self.config = config;
150        self
151    }
152
153    /// Attach an [`LlmClassifier`](crate::llm::LlmClassifier) implementation to
154    /// enable the LLM fallback (Priority 4).
155    ///
156    /// The LLM is called by [`classify_with_llm`](Self::classify_with_llm) when
157    /// the rule-based pipeline returns a result with
158    /// `recommended_action != Accept`, or returns
159    /// [`LowConfidenceNoLlm`](crate::HsPredictError::LowConfidenceNoLlm).
160    ///
161    /// Requires the **`llm`** Cargo feature.
162    #[cfg(feature = "llm")]
163    pub fn with_llm(mut self, client: impl crate::llm::LlmClassifier + 'static) -> Self {
164        self.llm = Some(Arc::new(client));
165        self
166    }
167
168    /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
169    /// automatic identifier enrichment before classification.
170    ///
171    /// Requires the **`pubchem`** Cargo feature.
172    #[cfg(feature = "pubchem")]
173    pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
174        self.pubchem = Some(std::sync::Arc::new(client));
175        self
176    }
177
178    /// Enrich a [`ProductDescription`] with PubChem data.
179    ///
180    /// Fills in any missing fields of the main identifier and each mixture
181    /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
182    ///
183    /// This is a **best-effort** operation:
184    /// - "Not found" and "no usable identifier" results are silently ignored.
185    /// - Network / parse errors **are** propagated.
186    /// - If no PubChem client is configured, returns `Ok(())` immediately.
187    ///
188    /// Requires the **`pubchem`** Cargo feature.
189    #[cfg(feature = "pubchem")]
190    pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
191        let Some(ref client) = self.pubchem else {
192            return Ok(());
193        };
194
195        client.enrich(&mut product.identifier).await?;
196
197        if let Some(ref mut comps) = product.mixture_components {
198            for comp in comps.iter_mut() {
199                client.enrich(&mut comp.substance).await?;
200            }
201        }
202
203        Ok(())
204    }
205
206    /// Classify a product and return an HS code prediction.
207    ///
208    /// Priority order:
209    /// 1. User-provided mapping
210    /// 2. Embedded static rule table
211    /// 3. (v0.3) SMILES rule engine
212    /// 4. (v0.4) LLM fallback
213    pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
214        // ── Priority 1: User-provided mappings ────────────────────────
215        if let Some(ref cas) = product.identifier.cas {
216            if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
217                let jp = find_jp_rule(hs_code);
218                return Ok(HsPrediction {
219                    hs_code: hs_code.clone(),
220                    heading_description: String::new(),
221                    confidence: 1.0,
222                    source: PredictionSource::UserMapping,
223                    notes: vec!["From user-provided mapping".to_string()],
224                    alternatives: vec![],
225                    recommended_action: RecommendedAction::Accept,
226                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
227                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
228                });
229            }
230        }
231
232        // ── Priority 2: Embedded static rule table ────────────────────
233        if let Some(ref cas) = product.identifier.cas {
234            if let Some(rule) = find_best_rule(
235                cas,
236                product.physical_form.as_ref(),
237                product.purity_pct,
238            ) {
239                let action = self.recommended_action(rule.confidence);
240                let jp = find_jp_rule(rule.hs_code);
241                return Ok(HsPrediction {
242                    hs_code: rule.hs_code.to_string(),
243                    heading_description: rule.heading_description.to_string(),
244                    confidence: rule.confidence,
245                    source: PredictionSource::EmbeddedRule {
246                        rule_id: format!("{}:{}", rule.cas, rule.hs_code),
247                    },
248                    notes: self.build_notes(product),
249                    alternatives: vec![],
250                    recommended_action: action,
251                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
252                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
253                });
254            }
255        }
256
257        // ── Priority 3: SMILES-based rule engine ─────────────────────────
258        if let Some(ref smiles) = product.identifier.smiles {
259            if let Some(classification) = crate::smiles::classify_smiles(smiles) {
260                let hint = &classification.heading_hint;
261                // Only emit a result when we have at least a 4-digit heading
262                // and confidence meets the LLM-required threshold.
263                if let Some(heading) = hint.heading {
264                    if hint.confidence >= self.config.confidence_threshold_llm_required {
265                        // Pad to 6 digits with "00" sub-heading (best guess)
266                        let hs_code = format!("{:04}00", heading);
267                        let jp = find_jp_rule(&hs_code);
268                        let action = self.recommended_action(hint.confidence);
269
270                        let mut notes = self.build_notes(product);
271                        notes.push(
272                            "Heading is derived from SMILES functional-group analysis. \
273                             Sub-heading (last two digits) is a placeholder — \
274                             verify the exact 6-digit code with the product specification."
275                                .to_string(),
276                        );
277
278                        let matched_rules: Vec<String> = classification
279                            .functional_groups
280                            .iter()
281                            .map(|g| g.label().to_string())
282                            .collect();
283
284                        return Ok(HsPrediction {
285                            hs_code,
286                            heading_description: hint.rationale.to_string(),
287                            confidence: hint.confidence,
288                            source: PredictionSource::RuleEngine { matched_rules },
289                            notes,
290                            alternatives: vec![],
291                            recommended_action: action,
292                            jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
293                            jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
294                        });
295                    }
296                }
297            }
298        }
299
300        // ── Priority 4: LLM fallback ─────────────────────────────────
301        // (async path — use classify_with_llm for LLM support)
302        Err(HsPredictError::LowConfidenceNoLlm {
303            confidence: 0.0,
304            threshold: self.config.confidence_threshold_llm_required,
305        })
306    }
307
308    /// Classify a product, falling back to the configured LLM when the
309    /// rule-based pipeline returns a low-confidence or uncertain result.
310    ///
311    /// # Priority order (same as [`classify`](Self::classify) + LLM)
312    ///
313    /// 1. User-provided mapping → `Accept` → return immediately.
314    /// 2. Embedded static rule table → `Accept` → return immediately.
315    /// 3. SMILES rule engine → `Accept` → return immediately.
316    /// 4. Any result with `recommended_action != Accept`, or
317    ///    `LowConfidenceNoLlm` → forward to LLM.
318    ///
319    /// If no LLM client has been configured via [`with_llm`](Self::with_llm),
320    /// returns [`HsPredictError::LlmNotConfigured`].
321    ///
322    /// # Validation
323    /// The LLM's `hs_code` must be exactly 6 ASCII digits; otherwise
324    /// [`HsPredictError::ValidationFailed`] is returned.
325    ///
326    /// # Chapter consistency
327    /// If the LLM chapter differs from the SMILES engine's chapter hint, a
328    /// warning note is appended — this is **not** a hard error.
329    ///
330    /// Requires the **`llm`** Cargo feature.
331    #[cfg(feature = "llm")]
332    pub async fn classify_with_llm(
333        &self,
334        product: &ProductDescription,
335    ) -> Result<HsPrediction> {
336        use crate::llm::PromptBuilder;
337        use crate::types::AlternativePrediction;
338
339        // First try the synchronous rule-based pipeline.
340        let needs_llm = match self.classify(product) {
341            Ok(pred) if pred.recommended_action == RecommendedAction::Accept => {
342                return Ok(pred);
343            }
344            Ok(_pred) => true,  // low-confidence result → try LLM
345            Err(HsPredictError::LowConfidenceNoLlm { .. }) => true,
346            Err(e) => return Err(e),
347        };
348
349        debug_assert!(needs_llm);
350
351        // Require a configured LLM client.
352        let llm = self
353            .llm
354            .as_ref()
355            .ok_or(HsPredictError::LlmNotConfigured)?;
356
357        // Build prompt and call the LLM.
358        let prompt = PromptBuilder::new().build(product);
359        let resp = llm.classify(&prompt).await?;
360
361        // Validate: must be exactly 6 ASCII digits.
362        if resp.hs_code.len() != 6 || !resp.hs_code.chars().all(|c| c.is_ascii_digit()) {
363            return Err(HsPredictError::ValidationFailed { code: resp.hs_code });
364        }
365
366        // Chapter consistency check (warning only).
367        let mut notes = self.build_notes(product);
368        if let Some(ref analysis) = prompt.smiles_analysis {
369            let llm_chapter = &resp.hs_code[..2];
370            let expected_chapter = format!("{:02}", analysis.heading_hint.chapter);
371            if llm_chapter != expected_chapter {
372                notes.push(format!(
373                    "Chapter mismatch: LLM returned Chapter {} but SMILES engine \
374                     suggested Chapter {}. Verify with Chapter Notes.",
375                    llm_chapter, expected_chapter
376                ));
377            }
378        }
379
380        notes.push(format!("LLM rationale: {}", resp.rationale));
381
382        let jp = find_jp_rule(&resp.hs_code);
383        let action = self.recommended_action(resp.confidence);
384
385        let alternatives = resp
386            .alternatives
387            .into_iter()
388            .map(|a| AlternativePrediction {
389                hs_code: a.hs_code,
390                confidence: a.confidence,
391                reason: a.reason,
392            })
393            .collect();
394
395        Ok(HsPrediction {
396            hs_code: resp.hs_code,
397            heading_description: String::new(),
398            confidence: resp.confidence,
399            source: PredictionSource::LlmApi { model: String::new() },
400            notes,
401            alternatives,
402            recommended_action: action,
403            jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
404            jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
405        })
406    }
407
408    // ─── Private helpers ──────────────────────────────────────────────
409
410    fn recommended_action(&self, confidence: f32) -> RecommendedAction {
411        if confidence >= self.config.confidence_threshold_direct {
412            RecommendedAction::Accept
413        } else if confidence >= self.config.confidence_threshold_llm_required {
414            RecommendedAction::VerifyWithLlm
415        } else {
416            RecommendedAction::ExpertReview
417        }
418    }
419
420    /// Build supplementary notes about shape / purity caveats.
421    fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
422        let mut notes = Vec::new();
423
424        match &product.physical_form {
425            None | Some(PhysicalForm::Unknown) => {
426                notes.push(
427                    "Physical form not specified — the HS subheading may differ \
428                     (e.g. solid vs. solution).".to_string(),
429                );
430            }
431            Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
432                notes.push(
433                    "Solution concentration not specified — subheading may differ \
434                     (e.g. fuming vs. standard grade).".to_string(),
435                );
436            }
437            _ => {}
438        }
439
440        if product.purity_pct.is_none() {
441            notes.push(
442                "Purity not specified — some headings require a minimum purity threshold."
443                    .to_string(),
444            );
445        }
446
447        notes
448    }
449}
450
451// ─────────────────────────────────────────────────────────────────────────────
452// Pipeline integration tests
453// ─────────────────────────────────────────────────────────────────────────────
454
455#[cfg(all(test, feature = "mock"))]
456mod tests {
457    use super::*;
458    use crate::llm::MockLlmClassifier;
459    use crate::types::{SubstanceIdentifier};
460
461    /// A product with no static rule and a SMILES → triggers LLM path.
462    fn unknown_organic() -> ProductDescription {
463        ProductDescription {
464            identifier: SubstanceIdentifier {
465                cas: None,
466                smiles: Some("CC(O)=O".to_string()), // acetic acid SMILES, unknown CAS
467                iupac_name: None,
468                inchi: None,
469                inchi_key: None,
470                cid: None,
471            },
472            physical_form: None,
473            purity_pct: None,
474            purity_type: None,
475            mixture_components: None,
476            intended_use: None,
477            additional_context: None,
478        }
479    }
480
481    #[tokio::test]
482    async fn classify_with_llm_mock_returns_6_digit_code() {
483        let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
484        let product = unknown_organic();
485        let pred = pipeline.classify_with_llm(&product).await.unwrap();
486        assert_eq!(pred.hs_code.len(), 6);
487        assert!(pred.hs_code.chars().all(|c| c.is_ascii_digit()));
488    }
489
490    #[tokio::test]
491    async fn classify_with_llm_mock_chapter_29_for_smiles_acid() {
492        let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
493        let product = unknown_organic();
494        let pred = pipeline.classify_with_llm(&product).await.unwrap();
495        assert!(
496            pred.hs_code.starts_with("29"),
497            "acetic acid SMILES should yield Chapter 29, got {}",
498            pred.hs_code
499        );
500    }
501
502    #[tokio::test]
503    async fn classify_with_llm_no_client_returns_error() {
504        let pipeline = HsPipeline::new(); // no LLM attached
505        let product = unknown_organic();
506        let err = pipeline.classify_with_llm(&product).await.unwrap_err();
507        assert!(
508            matches!(err, HsPredictError::LlmNotConfigured),
509            "expected LlmNotConfigured, got {:?}",
510            err
511        );
512    }
513
514    #[tokio::test]
515    async fn classify_with_llm_skips_llm_for_high_confidence_rule() {
516        // NaOH solid → static rule, confidence = 1.0 → should NOT call LLM
517        let pipeline = HsPipeline::new()
518            .with_llm(MockLlmClassifier::with_default("999999", 0.1));
519        let product = ProductDescription {
520            identifier: SubstanceIdentifier::from_cas("1310-73-2"),
521            physical_form: Some(crate::types::PhysicalForm::Solid),
522            purity_pct: None,
523            purity_type: None,
524            mixture_components: None,
525            intended_use: None,
526            additional_context: None,
527        };
528        let pred = pipeline.classify_with_llm(&product).await.unwrap();
529        // Should be the static rule result, not the mock's "999999"
530        assert_eq!(pred.hs_code, "281511", "static rule should win over LLM");
531    }
532
533    #[tokio::test]
534    async fn classify_with_llm_invalid_code_returns_validation_error() {
535        // Mock returning an invalid code
536        struct BadMock;
537        impl crate::llm::LlmClassifier for BadMock {
538            fn classify<'a>(
539                &'a self,
540                _prompt: &'a crate::llm::LlmPrompt,
541            ) -> futures::future::BoxFuture<'a, crate::Result<crate::llm::LlmResponse>> {
542                Box::pin(async {
543                    Ok(crate::llm::LlmResponse {
544                        hs_code: "BAD!!".to_string(),
545                        confidence: 0.5,
546                        rationale: "bad".to_string(),
547                        alternatives: vec![],
548                    })
549                })
550            }
551        }
552        let pipeline = HsPipeline::new().with_llm(BadMock);
553        let product = unknown_organic();
554        let err = pipeline.classify_with_llm(&product).await.unwrap_err();
555        assert!(
556            matches!(err, HsPredictError::ValidationFailed { .. }),
557            "expected ValidationFailed, got {:?}",
558            err
559        );
560    }
561}