hs_predict/pipeline.rs
1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. SMILES-based rule engine (v0.3)
7//! 4. LLM fallback via [`LlmClassifier`] trait hook (v0.4, `llm` feature)
8
9use std::collections::HashMap;
10#[cfg(feature = "llm")]
11use std::sync::Arc;
12
13use crate::error::{HsPredictError, Result};
14use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
15use crate::rules::matcher::find_best_rule;
16use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
17
18/// Configuration for the classification pipeline.
19#[derive(Debug, Clone)]
20pub struct PipelineConfig {
21 /// Confidence threshold above which a result is returned directly
22 /// without asking for LLM confirmation.
23 pub confidence_threshold_direct: f32,
24
25 /// Confidence threshold below which LLM is required.
26 /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
27 /// the result is returned with `RecommendedAction::VerifyWithLlm`.
28 pub confidence_threshold_llm_required: f32,
29}
30
31impl Default for PipelineConfig {
32 fn default() -> Self {
33 Self {
34 confidence_threshold_direct: 0.85,
35 confidence_threshold_llm_required: 0.50,
36 }
37 }
38}
39
40/// Main HS code classification pipeline.
41///
42/// # Example — direct (sync)
43/// ```rust,no_run
44/// use hs_predict::pipeline::HsPipeline;
45/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
46///
47/// let pipeline = HsPipeline::new();
48///
49/// let product = ProductDescription {
50/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
51/// physical_form: Some(PhysicalForm::Solid),
52/// purity_pct: None,
53/// purity_type: None,
54/// mixture_components: None,
55/// intended_use: None,
56/// additional_context: None,
57/// };
58///
59/// let prediction = pipeline.classify(&product).unwrap();
60/// assert_eq!(&prediction.hs_code, "281511");
61/// ```
62///
63/// # Example — with PubChem enrichment (async, `pubchem` feature)
64/// ```rust,no_run
65/// # #[cfg(feature = "pubchem")]
66/// # async fn example() -> hs_predict::Result<()> {
67/// use hs_predict::pipeline::HsPipeline;
68/// use hs_predict::pubchem::PubChemClient;
69/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
70///
71/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
72///
73/// let mut product = ProductDescription {
74/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
75/// physical_form: Some(PhysicalForm::Solid),
76/// purity_pct: None,
77/// purity_type: None,
78/// mixture_components: None,
79/// intended_use: None,
80/// additional_context: None,
81/// };
82///
83/// pipeline.enrich(&mut product).await?; // fills SMILES, InChI, IUPAC name …
84/// let prediction = pipeline.classify(&product)?;
85/// println!("{}", prediction.display()); // "28.15.11"
86/// # Ok(())
87/// # }
88/// ```
89///
90/// # Example — with LLM fallback (async, `llm` feature)
91/// ```rust,no_run
92/// # #[cfg(feature = "llm")]
93/// # async fn example() -> hs_predict::Result<()> {
94/// use hs_predict::pipeline::HsPipeline;
95/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse};
96/// use futures::future::BoxFuture;
97///
98/// struct MyClient;
99/// impl LlmClassifier for MyClient {
100/// fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
101/// Box::pin(async move { todo!() })
102/// }
103/// }
104///
105/// let pipeline = HsPipeline::new().with_llm(MyClient);
106/// use hs_predict::types::{ProductDescription, SubstanceIdentifier};
107/// let product = ProductDescription {
108/// identifier: SubstanceIdentifier::from_cas("12-34-5"),
109/// physical_form: None, purity_pct: None, purity_type: None,
110/// mixture_components: None, intended_use: None, additional_context: None,
111/// };
112/// let prediction = pipeline.classify_with_llm(&product).await?;
113/// println!("{}", prediction.display());
114/// # Ok(())
115/// # }
116/// ```
117#[derive(Default)]
118pub struct HsPipeline {
119 /// User-supplied CAS → HS code overrides. Highest priority.
120 user_mappings: HashMap<String, String>,
121
122 config: PipelineConfig,
123
124 /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
125 #[cfg(feature = "pubchem")]
126 pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
127
128 /// LLM classifier hook (v0.4, `llm` feature).
129 #[cfg(feature = "llm")]
130 llm: Option<Arc<dyn crate::llm::LlmClassifier>>,
131}
132
133impl std::fmt::Debug for HsPipeline {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 let mut s = f.debug_struct("HsPipeline");
136 s.field("user_mappings", &self.user_mappings);
137 s.field("config", &self.config);
138 #[cfg(feature = "pubchem")]
139 s.field("pubchem", &self.pubchem.as_ref().map(|_| "<PubChemClient>"));
140 #[cfg(feature = "llm")]
141 s.field("llm", &self.llm.as_ref().map(|_| "<dyn LlmClassifier>"));
142 s.finish()
143 }
144}
145
146impl HsPipeline {
147 /// Create a pipeline with default configuration.
148 pub fn new() -> Self {
149 Self::default()
150 }
151
152 /// Add a user-provided CAS → HS code mapping.
153 ///
154 /// These mappings override the embedded rule table with `confidence = 1.0`.
155 pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
156 self.user_mappings.insert(cas.into(), hs_code.into());
157 self
158 }
159
160 /// Override the default pipeline configuration.
161 pub fn with_config(mut self, config: PipelineConfig) -> Self {
162 self.config = config;
163 self
164 }
165
166 /// Attach an [`LlmClassifier`](crate::llm::LlmClassifier) implementation to
167 /// enable the LLM fallback (Priority 4).
168 ///
169 /// The LLM is called by [`classify_with_llm`](Self::classify_with_llm) when
170 /// the rule-based pipeline returns a result with
171 /// `recommended_action != Accept`, or returns
172 /// [`LowConfidenceNoLlm`](crate::HsPredictError::LowConfidenceNoLlm).
173 ///
174 /// Requires the **`llm`** Cargo feature.
175 #[cfg(feature = "llm")]
176 pub fn with_llm(mut self, client: impl crate::llm::LlmClassifier + 'static) -> Self {
177 self.llm = Some(Arc::new(client));
178 self
179 }
180
181 /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
182 /// automatic identifier enrichment before classification.
183 ///
184 /// Requires the **`pubchem`** Cargo feature.
185 #[cfg(feature = "pubchem")]
186 pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
187 self.pubchem = Some(std::sync::Arc::new(client));
188 self
189 }
190
191 /// Enrich a [`ProductDescription`] with PubChem data.
192 ///
193 /// Fills in any missing fields of the main identifier and each mixture
194 /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
195 ///
196 /// This is a **best-effort** operation:
197 /// - "Not found" and "no usable identifier" results are silently ignored.
198 /// - Network / parse errors **are** propagated.
199 /// - If no PubChem client is configured, returns `Ok(())` immediately.
200 ///
201 /// Requires the **`pubchem`** Cargo feature.
202 #[cfg(feature = "pubchem")]
203 pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
204 let Some(ref client) = self.pubchem else {
205 return Ok(());
206 };
207
208 client.enrich(&mut product.identifier).await?;
209
210 if let Some(ref mut comps) = product.mixture_components {
211 for comp in comps.iter_mut() {
212 client.enrich(&mut comp.substance).await?;
213 }
214 }
215
216 Ok(())
217 }
218
219 /// Classify a product and return an HS code prediction.
220 ///
221 /// Priority order:
222 /// 1. User-provided mapping
223 /// 2. Embedded static rule table
224 /// 3. (v0.3) SMILES rule engine
225 /// 4. (v0.4) LLM fallback
226 pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
227 // ── Priority 1: User-provided mappings ────────────────────────
228 if let Some(ref cas) = product.identifier.cas {
229 if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
230 let jp = find_jp_rule(hs_code);
231 return Ok(HsPrediction {
232 hs_code: hs_code.clone(),
233 heading_description: String::new(),
234 confidence: 1.0,
235 source: PredictionSource::UserMapping,
236 notes: vec!["From user-provided mapping".to_string()],
237 alternatives: vec![],
238 recommended_action: RecommendedAction::Accept,
239 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
240 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
241 });
242 }
243 }
244
245 // ── Priority 2: Embedded static rule table ────────────────────
246 if let Some(ref cas) = product.identifier.cas {
247 if let Some(rule) = find_best_rule(
248 cas,
249 product.physical_form.as_ref(),
250 product.purity_pct,
251 ) {
252 let action = self.recommended_action(rule.confidence);
253 let jp = find_jp_rule(rule.hs_code);
254 return Ok(HsPrediction {
255 hs_code: rule.hs_code.to_string(),
256 heading_description: rule.heading_description.to_string(),
257 confidence: rule.confidence,
258 source: PredictionSource::EmbeddedRule {
259 rule_id: format!("{}:{}", rule.cas, rule.hs_code),
260 },
261 notes: self.build_notes(product),
262 alternatives: vec![],
263 recommended_action: action,
264 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
265 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
266 });
267 }
268 }
269
270 // ── Priority 3: SMILES-based rule engine ─────────────────────────
271 if let Some(ref smiles) = product.identifier.smiles {
272 if let Some(classification) = crate::smiles::classify_smiles(smiles) {
273 let hint = &classification.heading_hint;
274 // Only emit a result when we have at least a 4-digit heading
275 // and confidence meets the LLM-required threshold.
276 if let Some(heading) = hint.heading {
277 if hint.confidence >= self.config.confidence_threshold_llm_required {
278 // Pad to 6 digits with "00" sub-heading (best guess)
279 let hs_code = format!("{:04}00", heading);
280 let jp = find_jp_rule(&hs_code);
281 let action = self.recommended_action(hint.confidence);
282
283 let mut notes = self.build_notes(product);
284 notes.push(
285 "Heading is derived from SMILES functional-group analysis. \
286 Sub-heading (last two digits) is a placeholder — \
287 verify the exact 6-digit code with the product specification."
288 .to_string(),
289 );
290
291 let matched_rules: Vec<String> = classification
292 .functional_groups
293 .iter()
294 .map(|g| g.label().to_string())
295 .collect();
296
297 return Ok(HsPrediction {
298 hs_code,
299 heading_description: hint.rationale.to_string(),
300 confidence: hint.confidence,
301 source: PredictionSource::RuleEngine { matched_rules },
302 notes,
303 alternatives: vec![],
304 recommended_action: action,
305 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
306 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
307 });
308 }
309 }
310 }
311 }
312
313 // ── Priority 4: LLM fallback ─────────────────────────────────
314 // (async path — use classify_with_llm for LLM support)
315 Err(HsPredictError::LowConfidenceNoLlm {
316 confidence: 0.0,
317 threshold: self.config.confidence_threshold_llm_required,
318 })
319 }
320
321 /// Classify a product, falling back to the configured LLM when the
322 /// rule-based pipeline returns a low-confidence or uncertain result.
323 ///
324 /// # Priority order (same as [`classify`](Self::classify) + LLM)
325 ///
326 /// 1. User-provided mapping → `Accept` → return immediately.
327 /// 2. Embedded static rule table → `Accept` → return immediately.
328 /// 3. SMILES rule engine → `Accept` → return immediately.
329 /// 4. Any result with `recommended_action != Accept`, or
330 /// `LowConfidenceNoLlm` → forward to LLM.
331 ///
332 /// If no LLM client has been configured via [`with_llm`](Self::with_llm),
333 /// returns [`HsPredictError::LlmNotConfigured`].
334 ///
335 /// # Validation
336 /// The LLM's `hs_code` must be exactly 6 ASCII digits; otherwise
337 /// [`HsPredictError::ValidationFailed`] is returned.
338 ///
339 /// # Chapter consistency
340 /// If the LLM chapter differs from the SMILES engine's chapter hint, a
341 /// warning note is appended — this is **not** a hard error.
342 ///
343 /// Requires the **`llm`** Cargo feature.
344 #[cfg(feature = "llm")]
345 pub async fn classify_with_llm(
346 &self,
347 product: &ProductDescription,
348 ) -> Result<HsPrediction> {
349 use crate::llm::PromptBuilder;
350 use crate::types::AlternativePrediction;
351
352 // First try the synchronous rule-based pipeline.
353 let needs_llm = match self.classify(product) {
354 Ok(pred) if pred.recommended_action == RecommendedAction::Accept => {
355 return Ok(pred);
356 }
357 Ok(_pred) => true, // low-confidence result → try LLM
358 Err(HsPredictError::LowConfidenceNoLlm { .. }) => true,
359 Err(e) => return Err(e),
360 };
361
362 debug_assert!(needs_llm);
363
364 // Require a configured LLM client.
365 let llm = self
366 .llm
367 .as_ref()
368 .ok_or(HsPredictError::LlmNotConfigured)?;
369
370 // Build prompt and call the LLM.
371 let prompt = PromptBuilder::new().build(product);
372 let resp = llm.classify(&prompt).await?;
373
374 // Validate: must be exactly 6 ASCII digits.
375 if resp.hs_code.len() != 6 || !resp.hs_code.chars().all(|c| c.is_ascii_digit()) {
376 return Err(HsPredictError::ValidationFailed { code: resp.hs_code });
377 }
378
379 // Chapter consistency check (warning only).
380 let mut notes = self.build_notes(product);
381 if let Some(ref analysis) = prompt.smiles_analysis {
382 let llm_chapter = &resp.hs_code[..2];
383 let expected_chapter = format!("{:02}", analysis.heading_hint.chapter);
384 if llm_chapter != expected_chapter {
385 notes.push(format!(
386 "Chapter mismatch: LLM returned Chapter {} but SMILES engine \
387 suggested Chapter {}. Verify with Chapter Notes.",
388 llm_chapter, expected_chapter
389 ));
390 }
391 }
392
393 notes.push(format!("LLM rationale: {}", resp.rationale));
394
395 let jp = find_jp_rule(&resp.hs_code);
396 let action = self.recommended_action(resp.confidence);
397
398 let alternatives = resp
399 .alternatives
400 .into_iter()
401 .map(|a| AlternativePrediction {
402 hs_code: a.hs_code,
403 confidence: a.confidence,
404 reason: a.reason,
405 })
406 .collect();
407
408 Ok(HsPrediction {
409 hs_code: resp.hs_code,
410 heading_description: String::new(),
411 confidence: resp.confidence,
412 source: PredictionSource::LlmApi { model: String::new() },
413 notes,
414 alternatives,
415 recommended_action: action,
416 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
417 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
418 })
419 }
420
421 // ─── Private helpers ──────────────────────────────────────────────
422
423 fn recommended_action(&self, confidence: f32) -> RecommendedAction {
424 if confidence >= self.config.confidence_threshold_direct {
425 RecommendedAction::Accept
426 } else if confidence >= self.config.confidence_threshold_llm_required {
427 RecommendedAction::VerifyWithLlm
428 } else {
429 RecommendedAction::ExpertReview
430 }
431 }
432
433 /// Build supplementary notes about shape / purity caveats.
434 fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
435 let mut notes = Vec::new();
436
437 match &product.physical_form {
438 None | Some(PhysicalForm::Unknown) => {
439 notes.push(
440 "Physical form not specified — the HS subheading may differ \
441 (e.g. solid vs. solution).".to_string(),
442 );
443 }
444 Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
445 notes.push(
446 "Solution concentration not specified — subheading may differ \
447 (e.g. fuming vs. standard grade).".to_string(),
448 );
449 }
450 _ => {}
451 }
452
453 if product.purity_pct.is_none() {
454 notes.push(
455 "Purity not specified — some headings require a minimum purity threshold."
456 .to_string(),
457 );
458 }
459
460 notes
461 }
462}
463
464// ─────────────────────────────────────────────────────────────────────────────
465// Pipeline integration tests
466// ─────────────────────────────────────────────────────────────────────────────
467
468#[cfg(all(test, feature = "mock"))]
469mod tests {
470 use super::*;
471 use crate::llm::MockLlmClassifier;
472 use crate::types::{SubstanceIdentifier};
473
474 /// A product with no static rule and a SMILES → triggers LLM path.
475 fn unknown_organic() -> ProductDescription {
476 ProductDescription {
477 identifier: SubstanceIdentifier {
478 cas: None,
479 smiles: Some("CC(O)=O".to_string()), // acetic acid SMILES, unknown CAS
480 iupac_name: None,
481 inchi: None,
482 inchi_key: None,
483 cid: None,
484 },
485 physical_form: None,
486 purity_pct: None,
487 purity_type: None,
488 mixture_components: None,
489 intended_use: None,
490 additional_context: None,
491 }
492 }
493
494 #[tokio::test]
495 async fn classify_with_llm_mock_returns_6_digit_code() {
496 let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
497 let product = unknown_organic();
498 let pred = pipeline.classify_with_llm(&product).await.unwrap();
499 assert_eq!(pred.hs_code.len(), 6);
500 assert!(pred.hs_code.chars().all(|c| c.is_ascii_digit()));
501 }
502
503 #[tokio::test]
504 async fn classify_with_llm_mock_chapter_29_for_smiles_acid() {
505 let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
506 let product = unknown_organic();
507 let pred = pipeline.classify_with_llm(&product).await.unwrap();
508 assert!(
509 pred.hs_code.starts_with("29"),
510 "acetic acid SMILES should yield Chapter 29, got {}",
511 pred.hs_code
512 );
513 }
514
515 #[tokio::test]
516 async fn classify_with_llm_no_client_returns_error() {
517 let pipeline = HsPipeline::new(); // no LLM attached
518 let product = unknown_organic();
519 let err = pipeline.classify_with_llm(&product).await.unwrap_err();
520 assert!(
521 matches!(err, HsPredictError::LlmNotConfigured),
522 "expected LlmNotConfigured, got {:?}",
523 err
524 );
525 }
526
527 #[tokio::test]
528 async fn classify_with_llm_skips_llm_for_high_confidence_rule() {
529 // NaOH solid → static rule, confidence = 1.0 → should NOT call LLM
530 let pipeline = HsPipeline::new()
531 .with_llm(MockLlmClassifier::with_default("999999", 0.1));
532 let product = ProductDescription {
533 identifier: SubstanceIdentifier::from_cas("1310-73-2"),
534 physical_form: Some(crate::types::PhysicalForm::Solid),
535 purity_pct: None,
536 purity_type: None,
537 mixture_components: None,
538 intended_use: None,
539 additional_context: None,
540 };
541 let pred = pipeline.classify_with_llm(&product).await.unwrap();
542 // Should be the static rule result, not the mock's "999999"
543 assert_eq!(pred.hs_code, "281511", "static rule should win over LLM");
544 }
545
546 #[tokio::test]
547 async fn classify_with_llm_invalid_code_returns_validation_error() {
548 // Mock returning an invalid code
549 struct BadMock;
550 impl crate::llm::LlmClassifier for BadMock {
551 fn classify<'a>(
552 &'a self,
553 _prompt: &'a crate::llm::LlmPrompt,
554 ) -> futures::future::BoxFuture<'a, crate::Result<crate::llm::LlmResponse>> {
555 Box::pin(async {
556 Ok(crate::llm::LlmResponse {
557 hs_code: "BAD!!".to_string(),
558 confidence: 0.5,
559 rationale: "bad".to_string(),
560 alternatives: vec![],
561 })
562 })
563 }
564 }
565 let pipeline = HsPipeline::new().with_llm(BadMock);
566 let product = unknown_organic();
567 let err = pipeline.classify_with_llm(&product).await.unwrap_err();
568 assert!(
569 matches!(err, HsPredictError::ValidationFailed { .. }),
570 "expected ValidationFailed, got {:?}",
571 err
572 );
573 }
574}