dset 0.1.12

A Rust library for processing and managing dataset-related files, with a focus on machine learning datasets, captions, and safetensors files
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
#![warn(clippy::all, clippy::pedantic)]

//! A library for processing and managing dataset-related files and metadata.
//!
//! This library provides functionality for:
//! - Processing safetensors files and extracting metadata
//! - Handling caption files
//! - Processing and formatting JSON files
//! - Converting between different file formats (JSON to caption)
//! - Concatenating files with different extensions
//! 
//! The library is organized into several modules:
//! - `caption`: Handles caption file processing
//! - `metadata`: Manages metadata extraction and processing
//! - `st`: SafeTensors-related functionality
//! - `reasoning`: Manages AI reasoning datasets for training
//! - `concat`: Utilities for concatenating files with different extensions

pub mod caption;
pub mod metadata;
pub mod reasoning;
pub mod st;
pub mod concat;

use log::info;
pub use xio;

// Re-export commonly used types
use anyhow::{Context, Result};
use serde_json::Value;
use std::{
    io,
    path::{Path, PathBuf},
    sync::Arc,
};
use tokio::fs;

// Include test modules
#[cfg(test)]
mod tests {
    pub mod e621_tests;
    pub mod text_tests;
}

/// Extracts and parses JSON metadata from a safetensors file.
///
/// This function reads a safetensors file, extracts its metadata, and converts it into
/// a JSON value. The metadata is processed through the `metadata::extract_training_metadata`
/// function to decode any nested JSON fields.
///
/// # Arguments
/// * `path` - Path to the safetensors file
///
/// # Returns
/// * `Result<Value>` - The parsed JSON metadata if successful
///
/// # Errors
/// Returns an error if:
/// * The file cannot be opened
/// * Memory mapping fails
/// * Metadata cannot be read from the safetensors file
/// * Metadata cannot be converted to JSON
fn get_json_metadata(path: &Path) -> Result<Value> {
    use ::safetensors::SafeTensors;
    use memmap2::MmapOptions;
    use std::fs::File;

    let file = File::open(path).context("Failed to open file")?;
    let mmap = unsafe {
        MmapOptions::new()
            .map(&file)
            .context("Failed to mmap file")?
    };
    let (_header_size, metadata) =
        SafeTensors::read_metadata(&mmap).context("Failed to read metadata")?;

    // Convert the raw metadata into a JSON value
    let metadata_json: Value =
        serde_json::to_value(&metadata).context("Failed to convert metadata to JSON value")?;

    // Use the new helper function to extract and recursively decode JSON fields
    let training_metadata = crate::metadata::extract_training_metadata(&metadata_json);

    Ok(training_metadata)
}

/// Processes a safetensors file by extracting its metadata and saving it as a JSON file.
///
/// This function:
/// 1. Extracts metadata from the safetensors file
/// 2. Pretty-prints the JSON metadata
/// 3. Saves the metadata to a new file with the same name but .json extension
///
/// # Arguments
/// * `path` - Path to the safetensors file to process
///
/// # Returns
/// * `Result<()>` - Success or failure of the operation
///
/// # Errors
/// Returns an error if:
/// * Metadata extraction fails
/// * JSON formatting fails
/// * Writing the output file fails
pub async fn process_safetensors_file(path: &Path) -> Result<()> {
    let json = get_json_metadata(path)?;
    let pretty_json = serde_json::to_string_pretty(&json)?;
    info!("{pretty_json}");
    fs::write(path.with_extension("json"), pretty_json).await?;
    Ok(())
}

/// Processes a caption file using the functionality in the caption module.
///
/// This is a wrapper function that delegates the actual processing to the
/// caption module's implementation.
///
/// # Arguments
/// * `path` - Path to the caption file to process
///
/// # Returns
/// * `Result<()>` - Success or failure of the operation
///
/// # Errors
/// Returns an error if the caption processing fails
pub async fn process_caption_file(path: &Path) -> Result<()> {
    caption::process_file(path).await
}

/// Processes a JSON file using a provided async processor function.
///
/// This function reads a JSON file, parses it, and applies a custom processor
/// function to the parsed data. The processor function can perform any desired
/// transformations or operations on the JSON data.
///
/// # Type Parameters
/// * `F` - The processor function type
/// * `Fut` - The future type returned by the processor function
///
/// # Arguments
/// * `file_path` - Path to the JSON file to process
/// * `processor` - Async function that processes the parsed JSON data
///
/// # Returns
/// * `io::Result<()>` - Success or failure of the operation
///
/// # Errors
/// Returns an error if:
/// * The file cannot be read
/// * The content cannot be parsed as JSON
/// * The processor function returns an error
#[must_use = "Processes a JSON file and requires handling of the result to ensure proper file processing"]
pub async fn process_json_file<F, Fut>(file_path: &Path, processor: F) -> io::Result<()>
where
    F: FnOnce(Value) -> Fut + Send,
    Fut: std::future::Future<Output = io::Result<()>> + Send,
{
    let content = fs::read_to_string(file_path).await?;
    let data: Value = serde_json::from_str(&content)?;
    processor(data).await
}

/// Formats a JSON file by pretty-printing its contents.
///
/// This function reads a JSON file, parses it, and writes it back with proper
/// formatting and indentation. The original file is overwritten with the
/// formatted version.
///
/// # Arguments
/// * `path` - Path to the JSON file to format
///
/// # Returns
/// * `Result<()>` - Success or failure of the operation
///
/// # Errors
/// Returns an error if:
/// * The file cannot be read
/// * The content cannot be parsed as JSON
/// * The formatted JSON cannot be written back to the file
#[must_use = "Formats a JSON file and requires handling of the result to ensure the file is properly formatted"]
pub async fn format_json_file(path: PathBuf) -> Result<()> {
    info!("Processing file: {}", path.display());

    let file_content = fs::read_to_string(path.clone())
        .await
        .context("Failed to read file content")?;
    let json: Value = serde_json::from_str(&file_content).context("Failed to parse JSON")?;
    let pretty_json = serde_json::to_string_pretty(&json).context("Failed to format JSON")?;
    fs::write(path.clone(), pretty_json)
        .await
        .context("Failed to write formatted JSON")?;

    info!("Formatted {} successfully.", path.display());
    Ok(())
}

/// Splits a content string into tags and sentences.
///
/// This function takes a string in the format "tag1, tag2, tag3., Sentence text"
/// and splits it into a vector of tags and the remaining sentence text.
///
/// # Arguments
/// * `content` - The string to split, expected to be in the format "tags., sentence"
///
/// # Returns
/// * `(Vec<String>, String)` - A tuple containing:
///   * A vector of tag strings
///   * The remaining sentence text
///
/// # Examples
/// ```
/// use dset::split_content;
///
/// let content = "tag1, tag2, tag3., This is a sentence.";
/// let (tags, sentence) = split_content(content);
/// assert_eq!(tags, vec!["tag1", "tag2", "tag3"]);
/// assert_eq!(sentence, "This is a sentence.");
/// ```
#[must_use = "Splits content into tags and sentences and the result should be checked"]
pub fn split_content(content: &str) -> (Vec<String>, String) {
    let split: Vec<_> = content.split("., ").collect();
    let tags: Vec<_> = split[0]
        .split(',')
        .map(str::trim)
        .map(String::from)
        .collect();
    let sentences = (*split.get(1).unwrap_or(&"")).to_string();
    (tags, sentences.trim().to_string())
}

/// Converts a JSON file containing tag probabilities into a caption file.
///
/// This function reads a JSON file containing tag-probability pairs, filters
/// tags based on a probability threshold (0.2), and writes the selected tags
/// to a new .txt file. Tags are sorted by probability in descending order.
///
/// # Arguments
/// * `input_path` - Path to the input JSON file
///
/// # Returns
/// * `io::Result<()>` - Success or failure of the operation
///
/// # Errors
/// Returns an error if:
/// * The input file cannot be read
/// * The content cannot be parsed as JSON
/// * The output file cannot be written
///
/// # Format
/// Input JSON should be in the format:
/// ```json
/// {
///     "tag1": 0.9,
///     "tag2": 0.5,
///     "tag3": 0.1
/// }
/// ```
#[must_use = "Processes a JSON file to create a caption file and requires handling of the result to ensure proper conversion"]
pub async fn process_json_to_caption(input_path: &Path) -> io::Result<()> {
    // Early return if not a JSON file
    if input_path.extension().and_then(|s| s.to_str()) != Some("json") {
        return Ok(());
    }

    let content = fs::read_to_string(input_path).await?;
    let json: Value = serde_json::from_str(&content)?;
    info!("Processing JSON: {}", json);

    let mut tags = Vec::new();
    if let Value::Object(map) = json {
        for (tag, prob) in map {
            if let Value::Number(prob) = prob {
                if let Some(prob) = prob.as_f64() {
                    if prob >= 0.2 {
                        tags.push((tag, prob));
                    }
                }
            }
        }
    }

    tags.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
    let tags: Vec<_> = tags
        .into_iter()
        .map(|(tag, _)| {
            // Escape special characters with backslashes
            tag.replace('(', "\\(").replace(')', "\\)")
        })
        .collect();

    let output = tags.join(", ");
    fs::write(input_path.with_extension("txt"), output).await?;
    Ok(())
}

/// Renames a file to remove any image extensions that appear between the base filename and the actual extension.
///
/// For example:
/// - `image.jpg.toml` -> `image.toml`
/// - `image.jpeg.json` -> `image.json`
/// - `image.png` -> `image.png` (unchanged)
/// - `image.png.jpg` -> `image.jpg`
///
/// This is useful for cleaning up file names in datasets where image extensions might have been
/// accidentally preserved when converting files to other formats.
///
/// # Arguments
/// * `path` - Path to the file to rename
///
/// # Returns
/// * `io::Result<()>` - Success or failure of the operation
///
/// # Errors
/// Returns an error if:
/// * The file cannot be renamed
/// * The file system operation fails
/// * The file name is invalid UTF-8
///
/// # Panics
/// This function will panic if:
/// * The file name has multiple extensions but `parts.last()` fails to get the last extension
///   (this should never happen as we check `parts.len() >= 3` before accessing)
///
/// # Example
/// ```no_run
/// use std::path::Path;
/// use dset::rename_file_without_image_extension;
///
/// async fn example() -> std::io::Result<()> {
///     let path = Path::new("image.jpg.toml");
///     rename_file_without_image_extension(&path).await?;  // Will rename to "image.toml"
///     Ok(())
/// }
/// ```
#[must_use = "Renames a file and requires handling of the result to ensure the file is properly renamed"]
pub async fn rename_file_without_image_extension(path: &Path) -> io::Result<()> {
    // Get the file stem and extension
    let file_name = path
        .file_name()
        .and_then(|n| n.to_str())
        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Invalid file name"))?;

    // Split the filename into parts
    let parts: Vec<&str> = file_name.split('.').collect();

    // Only proceed if we have at least 3 parts (name.img_ext.real_ext)
    if parts.len() >= 3 {
        // Check if any middle extension is an image extension
        let mut has_image_ext = false;
        for ext in &parts[1..parts.len() - 1] {
            if matches!(ext.to_lowercase().as_str(), "jpg" | "jpeg" | "png") {
                has_image_ext = true;
                break;
            }
        }

        if has_image_ext {
            // Reconstruct the filename without image extensions
            let mut new_name = String::from(parts[0]);
            // SAFETY: We checked parts.len() >= 3 above, so last() will never be None
            let last_ext = parts.last().unwrap();
            new_name.push('.');
            new_name.push_str(last_ext);

            // Create the new path in the same directory
            let parent = path.parent().unwrap_or_else(|| Path::new(""));
            let new_path = parent.join(new_name);

            fs::rename(path, &new_path).await?;
            info!("Renamed {} to {}", path.display(), new_path.display());
        }
    }
    Ok(())
}

/// Process an e621 JSON file and generate caption files.
///
/// # Arguments
///
/// * `file_path` - Path to the JSON file to process
/// * `config` - Optional configuration for processing. If None, uses default settings.
///
/// # Errors
///
/// This function will return an error if:
/// * The file cannot be read
/// * The file contains invalid JSON
/// * There are issues writing the caption files
/// * The JSON structure doesn't match the expected e621 format
///
/// # Returns
///
/// Returns `Ok(())` on success, or an error if any step fails.
pub async fn process_e621_json_file(
    file_path: &Path,
    config: Option<caption::E621Config>,
) -> Result<()> {
    let content = fs::read_to_string(file_path).await?;
    let data_owned: Value = serde_json::from_str(&content)?;
    let file_path = Arc::new(file_path.to_path_buf());
    caption::process_e621_json_data(&data_owned, &file_path, config).await
}

pub use caption::{
    caption_file_exists_and_not_empty, format_text_content, json_to_text, process_file,
    replace_special_chars, replace_string,
};