lcpfs 2026.1.102

LCP File System - A ZFS-inspired copy-on-write filesystem for Rust
// Copyright 2025 LunaOS Contributors
// SPDX-License-Identifier: Apache-2.0

//! Dictionary training algorithms.
//!
//! This module implements dictionary training by finding common substrings
//! across sample data and selecting the most valuable ones for compression.

use alloc::collections::BTreeMap;
use alloc::string::ToString;
use alloc::vec::Vec;

use super::types::{
    CompressionDict, DEFAULT_DICT_SIZE, DictError, DictResult, MAX_DICT_SIZE, MIN_DICT_SIZE,
    MIN_MATCH_LEN, SubstringEntry, TrainingOptions,
};

// ═══════════════════════════════════════════════════════════════════════════════
// DICTIONARY TRAINING
// ═══════════════════════════════════════════════════════════════════════════════

/// Train a compression dictionary from sample data.
///
/// Finds common substrings across samples and selects the most valuable
/// ones (by frequency * length) to include in the dictionary.
///
/// # Arguments
/// * `samples` - Sample data to train from
/// * `options` - Training options
///
/// # Returns
/// Raw dictionary data
pub fn train_dictionary(samples: &[&[u8]], options: &TrainingOptions) -> DictResult<Vec<u8>> {
    // Validate inputs
    if samples.len() < 2 {
        return Err(DictError::InsufficientSamples(samples.len()));
    }

    let dict_size = options.dict_size.clamp(MIN_DICT_SIZE, MAX_DICT_SIZE);
    let min_len = options.min_substring_len.max(MIN_MATCH_LEN);
    let max_len = options.max_substring_len.min(256);

    // Find common substrings across all samples
    let mut substring_counts: BTreeMap<Vec<u8>, SubstringEntry> = BTreeMap::new();

    for sample in samples {
        // Track substrings we've seen in this sample to avoid double-counting
        let mut seen_in_sample: alloc::collections::BTreeSet<Vec<u8>> =
            alloc::collections::BTreeSet::new();

        // Extract substrings of various lengths
        for len in min_len..=max_len.min(sample.len()) {
            for start in 0..=sample.len().saturating_sub(len) {
                let substr = sample[start..start + len].to_vec();

                // Skip if already seen in this sample
                if seen_in_sample.contains(&substr) {
                    continue;
                }
                seen_in_sample.insert(substr.clone());

                // Update counts
                if let Some(entry) = substring_counts.get_mut(&substr) {
                    entry.increment();
                } else {
                    substring_counts.insert(substr.clone(), SubstringEntry::new(substr));
                }
            }
        }
    }

    // Filter by minimum occurrences and sort by score
    let mut scored: Vec<(u64, Vec<u8>)> = substring_counts
        .into_iter()
        .filter(|(_, entry)| entry.count >= options.min_occurrences)
        .map(|(data, entry)| (entry.score(), data))
        .collect();

    // Sort by score (descending)
    scored.sort_by(|a, b| b.0.cmp(&a.0));

    // Build dictionary by selecting highest-scoring substrings
    let mut dict_data = Vec::with_capacity(dict_size);
    let mut included: alloc::collections::BTreeSet<Vec<u8>> = alloc::collections::BTreeSet::new();

    for (_, substr) in scored {
        // Skip if we've already included a substring that contains this one
        // or if this substring contains one we've included
        let dominated = included
            .iter()
            .any(|inc| contains_substring(&substr, inc) || contains_substring(inc, &substr));

        if dominated {
            continue;
        }

        // Check if adding this would exceed dictionary size
        if dict_data.len() + substr.len() > dict_size {
            // Try to fit what we can
            if dict_data.len() + MIN_MATCH_LEN > dict_size {
                break;
            }
            continue;
        }

        included.insert(substr.clone());
        dict_data.extend_from_slice(&substr);
    }

    Ok(dict_data)
}

/// Train dictionary with default options.
pub fn train_dictionary_default(samples: &[&[u8]]) -> DictResult<Vec<u8>> {
    train_dictionary(samples, &TrainingOptions::default())
}

/// Train a dictionary and create a CompressionDict.
pub fn train(
    id: u64,
    name: &str,
    pattern: &str,
    dataset: &str,
    samples: &[&[u8]],
    options: &TrainingOptions,
    timestamp: u64,
) -> DictResult<CompressionDict> {
    let data = train_dictionary(samples, options)?;

    // Calculate average compression ratio on samples
    let mut total_ratio = 0.0f32;
    let mut sample_count = 0u32;

    for sample in samples {
        let compressed_estimate = estimate_compressed_size(&data, sample);
        if !sample.is_empty() {
            total_ratio += compressed_estimate as f32 / sample.len() as f32;
            sample_count += 1;
        }
    }

    let avg_ratio = if sample_count > 0 {
        total_ratio / sample_count as f32
    } else {
        1.0
    };

    Ok(
        CompressionDict::new(id, name, data, pattern, dataset, timestamp)
            .with_training_info(sample_count, avg_ratio),
    )
}

// ═══════════════════════════════════════════════════════════════════════════════
// HELPER FUNCTIONS
// ═══════════════════════════════════════════════════════════════════════════════

/// Check if `haystack` contains `needle`.
fn contains_substring(haystack: &[u8], needle: &[u8]) -> bool {
    if needle.is_empty() || needle.len() > haystack.len() {
        return false;
    }

    for i in 0..=haystack.len() - needle.len() {
        if &haystack[i..i + needle.len()] == needle {
            return true;
        }
    }

    false
}

/// Estimate compressed size using dictionary.
fn estimate_compressed_size(dict: &[u8], data: &[u8]) -> usize {
    if data.is_empty() {
        return 0;
    }

    let mut pos = 0;
    let mut compressed = 0;

    while pos < data.len() {
        // Find longest match in dictionary
        let (match_offset, match_len) = find_longest_match(dict, &data[pos..]);

        if match_len >= MIN_MATCH_LEN {
            // Dictionary reference: 5 bytes
            compressed += 5;
            pos += match_len;
        } else {
            // Literal: count consecutive non-matches
            let literal_start = pos;
            while pos < data.len() {
                let (_, len) = find_longest_match(dict, &data[pos..]);
                if len >= MIN_MATCH_LEN {
                    break;
                }
                pos += 1;
            }
            let literal_len = pos - literal_start;
            // Literal: 3 bytes header + data
            compressed += 3 + literal_len;
        }
    }

    compressed
}

/// Find longest match in dictionary.
fn find_longest_match(dict: &[u8], data: &[u8]) -> (usize, usize) {
    if dict.is_empty() || data.len() < MIN_MATCH_LEN {
        return (0, 0);
    }

    let mut best_offset = 0;
    let mut best_len = 0;

    let max_check = data.len().min(256);

    for offset in 0..dict.len() {
        let remaining_dict = dict.len() - offset;
        let max_len = remaining_dict.min(max_check).min(u16::MAX as usize);

        let mut len = 0;
        while len < max_len && offset + len < dict.len() && len < data.len() {
            if dict[offset + len] != data[len] {
                break;
            }
            len += 1;
        }

        if len > best_len && len >= MIN_MATCH_LEN {
            best_offset = offset;
            best_len = len;
        }
    }

    (best_offset, best_len)
}

// ═══════════════════════════════════════════════════════════════════════════════
// TESTS
// ═══════════════════════════════════════════════════════════════════════════════

#[cfg(test)]
mod tests {
    use super::*;
    use alloc::vec;

    #[test]
    fn test_train_insufficient_samples() {
        let samples: &[&[u8]] = &[b"only one"];
        let result = train_dictionary(samples, &TrainingOptions::default());
        assert!(matches!(result, Err(DictError::InsufficientSamples(1))));
    }

    #[test]
    fn test_train_basic() {
        let samples: &[&[u8]] = &[
            b"hello world hello",
            b"hello there world",
            b"world hello world",
        ];

        let dict = train_dictionary_default(samples).unwrap();

        // Dictionary should contain common substrings
        assert!(!dict.is_empty());
    }

    #[test]
    fn test_train_with_options() {
        let samples: &[&[u8]] = &[
            b"the quick brown fox jumps over",
            b"the lazy dog jumps over the",
            b"quick brown fox the lazy dog",
        ];

        let options = TrainingOptions::default()
            .with_size(1024)
            .min_len(4)
            .min_count(2);

        let dict = train_dictionary(samples, &options).unwrap();

        // Should fit within size limit
        assert!(dict.len() <= 1024);
    }

    #[test]
    fn test_train_repeated_patterns() {
        // Create samples with very repeated patterns
        let pattern = b"ABCD".repeat(100);
        let samples: &[&[u8]] = &[&pattern, &pattern, &pattern];

        let dict = train_dictionary_default(samples).unwrap();

        // Should include the repeated pattern
        assert!(!dict.is_empty());
    }

    #[test]
    fn test_contains_substring() {
        assert!(contains_substring(b"hello world", b"world"));
        assert!(contains_substring(b"hello world", b"hello"));
        assert!(contains_substring(b"hello world", b"lo wo"));
        assert!(!contains_substring(b"hello", b"world"));
        assert!(!contains_substring(b"hi", b"hello"));
        assert!(contains_substring(b"aaa", b"aaa"));
    }

    #[test]
    fn test_find_longest_match() {
        let dict = b"hello world";

        // "hello " (with space) matches at offset 0, 6 chars
        let (offset, len) = find_longest_match(dict, b"hello there");
        assert_eq!(offset, 0);
        assert_eq!(len, 6); // "hello " matches

        // "world" at offset 6, 5 chars
        let (offset2, len2) = find_longest_match(dict, b"world hello");
        assert_eq!(offset2, 6);
        assert_eq!(len2, 5);

        let (_, len3) = find_longest_match(dict, b"xyz");
        assert_eq!(len3, 0); // No match
    }

    #[test]
    fn test_estimate_compressed_size() {
        let dict = b"hello world";
        let data = b"hello world hello";

        let estimated = estimate_compressed_size(dict, data);

        // Should be smaller than original if dict is useful
        // But with overhead it might not be much smaller
        assert!(estimated > 0);
    }

    #[test]
    fn test_train_full() {
        let samples: &[&[u8]] = &[
            b"json data with common patterns",
            b"json data with repeated fields",
        ];

        let dict = train(
            42,
            "test_dict",
            "*.json",
            "pool/data",
            samples,
            &TrainingOptions::default().with_size(1024),
            12345,
        )
        .unwrap();

        assert_eq!(dict.id, 42);
        assert_eq!(dict.name, "test_dict");
        assert_eq!(dict.pattern, "*.json");
        assert!(dict.sample_count > 0);
    }
}