use std::collections::HashMap;
use crate::error::{HsPredictError, Result};
use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
use crate::rules::matcher::find_best_rule;
use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub confidence_threshold_direct: f32,
pub confidence_threshold_llm_required: f32,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
confidence_threshold_direct: 0.85,
confidence_threshold_llm_required: 0.50,
}
}
}
#[derive(Debug, Default)]
pub struct HsPipeline {
user_mappings: HashMap<String, String>,
config: PipelineConfig,
#[cfg(feature = "pubchem")]
pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
}
impl HsPipeline {
pub fn new() -> Self {
Self::default()
}
pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
self.user_mappings.insert(cas.into(), hs_code.into());
self
}
pub fn with_config(mut self, config: PipelineConfig) -> Self {
self.config = config;
self
}
#[cfg(feature = "pubchem")]
pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
self.pubchem = Some(std::sync::Arc::new(client));
self
}
#[cfg(feature = "pubchem")]
pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
let Some(ref client) = self.pubchem else {
return Ok(());
};
client.enrich(&mut product.identifier).await?;
if let Some(ref mut comps) = product.mixture_components {
for comp in comps.iter_mut() {
client.enrich(&mut comp.substance).await?;
}
}
Ok(())
}
pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
if let Some(ref cas) = product.identifier.cas {
if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
let jp = find_jp_rule(hs_code);
return Ok(HsPrediction {
hs_code: hs_code.clone(),
heading_description: String::new(),
confidence: 1.0,
source: PredictionSource::UserMapping,
notes: vec!["From user-provided mapping".to_string()],
alternatives: vec![],
recommended_action: RecommendedAction::Accept,
jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
});
}
}
if let Some(ref cas) = product.identifier.cas {
if let Some(rule) = find_best_rule(
cas,
product.physical_form.as_ref(),
product.purity_pct,
) {
let action = self.recommended_action(rule.confidence);
let jp = find_jp_rule(rule.hs_code);
return Ok(HsPrediction {
hs_code: rule.hs_code.to_string(),
heading_description: rule.heading_description.to_string(),
confidence: rule.confidence,
source: PredictionSource::EmbeddedRule {
rule_id: format!("{}:{}", rule.cas, rule.hs_code),
},
notes: self.build_notes(product),
alternatives: vec![],
recommended_action: action,
jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
});
}
}
if let Some(ref smiles) = product.identifier.smiles {
if let Some(classification) = crate::smiles::classify_smiles(smiles) {
let hint = &classification.heading_hint;
if let Some(heading) = hint.heading {
if hint.confidence >= self.config.confidence_threshold_llm_required {
let hs_code = format!("{:04}00", heading);
let jp = find_jp_rule(&hs_code);
let action = self.recommended_action(hint.confidence);
let mut notes = self.build_notes(product);
notes.push(
"Heading is derived from SMILES functional-group analysis. \
Sub-heading (last two digits) is a placeholder — \
verify the exact 6-digit code with the product specification."
.to_string(),
);
let matched_rules: Vec<String> = classification
.functional_groups
.iter()
.map(|g| g.label().to_string())
.collect();
return Ok(HsPrediction {
hs_code,
heading_description: hint.rationale.to_string(),
confidence: hint.confidence,
source: PredictionSource::RuleEngine { matched_rules },
notes,
alternatives: vec![],
recommended_action: action,
jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
});
}
}
}
}
Err(HsPredictError::LowConfidenceNoLlm {
confidence: 0.0,
threshold: self.config.confidence_threshold_llm_required,
})
}
fn recommended_action(&self, confidence: f32) -> RecommendedAction {
if confidence >= self.config.confidence_threshold_direct {
RecommendedAction::Accept
} else if confidence >= self.config.confidence_threshold_llm_required {
RecommendedAction::VerifyWithLlm
} else {
RecommendedAction::ExpertReview
}
}
fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
let mut notes = Vec::new();
match &product.physical_form {
None | Some(PhysicalForm::Unknown) => {
notes.push(
"Physical form not specified — the HS subheading may differ \
(e.g. solid vs. solution).".to_string(),
);
}
Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
notes.push(
"Solution concentration not specified — subheading may differ \
(e.g. fuming vs. standard grade).".to_string(),
);
}
_ => {}
}
if product.purity_pct.is_none() {
notes.push(
"Purity not specified — some headings require a minimum purity threshold."
.to_string(),
);
}
notes
}
}