Skip to main content

hs_predict/
pipeline.rs

1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. SMILES-based rule engine (v0.3)
7//! 4. LLM fallback via [`LlmClassifier`] trait hook (v0.4, `llm` feature)
8
9use std::collections::HashMap;
10#[cfg(feature = "llm")]
11use std::sync::Arc;
12
13use crate::error::{HsPredictError, Result};
14use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
15use crate::rules::matcher::find_best_rule;
16use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
17
18/// Configuration for the classification pipeline.
19#[derive(Debug, Clone)]
20pub struct PipelineConfig {
21    /// Confidence threshold above which a result is returned directly
22    /// without asking for LLM confirmation.
23    pub confidence_threshold_direct: f32,
24
25    /// Confidence threshold below which LLM is required.
26    /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
27    /// the result is returned with `RecommendedAction::VerifyWithLlm`.
28    pub confidence_threshold_llm_required: f32,
29}
30
31impl Default for PipelineConfig {
32    fn default() -> Self {
33        Self {
34            confidence_threshold_direct: 0.85,
35            confidence_threshold_llm_required: 0.50,
36        }
37    }
38}
39
40/// Main HS code classification pipeline.
41///
42/// # Example — direct (sync)
43/// ```rust,no_run
44/// use hs_predict::pipeline::HsPipeline;
45/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
46///
47/// let pipeline = HsPipeline::new();
48///
49/// let product = ProductDescription {
50///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
51///     physical_form: Some(PhysicalForm::Solid),
52///     purity_pct: None,
53///     purity_type: None,
54///     mixture_components: None,
55///     intended_use: None,
56///     additional_context: None,
57/// };
58///
59/// let prediction = pipeline.classify(&product).unwrap();
60/// assert_eq!(&prediction.hs_code, "281511");
61/// ```
62///
63/// # Example — with PubChem enrichment (async, `pubchem` feature)
64/// ```rust,no_run
65/// # #[cfg(feature = "pubchem")]
66/// # async fn example() -> hs_predict::Result<()> {
67/// use hs_predict::pipeline::HsPipeline;
68/// use hs_predict::pubchem::PubChemClient;
69/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
70///
71/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
72///
73/// let mut product = ProductDescription {
74///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
75///     physical_form: Some(PhysicalForm::Solid),
76///     purity_pct: None,
77///     purity_type: None,
78///     mixture_components: None,
79///     intended_use: None,
80///     additional_context: None,
81/// };
82///
83/// pipeline.enrich(&mut product).await?;   // fills SMILES, InChI, IUPAC name …
84/// let prediction = pipeline.classify(&product)?;
85/// println!("{}", prediction.display());   // "28.15.11"
86/// # Ok(())
87/// # }
88/// ```
89///
90/// # Example — with LLM fallback (async, `llm` feature)
91/// ```rust,no_run
92/// # #[cfg(feature = "llm")]
93/// # async fn example() -> hs_predict::Result<()> {
94/// use hs_predict::pipeline::HsPipeline;
95/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse};
96/// use futures::future::BoxFuture;
97///
98/// struct MyClient;
99/// impl LlmClassifier for MyClient {
100///     fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
101///         Box::pin(async move { todo!() })
102///     }
103/// }
104///
105/// let pipeline = HsPipeline::new().with_llm(MyClient);
106/// use hs_predict::types::{ProductDescription, SubstanceIdentifier};
107/// let product = ProductDescription {
108///     identifier: SubstanceIdentifier::from_cas("12-34-5"),
109///     physical_form: None, purity_pct: None, purity_type: None,
110///     mixture_components: None, intended_use: None, additional_context: None,
111/// };
112/// let prediction = pipeline.classify_with_llm(&product).await?;
113/// println!("{}", prediction.display());
114/// # Ok(())
115/// # }
116/// ```
117#[derive(Default)]
118pub struct HsPipeline {
119    /// User-supplied CAS → HS code overrides. Highest priority.
120    user_mappings: HashMap<String, String>,
121
122    config: PipelineConfig,
123
124    /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
125    #[cfg(feature = "pubchem")]
126    pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
127
128    /// LLM classifier hook (v0.4, `llm` feature).
129    #[cfg(feature = "llm")]
130    llm: Option<Arc<dyn crate::llm::LlmClassifier>>,
131}
132
133impl std::fmt::Debug for HsPipeline {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        let mut s = f.debug_struct("HsPipeline");
136        s.field("user_mappings", &self.user_mappings);
137        s.field("config", &self.config);
138        #[cfg(feature = "pubchem")]
139        s.field("pubchem", &self.pubchem.as_ref().map(|_| "<PubChemClient>"));
140        #[cfg(feature = "llm")]
141        s.field("llm", &self.llm.as_ref().map(|_| "<dyn LlmClassifier>"));
142        s.finish()
143    }
144}
145
146impl HsPipeline {
147    /// Create a pipeline with default configuration.
148    pub fn new() -> Self {
149        Self::default()
150    }
151
152    /// Add a user-provided CAS → HS code mapping.
153    ///
154    /// These mappings override the embedded rule table with `confidence = 1.0`.
155    pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
156        self.user_mappings.insert(cas.into(), hs_code.into());
157        self
158    }
159
160    /// Override the default pipeline configuration.
161    pub fn with_config(mut self, config: PipelineConfig) -> Self {
162        self.config = config;
163        self
164    }
165
166    /// Attach an [`LlmClassifier`](crate::llm::LlmClassifier) implementation to
167    /// enable the LLM fallback (Priority 4).
168    ///
169    /// The LLM is called by [`classify_with_llm`](Self::classify_with_llm) when
170    /// the rule-based pipeline returns a result with
171    /// `recommended_action != Accept`, or returns
172    /// [`LowConfidenceNoLlm`](crate::HsPredictError::LowConfidenceNoLlm).
173    ///
174    /// Requires the **`llm`** Cargo feature.
175    #[cfg(feature = "llm")]
176    pub fn with_llm(mut self, client: impl crate::llm::LlmClassifier + 'static) -> Self {
177        self.llm = Some(Arc::new(client));
178        self
179    }
180
181    /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
182    /// automatic identifier enrichment before classification.
183    ///
184    /// Requires the **`pubchem`** Cargo feature.
185    #[cfg(feature = "pubchem")]
186    pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
187        self.pubchem = Some(std::sync::Arc::new(client));
188        self
189    }
190
191    /// Enrich a [`ProductDescription`] with PubChem data.
192    ///
193    /// Fills in any missing fields of the main identifier and each mixture
194    /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
195    ///
196    /// This is a **best-effort** operation:
197    /// - "Not found" and "no usable identifier" results are silently ignored.
198    /// - Network / parse errors **are** propagated.
199    /// - If no PubChem client is configured, returns `Ok(())` immediately.
200    ///
201    /// Requires the **`pubchem`** Cargo feature.
202    #[cfg(feature = "pubchem")]
203    pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
204        let Some(ref client) = self.pubchem else {
205            return Ok(());
206        };
207
208        client.enrich(&mut product.identifier).await?;
209
210        if let Some(ref mut comps) = product.mixture_components {
211            for comp in comps.iter_mut() {
212                client.enrich(&mut comp.substance).await?;
213            }
214        }
215
216        Ok(())
217    }
218
219    /// Classify a product and return an HS code prediction.
220    ///
221    /// Priority order:
222    /// 1. User-provided mapping
223    /// 2. Embedded static rule table
224    /// 3. (v0.3) SMILES rule engine
225    /// 4. (v0.4) LLM fallback
226    pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
227        // ── Priority 1: User-provided mappings ────────────────────────
228        if let Some(ref cas) = product.identifier.cas {
229            if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
230                let jp = find_jp_rule(hs_code);
231                return Ok(HsPrediction {
232                    hs_code: hs_code.clone(),
233                    heading_description: String::new(),
234                    confidence: 1.0,
235                    source: PredictionSource::UserMapping,
236                    notes: vec!["From user-provided mapping".to_string()],
237                    alternatives: vec![],
238                    recommended_action: RecommendedAction::Accept,
239                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
240                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
241                });
242            }
243        }
244
245        // ── Priority 2: Embedded static rule table ────────────────────
246        if let Some(ref cas) = product.identifier.cas {
247            if let Some(rule) = find_best_rule(
248                cas,
249                product.physical_form.as_ref(),
250                product.purity_pct,
251            ) {
252                let action = self.recommended_action(rule.confidence);
253                let jp = find_jp_rule(rule.hs_code);
254                return Ok(HsPrediction {
255                    hs_code: rule.hs_code.to_string(),
256                    heading_description: rule.heading_description.to_string(),
257                    confidence: rule.confidence,
258                    source: PredictionSource::EmbeddedRule {
259                        rule_id: format!("{}:{}", rule.cas, rule.hs_code),
260                    },
261                    notes: self.build_notes(product),
262                    alternatives: vec![],
263                    recommended_action: action,
264                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
265                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
266                });
267            }
268        }
269
270        // ── Priority 3: SMILES-based rule engine ─────────────────────────
271        if let Some(ref smiles) = product.identifier.smiles {
272            if let Some(classification) = crate::smiles::classify_smiles(smiles) {
273                let hint = &classification.heading_hint;
274                // Only emit a result when we have at least a 4-digit heading
275                // and confidence meets the LLM-required threshold.
276                if let Some(heading) = hint.heading {
277                    if hint.confidence >= self.config.confidence_threshold_llm_required {
278                        // Pad to 6 digits with "00" sub-heading (best guess)
279                        let hs_code = format!("{:04}00", heading);
280                        let jp = find_jp_rule(&hs_code);
281                        let action = self.recommended_action(hint.confidence);
282
283                        let mut notes = self.build_notes(product);
284                        notes.push(
285                            "Heading is derived from SMILES functional-group analysis. \
286                             Sub-heading (last two digits) is a placeholder — \
287                             verify the exact 6-digit code with the product specification."
288                                .to_string(),
289                        );
290
291                        let matched_rules: Vec<String> = classification
292                            .functional_groups
293                            .iter()
294                            .map(|g| g.label().to_string())
295                            .collect();
296
297                        return Ok(HsPrediction {
298                            hs_code,
299                            heading_description: hint.rationale.to_string(),
300                            confidence: hint.confidence,
301                            source: PredictionSource::RuleEngine { matched_rules },
302                            notes,
303                            alternatives: vec![],
304                            recommended_action: action,
305                            jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
306                            jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
307                        });
308                    }
309                }
310            }
311        }
312
313        // ── Priority 4: LLM fallback ─────────────────────────────────
314        // (async path — use classify_with_llm for LLM support)
315        Err(HsPredictError::LowConfidenceNoLlm {
316            confidence: 0.0,
317            threshold: self.config.confidence_threshold_llm_required,
318        })
319    }
320
321    /// Classify a product, falling back to the configured LLM when the
322    /// rule-based pipeline returns a low-confidence or uncertain result.
323    ///
324    /// # Priority order (same as [`classify`](Self::classify) + LLM)
325    ///
326    /// 1. User-provided mapping → `Accept` → return immediately.
327    /// 2. Embedded static rule table → `Accept` → return immediately.
328    /// 3. SMILES rule engine → `Accept` → return immediately.
329    /// 4. Any result with `recommended_action != Accept`, or
330    ///    `LowConfidenceNoLlm` → forward to LLM.
331    ///
332    /// If no LLM client has been configured via [`with_llm`](Self::with_llm),
333    /// returns [`HsPredictError::LlmNotConfigured`].
334    ///
335    /// # Validation
336    /// The LLM's `hs_code` must be exactly 6 ASCII digits; otherwise
337    /// [`HsPredictError::ValidationFailed`] is returned.
338    ///
339    /// # Chapter consistency
340    /// If the LLM chapter differs from the SMILES engine's chapter hint, a
341    /// warning note is appended — this is **not** a hard error.
342    ///
343    /// Requires the **`llm`** Cargo feature.
344    #[cfg(feature = "llm")]
345    pub async fn classify_with_llm(
346        &self,
347        product: &ProductDescription,
348    ) -> Result<HsPrediction> {
349        use crate::llm::PromptBuilder;
350        use crate::types::AlternativePrediction;
351
352        // First try the synchronous rule-based pipeline.
353        let needs_llm = match self.classify(product) {
354            Ok(pred) if pred.recommended_action == RecommendedAction::Accept => {
355                return Ok(pred);
356            }
357            Ok(_pred) => true,  // low-confidence result → try LLM
358            Err(HsPredictError::LowConfidenceNoLlm { .. }) => true,
359            Err(e) => return Err(e),
360        };
361
362        debug_assert!(needs_llm);
363
364        // Require a configured LLM client.
365        let llm = self
366            .llm
367            .as_ref()
368            .ok_or(HsPredictError::LlmNotConfigured)?;
369
370        // Build prompt and call the LLM.
371        let prompt = PromptBuilder::new().build(product);
372        let resp = llm.classify(&prompt).await?;
373
374        // Validate: must be exactly 6 ASCII digits.
375        if resp.hs_code.len() != 6 || !resp.hs_code.chars().all(|c| c.is_ascii_digit()) {
376            return Err(HsPredictError::ValidationFailed { code: resp.hs_code });
377        }
378
379        // Chapter consistency check (warning only).
380        let mut notes = self.build_notes(product);
381        if let Some(ref analysis) = prompt.smiles_analysis {
382            let llm_chapter = &resp.hs_code[..2];
383            let expected_chapter = format!("{:02}", analysis.heading_hint.chapter);
384            if llm_chapter != expected_chapter {
385                notes.push(format!(
386                    "Chapter mismatch: LLM returned Chapter {} but SMILES engine \
387                     suggested Chapter {}. Verify with Chapter Notes.",
388                    llm_chapter, expected_chapter
389                ));
390            }
391        }
392
393        notes.push(format!("LLM rationale: {}", resp.rationale));
394
395        let jp = find_jp_rule(&resp.hs_code);
396        let action = self.recommended_action(resp.confidence);
397
398        let alternatives = resp
399            .alternatives
400            .into_iter()
401            .map(|a| AlternativePrediction {
402                hs_code: a.hs_code,
403                confidence: a.confidence,
404                reason: a.reason,
405            })
406            .collect();
407
408        Ok(HsPrediction {
409            hs_code: resp.hs_code,
410            heading_description: String::new(),
411            confidence: resp.confidence,
412            source: PredictionSource::LlmApi { model: String::new() },
413            notes,
414            alternatives,
415            recommended_action: action,
416            jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
417            jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
418        })
419    }
420
421    // ─── Private helpers ──────────────────────────────────────────────
422
423    fn recommended_action(&self, confidence: f32) -> RecommendedAction {
424        if confidence >= self.config.confidence_threshold_direct {
425            RecommendedAction::Accept
426        } else if confidence >= self.config.confidence_threshold_llm_required {
427            RecommendedAction::VerifyWithLlm
428        } else {
429            RecommendedAction::ExpertReview
430        }
431    }
432
433    /// Build supplementary notes about shape / purity caveats.
434    fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
435        let mut notes = Vec::new();
436
437        match &product.physical_form {
438            None | Some(PhysicalForm::Unknown) => {
439                notes.push(
440                    "Physical form not specified — the HS subheading may differ \
441                     (e.g. solid vs. solution).".to_string(),
442                );
443            }
444            Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
445                notes.push(
446                    "Solution concentration not specified — subheading may differ \
447                     (e.g. fuming vs. standard grade).".to_string(),
448                );
449            }
450            _ => {}
451        }
452
453        if product.purity_pct.is_none() {
454            notes.push(
455                "Purity not specified — some headings require a minimum purity threshold."
456                    .to_string(),
457            );
458        }
459
460        notes
461    }
462}
463
464// ─────────────────────────────────────────────────────────────────────────────
465// Pipeline integration tests
466// ─────────────────────────────────────────────────────────────────────────────
467
468#[cfg(all(test, feature = "mock"))]
469mod tests {
470    use super::*;
471    use crate::llm::MockLlmClassifier;
472    use crate::types::{SubstanceIdentifier};
473
474    /// A product with no static rule and a SMILES → triggers LLM path.
475    fn unknown_organic() -> ProductDescription {
476        ProductDescription {
477            identifier: SubstanceIdentifier {
478                cas: None,
479                smiles: Some("CC(O)=O".to_string()), // acetic acid SMILES, unknown CAS
480                iupac_name: None,
481                inchi: None,
482                inchi_key: None,
483                cid: None,
484            },
485            physical_form: None,
486            purity_pct: None,
487            purity_type: None,
488            mixture_components: None,
489            intended_use: None,
490            additional_context: None,
491        }
492    }
493
494    #[tokio::test]
495    async fn classify_with_llm_mock_returns_6_digit_code() {
496        let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
497        let product = unknown_organic();
498        let pred = pipeline.classify_with_llm(&product).await.unwrap();
499        assert_eq!(pred.hs_code.len(), 6);
500        assert!(pred.hs_code.chars().all(|c| c.is_ascii_digit()));
501    }
502
503    #[tokio::test]
504    async fn classify_with_llm_mock_chapter_29_for_smiles_acid() {
505        let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
506        let product = unknown_organic();
507        let pred = pipeline.classify_with_llm(&product).await.unwrap();
508        assert!(
509            pred.hs_code.starts_with("29"),
510            "acetic acid SMILES should yield Chapter 29, got {}",
511            pred.hs_code
512        );
513    }
514
515    #[tokio::test]
516    async fn classify_with_llm_no_client_returns_error() {
517        let pipeline = HsPipeline::new(); // no LLM attached
518        let product = unknown_organic();
519        let err = pipeline.classify_with_llm(&product).await.unwrap_err();
520        assert!(
521            matches!(err, HsPredictError::LlmNotConfigured),
522            "expected LlmNotConfigured, got {:?}",
523            err
524        );
525    }
526
527    #[tokio::test]
528    async fn classify_with_llm_skips_llm_for_high_confidence_rule() {
529        // NaOH solid → static rule, confidence = 1.0 → should NOT call LLM
530        let pipeline = HsPipeline::new()
531            .with_llm(MockLlmClassifier::with_default("999999", 0.1));
532        let product = ProductDescription {
533            identifier: SubstanceIdentifier::from_cas("1310-73-2"),
534            physical_form: Some(crate::types::PhysicalForm::Solid),
535            purity_pct: None,
536            purity_type: None,
537            mixture_components: None,
538            intended_use: None,
539            additional_context: None,
540        };
541        let pred = pipeline.classify_with_llm(&product).await.unwrap();
542        // Should be the static rule result, not the mock's "999999"
543        assert_eq!(pred.hs_code, "281511", "static rule should win over LLM");
544    }
545
546    #[tokio::test]
547    async fn classify_with_llm_invalid_code_returns_validation_error() {
548        // Mock returning an invalid code
549        struct BadMock;
550        impl crate::llm::LlmClassifier for BadMock {
551            fn classify<'a>(
552                &'a self,
553                _prompt: &'a crate::llm::LlmPrompt,
554            ) -> futures::future::BoxFuture<'a, crate::Result<crate::llm::LlmResponse>> {
555                Box::pin(async {
556                    Ok(crate::llm::LlmResponse {
557                        hs_code: "BAD!!".to_string(),
558                        confidence: 0.5,
559                        rationale: "bad".to_string(),
560                        alternatives: vec![],
561                    })
562                })
563            }
564        }
565        let pipeline = HsPipeline::new().with_llm(BadMock);
566        let product = unknown_organic();
567        let err = pipeline.classify_with_llm(&product).await.unwrap_err();
568        assert!(
569            matches!(err, HsPredictError::ValidationFailed { .. }),
570            "expected ValidationFailed, got {:?}",
571            err
572        );
573    }
574}