Skip to main content

hs_predict/llm/
mod.rs

1//! LLM-based HS code classification — **trait hook** (v0.4).
2//!
3//! `hs-predict` deliberately does **not** ship a concrete LLM API client.
4//! Instead it defines the [`LlmClassifier`] trait, which you implement with
5//! whatever HTTP transport, model, and prompt customisation your application
6//! requires.
7//!
8//! The library provides:
9//! - [`LlmPrompt`] — pre-built system + user text (EN/JA) ready to send
10//! - [`LlmResponse`] — the expected return value from your implementation
11//! - [`parse_llm_json`] — helper that strips markdown fences and deserialises
12//!   the LLM's JSON reply into an [`LlmResponse`]
13//! - [`MockLlmClassifier`] — deterministic stub for unit tests (`mock` feature)
14//!
15//! Requires the **`llm`** Cargo feature.
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! # #[cfg(feature = "llm")]
21//! # mod example {
22//! use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse, parse_llm_json};
23//! use futures::future::BoxFuture;
24//!
25//! struct MyClient { api_key: String }
26//!
27//! impl LlmClassifier for MyClient {
28//!     fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
29//!         Box::pin(async move {
30//!             // 1. Call your LLM API using prompt.system_text / prompt.user_text
31//!             let raw_json: String = todo!("send HTTP request, receive text");
32//!             // 2. Parse and return
33//!             parse_llm_json(&raw_json)
34//!         })
35//!     }
36//! }
37//! # }
38//! ```
39
40pub mod mock;
41pub mod prompt;
42
43#[cfg(feature = "mock")]
44pub use mock::MockLlmClassifier;
45pub use prompt::PromptBuilder;
46
47use futures::future::BoxFuture;
48use serde::{Deserialize, Serialize};
49
50// ─────────────────────────────────────────────────────────────────────────────
51// LlmPrompt
52// ─────────────────────────────────────────────────────────────────────────────
53
54/// Input passed to [`LlmClassifier::classify`].
55///
56/// Contains pre-built prompt text as well as structured SMILES analysis
57/// for implementations that want to build a custom prompt.
58#[derive(Debug, Clone)]
59pub struct LlmPrompt {
60    /// Pre-built system prompt (role + format instructions + confidence guide).
61    pub system_text: String,
62
63    /// Pre-built user message (all product identifiers, physical description,
64    /// and SMILES functional-group hints if available).
65    pub user_text: String,
66
67    /// SMILES-based pre-classification, if a SMILES string was available.
68    /// Useful for building custom prompts or for post-call chapter validation.
69    pub smiles_analysis: Option<crate::smiles::SmilesClassification>,
70}
71
72// ─────────────────────────────────────────────────────────────────────────────
73// LlmResponse
74// ─────────────────────────────────────────────────────────────────────────────
75
76/// Response that [`LlmClassifier::classify`] must return.
77///
78/// All fields map directly to fields of [`HsPrediction`](crate::types::HsPrediction).
79///
80/// # JSON schema expected from the LLM
81/// ```json
82/// {
83///   "hs_code":    "291511",
84///   "confidence": 0.85,
85///   "rationale":  "Acetic acid → heading 29.15 (saturated acyclic carboxylic acid).",
86///   "alternatives": [
87///     { "hs_code": "291519", "confidence": 0.10, "reason": "If purity threshold not met." }
88///   ]
89/// }
90/// ```
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct LlmResponse {
93    /// Six-digit HS 2022 code, no punctuation (e.g. `"291511"`).
94    ///
95    /// The pipeline validates this is exactly 6 ASCII digits before accepting it.
96    pub hs_code: String,
97
98    /// Confidence score in [0.0, 1.0].
99    pub confidence: f32,
100
101    /// Natural-language rationale (1–3 sentences).
102    pub rationale: String,
103
104    /// Alternative HS codes with lower confidence. May be empty.
105    #[serde(default)]
106    pub alternatives: Vec<LlmAlternative>,
107}
108
109/// An alternative HS code suggestion returned by the LLM.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LlmAlternative {
112    /// Six-digit HS 2022 code.
113    pub hs_code: String,
114    /// Confidence for this alternative, in [0.0, 1.0].
115    pub confidence: f32,
116    /// Why this alternative applies.
117    pub reason: String,
118}
119
120// ─────────────────────────────────────────────────────────────────────────────
121// LlmClassifier trait
122// ─────────────────────────────────────────────────────────────────────────────
123
124/// Trait for LLM-based HS code classification.
125///
126/// Implement this with your preferred LLM provider (Anthropic Claude,
127/// OpenAI GPT-4o, local Ollama, …) and attach it to the pipeline via
128/// [`HsPipeline::with_llm`](crate::pipeline::HsPipeline::with_llm).
129///
130/// # Contract
131/// - Must return an [`LlmResponse`] with `hs_code` that is exactly 6 ASCII
132///   digits. The pipeline validates this and returns
133///   [`HsPredictError::ValidationFailed`](crate::HsPredictError::ValidationFailed)
134///   if the code is malformed.
135/// - `confidence` should follow the guide in [`LlmPrompt::system_text`]:
136///   ≥ 0.90 for certain sub-heading, ≥ 0.70 for certain heading.
137/// - Must be `Send + Sync` (required for `Arc<dyn LlmClassifier>`).
138///
139/// # Minimal implementation
140/// ```rust,no_run
141/// # #[cfg(feature = "llm")]
142/// # {
143/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse, parse_llm_json};
144/// use futures::future::BoxFuture;
145///
146/// struct MyClient;
147///
148/// impl LlmClassifier for MyClient {
149///     fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
150///         Box::pin(async move {
151///             let raw = String::from(r#"{"hs_code":"291511","confidence":0.85,"rationale":"...","alternatives":[]}"#);
152///             parse_llm_json(&raw)
153///         })
154///     }
155/// }
156/// # }
157/// ```
158pub trait LlmClassifier: Send + Sync {
159    /// Classify the product described in `prompt` and return an HS code prediction.
160    fn classify<'a>(
161        &'a self,
162        prompt: &'a LlmPrompt,
163    ) -> BoxFuture<'a, crate::Result<LlmResponse>>;
164}
165
166// ─────────────────────────────────────────────────────────────────────────────
167// parse_llm_json helper
168// ─────────────────────────────────────────────────────────────────────────────
169
170/// Parse a raw LLM API text response into an [`LlmResponse`].
171///
172/// Handles the most common formatting quirks LLMs exhibit:
173/// - Plain JSON
174/// - JSON wrapped in ` ```json … ``` ` markdown fences
175/// - JSON wrapped in plain ` ``` … ``` ` fences
176/// - Leading / trailing whitespace
177///
178/// # Errors
179/// Returns [`HsPredictError::LlmResponseParseError`](crate::HsPredictError::LlmResponseParseError)
180/// if the string cannot be deserialised as [`LlmResponse`].
181///
182/// # Example
183/// ```rust
184/// # #[cfg(feature = "llm")]
185/// # {
186/// use hs_predict::llm::{parse_llm_json, LlmResponse};
187///
188/// let raw = r#"```json
189/// {"hs_code":"291511","confidence":0.85,"rationale":"Acetic acid.","alternatives":[]}
190/// ```"#;
191///
192/// let r: LlmResponse = parse_llm_json(raw).unwrap();
193/// assert_eq!(r.hs_code, "291511");
194/// # }
195/// ```
196pub fn parse_llm_json(raw: &str) -> crate::Result<LlmResponse> {
197    let json_str = strip_markdown_fences(raw);
198    serde_json::from_str::<LlmResponse>(json_str.trim()).map_err(|e| {
199        crate::HsPredictError::LlmResponseParseError {
200            source: e,
201            raw: raw.to_string(),
202        }
203    })
204}
205
206/// Strip ` ```json ` or ` ``` ` fences from LLM output.
207fn strip_markdown_fences(s: &str) -> &str {
208    let s = s.trim();
209    // Try ```json first, then plain ```
210    let inner = s
211        .strip_prefix("```json")
212        .or_else(|| s.strip_prefix("```"))
213        .and_then(|s| s.strip_suffix("```"));
214    inner.map(str::trim).unwrap_or(s)
215}
216
217// ─────────────────────────────────────────────────────────────────────────────
218// Tests
219// ─────────────────────────────────────────────────────────────────────────────
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn parse_plain_json() {
227        let raw = r#"{"hs_code":"291511","confidence":0.85,"rationale":"test","alternatives":[]}"#;
228        let r = parse_llm_json(raw).unwrap();
229        assert_eq!(r.hs_code, "291511");
230        assert!((r.confidence - 0.85).abs() < 0.001);
231    }
232
233    #[test]
234    fn parse_json_with_json_fence() {
235        let raw = "```json\n{\"hs_code\":\"280511\",\"confidence\":0.9,\"rationale\":\"ok\",\"alternatives\":[]}\n```";
236        let r = parse_llm_json(raw).unwrap();
237        assert_eq!(r.hs_code, "280511");
238    }
239
240    #[test]
241    fn parse_json_with_plain_fence() {
242        let raw = "```\n{\"hs_code\":\"290900\",\"confidence\":0.6,\"rationale\":\"ether\",\"alternatives\":[]}\n```";
243        let r = parse_llm_json(raw).unwrap();
244        assert_eq!(r.hs_code, "290900");
245    }
246
247    #[test]
248    fn parse_alternatives_populated() {
249        let raw = r#"{
250            "hs_code": "291511",
251            "confidence": 0.75,
252            "rationale": "likely acetic acid",
253            "alternatives": [
254                { "hs_code": "291519", "confidence": 0.15, "reason": "other acids" }
255            ]
256        }"#;
257        let r = parse_llm_json(raw).unwrap();
258        assert_eq!(r.alternatives.len(), 1);
259        assert_eq!(r.alternatives[0].hs_code, "291519");
260    }
261
262    #[test]
263    fn parse_missing_alternatives_defaults_to_empty() {
264        let raw = r#"{"hs_code":"290900","confidence":0.7,"rationale":"ether"}"#;
265        let r = parse_llm_json(raw).unwrap();
266        assert!(r.alternatives.is_empty());
267    }
268
269    #[test]
270    fn parse_invalid_json_returns_error() {
271        let result = parse_llm_json("not json at all");
272        assert!(result.is_err());
273    }
274}