hs_predict/pipeline.rs
1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. *(placeholder)* SMILES-based rule engine — v0.3
7//! 4. *(placeholder)* LLM API — v0.4
8
9use std::collections::HashMap;
10
11use crate::error::{HsPredictError, Result};
12use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
13use crate::rules::matcher::find_best_rule;
14use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
15
16/// Configuration for the classification pipeline.
17#[derive(Debug, Clone)]
18pub struct PipelineConfig {
19 /// Confidence threshold above which a result is returned directly
20 /// without asking for LLM confirmation.
21 pub confidence_threshold_direct: f32,
22
23 /// Confidence threshold below which LLM is required.
24 /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
25 /// the result is returned with `RecommendedAction::VerifyWithLlm`.
26 pub confidence_threshold_llm_required: f32,
27}
28
29impl Default for PipelineConfig {
30 fn default() -> Self {
31 Self {
32 confidence_threshold_direct: 0.85,
33 confidence_threshold_llm_required: 0.50,
34 }
35 }
36}
37
38/// Main HS code classification pipeline.
39///
40/// # Example — direct (sync)
41/// ```rust,no_run
42/// use hs_predict::pipeline::HsPipeline;
43/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
44///
45/// let pipeline = HsPipeline::new();
46///
47/// let product = ProductDescription {
48/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
49/// physical_form: Some(PhysicalForm::Solid),
50/// purity_pct: None,
51/// purity_type: None,
52/// mixture_components: None,
53/// intended_use: None,
54/// additional_context: None,
55/// };
56///
57/// let prediction = pipeline.classify(&product).unwrap();
58/// assert_eq!(&prediction.hs_code, "281511");
59/// ```
60///
61/// # Example — with PubChem enrichment (async, `pubchem` feature)
62/// ```rust,no_run
63/// # #[cfg(feature = "pubchem")]
64/// # async fn example() -> hs_predict::Result<()> {
65/// use hs_predict::pipeline::HsPipeline;
66/// use hs_predict::pubchem::PubChemClient;
67/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
68///
69/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
70///
71/// let mut product = ProductDescription {
72/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
73/// physical_form: Some(PhysicalForm::Solid),
74/// purity_pct: None,
75/// purity_type: None,
76/// mixture_components: None,
77/// intended_use: None,
78/// additional_context: None,
79/// };
80///
81/// pipeline.enrich(&mut product).await?; // fills SMILES, InChI, IUPAC name …
82/// let prediction = pipeline.classify(&product)?;
83/// println!("{}", prediction.display()); // "28.15.11"
84/// # Ok(())
85/// # }
86/// ```
87#[derive(Debug, Default)]
88pub struct HsPipeline {
89 /// User-supplied CAS → HS code overrides. Highest priority.
90 user_mappings: HashMap<String, String>,
91
92 config: PipelineConfig,
93
94 /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
95 #[cfg(feature = "pubchem")]
96 pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
97}
98
99impl HsPipeline {
100 /// Create a pipeline with default configuration.
101 pub fn new() -> Self {
102 Self::default()
103 }
104
105 /// Add a user-provided CAS → HS code mapping.
106 ///
107 /// These mappings override the embedded rule table with `confidence = 1.0`.
108 pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
109 self.user_mappings.insert(cas.into(), hs_code.into());
110 self
111 }
112
113 /// Override the default pipeline configuration.
114 pub fn with_config(mut self, config: PipelineConfig) -> Self {
115 self.config = config;
116 self
117 }
118
119 /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
120 /// automatic identifier enrichment before classification.
121 ///
122 /// Requires the **`pubchem`** Cargo feature.
123 #[cfg(feature = "pubchem")]
124 pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
125 self.pubchem = Some(std::sync::Arc::new(client));
126 self
127 }
128
129 /// Enrich a [`ProductDescription`] with PubChem data.
130 ///
131 /// Fills in any missing fields of the main identifier and each mixture
132 /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
133 ///
134 /// This is a **best-effort** operation:
135 /// - "Not found" and "no usable identifier" results are silently ignored.
136 /// - Network / parse errors **are** propagated.
137 /// - If no PubChem client is configured, returns `Ok(())` immediately.
138 ///
139 /// Requires the **`pubchem`** Cargo feature.
140 #[cfg(feature = "pubchem")]
141 pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
142 let Some(ref client) = self.pubchem else {
143 return Ok(());
144 };
145
146 client.enrich(&mut product.identifier).await?;
147
148 if let Some(ref mut comps) = product.mixture_components {
149 for comp in comps.iter_mut() {
150 client.enrich(&mut comp.substance).await?;
151 }
152 }
153
154 Ok(())
155 }
156
157 /// Classify a product and return an HS code prediction.
158 ///
159 /// Priority order:
160 /// 1. User-provided mapping
161 /// 2. Embedded static rule table
162 /// 3. (v0.3) SMILES rule engine
163 /// 4. (v0.4) LLM fallback
164 pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
165 // ── Priority 1: User-provided mappings ────────────────────────
166 if let Some(ref cas) = product.identifier.cas {
167 if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
168 let jp = find_jp_rule(hs_code);
169 return Ok(HsPrediction {
170 hs_code: hs_code.clone(),
171 heading_description: String::new(),
172 confidence: 1.0,
173 source: PredictionSource::UserMapping,
174 notes: vec!["From user-provided mapping".to_string()],
175 alternatives: vec![],
176 recommended_action: RecommendedAction::Accept,
177 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
178 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
179 });
180 }
181 }
182
183 // ── Priority 2: Embedded static rule table ────────────────────
184 if let Some(ref cas) = product.identifier.cas {
185 if let Some(rule) = find_best_rule(
186 cas,
187 product.physical_form.as_ref(),
188 product.purity_pct,
189 ) {
190 let action = self.recommended_action(rule.confidence);
191 let jp = find_jp_rule(rule.hs_code);
192 return Ok(HsPrediction {
193 hs_code: rule.hs_code.to_string(),
194 heading_description: rule.heading_description.to_string(),
195 confidence: rule.confidence,
196 source: PredictionSource::EmbeddedRule {
197 rule_id: format!("{}:{}", rule.cas, rule.hs_code),
198 },
199 notes: self.build_notes(product),
200 alternatives: vec![],
201 recommended_action: action,
202 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
203 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
204 });
205 }
206 }
207
208 // ── Priority 3: SMILES-based rule engine ─────────────────────────
209 if let Some(ref smiles) = product.identifier.smiles {
210 if let Some(classification) = crate::smiles::classify_smiles(smiles) {
211 let hint = &classification.heading_hint;
212 // Only emit a result when we have at least a 4-digit heading
213 // and confidence meets the LLM-required threshold.
214 if let Some(heading) = hint.heading {
215 if hint.confidence >= self.config.confidence_threshold_llm_required {
216 // Pad to 6 digits with "00" sub-heading (best guess)
217 let hs_code = format!("{:04}00", heading);
218 let jp = find_jp_rule(&hs_code);
219 let action = self.recommended_action(hint.confidence);
220
221 let mut notes = self.build_notes(product);
222 notes.push(
223 "Heading is derived from SMILES functional-group analysis. \
224 Sub-heading (last two digits) is a placeholder — \
225 verify the exact 6-digit code with the product specification."
226 .to_string(),
227 );
228
229 let matched_rules: Vec<String> = classification
230 .functional_groups
231 .iter()
232 .map(|g| g.label().to_string())
233 .collect();
234
235 return Ok(HsPrediction {
236 hs_code,
237 heading_description: hint.rationale.to_string(),
238 confidence: hint.confidence,
239 source: PredictionSource::RuleEngine { matched_rules },
240 notes,
241 alternatives: vec![],
242 recommended_action: action,
243 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
244 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
245 });
246 }
247 }
248 }
249 }
250
251 // ── Priority 4: LLM fallback (v0.4 placeholder) ───────────────
252 // TODO: implement LLM client in v0.4.
253
254 // No rule matched — return low-confidence placeholder
255 Err(HsPredictError::LowConfidenceNoLlm {
256 confidence: 0.0,
257 threshold: self.config.confidence_threshold_llm_required,
258 })
259 }
260
261 // ─── Private helpers ──────────────────────────────────────────────
262
263 fn recommended_action(&self, confidence: f32) -> RecommendedAction {
264 if confidence >= self.config.confidence_threshold_direct {
265 RecommendedAction::Accept
266 } else if confidence >= self.config.confidence_threshold_llm_required {
267 RecommendedAction::VerifyWithLlm
268 } else {
269 RecommendedAction::ExpertReview
270 }
271 }
272
273 /// Build supplementary notes about shape / purity caveats.
274 fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
275 let mut notes = Vec::new();
276
277 match &product.physical_form {
278 None | Some(PhysicalForm::Unknown) => {
279 notes.push(
280 "Physical form not specified — the HS subheading may differ \
281 (e.g. solid vs. solution).".to_string(),
282 );
283 }
284 Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
285 notes.push(
286 "Solution concentration not specified — subheading may differ \
287 (e.g. fuming vs. standard grade).".to_string(),
288 );
289 }
290 _ => {}
291 }
292
293 if product.purity_pct.is_none() {
294 notes.push(
295 "Purity not specified — some headings require a minimum purity threshold."
296 .to_string(),
297 );
298 }
299
300 notes
301 }
302}