1#![warn(clippy::all, clippy::pedantic)]
2pub mod core;
3pub mod caption;
4pub mod metadata;
5pub mod reasoning;
6pub mod st;
7pub mod concat;
8mod hf;
9use log::info;
10use anyhow::{Context, Result};
12use serde_json::Value;
13use std::{
14 io, path::{Path, PathBuf},
15 sync::Arc,
16};
17use memmap2;
18#[cfg(test)]
19mod tests {
20 pub mod e621_tests;
21 pub mod text_tests;
22}
23fn get_json_metadata(path: &Path) -> Result<Value> {
24 use safetensors::SafeTensors;
25 use memmap2::MmapOptions;
26 use std::fs::File;
27 let file = File::open(path).context("Failed to open file")?;
28 let mmap = unsafe { MmapOptions::new().map(&file).context("Failed to mmap file")? };
29 let (_header_size, metadata) = SafeTensors::read_metadata(&mmap)
30 .context("Failed to read metadata")?;
31 let metadata_json: Value = serde_json::to_value(&metadata)
32 .context("Failed to convert metadata to JSON value")?;
33 let training_metadata = metadata::extract_training_metadata(&metadata_json);
34 Ok(training_metadata)
35}
36pub async fn process_safetensors_file(path: &Path) -> Result<()> {
37 let json = get_json_metadata(path)?;
38 let pretty_json = serde_json::to_string_pretty(&json)?;
39 info!("{pretty_json}");
40 tokio::fs::write(path.with_extension("json"), pretty_json).await?;
41 Ok(())
42}
43pub async fn process_caption_file(path: &Path) -> Result<()> {
44 caption::process_file(path).await
45}
46#[must_use = "Processes a JSON file and requires handling of the result to ensure proper file processing"]
47pub async fn process_json_file<F, Fut>(file_path: &Path, processor: F) -> io::Result<()>
48where
49 F: FnOnce(Value) -> Fut + Send,
50 Fut: std::future::Future<Output = io::Result<()>> + Send,
51{
52 let content = tokio::fs::read_to_string(file_path).await?;
53 let data: Value = serde_json::from_str(&content)?;
54 processor(data).await
55}
56#[must_use = "Formats a JSON file and requires handling of the result to ensure the file is properly formatted"]
57pub async fn format_json_file(path: PathBuf) -> Result<()> {
58 info!("Processing file: {}", path.display());
59 let file_content = tokio::fs::read_to_string(path.clone())
60 .await
61 .context("Failed to read file content")?;
62 let json: Value = serde_json::from_str(&file_content)
63 .context("Failed to parse JSON")?;
64 let pretty_json = serde_json::to_string_pretty(&json)
65 .context("Failed to format JSON")?;
66 tokio::fs::write(path.clone(), pretty_json)
67 .await
68 .context("Failed to write formatted JSON")?;
69 info!("Formatted {} successfully.", path.display());
70 Ok(())
71}
72#[must_use = "Splits content into tags and sentences and the result should be checked"]
73pub fn split_content(content: &str) -> (Vec<String>, String) {
74 let split: Vec<_> = content.split("., ").collect();
75 let tags: Vec<_> = split[0].split(',').map(str::trim).map(String::from).collect();
76 let sentences = (*split.get(1).unwrap_or(&"")).to_string();
77 (tags, sentences.trim().to_string())
78}
79#[must_use = "Processes a JSON file to create a caption file and requires handling of the result to ensure proper conversion"]
80pub async fn process_json_to_caption(input_path: &Path) -> io::Result<()> {
81 if input_path.extension().and_then(|s| s.to_str()) != Some("json") {
82 return Ok(());
83 }
84 let content = tokio::fs::read_to_string(input_path).await?;
85 let json: Value = serde_json::from_str(&content)?;
86 info!("Processing JSON: {}", json);
87 let mut tags = Vec::new();
88 if let Value::Object(map) = json {
89 for (tag, prob) in map {
90 if let Value::Number(prob) = prob {
91 if let Some(prob) = prob.as_f64() {
92 if prob >= 0.2 {
93 tags.push((tag, prob));
94 }
95 }
96 }
97 }
98 }
99 tags.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
100 let tags: Vec<_> = tags
101 .into_iter()
102 .map(|(tag, _)| { tag.replace('(', "\\(").replace(')', "\\)") })
103 .collect();
104 let output = tags.join(", ");
105 tokio::fs::write(input_path.with_extension("txt"), output).await?;
106 Ok(())
107}
108#[must_use = "Renames a file and requires handling of the result to ensure the file is properly renamed"]
109pub async fn rename_file_without_image_extension(path: &Path) -> io::Result<()> {
110 let file_name = path
111 .file_name()
112 .and_then(|n| n.to_str())
113 .ok_or_else(|| io::Error::new(
114 io::ErrorKind::InvalidInput,
115 "Invalid file name",
116 ))?;
117 let parts: Vec<&str> = file_name.split('.').collect();
118 if parts.len() >= 3 {
119 let mut has_image_ext = false;
120 for ext in &parts[1..parts.len() - 1] {
121 if matches!(ext.to_lowercase().as_str(), "jpg" | "jpeg" | "png") {
122 has_image_ext = true;
123 break;
124 }
125 }
126 if has_image_ext {
127 let mut new_name = String::from(parts[0]);
128 let last_ext = parts.last().unwrap();
129 new_name.push('.');
130 new_name.push_str(last_ext);
131 let parent = path.parent().unwrap_or_else(|| Path::new(""));
132 let new_path = parent.join(new_name);
133 tokio::fs::rename(path, &new_path).await?;
134 info!("Renamed {} to {}", path.display(), new_path.display());
135 }
136 }
137 Ok(())
138}
139pub async fn process_e621_json_file(
140 file_path: &Path,
141 config: Option<caption::E621Config>,
142) -> Result<()> {
143 let content = tokio::fs::read_to_string(file_path).await?;
144 let data_owned: Value = serde_json::from_str(&content)?;
145 let file_path = Arc::new(file_path.to_path_buf());
146 caption::process_e621_json_data(&data_owned, &file_path, config).await
147}
148pub use caption::{
149 caption_file_exists_and_not_empty, format_text_content, json_to_text, process_file,
150 replace_special_chars, replace_string,
151};
152
153pub use core::{
154 TrainingFormat, TrainingSample, TrainingDataset, DatasetStats,
155 BCOSample, BCODataset, DPOSample, DPODataset, PPOSample, PPODataset,
156 SFTSample, SFTDataset, DatasetQualityReport,
157 GenericJSONDataset,
158};
159
160pub use hf::{HfProcessor, HfDatasetConfig, HuggingFaceDataset};