hs_predict/llm/mod.rs
1//! LLM-based HS code classification — **trait hook** (v0.4).
2//!
3//! `hs-predict` deliberately does **not** ship a concrete LLM API client.
4//! Instead it defines the [`LlmClassifier`] trait, which you implement with
5//! whatever HTTP transport, model, and prompt customisation your application
6//! requires.
7//!
8//! The library provides:
9//! - [`LlmPrompt`] — pre-built system + user text (EN/JA) ready to send
10//! - [`LlmResponse`] — the expected return value from your implementation
11//! - [`parse_llm_json`] — helper that strips markdown fences and deserialises
12//! the LLM's JSON reply into an [`LlmResponse`]
13//! - [`MockLlmClassifier`] — deterministic stub for unit tests (`mock` feature)
14//!
15//! Requires the **`llm`** Cargo feature.
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! # #[cfg(feature = "llm")]
21//! # mod example {
22//! use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse, parse_llm_json};
23//! use futures::future::BoxFuture;
24//!
25//! struct MyClient { api_key: String }
26//!
27//! impl LlmClassifier for MyClient {
28//! fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
29//! Box::pin(async move {
30//! // 1. Call your LLM API using prompt.system_text / prompt.user_text
31//! let raw_json: String = todo!("send HTTP request, receive text");
32//! // 2. Parse and return
33//! parse_llm_json(&raw_json)
34//! })
35//! }
36//! }
37//! # }
38//! ```
39
40pub mod mock;
41pub mod prompt;
42
43#[cfg(feature = "mock")]
44pub use mock::MockLlmClassifier;
45pub use prompt::PromptBuilder;
46
47use futures::future::BoxFuture;
48use serde::{Deserialize, Serialize};
49
50// ─────────────────────────────────────────────────────────────────────────────
51// LlmPrompt
52// ─────────────────────────────────────────────────────────────────────────────
53
54/// Input passed to [`LlmClassifier::classify`].
55///
56/// Contains pre-built prompt text as well as structured SMILES analysis
57/// for implementations that want to build a custom prompt.
58#[derive(Debug, Clone)]
59pub struct LlmPrompt {
60 /// Pre-built system prompt (role + format instructions + confidence guide).
61 pub system_text: String,
62
63 /// Pre-built user message (all product identifiers, physical description,
64 /// and SMILES functional-group hints if available).
65 pub user_text: String,
66
67 /// SMILES-based pre-classification, if a SMILES string was available.
68 /// Useful for building custom prompts or for post-call chapter validation.
69 pub smiles_analysis: Option<crate::smiles::SmilesClassification>,
70}
71
72// ─────────────────────────────────────────────────────────────────────────────
73// LlmResponse
74// ─────────────────────────────────────────────────────────────────────────────
75
76/// Response that [`LlmClassifier::classify`] must return.
77///
78/// All fields map directly to fields of [`HsPrediction`](crate::types::HsPrediction).
79///
80/// # JSON schema expected from the LLM
81/// ```json
82/// {
83/// "hs_code": "291511",
84/// "confidence": 0.85,
85/// "rationale": "Acetic acid → heading 29.15 (saturated acyclic carboxylic acid).",
86/// "alternatives": [
87/// { "hs_code": "291519", "confidence": 0.10, "reason": "If purity threshold not met." }
88/// ]
89/// }
90/// ```
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct LlmResponse {
93 /// Six-digit HS 2022 code, no punctuation (e.g. `"291511"`).
94 ///
95 /// The pipeline validates this is exactly 6 ASCII digits before accepting it.
96 pub hs_code: String,
97
98 /// Confidence score in [0.0, 1.0].
99 pub confidence: f32,
100
101 /// Natural-language rationale (1–3 sentences).
102 pub rationale: String,
103
104 /// Alternative HS codes with lower confidence. May be empty.
105 #[serde(default)]
106 pub alternatives: Vec<LlmAlternative>,
107}
108
109/// An alternative HS code suggestion returned by the LLM.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LlmAlternative {
112 /// Six-digit HS 2022 code.
113 pub hs_code: String,
114 /// Confidence for this alternative, in [0.0, 1.0].
115 pub confidence: f32,
116 /// Why this alternative applies.
117 pub reason: String,
118}
119
120// ─────────────────────────────────────────────────────────────────────────────
121// LlmClassifier trait
122// ─────────────────────────────────────────────────────────────────────────────
123
124/// Trait for LLM-based HS code classification.
125///
126/// Implement this with your preferred LLM provider (Anthropic Claude,
127/// OpenAI GPT-4o, local Ollama, …) and attach it to the pipeline via
128/// [`HsPipeline::with_llm`](crate::pipeline::HsPipeline::with_llm).
129///
130/// # Contract
131/// - Must return an [`LlmResponse`] with `hs_code` that is exactly 6 ASCII
132/// digits. The pipeline validates this and returns
133/// [`HsPredictError::ValidationFailed`](crate::HsPredictError::ValidationFailed)
134/// if the code is malformed.
135/// - `confidence` should follow the guide in [`LlmPrompt::system_text`]:
136/// ≥ 0.90 for certain sub-heading, ≥ 0.70 for certain heading.
137/// - Must be `Send + Sync` (required for `Arc<dyn LlmClassifier>`).
138///
139/// # Minimal implementation
140/// ```rust,no_run
141/// # #[cfg(feature = "llm")]
142/// # {
143/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse, parse_llm_json};
144/// use futures::future::BoxFuture;
145///
146/// struct MyClient;
147///
148/// impl LlmClassifier for MyClient {
149/// fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
150/// Box::pin(async move {
151/// let raw = String::from(r#"{"hs_code":"291511","confidence":0.85,"rationale":"...","alternatives":[]}"#);
152/// parse_llm_json(&raw)
153/// })
154/// }
155/// }
156/// # }
157/// ```
158pub trait LlmClassifier: Send + Sync {
159 /// Classify the product described in `prompt` and return an HS code prediction.
160 fn classify<'a>(
161 &'a self,
162 prompt: &'a LlmPrompt,
163 ) -> BoxFuture<'a, crate::Result<LlmResponse>>;
164}
165
166// ─────────────────────────────────────────────────────────────────────────────
167// parse_llm_json helper
168// ─────────────────────────────────────────────────────────────────────────────
169
170/// Parse a raw LLM API text response into an [`LlmResponse`].
171///
172/// Handles the most common formatting quirks LLMs exhibit:
173/// - Plain JSON
174/// - JSON wrapped in ` ```json … ``` ` markdown fences
175/// - JSON wrapped in plain ` ``` … ``` ` fences
176/// - Leading / trailing whitespace
177///
178/// # Errors
179/// Returns [`HsPredictError::LlmResponseParseError`](crate::HsPredictError::LlmResponseParseError)
180/// if the string cannot be deserialised as [`LlmResponse`].
181///
182/// # Example
183/// ```rust
184/// # #[cfg(feature = "llm")]
185/// # {
186/// use hs_predict::llm::{parse_llm_json, LlmResponse};
187///
188/// let raw = r#"```json
189/// {"hs_code":"291511","confidence":0.85,"rationale":"Acetic acid.","alternatives":[]}
190/// ```"#;
191///
192/// let r: LlmResponse = parse_llm_json(raw).unwrap();
193/// assert_eq!(r.hs_code, "291511");
194/// # }
195/// ```
196pub fn parse_llm_json(raw: &str) -> crate::Result<LlmResponse> {
197 let json_str = strip_markdown_fences(raw);
198 serde_json::from_str::<LlmResponse>(json_str.trim()).map_err(|e| {
199 crate::HsPredictError::LlmResponseParseError {
200 source: e,
201 raw: raw.to_string(),
202 }
203 })
204}
205
206/// Strip ` ```json ` or ` ``` ` fences from LLM output.
207fn strip_markdown_fences(s: &str) -> &str {
208 let s = s.trim();
209 // Try ```json first, then plain ```
210 let inner = s
211 .strip_prefix("```json")
212 .or_else(|| s.strip_prefix("```"))
213 .and_then(|s| s.strip_suffix("```"));
214 inner.map(str::trim).unwrap_or(s)
215}
216
217// ─────────────────────────────────────────────────────────────────────────────
218// Tests
219// ─────────────────────────────────────────────────────────────────────────────
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn parse_plain_json() {
227 let raw = r#"{"hs_code":"291511","confidence":0.85,"rationale":"test","alternatives":[]}"#;
228 let r = parse_llm_json(raw).unwrap();
229 assert_eq!(r.hs_code, "291511");
230 assert!((r.confidence - 0.85).abs() < 0.001);
231 }
232
233 #[test]
234 fn parse_json_with_json_fence() {
235 let raw = "```json\n{\"hs_code\":\"280511\",\"confidence\":0.9,\"rationale\":\"ok\",\"alternatives\":[]}\n```";
236 let r = parse_llm_json(raw).unwrap();
237 assert_eq!(r.hs_code, "280511");
238 }
239
240 #[test]
241 fn parse_json_with_plain_fence() {
242 let raw = "```\n{\"hs_code\":\"290900\",\"confidence\":0.6,\"rationale\":\"ether\",\"alternatives\":[]}\n```";
243 let r = parse_llm_json(raw).unwrap();
244 assert_eq!(r.hs_code, "290900");
245 }
246
247 #[test]
248 fn parse_alternatives_populated() {
249 let raw = r#"{
250 "hs_code": "291511",
251 "confidence": 0.75,
252 "rationale": "likely acetic acid",
253 "alternatives": [
254 { "hs_code": "291519", "confidence": 0.15, "reason": "other acids" }
255 ]
256 }"#;
257 let r = parse_llm_json(raw).unwrap();
258 assert_eq!(r.alternatives.len(), 1);
259 assert_eq!(r.alternatives[0].hs_code, "291519");
260 }
261
262 #[test]
263 fn parse_missing_alternatives_defaults_to_empty() {
264 let raw = r#"{"hs_code":"290900","confidence":0.7,"rationale":"ether"}"#;
265 let r = parse_llm_json(raw).unwrap();
266 assert!(r.alternatives.is_empty());
267 }
268
269 #[test]
270 fn parse_invalid_json_returns_error() {
271 let result = parse_llm_json("not json at all");
272 assert!(result.is_err());
273 }
274}