use crate::model::{Atom, Message};
use crate::writer::write;
use chrono::{DateTime, Utc};
use rand::{RngExt, SeedableRng};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorpusConfig {
pub seed: u64,
pub count: usize,
pub batch_size: usize,
pub output_dir: Option<String>,
pub create_splits: bool,
pub split_ratios: Option<(f64, f64, f64)>,
}
impl Default for CorpusConfig {
fn default() -> Self {
Self {
seed: 42,
count: 100,
batch_size: 50,
output_dir: None,
create_splits: false,
split_ratios: Some((0.7, 0.15, 0.15)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateInfo {
pub path: String,
pub sha256: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfileInfo {
pub path: String,
pub sha256: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageInfo {
pub path: String,
pub sha256: String,
pub message_type: String,
pub template_index: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CorpusSplits {
pub train: Vec<String>,
pub validation: Vec<String>,
pub test: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorpusManifest {
pub version: String,
pub tool_version: String,
pub seed: u64,
pub templates: Vec<TemplateInfo>,
#[serde(default)]
pub profiles: Vec<ProfileInfo>,
pub messages: Vec<MessageInfo>,
pub generated_at: DateTime<Utc>,
#[serde(default)]
pub splits: CorpusSplits,
}
impl CorpusManifest {
pub fn new(seed: u64) -> Self {
Self {
version: "1.0.0".to_string(),
tool_version: env!("CARGO_PKG_VERSION").to_string(),
seed,
templates: Vec::new(),
profiles: Vec::new(),
messages: Vec::new(),
generated_at: Utc::now(),
splits: CorpusSplits::default(),
}
}
pub fn add_template(&mut self, path: &str, content: &str) {
let sha256 = compute_sha256(content);
self.templates.push(TemplateInfo {
path: path.to_string(),
sha256,
});
}
pub fn add_profile(&mut self, path: &str, content: &str) {
let sha256 = compute_sha256(content);
self.profiles.push(ProfileInfo {
path: path.to_string(),
sha256,
});
}
pub fn add_message(
&mut self,
path: &str,
content: &str,
message_type: &str,
template_index: usize,
) {
let sha256 = compute_sha256(content);
self.messages.push(MessageInfo {
path: path.to_string(),
sha256,
message_type: message_type.to_string(),
template_index,
});
}
pub fn to_json(&self) -> Result<String, CorpusError> {
serde_json::to_string_pretty(self)
.map_err(|e| CorpusError::SerializationError(e.to_string()))
}
pub fn from_json(json: &str) -> Result<Self, CorpusError> {
serde_json::from_str(json).map_err(|e| CorpusError::SerializationError(e.to_string()))
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn message_type_counts(&self) -> HashMap<String, usize> {
let mut counts = HashMap::new();
for msg in &self.messages {
let count = counts.entry(msg.message_type.clone()).or_insert(0usize);
*count = count.saturating_add(1);
}
counts
}
pub fn create_splits(&mut self, ratios: (f64, f64, f64)) {
let total = self.messages.len();
if total == 0 {
return;
}
let train_count = rounded_ratio_count(total, ratios.0);
let remaining_after_train = total.saturating_sub(train_count);
let val_count = rounded_ratio_count(total, ratios.1).min(remaining_after_train);
let validation_end = train_count.saturating_add(val_count);
let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
let mut indices: Vec<usize> = (0..total).collect();
for i in (1..total).rev() {
let j = rng.random_range(0..=i);
indices.swap(i, j);
}
self.splits.train = indices
.get(..train_count)
.unwrap_or_default()
.iter()
.filter_map(|&i| self.messages.get(i).map(|message| message.path.clone()))
.collect();
self.splits.validation = indices
.get(train_count..validation_end)
.unwrap_or_default()
.iter()
.filter_map(|&i| self.messages.get(i).map(|message| message.path.clone()))
.collect();
self.splits.test = indices
.get(validation_end..)
.unwrap_or_default()
.iter()
.filter_map(|&i| self.messages.get(i).map(|message| message.path.clone()))
.collect();
}
}
#[expect(
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
reason = "split ratios are configured as f64 percentages by the public API"
)]
fn rounded_ratio_count(total: usize, ratio: f64) -> usize {
if !ratio.is_finite() || ratio <= 0.0 {
return 0;
}
let total_f64 = total as f64;
let rounded = (total_f64 * ratio).round();
if rounded <= 0.0 {
0
} else if rounded >= total_f64 {
total
} else {
rounded as usize
}
}
pub fn compute_sha256(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
let hash_result = hasher.finalize();
format!("{hash_result:x}")
}
pub fn compute_message_hash(message: &Message) -> String {
let message_bytes = write(message);
let message_string = String::from_utf8_lossy(&message_bytes);
compute_sha256(&message_string)
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum CorpusError {
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("IO error: {0}")]
IoError(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Invalid split ratios: must sum to 1.0")]
InvalidSplitRatios,
}
pub fn extract_message_type(message: &Message) -> String {
for segment in &message.segments {
if &segment.id == b"MSH" {
if let Some(field) = segment.fields.get(7)
&& let Some(rep) = field.reps.first()
&& !rep.comps.is_empty()
{
let parts: Vec<String> = rep
.comps
.iter()
.filter_map(|c| match c.subs.first() {
Some(Atom::Text(t)) => Some(t.clone()),
_ => None,
})
.collect();
return parts.join("^");
}
}
}
"UNKNOWN".to_string()
}