Skip to main content

hs_predict/
pipeline.rs

1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. *(placeholder)* SMILES-based rule engine — v0.3
7//! 4. *(placeholder)* LLM API — v0.4
8
9use std::collections::HashMap;
10
11use crate::error::{HsPredictError, Result};
12use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
13use crate::rules::matcher::find_best_rule;
14use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
15
16/// Configuration for the classification pipeline.
17#[derive(Debug, Clone)]
18pub struct PipelineConfig {
19    /// Confidence threshold above which a result is returned directly
20    /// without asking for LLM confirmation.
21    pub confidence_threshold_direct: f32,
22
23    /// Confidence threshold below which LLM is required.
24    /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
25    /// the result is returned with `RecommendedAction::VerifyWithLlm`.
26    pub confidence_threshold_llm_required: f32,
27}
28
29impl Default for PipelineConfig {
30    fn default() -> Self {
31        Self {
32            confidence_threshold_direct: 0.85,
33            confidence_threshold_llm_required: 0.50,
34        }
35    }
36}
37
38/// Main HS code classification pipeline.
39///
40/// # Example — direct (sync)
41/// ```rust,no_run
42/// use hs_predict::pipeline::HsPipeline;
43/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
44///
45/// let pipeline = HsPipeline::new();
46///
47/// let product = ProductDescription {
48///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
49///     physical_form: Some(PhysicalForm::Solid),
50///     purity_pct: None,
51///     purity_type: None,
52///     mixture_components: None,
53///     intended_use: None,
54///     additional_context: None,
55/// };
56///
57/// let prediction = pipeline.classify(&product).unwrap();
58/// assert_eq!(&prediction.hs_code, "281511");
59/// ```
60///
61/// # Example — with PubChem enrichment (async, `pubchem` feature)
62/// ```rust,no_run
63/// # #[cfg(feature = "pubchem")]
64/// # async fn example() -> hs_predict::Result<()> {
65/// use hs_predict::pipeline::HsPipeline;
66/// use hs_predict::pubchem::PubChemClient;
67/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
68///
69/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
70///
71/// let mut product = ProductDescription {
72///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
73///     physical_form: Some(PhysicalForm::Solid),
74///     purity_pct: None,
75///     purity_type: None,
76///     mixture_components: None,
77///     intended_use: None,
78///     additional_context: None,
79/// };
80///
81/// pipeline.enrich(&mut product).await?;   // fills SMILES, InChI, IUPAC name …
82/// let prediction = pipeline.classify(&product)?;
83/// println!("{}", prediction.display());   // "28.15.11"
84/// # Ok(())
85/// # }
86/// ```
87#[derive(Debug, Default)]
88pub struct HsPipeline {
89    /// User-supplied CAS → HS code overrides. Highest priority.
90    user_mappings: HashMap<String, String>,
91
92    config: PipelineConfig,
93
94    /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
95    #[cfg(feature = "pubchem")]
96    pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
97}
98
99impl HsPipeline {
100    /// Create a pipeline with default configuration.
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Add a user-provided CAS → HS code mapping.
106    ///
107    /// These mappings override the embedded rule table with `confidence = 1.0`.
108    pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
109        self.user_mappings.insert(cas.into(), hs_code.into());
110        self
111    }
112
113    /// Override the default pipeline configuration.
114    pub fn with_config(mut self, config: PipelineConfig) -> Self {
115        self.config = config;
116        self
117    }
118
119    /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
120    /// automatic identifier enrichment before classification.
121    ///
122    /// Requires the **`pubchem`** Cargo feature.
123    #[cfg(feature = "pubchem")]
124    pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
125        self.pubchem = Some(std::sync::Arc::new(client));
126        self
127    }
128
129    /// Enrich a [`ProductDescription`] with PubChem data.
130    ///
131    /// Fills in any missing fields of the main identifier and each mixture
132    /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
133    ///
134    /// This is a **best-effort** operation:
135    /// - "Not found" and "no usable identifier" results are silently ignored.
136    /// - Network / parse errors **are** propagated.
137    /// - If no PubChem client is configured, returns `Ok(())` immediately.
138    ///
139    /// Requires the **`pubchem`** Cargo feature.
140    #[cfg(feature = "pubchem")]
141    pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
142        let Some(ref client) = self.pubchem else {
143            return Ok(());
144        };
145
146        client.enrich(&mut product.identifier).await?;
147
148        if let Some(ref mut comps) = product.mixture_components {
149            for comp in comps.iter_mut() {
150                client.enrich(&mut comp.substance).await?;
151            }
152        }
153
154        Ok(())
155    }
156
157    /// Classify a product and return an HS code prediction.
158    ///
159    /// Priority order:
160    /// 1. User-provided mapping
161    /// 2. Embedded static rule table
162    /// 3. (v0.3) SMILES rule engine
163    /// 4. (v0.4) LLM fallback
164    pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
165        // ── Priority 1: User-provided mappings ────────────────────────
166        if let Some(ref cas) = product.identifier.cas {
167            if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
168                let jp = find_jp_rule(hs_code);
169                return Ok(HsPrediction {
170                    hs_code: hs_code.clone(),
171                    heading_description: String::new(),
172                    confidence: 1.0,
173                    source: PredictionSource::UserMapping,
174                    notes: vec!["From user-provided mapping".to_string()],
175                    alternatives: vec![],
176                    recommended_action: RecommendedAction::Accept,
177                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
178                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
179                });
180            }
181        }
182
183        // ── Priority 2: Embedded static rule table ────────────────────
184        if let Some(ref cas) = product.identifier.cas {
185            if let Some(rule) = find_best_rule(
186                cas,
187                product.physical_form.as_ref(),
188                product.purity_pct,
189            ) {
190                let action = self.recommended_action(rule.confidence);
191                let jp = find_jp_rule(rule.hs_code);
192                return Ok(HsPrediction {
193                    hs_code: rule.hs_code.to_string(),
194                    heading_description: rule.heading_description.to_string(),
195                    confidence: rule.confidence,
196                    source: PredictionSource::EmbeddedRule {
197                        rule_id: format!("{}:{}", rule.cas, rule.hs_code),
198                    },
199                    notes: self.build_notes(product),
200                    alternatives: vec![],
201                    recommended_action: action,
202                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
203                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
204                });
205            }
206        }
207
208        // ── Priority 3: SMILES-based rule engine ─────────────────────────
209        if let Some(ref smiles) = product.identifier.smiles {
210            if let Some(classification) = crate::smiles::classify_smiles(smiles) {
211                let hint = &classification.heading_hint;
212                // Only emit a result when we have at least a 4-digit heading
213                // and confidence meets the LLM-required threshold.
214                if let Some(heading) = hint.heading {
215                    if hint.confidence >= self.config.confidence_threshold_llm_required {
216                        // Pad to 6 digits with "00" sub-heading (best guess)
217                        let hs_code = format!("{:04}00", heading);
218                        let jp = find_jp_rule(&hs_code);
219                        let action = self.recommended_action(hint.confidence);
220
221                        let mut notes = self.build_notes(product);
222                        notes.push(
223                            "Heading is derived from SMILES functional-group analysis. \
224                             Sub-heading (last two digits) is a placeholder — \
225                             verify the exact 6-digit code with the product specification."
226                                .to_string(),
227                        );
228
229                        let matched_rules: Vec<String> = classification
230                            .functional_groups
231                            .iter()
232                            .map(|g| g.label().to_string())
233                            .collect();
234
235                        return Ok(HsPrediction {
236                            hs_code,
237                            heading_description: hint.rationale.to_string(),
238                            confidence: hint.confidence,
239                            source: PredictionSource::RuleEngine { matched_rules },
240                            notes,
241                            alternatives: vec![],
242                            recommended_action: action,
243                            jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
244                            jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
245                        });
246                    }
247                }
248            }
249        }
250
251        // ── Priority 4: LLM fallback (v0.4 placeholder) ───────────────
252        // TODO: implement LLM client in v0.4.
253
254        // No rule matched — return low-confidence placeholder
255        Err(HsPredictError::LowConfidenceNoLlm {
256            confidence: 0.0,
257            threshold: self.config.confidence_threshold_llm_required,
258        })
259    }
260
261    // ─── Private helpers ──────────────────────────────────────────────
262
263    fn recommended_action(&self, confidence: f32) -> RecommendedAction {
264        if confidence >= self.config.confidence_threshold_direct {
265            RecommendedAction::Accept
266        } else if confidence >= self.config.confidence_threshold_llm_required {
267            RecommendedAction::VerifyWithLlm
268        } else {
269            RecommendedAction::ExpertReview
270        }
271    }
272
273    /// Build supplementary notes about shape / purity caveats.
274    fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
275        let mut notes = Vec::new();
276
277        match &product.physical_form {
278            None | Some(PhysicalForm::Unknown) => {
279                notes.push(
280                    "Physical form not specified — the HS subheading may differ \
281                     (e.g. solid vs. solution).".to_string(),
282                );
283            }
284            Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
285                notes.push(
286                    "Solution concentration not specified — subheading may differ \
287                     (e.g. fuming vs. standard grade).".to_string(),
288                );
289            }
290            _ => {}
291        }
292
293        if product.purity_pct.is_none() {
294            notes.push(
295                "Purity not specified — some headings require a minimum purity threshold."
296                    .to_string(),
297            );
298        }
299
300        notes
301    }
302}