hs_predict/pipeline.rs
1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. SMILES-based rule engine (v0.3)
7//! 4. LLM fallback via [`LlmClassifier`] trait hook (v0.4, `llm` feature)
8
9use std::collections::HashMap;
10#[cfg(feature = "llm")]
11use std::sync::Arc;
12
13use crate::error::{HsPredictError, Result};
14use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
15use crate::rules::matcher::find_best_rule;
16use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
17
18/// Configuration for the classification pipeline.
19#[derive(Debug, Clone)]
20pub struct PipelineConfig {
21 /// Confidence threshold above which a result is returned directly
22 /// without asking for LLM confirmation.
23 pub confidence_threshold_direct: f32,
24
25 /// Confidence threshold below which LLM is required.
26 /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
27 /// the result is returned with `RecommendedAction::VerifyWithLlm`.
28 pub confidence_threshold_llm_required: f32,
29}
30
31impl Default for PipelineConfig {
32 fn default() -> Self {
33 Self {
34 confidence_threshold_direct: 0.85,
35 confidence_threshold_llm_required: 0.50,
36 }
37 }
38}
39
40/// Main HS code classification pipeline.
41///
42/// # Example — direct (sync)
43/// ```rust,no_run
44/// use hs_predict::pipeline::HsPipeline;
45/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
46///
47/// let pipeline = HsPipeline::new();
48///
49/// let product = ProductDescription {
50/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
51/// physical_form: Some(PhysicalForm::Solid),
52/// purity_pct: None,
53/// purity_type: None,
54/// mixture_components: None,
55/// intended_use: None,
56/// additional_context: None,
57/// };
58///
59/// let prediction = pipeline.classify(&product).unwrap();
60/// assert_eq!(&prediction.hs_code, "281511");
61/// ```
62///
63/// # Example — with PubChem enrichment (async, `pubchem` feature)
64/// ```rust,no_run
65/// # #[cfg(feature = "pubchem")]
66/// # async fn example() -> hs_predict::Result<()> {
67/// use hs_predict::pipeline::HsPipeline;
68/// use hs_predict::pubchem::PubChemClient;
69/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
70///
71/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
72///
73/// let mut product = ProductDescription {
74/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
75/// physical_form: Some(PhysicalForm::Solid),
76/// purity_pct: None,
77/// purity_type: None,
78/// mixture_components: None,
79/// intended_use: None,
80/// additional_context: None,
81/// };
82///
83/// pipeline.enrich(&mut product).await?; // fills SMILES, InChI, IUPAC name …
84/// let prediction = pipeline.classify(&product)?;
85/// println!("{}", prediction.display()); // "28.15.11"
86/// # Ok(())
87/// # }
88/// ```
89///
90/// # Example — with LLM fallback (async, `llm` feature)
91/// ```rust,no_run
92/// # #[cfg(feature = "llm")]
93/// # async fn example() -> hs_predict::Result<()> {
94/// use hs_predict::pipeline::HsPipeline;
95/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse};
96/// use futures::future::BoxFuture;
97///
98/// struct MyClient;
99/// impl LlmClassifier for MyClient {
100/// fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
101/// Box::pin(async move { todo!() })
102/// }
103/// }
104///
105/// let pipeline = HsPipeline::new().with_llm(MyClient);
106/// use hs_predict::types::{ProductDescription, SubstanceIdentifier};
107/// let product = ProductDescription {
108/// identifier: SubstanceIdentifier::from_cas("12-34-5"),
109/// physical_form: None, purity_pct: None, purity_type: None,
110/// mixture_components: None, intended_use: None, additional_context: None,
111/// };
112/// let prediction = pipeline.classify_with_llm(&product).await?;
113/// println!("{}", prediction.display());
114/// # Ok(())
115/// # }
116/// ```
117#[derive(Default)]
118pub struct HsPipeline {
119 /// User-supplied CAS → HS code overrides. Highest priority.
120 user_mappings: HashMap<String, String>,
121
122 config: PipelineConfig,
123
124 /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
125 #[cfg(feature = "pubchem")]
126 pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
127
128 /// LLM classifier hook (v0.4, `llm` feature).
129 #[cfg(feature = "llm")]
130 llm: Option<Arc<dyn crate::llm::LlmClassifier>>,
131}
132
133impl HsPipeline {
134 /// Create a pipeline with default configuration.
135 pub fn new() -> Self {
136 Self::default()
137 }
138
139 /// Add a user-provided CAS → HS code mapping.
140 ///
141 /// These mappings override the embedded rule table with `confidence = 1.0`.
142 pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
143 self.user_mappings.insert(cas.into(), hs_code.into());
144 self
145 }
146
147 /// Override the default pipeline configuration.
148 pub fn with_config(mut self, config: PipelineConfig) -> Self {
149 self.config = config;
150 self
151 }
152
153 /// Attach an [`LlmClassifier`](crate::llm::LlmClassifier) implementation to
154 /// enable the LLM fallback (Priority 4).
155 ///
156 /// The LLM is called by [`classify_with_llm`](Self::classify_with_llm) when
157 /// the rule-based pipeline returns a result with
158 /// `recommended_action != Accept`, or returns
159 /// [`LowConfidenceNoLlm`](crate::HsPredictError::LowConfidenceNoLlm).
160 ///
161 /// Requires the **`llm`** Cargo feature.
162 #[cfg(feature = "llm")]
163 pub fn with_llm(mut self, client: impl crate::llm::LlmClassifier + 'static) -> Self {
164 self.llm = Some(Arc::new(client));
165 self
166 }
167
168 /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
169 /// automatic identifier enrichment before classification.
170 ///
171 /// Requires the **`pubchem`** Cargo feature.
172 #[cfg(feature = "pubchem")]
173 pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
174 self.pubchem = Some(std::sync::Arc::new(client));
175 self
176 }
177
178 /// Enrich a [`ProductDescription`] with PubChem data.
179 ///
180 /// Fills in any missing fields of the main identifier and each mixture
181 /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
182 ///
183 /// This is a **best-effort** operation:
184 /// - "Not found" and "no usable identifier" results are silently ignored.
185 /// - Network / parse errors **are** propagated.
186 /// - If no PubChem client is configured, returns `Ok(())` immediately.
187 ///
188 /// Requires the **`pubchem`** Cargo feature.
189 #[cfg(feature = "pubchem")]
190 pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
191 let Some(ref client) = self.pubchem else {
192 return Ok(());
193 };
194
195 client.enrich(&mut product.identifier).await?;
196
197 if let Some(ref mut comps) = product.mixture_components {
198 for comp in comps.iter_mut() {
199 client.enrich(&mut comp.substance).await?;
200 }
201 }
202
203 Ok(())
204 }
205
206 /// Classify a product and return an HS code prediction.
207 ///
208 /// Priority order:
209 /// 1. User-provided mapping
210 /// 2. Embedded static rule table
211 /// 3. (v0.3) SMILES rule engine
212 /// 4. (v0.4) LLM fallback
213 pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
214 // ── Priority 1: User-provided mappings ────────────────────────
215 if let Some(ref cas) = product.identifier.cas {
216 if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
217 let jp = find_jp_rule(hs_code);
218 return Ok(HsPrediction {
219 hs_code: hs_code.clone(),
220 heading_description: String::new(),
221 confidence: 1.0,
222 source: PredictionSource::UserMapping,
223 notes: vec!["From user-provided mapping".to_string()],
224 alternatives: vec![],
225 recommended_action: RecommendedAction::Accept,
226 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
227 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
228 });
229 }
230 }
231
232 // ── Priority 2: Embedded static rule table ────────────────────
233 if let Some(ref cas) = product.identifier.cas {
234 if let Some(rule) = find_best_rule(
235 cas,
236 product.physical_form.as_ref(),
237 product.purity_pct,
238 ) {
239 let action = self.recommended_action(rule.confidence);
240 let jp = find_jp_rule(rule.hs_code);
241 return Ok(HsPrediction {
242 hs_code: rule.hs_code.to_string(),
243 heading_description: rule.heading_description.to_string(),
244 confidence: rule.confidence,
245 source: PredictionSource::EmbeddedRule {
246 rule_id: format!("{}:{}", rule.cas, rule.hs_code),
247 },
248 notes: self.build_notes(product),
249 alternatives: vec![],
250 recommended_action: action,
251 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
252 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
253 });
254 }
255 }
256
257 // ── Priority 3: SMILES-based rule engine ─────────────────────────
258 if let Some(ref smiles) = product.identifier.smiles {
259 if let Some(classification) = crate::smiles::classify_smiles(smiles) {
260 let hint = &classification.heading_hint;
261 // Only emit a result when we have at least a 4-digit heading
262 // and confidence meets the LLM-required threshold.
263 if let Some(heading) = hint.heading {
264 if hint.confidence >= self.config.confidence_threshold_llm_required {
265 // Pad to 6 digits with "00" sub-heading (best guess)
266 let hs_code = format!("{:04}00", heading);
267 let jp = find_jp_rule(&hs_code);
268 let action = self.recommended_action(hint.confidence);
269
270 let mut notes = self.build_notes(product);
271 notes.push(
272 "Heading is derived from SMILES functional-group analysis. \
273 Sub-heading (last two digits) is a placeholder — \
274 verify the exact 6-digit code with the product specification."
275 .to_string(),
276 );
277
278 let matched_rules: Vec<String> = classification
279 .functional_groups
280 .iter()
281 .map(|g| g.label().to_string())
282 .collect();
283
284 return Ok(HsPrediction {
285 hs_code,
286 heading_description: hint.rationale.to_string(),
287 confidence: hint.confidence,
288 source: PredictionSource::RuleEngine { matched_rules },
289 notes,
290 alternatives: vec![],
291 recommended_action: action,
292 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
293 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
294 });
295 }
296 }
297 }
298 }
299
300 // ── Priority 4: LLM fallback ─────────────────────────────────
301 // (async path — use classify_with_llm for LLM support)
302 Err(HsPredictError::LowConfidenceNoLlm {
303 confidence: 0.0,
304 threshold: self.config.confidence_threshold_llm_required,
305 })
306 }
307
308 /// Classify a product, falling back to the configured LLM when the
309 /// rule-based pipeline returns a low-confidence or uncertain result.
310 ///
311 /// # Priority order (same as [`classify`](Self::classify) + LLM)
312 ///
313 /// 1. User-provided mapping → `Accept` → return immediately.
314 /// 2. Embedded static rule table → `Accept` → return immediately.
315 /// 3. SMILES rule engine → `Accept` → return immediately.
316 /// 4. Any result with `recommended_action != Accept`, or
317 /// `LowConfidenceNoLlm` → forward to LLM.
318 ///
319 /// If no LLM client has been configured via [`with_llm`](Self::with_llm),
320 /// returns [`HsPredictError::LlmNotConfigured`].
321 ///
322 /// # Validation
323 /// The LLM's `hs_code` must be exactly 6 ASCII digits; otherwise
324 /// [`HsPredictError::ValidationFailed`] is returned.
325 ///
326 /// # Chapter consistency
327 /// If the LLM chapter differs from the SMILES engine's chapter hint, a
328 /// warning note is appended — this is **not** a hard error.
329 ///
330 /// Requires the **`llm`** Cargo feature.
331 #[cfg(feature = "llm")]
332 pub async fn classify_with_llm(
333 &self,
334 product: &ProductDescription,
335 ) -> Result<HsPrediction> {
336 use crate::llm::PromptBuilder;
337 use crate::types::AlternativePrediction;
338
339 // First try the synchronous rule-based pipeline.
340 let needs_llm = match self.classify(product) {
341 Ok(pred) if pred.recommended_action == RecommendedAction::Accept => {
342 return Ok(pred);
343 }
344 Ok(_pred) => true, // low-confidence result → try LLM
345 Err(HsPredictError::LowConfidenceNoLlm { .. }) => true,
346 Err(e) => return Err(e),
347 };
348
349 debug_assert!(needs_llm);
350
351 // Require a configured LLM client.
352 let llm = self
353 .llm
354 .as_ref()
355 .ok_or(HsPredictError::LlmNotConfigured)?;
356
357 // Build prompt and call the LLM.
358 let prompt = PromptBuilder::new().build(product);
359 let resp = llm.classify(&prompt).await?;
360
361 // Validate: must be exactly 6 ASCII digits.
362 if resp.hs_code.len() != 6 || !resp.hs_code.chars().all(|c| c.is_ascii_digit()) {
363 return Err(HsPredictError::ValidationFailed { code: resp.hs_code });
364 }
365
366 // Chapter consistency check (warning only).
367 let mut notes = self.build_notes(product);
368 if let Some(ref analysis) = prompt.smiles_analysis {
369 let llm_chapter = &resp.hs_code[..2];
370 let expected_chapter = format!("{:02}", analysis.heading_hint.chapter);
371 if llm_chapter != expected_chapter {
372 notes.push(format!(
373 "Chapter mismatch: LLM returned Chapter {} but SMILES engine \
374 suggested Chapter {}. Verify with Chapter Notes.",
375 llm_chapter, expected_chapter
376 ));
377 }
378 }
379
380 notes.push(format!("LLM rationale: {}", resp.rationale));
381
382 let jp = find_jp_rule(&resp.hs_code);
383 let action = self.recommended_action(resp.confidence);
384
385 let alternatives = resp
386 .alternatives
387 .into_iter()
388 .map(|a| AlternativePrediction {
389 hs_code: a.hs_code,
390 confidence: a.confidence,
391 reason: a.reason,
392 })
393 .collect();
394
395 Ok(HsPrediction {
396 hs_code: resp.hs_code,
397 heading_description: String::new(),
398 confidence: resp.confidence,
399 source: PredictionSource::LlmApi { model: String::new() },
400 notes,
401 alternatives,
402 recommended_action: action,
403 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
404 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
405 })
406 }
407
408 // ─── Private helpers ──────────────────────────────────────────────
409
410 fn recommended_action(&self, confidence: f32) -> RecommendedAction {
411 if confidence >= self.config.confidence_threshold_direct {
412 RecommendedAction::Accept
413 } else if confidence >= self.config.confidence_threshold_llm_required {
414 RecommendedAction::VerifyWithLlm
415 } else {
416 RecommendedAction::ExpertReview
417 }
418 }
419
420 /// Build supplementary notes about shape / purity caveats.
421 fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
422 let mut notes = Vec::new();
423
424 match &product.physical_form {
425 None | Some(PhysicalForm::Unknown) => {
426 notes.push(
427 "Physical form not specified — the HS subheading may differ \
428 (e.g. solid vs. solution).".to_string(),
429 );
430 }
431 Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
432 notes.push(
433 "Solution concentration not specified — subheading may differ \
434 (e.g. fuming vs. standard grade).".to_string(),
435 );
436 }
437 _ => {}
438 }
439
440 if product.purity_pct.is_none() {
441 notes.push(
442 "Purity not specified — some headings require a minimum purity threshold."
443 .to_string(),
444 );
445 }
446
447 notes
448 }
449}
450
451// ─────────────────────────────────────────────────────────────────────────────
452// Pipeline integration tests
453// ─────────────────────────────────────────────────────────────────────────────
454
455#[cfg(all(test, feature = "mock"))]
456mod tests {
457 use super::*;
458 use crate::llm::MockLlmClassifier;
459 use crate::types::{SubstanceIdentifier};
460
461 /// A product with no static rule and a SMILES → triggers LLM path.
462 fn unknown_organic() -> ProductDescription {
463 ProductDescription {
464 identifier: SubstanceIdentifier {
465 cas: None,
466 smiles: Some("CC(O)=O".to_string()), // acetic acid SMILES, unknown CAS
467 iupac_name: None,
468 inchi: None,
469 inchi_key: None,
470 cid: None,
471 },
472 physical_form: None,
473 purity_pct: None,
474 purity_type: None,
475 mixture_components: None,
476 intended_use: None,
477 additional_context: None,
478 }
479 }
480
481 #[tokio::test]
482 async fn classify_with_llm_mock_returns_6_digit_code() {
483 let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
484 let product = unknown_organic();
485 let pred = pipeline.classify_with_llm(&product).await.unwrap();
486 assert_eq!(pred.hs_code.len(), 6);
487 assert!(pred.hs_code.chars().all(|c| c.is_ascii_digit()));
488 }
489
490 #[tokio::test]
491 async fn classify_with_llm_mock_chapter_29_for_smiles_acid() {
492 let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
493 let product = unknown_organic();
494 let pred = pipeline.classify_with_llm(&product).await.unwrap();
495 assert!(
496 pred.hs_code.starts_with("29"),
497 "acetic acid SMILES should yield Chapter 29, got {}",
498 pred.hs_code
499 );
500 }
501
502 #[tokio::test]
503 async fn classify_with_llm_no_client_returns_error() {
504 let pipeline = HsPipeline::new(); // no LLM attached
505 let product = unknown_organic();
506 let err = pipeline.classify_with_llm(&product).await.unwrap_err();
507 assert!(
508 matches!(err, HsPredictError::LlmNotConfigured),
509 "expected LlmNotConfigured, got {:?}",
510 err
511 );
512 }
513
514 #[tokio::test]
515 async fn classify_with_llm_skips_llm_for_high_confidence_rule() {
516 // NaOH solid → static rule, confidence = 1.0 → should NOT call LLM
517 let pipeline = HsPipeline::new()
518 .with_llm(MockLlmClassifier::with_default("999999", 0.1));
519 let product = ProductDescription {
520 identifier: SubstanceIdentifier::from_cas("1310-73-2"),
521 physical_form: Some(crate::types::PhysicalForm::Solid),
522 purity_pct: None,
523 purity_type: None,
524 mixture_components: None,
525 intended_use: None,
526 additional_context: None,
527 };
528 let pred = pipeline.classify_with_llm(&product).await.unwrap();
529 // Should be the static rule result, not the mock's "999999"
530 assert_eq!(pred.hs_code, "281511", "static rule should win over LLM");
531 }
532
533 #[tokio::test]
534 async fn classify_with_llm_invalid_code_returns_validation_error() {
535 // Mock returning an invalid code
536 struct BadMock;
537 impl crate::llm::LlmClassifier for BadMock {
538 fn classify<'a>(
539 &'a self,
540 _prompt: &'a crate::llm::LlmPrompt,
541 ) -> futures::future::BoxFuture<'a, crate::Result<crate::llm::LlmResponse>> {
542 Box::pin(async {
543 Ok(crate::llm::LlmResponse {
544 hs_code: "BAD!!".to_string(),
545 confidence: 0.5,
546 rationale: "bad".to_string(),
547 alternatives: vec![],
548 })
549 })
550 }
551 }
552 let pipeline = HsPipeline::new().with_llm(BadMock);
553 let product = unknown_organic();
554 let err = pipeline.classify_with_llm(&product).await.unwrap_err();
555 assert!(
556 matches!(err, HsPredictError::ValidationFailed { .. }),
557 "expected ValidationFailed, got {:?}",
558 err
559 );
560 }
561}