pub mod flow;
pub mod messages;
pub mod question;
pub mod state;
use serde::{Deserialize, Serialize};
pub use question::{Answer, QAPair, Question, QuestionStep, SessionResult};
pub use state::{ClassificationState, PartialComponent};
use crate::error::{HsPredictError, Result};
use crate::session::flow::{
choice_index_to_intended_use, choice_index_to_organic_inorganic,
choice_index_to_physical_form, multi_choice_indices_to_functional_groups, next_question,
};
use crate::types::{Language, MixtureComponent, PhysicalForm, ProductDescription, SubstanceIdentifier};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationSession {
state: ClassificationState,
history: Vec<QAPair>,
current_question: Option<Question>,
current_step: Option<QuestionStep>,
language: Language,
}
impl ClassificationSession {
pub fn new() -> Self {
Self {
state: ClassificationState::default(),
history: Vec::new(),
current_question: None,
current_step: None,
language: Language::En,
}
}
pub fn new_ja() -> Self {
Self::new().with_language(Language::Ja)
}
pub fn with_language(mut self, language: Language) -> Self {
self.language = language;
self
}
pub fn start(&mut self) -> Question {
let (q, step) = next_question(&self.state, self.language)
.expect("new session should always have a first question");
self.current_question = Some(q.clone());
self.current_step = Some(step);
q
}
pub fn answer(&mut self, answer: Answer) -> Result<SessionResult> {
let question = self
.current_question
.clone()
.ok_or(HsPredictError::NoActiveQuestion)?;
self.validate_and_apply(&question, &answer)?;
self.history.push(QAPair {
question: question.clone(),
answer,
});
self.try_resolve_smiles();
match next_question(&self.state, self.language) {
Some((q, step)) => {
self.current_question = Some(q.clone());
self.current_step = Some(step);
Ok(SessionResult::NeedMoreInfo { next_question: q })
}
None => {
self.state.is_complete = true;
self.current_question = None;
self.current_step = None;
if self.state.confidence_estimate() < 0.25 {
Ok(SessionResult::RequiresLlm)
} else {
Ok(SessionResult::Ready)
}
}
}
}
pub fn to_product_description(&self) -> ProductDescription {
let mixture_components = if self.state.is_mixture == Some(true) {
Some(
self.state
.components
.iter()
.map(|c| MixtureComponent {
substance: c.identifier.clone(),
weight_fraction_pct: c.weight_fraction_pct,
volume_fraction_pct: None,
is_solvent: c.is_solvent,
})
.collect(),
)
} else {
None
};
ProductDescription {
identifier: self.state.identifier.clone(),
physical_form: self.state.physical_form.clone(),
purity_pct: self.state.purity_pct,
purity_type: None,
mixture_components,
intended_use: self.state.intended_use.clone(),
additional_context: None,
}
}
pub fn state(&self) -> &ClassificationState {
&self.state
}
pub fn history(&self) -> &[QAPair] {
&self.history
}
pub fn question_count(&self) -> usize {
self.history.len()
}
pub fn is_complete(&self) -> bool {
self.state.is_complete
}
pub fn language(&self) -> Language {
self.language
}
pub fn current_step(&self) -> Option<QuestionStep> {
self.current_step
}
fn validate_and_apply(&mut self, question: &Question, answer: &Answer) -> Result<()> {
match (question, answer) {
(Question::Text { .. }, Answer::Text(text)) => {
self.apply_identifier_input(text);
}
(Question::Text { .. }, Answer::Skip) => {
if !self.state.has_identifier()
&& self.current_step != Some(QuestionStep::ComponentIdentifier)
{
return Err(HsPredictError::MissingIdentifier);
}
}
(Question::YesNo { .. }, Answer::YesNo(val)) => {
self.apply_yes_no(*val);
}
(Question::Number { min, max, .. }, Answer::Number(val)) => {
if *val < *min || *val > *max {
return Err(HsPredictError::NumberOutOfRange {
value: *val,
min: *min,
max: *max,
});
}
self.apply_number(*val);
}
(Question::Choice { options, .. }, Answer::Choice(idx)) => {
if *idx >= options.len() {
return Err(HsPredictError::InvalidChoiceIndex {
index: *idx,
max: options.len() - 1,
});
}
self.apply_choice(*idx);
}
(Question::MultiChoice { options, .. }, Answer::MultiChoice(indices)) => {
for &idx in indices {
if idx >= options.len() {
return Err(HsPredictError::InvalidChoiceIndex {
index: idx,
max: options.len() - 1,
});
}
}
self.apply_multi_choice(indices);
}
_ => {
return Err(HsPredictError::AnswerTypeMismatch {
expected: question_kind_name(question),
got: answer.kind_name(),
});
}
}
Ok(())
}
fn apply_identifier_input(&mut self, input: &str) {
let input = input.trim();
let in_mixture = self.state.is_mixture == Some(true)
&& self.state.current_component_index < self.state.component_count.unwrap_or(0);
if in_mixture {
let idx = self.state.current_component_index;
while self.state.components.len() <= idx {
self.state.components.push(PartialComponent::default());
}
self.state.components[idx].identifier = parse_identifier(input);
} else {
self.state.identifier = parse_identifier(input);
}
}
fn apply_yes_no(&mut self, val: bool) {
match self.current_step {
Some(QuestionStep::IsMixture) => {
self.state.is_mixture = Some(val);
}
_ => {}
}
}
fn apply_number(&mut self, val: f64) {
match self.current_step {
Some(QuestionStep::ComponentCount) => {
self.state.component_count = Some(val as usize);
}
Some(QuestionStep::ComponentFraction) => {
let idx = self.state.current_component_index;
if idx < self.state.components.len() {
self.state.components[idx].weight_fraction_pct =
if val > 0.0 { Some(val) } else { None };
self.state.current_component_index += 1;
}
}
Some(QuestionStep::SolutionConcentration) => {
if let Some(PhysicalForm::Solution { concentration_pct_ww, .. }) =
&mut self.state.physical_form
{
*concentration_pct_ww = if val > 0.0 { Some(val) } else { None };
}
}
_ => {}
}
}
fn apply_choice(&mut self, idx: usize) {
match self.current_step {
Some(QuestionStep::PhysicalForm) => {
self.state.physical_form = Some(choice_index_to_physical_form(idx));
}
Some(QuestionStep::IntendedUse) => {
self.state.intended_use = Some(choice_index_to_intended_use(idx));
}
Some(QuestionStep::OrganicInorganic) => {
self.state.organic_inorganic = Some(choice_index_to_organic_inorganic(idx));
}
_ => {}
}
}
fn apply_multi_choice(&mut self, indices: &[usize]) {
self.state.detected_functional_groups = multi_choice_indices_to_functional_groups(indices);
}
fn try_resolve_smiles(&mut self) {
if self.state.identifier.smiles.is_none() {
if let Some(ref iupac) = self.state.identifier.iupac_name.clone() {
if let Some(smiles) = resolve_iupac_to_smiles(iupac) {
self.state.identifier.smiles = Some(smiles);
}
}
}
for comp in &mut self.state.components {
if comp.identifier.smiles.is_none() {
if let Some(ref iupac) = comp.identifier.iupac_name.clone() {
if let Some(smiles) = resolve_iupac_to_smiles(iupac) {
comp.identifier.smiles = Some(smiles);
}
}
}
}
}
}
impl Default for ClassificationSession {
fn default() -> Self {
Self::new()
}
}
fn parse_identifier(input: &str) -> SubstanceIdentifier {
let s = input.trim();
if is_cas_format(s) {
return SubstanceIdentifier::from_cas(s);
}
if is_inchi_key_format(s) {
return SubstanceIdentifier {
inchi_key: Some(s.to_string()),
..Default::default()
};
}
if s.starts_with("InChI=") {
return SubstanceIdentifier {
inchi: Some(s.to_string()),
..Default::default()
};
}
if !s.contains(' ')
&& s.chars()
.any(|c| matches!(c, '(' | ')' | '=' | '#' | '[' | ']' | '+' | '-'))
{
return SubstanceIdentifier::from_smiles(s);
}
SubstanceIdentifier::from_iupac_name(s)
}
fn is_cas_format(s: &str) -> bool {
let parts: Vec<&str> = s.split('-').collect();
parts.len() == 3
&& parts[0].len() >= 2
&& parts[0].chars().all(|c| c.is_ascii_digit())
&& parts[1].len() == 2
&& parts[1].chars().all(|c| c.is_ascii_digit())
&& parts[2].len() == 1
&& parts[2].chars().all(|c| c.is_ascii_digit())
}
fn is_inchi_key_format(s: &str) -> bool {
let parts: Vec<&str> = s.split('-').collect();
parts.len() == 3
&& parts[0].len() == 14
&& parts[1].len() == 10
&& parts[2].len() == 1
&& s.chars().all(|c| c.is_ascii_uppercase() || c == '-')
}
fn resolve_iupac_to_smiles(iupac_name: &str) -> Option<String> {
chem_name_resolver::resolve(iupac_name)
.ok()
.map(|r| r.smiles)
}
fn question_kind_name(q: &Question) -> &'static str {
match q {
Question::Text { .. } => "text",
Question::Choice { .. } => "choice",
Question::YesNo { .. } => "yes_no",
Question::Number { .. } => "number",
Question::MultiChoice { .. } => "multi_choice",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{IntendedUse, OrganicInorganic, PhysicalForm};
fn next_q(result: SessionResult) -> Question {
match result {
SessionResult::NeedMoreInfo { next_question } => next_question,
other => panic!("expected NeedMoreInfo, got {:?}", std::mem::discriminant(&other)),
}
}
#[test]
fn session_starts_with_identifier_question() {
let mut session = ClassificationSession::new();
let q = session.start();
assert!(matches!(q, Question::Text { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::Identifier));
}
#[test]
fn session_pure_cas_inorganic_full_flow() {
let mut session = ClassificationSession::new();
session.start();
let r = session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
assert!(matches!(next_q(r), Question::YesNo { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::IsMixture));
let r = session.answer(Answer::YesNo(false)).unwrap();
assert!(matches!(next_q(r), Question::Choice { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::PhysicalForm));
let r = session.answer(Answer::Choice(0)).unwrap();
assert!(matches!(next_q(r), Question::Choice { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::IntendedUse));
let r = session.answer(Answer::Choice(0)).unwrap();
assert!(matches!(next_q(r), Question::Choice { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::OrganicInorganic));
let r = session.answer(Answer::Choice(1)).unwrap();
assert!(matches!(r, SessionResult::Ready));
let product = session.to_product_description();
assert_eq!(product.identifier.cas.as_deref(), Some("1310-73-2"));
assert!(matches!(product.physical_form, Some(PhysicalForm::Solid)));
assert_eq!(product.intended_use, Some(IntendedUse::Industrial));
assert_eq!(session.question_count(), 5);
assert!(session.is_complete());
}
#[test]
fn session_smiles_input_skips_organic_inorganic_question() {
let mut session = ClassificationSession::new();
session.start();
let r = session.answer(Answer::Text("[Na+].[OH-]".to_string())).unwrap();
assert!(matches!(next_q(r), Question::YesNo { .. }));
let r = session.answer(Answer::YesNo(false)).unwrap();
assert!(matches!(next_q(r), Question::Choice { .. }));
let r = session.answer(Answer::Choice(3)).unwrap();
assert!(matches!(next_q(r), Question::Choice { .. }));
let r = session.answer(Answer::Choice(0)).unwrap();
assert!(matches!(r, SessionResult::Ready));
assert_eq!(session.question_count(), 4);
let product = session.to_product_description();
assert!(product.identifier.smiles.is_some());
}
#[test]
fn session_organic_cas_asks_functional_groups() {
let mut session = ClassificationSession::new();
session.start();
session.answer(Answer::Text("108-88-3".to_string())).unwrap(); session.answer(Answer::YesNo(false)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(0)).unwrap();
let q = next_q(r);
assert!(matches!(q, Question::MultiChoice { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::FunctionalGroups));
let r = session.answer(Answer::MultiChoice(vec![10])).unwrap();
assert!(matches!(r, SessionResult::Ready));
let state = session.state();
assert_eq!(state.organic_inorganic, Some(OrganicInorganic::Organic));
assert!(state.detected_functional_groups.contains(&"aromatic".to_string()));
}
#[test]
fn session_solution_asks_concentration() {
let mut session = ClassificationSession::new();
session.start();
session.answer(Answer::Text("7647-01-0".to_string())).unwrap(); session.answer(Answer::YesNo(false)).unwrap();
let r = session.answer(Answer::Choice(4)).unwrap();
let q = next_q(r);
assert!(matches!(q, Question::Number { .. }));
assert_eq!(session.current_step(), Some(QuestionStep::SolutionConcentration));
let r = session.answer(Answer::Number(35.0)).unwrap();
assert!(matches!(next_q(r), Question::Choice { .. }));
session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(1)).unwrap(); assert!(matches!(r, SessionResult::Ready));
let product = session.to_product_description();
assert_eq!(
product.physical_form,
Some(PhysicalForm::Solution {
solvent: None,
concentration_pct_ww: Some(35.0),
})
);
}
#[test]
fn session_mixture_two_components() {
let mut session = ClassificationSession::new();
session.start();
session.answer(Answer::Text("7664-93-9".to_string())).unwrap();
session.answer(Answer::YesNo(true)).unwrap();
assert_eq!(session.current_step(), Some(QuestionStep::ComponentCount));
session.answer(Answer::Number(2.0)).unwrap();
assert_eq!(session.current_step(), Some(QuestionStep::ComponentIdentifier));
session.answer(Answer::Text("7664-93-9".to_string())).unwrap();
assert_eq!(session.current_step(), Some(QuestionStep::ComponentFraction));
session.answer(Answer::Number(70.0)).unwrap();
assert_eq!(session.current_step(), Some(QuestionStep::ComponentIdentifier));
session.answer(Answer::Text("7732-18-5".to_string())).unwrap(); assert_eq!(session.current_step(), Some(QuestionStep::ComponentFraction));
let r = session.answer(Answer::Number(30.0)).unwrap();
assert!(matches!(r, SessionResult::Ready | SessionResult::RequiresLlm));
let product = session.to_product_description();
let comps = product.mixture_components.unwrap();
assert_eq!(comps.len(), 2);
assert_eq!(comps[0].substance.cas.as_deref(), Some("7664-93-9"));
assert_eq!(comps[0].weight_fraction_pct, Some(70.0));
assert_eq!(comps[1].substance.cas.as_deref(), Some("7732-18-5"));
assert_eq!(comps[1].weight_fraction_pct, Some(30.0));
}
#[test]
fn error_no_active_question_before_start() {
let mut session = ClassificationSession::new();
let err = session.answer(Answer::Text("1310-73-2".to_string())).unwrap_err();
assert!(matches!(err, HsPredictError::NoActiveQuestion));
}
#[test]
fn error_answer_type_mismatch() {
let mut session = ClassificationSession::new();
session.start(); let err = session.answer(Answer::YesNo(true)).unwrap_err();
assert!(matches!(err, HsPredictError::AnswerTypeMismatch { .. }));
}
#[test]
fn error_choice_index_out_of_range() {
let mut session = ClassificationSession::new();
session.start();
session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
session.answer(Answer::YesNo(false)).unwrap(); let err = session.answer(Answer::Choice(99)).unwrap_err();
assert!(matches!(err, HsPredictError::InvalidChoiceIndex { .. }));
}
#[test]
fn error_number_out_of_range() {
let mut session = ClassificationSession::new();
session.start();
session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
session.answer(Answer::YesNo(true)).unwrap(); let err = session.answer(Answer::Number(1.0)).unwrap_err(); assert!(matches!(err, HsPredictError::NumberOutOfRange { .. }));
}
#[test]
fn japanese_session_prompts_are_in_japanese() {
let mut session = ClassificationSession::new_ja();
let q = session.start();
assert!(q.prompt().chars().any(|c| c as u32 > 0x7F));
}
#[test]
fn japanese_session_completes_same_as_english() {
let mut session = ClassificationSession::new_ja();
session.start();
session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
session.answer(Answer::YesNo(false)).unwrap();
session.answer(Answer::Choice(0)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(1)).unwrap();
assert!(matches!(r, SessionResult::Ready));
let product = session.to_product_description();
assert_eq!(product.identifier.cas.as_deref(), Some("1310-73-2"));
}
#[test]
fn session_serializes_and_deserializes() {
let mut session = ClassificationSession::new();
session.start();
session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
let json = serde_json::to_string(&session).unwrap();
let restored: ClassificationSession = serde_json::from_str(&json).unwrap();
assert_eq!(
restored.state().identifier.cas.as_deref(),
Some("1310-73-2")
);
assert_eq!(restored.language(), Language::En);
assert_eq!(restored.current_step(), Some(QuestionStep::IsMixture));
}
}