helix_core/json/
mod.rs

1#![warn(clippy::all, clippy::pedantic)]
2pub mod core;
3pub mod caption;
4pub mod metadata;
5pub mod reasoning;
6pub mod st;
7pub mod concat;
8mod hf;
9use log::info;
10// pub use xio; // xio module doesn't exist
11use anyhow::{Context, Result};
12use serde_json::Value;
13use std::{
14    io, path::{Path, PathBuf},
15    sync::Arc,
16};
17use memmap2;
18#[cfg(test)]
19mod tests {
20    pub mod e621_tests;
21    pub mod text_tests;
22}
23fn get_json_metadata(path: &Path) -> Result<Value> {
24    use safetensors::SafeTensors;
25    use memmap2::MmapOptions;
26    use std::fs::File;
27    let file = File::open(path).context("Failed to open file")?;
28    let mmap = unsafe { MmapOptions::new().map(&file).context("Failed to mmap file")? };
29    let (_header_size, metadata) = SafeTensors::read_metadata(&mmap)
30        .context("Failed to read metadata")?;
31    let metadata_json: Value = serde_json::to_value(&metadata)
32        .context("Failed to convert metadata to JSON value")?;
33    let training_metadata = metadata::extract_training_metadata(&metadata_json);
34    Ok(training_metadata)
35}
36pub async fn process_safetensors_file(path: &Path) -> Result<()> {
37    let json = get_json_metadata(path)?;
38    let pretty_json = serde_json::to_string_pretty(&json)?;
39    info!("{pretty_json}");
40    tokio::fs::write(path.with_extension("json"), pretty_json).await?;
41    Ok(())
42}
43pub async fn process_caption_file(path: &Path) -> Result<()> {
44    caption::process_file(path).await
45}
46#[must_use = "Processes a JSON file and requires handling of the result to ensure proper file processing"]
47pub async fn process_json_file<F, Fut>(file_path: &Path, processor: F) -> io::Result<()>
48where
49    F: FnOnce(Value) -> Fut + Send,
50    Fut: std::future::Future<Output = io::Result<()>> + Send,
51{
52    let content = tokio::fs::read_to_string(file_path).await?;
53    let data: Value = serde_json::from_str(&content)?;
54    processor(data).await
55}
56#[must_use = "Formats a JSON file and requires handling of the result to ensure the file is properly formatted"]
57pub async fn format_json_file(path: PathBuf) -> Result<()> {
58    info!("Processing file: {}", path.display());
59    let file_content = tokio::fs::read_to_string(path.clone())
60        .await
61        .context("Failed to read file content")?;
62    let json: Value = serde_json::from_str(&file_content)
63        .context("Failed to parse JSON")?;
64    let pretty_json = serde_json::to_string_pretty(&json)
65        .context("Failed to format JSON")?;
66    tokio::fs::write(path.clone(), pretty_json)
67        .await
68        .context("Failed to write formatted JSON")?;
69    info!("Formatted {} successfully.", path.display());
70    Ok(())
71}
72#[must_use = "Splits content into tags and sentences and the result should be checked"]
73pub fn split_content(content: &str) -> (Vec<String>, String) {
74    let split: Vec<_> = content.split("., ").collect();
75    let tags: Vec<_> = split[0].split(',').map(str::trim).map(String::from).collect();
76    let sentences = (*split.get(1).unwrap_or(&"")).to_string();
77    (tags, sentences.trim().to_string())
78}
79#[must_use = "Processes a JSON file to create a caption file and requires handling of the result to ensure proper conversion"]
80pub async fn process_json_to_caption(input_path: &Path) -> io::Result<()> {
81    if input_path.extension().and_then(|s| s.to_str()) != Some("json") {
82        return Ok(());
83    }
84    let content = tokio::fs::read_to_string(input_path).await?;
85    let json: Value = serde_json::from_str(&content)?;
86    info!("Processing JSON: {}", json);
87    let mut tags = Vec::new();
88    if let Value::Object(map) = json {
89        for (tag, prob) in map {
90            if let Value::Number(prob) = prob {
91                if let Some(prob) = prob.as_f64() {
92                    if prob >= 0.2 {
93                        tags.push((tag, prob));
94                    }
95                }
96            }
97        }
98    }
99    tags.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
100    let tags: Vec<_> = tags
101        .into_iter()
102        .map(|(tag, _)| { tag.replace('(', "\\(").replace(')', "\\)") })
103        .collect();
104    let output = tags.join(", ");
105    tokio::fs::write(input_path.with_extension("txt"), output).await?;
106    Ok(())
107}
108#[must_use = "Renames a file and requires handling of the result to ensure the file is properly renamed"]
109pub async fn rename_file_without_image_extension(path: &Path) -> io::Result<()> {
110    let file_name = path
111        .file_name()
112        .and_then(|n| n.to_str())
113        .ok_or_else(|| io::Error::new(
114            io::ErrorKind::InvalidInput,
115            "Invalid file name",
116        ))?;
117    let parts: Vec<&str> = file_name.split('.').collect();
118    if parts.len() >= 3 {
119        let mut has_image_ext = false;
120        for ext in &parts[1..parts.len() - 1] {
121            if matches!(ext.to_lowercase().as_str(), "jpg" | "jpeg" | "png") {
122                has_image_ext = true;
123                break;
124            }
125        }
126        if has_image_ext {
127            let mut new_name = String::from(parts[0]);
128            let last_ext = parts.last().unwrap();
129            new_name.push('.');
130            new_name.push_str(last_ext);
131            let parent = path.parent().unwrap_or_else(|| Path::new(""));
132            let new_path = parent.join(new_name);
133            tokio::fs::rename(path, &new_path).await?;
134            info!("Renamed {} to {}", path.display(), new_path.display());
135        }
136    }
137    Ok(())
138}
139pub async fn process_e621_json_file(
140    file_path: &Path,
141    config: Option<caption::E621Config>,
142) -> Result<()> {
143    let content = tokio::fs::read_to_string(file_path).await?;
144    let data_owned: Value = serde_json::from_str(&content)?;
145    let file_path = Arc::new(file_path.to_path_buf());
146    caption::process_e621_json_data(&data_owned, &file_path, config).await
147}
148pub use caption::{
149    caption_file_exists_and_not_empty, format_text_content, json_to_text, process_file,
150    replace_special_chars, replace_string,
151};
152
153pub use core::{
154    TrainingFormat, TrainingSample, TrainingDataset, DatasetStats,
155    BCOSample, BCODataset, DPOSample, DPODataset, PPOSample, PPODataset,
156    SFTSample, SFTDataset, DatasetQualityReport,
157    GenericJSONDataset,
158};
159
160pub use hf::{HfProcessor, HfDatasetConfig, HuggingFaceDataset};