Skip to main content

entrenar/config/train/loader/
data.rs

1/// Resolve a model path, downloading from HuggingFace Hub if it's a repo ID.
2///
3/// If `model_path` looks like a HF repo ID (e.g., "Qwen/Qwen2.5-Coder-0.5B"),
4/// downloads the model to the HF cache and returns the resolved local path.
5/// Otherwise, returns the original path unchanged.
6#[cfg(feature = "hub-publish")]
7fn resolve_model_path(model_path: &Path) -> Result<PathBuf> {
8    use crate::config::schema::is_hf_repo_id;
9    use crate::hf_pipeline::{FetchOptions, HfModelFetcher};
10
11    let path_str = model_path.to_string_lossy();
12    if !is_hf_repo_id(&path_str) {
13        return Ok(model_path.to_path_buf());
14    }
15
16    println!("Downloading {path_str} from HuggingFace Hub...");
17    let fetcher = HfModelFetcher::new()
18        .map_err(|e| Error::ConfigError(format!("HF fetcher initialization: {e}")))?;
19
20    let artifact = fetcher
21        .download_model(&path_str, FetchOptions::new())
22        .map_err(|e| Error::ConfigError(format!("Model download failed: {e}")))?;
23
24    println!("  Cached at: {}", artifact.path.display());
25    Ok(artifact.path)
26}
27
28#[cfg(not(feature = "hub-publish"))]
29fn resolve_model_path(model_path: &Path) -> Result<PathBuf> {
30    use crate::config::schema::is_hf_repo_id;
31
32    let path_str = model_path.to_string_lossy();
33    if is_hf_repo_id(&path_str) {
34        return Err(Error::ConfigError(format!(
35            "HF model ID '{path_str}' requires the 'hub-publish' feature. \
36             Rebuild with: cargo install entrenar --features hub-publish"
37        )));
38    }
39    Ok(model_path.to_path_buf())
40}
41
42/// ALB-096: Load transformer model from APR or SafeTensors weights.
43///
44/// Tries APR first (sovereign format), then falls back to SafeTensors.
45/// Returns `(model, checkpoint_step)` where checkpoint_step is extracted
46/// from APR metadata or parsed from checkpoint filename.
47fn load_transformer_model(
48    model_path: &Path,
49    config: &TransformerConfig,
50    output_dir: &Path,
51) -> Result<(Option<Transformer>, usize)> {
52    // ALB-097: Check output_dir first for checkpoint resume (APR, then SafeTensors)
53    if output_dir.is_dir() {
54        // ENT-282: Check if checkpoint is a delta (NF4+QLoRA — no frozen base weights).
55        // If delta, load base model from model_path first, then overlay delta tensors.
56        if let Some(result) = try_load_apr_delta(output_dir, config, model_path) {
57            return Ok(result);
58        }
59        // Try full APR checkpoint
60        if let Some(result) = try_load_apr(output_dir, config) {
61            return Ok(result);
62        }
63        // Try SafeTensors from output_dir (backward compat with pre-APR checkpoints)
64        if let Some(result) = try_load_safetensors_dir(output_dir, config) {
65            return Ok(result);
66        }
67    }
68
69    if !model_path.exists() {
70        println!("  Model path not found, using random initialization");
71        return Ok((None, 0));
72    }
73
74    println!("Loading model weights from {}...", model_path.display());
75
76    // ALB-117: When loading from model_path (initial weights, NOT resume from output_dir),
77    // always return step=0. The checkpoint step from the source model is irrelevant —
78    // we're starting fresh training with pre-trained weights, not resuming a training run.
79    // Without this, loading model-step-14500.apr as initial weights would set step=14500,
80    // causing immediate exit when max_steps < 14500 (loss=0.0, no training executed).
81
82    // ALB-096: Try APR format from model_path (direct .apr file or HF download)
83    if let Some((model, _source_step)) = try_load_apr(model_path, config) {
84        return Ok((model, 0));
85    }
86
87    // Fallback: SafeTensors from model_path
88    if let Some((model, _source_step)) = try_load_safetensors_dir(model_path, config) {
89        return Ok((model, 0));
90    }
91
92    eprintln!("Warning: No loadable checkpoint found, using random initialization");
93    Ok((None, 0))
94}
95
96/// Try loading SafeTensors checkpoint from a directory.
97/// Returns `Some((model, step))` if successful, `None` to fall back.
98fn try_load_safetensors_dir(
99    dir: &Path,
100    config: &TransformerConfig,
101) -> Option<(Option<Transformer>, usize)> {
102    let checkpoint_step = detect_checkpoint_step(dir);
103
104    match load_safetensors_weights(dir, Architecture::Auto) {
105        Ok(weights) => {
106            println!("  Found {} weight tensors (SafeTensors)", weights.len());
107            if let Some(transformer) = Transformer::from_params(config, &weights) {
108                let embed = &transformer.embed_tokens.weight;
109                let embed_data = embed.data();
110                let embed_slice = embed_data.as_slice().unwrap_or(&[]);
111                let (emin, emax, emean) = if embed_slice.is_empty() {
112                    (0.0, 0.0, 0.0)
113                } else {
114                    let min = embed_slice.iter().copied().fold(f32::INFINITY, f32::min);
115                    let max = embed_slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
116                    let mean = embed_slice.iter().sum::<f32>() / embed_slice.len() as f32;
117                    (min, max, mean)
118                };
119                println!("✓ Loaded pre-trained weights successfully (SafeTensors)");
120                println!("  embed_tokens stats: min={emin:.4e} max={emax:.4e} mean={emean:.4e}");
121                return Some((Some(transformer), checkpoint_step));
122            }
123            None
124        }
125        Err(_) => None,
126    }
127}
128
129/// ALB-096: Try to load a model from APR format.
130///
131/// Looks for `.apr` files: direct file, `model-best.apr`, or latest `model-step-*.apr`.
132/// Returns `Some((model, step))` if successful, `None` to fall back to SafeTensors.
133///
134/// Public variant for use by `CudaTransformerTrainer::for_inference` (ALB-089).
135pub fn try_load_apr_for_inference(
136    model_path: &Path,
137    config: &TransformerConfig,
138) -> Option<(Option<Transformer>, usize)> {
139    try_load_apr(model_path, config)
140}
141
142/// ENT-282: Load a delta checkpoint (NF4+QLoRA lazy save).
143///
144/// Delta checkpoints only contain trainable/updated weights (norms, embed, lm_head, LoRA)
145/// and skip frozen NF4 base weights (~15 GB). On resume, base weights are loaded from the
146/// original model_path first, then delta tensors are overlaid.
147///
148/// Returns None if the checkpoint is not a delta (falls through to full-checkpoint load).
149fn try_load_apr_delta(
150    output_dir: &Path,
151    config: &TransformerConfig,
152    base_model_path: &Path,
153) -> Option<(Option<Transformer>, usize)> {
154    use aprender::serialization::apr::AprReader;
155
156    let apr_path = find_latest_apr_checkpoint(output_dir)?;
157    let reader = AprReader::open(&apr_path).ok()?;
158
159    // Only handle delta checkpoints
160    let format = reader.get_metadata("format").and_then(|v| v.as_str().map(String::from))?;
161    if format != "entrenar-delta-checkpoint" {
162        return None; // Not a delta — fall through to full load
163    }
164
165    let checkpoint_step = reader
166        .get_metadata("checkpoint_step")
167        .and_then(|v| v.as_str())
168        .and_then(|s| s.parse::<usize>().ok())
169        .unwrap_or(0);
170
171    println!(
172        "  Delta checkpoint at step {checkpoint_step} (loading base from {})",
173        base_model_path.display()
174    );
175
176    // Step 1: Load full base model from original pretrained weights
177    let (base_model, _) = try_load_apr(base_model_path, config)
178        .or_else(|| try_load_safetensors_dir(base_model_path, config))?;
179    let mut transformer = base_model?;
180
181    // Step 2: Overlay delta tensors (norms, embed, lm_head)
182    let mut overlaid = 0usize;
183    for desc in &reader.tensors {
184        let name = &desc.name;
185        if name.starts_with("__training__") || name.starts_with("lora.") {
186            continue; // Handled separately by restore_lora_from_apr / load_optimizer_state_apr
187        }
188        if let Ok(data) = reader.read_tensor_as_f32(name) {
189            let tensor = crate::Tensor::from_vec(data, false);
190            if transformer.set_named_parameter(name, tensor) {
191                overlaid += 1;
192            }
193        }
194    }
195    println!("  ✓ Delta: {overlaid} tensors overlaid on base model");
196
197    Some((Some(transformer), checkpoint_step))
198}
199
200fn try_load_apr(
201    model_path: &Path,
202    config: &TransformerConfig,
203) -> Option<(Option<Transformer>, usize)> {
204    use aprender::serialization::apr::AprReader;
205    use std::collections::HashMap;
206
207    // Determine which APR file to load
208    let apr_path =
209        if model_path.is_file() && model_path.extension().and_then(|e| e.to_str()) == Some("apr") {
210            model_path.to_path_buf()
211        } else if model_path.is_dir() {
212            find_latest_apr_checkpoint(model_path)?
213        } else {
214            return None;
215        };
216
217    let reader = match AprReader::open(&apr_path) {
218        Ok(r) => r,
219        Err(_) => return None,
220    };
221
222    // Extract checkpoint step from APR metadata
223    let checkpoint_step = reader
224        .get_metadata("checkpoint_step")
225        .and_then(|v| v.as_str())
226        .and_then(|s| s.parse::<usize>().ok())
227        .unwrap_or_else(|| {
228            apr_path
229                .file_name()
230                .and_then(|n| n.to_str())
231                .and_then(parse_checkpoint_step)
232                .unwrap_or(0)
233        });
234
235    // Load weight tensors (skip __training__.* namespace)
236    // Detect GGUF tensor names and map to HF convention (PMAT-489)
237    let mut weights = HashMap::new();
238    let is_gguf_names = reader.tensors.iter().any(|t| t.name == "token_embd.weight");
239    if is_gguf_names {
240        eprintln!("[PMAT-489] Detected GGUF tensor names in APR file, mapping to HF convention");
241    }
242    for desc in &reader.tensors {
243        let tensor_name = &desc.name;
244        if tensor_name.starts_with("__training__") {
245            continue;
246        }
247        match reader.read_tensor_as_f32(tensor_name) {
248            Ok(data) => {
249                let mapped_name = if is_gguf_names {
250                    use crate::transformer::weights::{Architecture, mapping::map_weight_name};
251                    map_weight_name(tensor_name, Architecture::Gguf)
252                } else {
253                    tensor_name.clone()
254                };
255                weights.insert(mapped_name, crate::Tensor::from_vec(data, false));
256            }
257            Err(e) => {
258                eprintln!("Warning: Failed to read APR tensor '{tensor_name}': {e}");
259                return None;
260            }
261        }
262    }
263
264    println!("  Found {} weight tensors (APR)", weights.len());
265
266    let transformer = Transformer::from_params(config, &weights)?;
267
268    let embed = &transformer.embed_tokens.weight;
269    let embed_data = embed.data();
270    let embed_slice = embed_data.as_slice().unwrap_or(&[]);
271    let (emin, emax, emean) = if embed_slice.is_empty() {
272        (0.0, 0.0, 0.0)
273    } else {
274        let min = embed_slice.iter().copied().fold(f32::INFINITY, f32::min);
275        let max = embed_slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
276        let mean = embed_slice.iter().sum::<f32>() / embed_slice.len() as f32;
277        (min, max, mean)
278    };
279    println!("✓ Loaded pre-trained weights successfully (APR)");
280    // ALB-117: Don't print "Resuming from step" here — the caller decides whether
281    // this is a genuine resume (output_dir) or fresh training (model_path).
282    // The caller at set_initial_step prints the resume message when appropriate.
283    println!("  embed_tokens stats: min={emin:.4e} max={emax:.4e} mean={emean:.4e}");
284
285    Some((Some(transformer), checkpoint_step))
286}
287
288/// Find the latest APR checkpoint in a directory.
289///
290/// Priority: latest `model-step-N.apr` by step number. Falls back to `model-best.apr`.
291fn find_latest_apr_checkpoint(dir: &Path) -> Option<std::path::PathBuf> {
292    let mut best_step = 0usize;
293    let mut best_path = None;
294
295    if let Ok(entries) = std::fs::read_dir(dir) {
296        for entry in entries.flatten() {
297            let path = entry.path();
298            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
299                if let Some(step) = parse_checkpoint_step(name) {
300                    if step >= best_step {
301                        best_step = step;
302                        best_path = Some(path);
303                    }
304                }
305            }
306        }
307    }
308
309    if best_path.is_some() {
310        return best_path;
311    }
312
313    // Fallback: model-best.apr, then model.apr (final checkpoint)
314    let best = dir.join("model-best.apr");
315    if best.exists() {
316        return Some(best);
317    }
318    let model = dir.join("model.apr");
319    if model.exists() {
320        return Some(model);
321    }
322
323    None
324}
325
326/// Detect the step number from the latest checkpoint in a directory.
327///
328/// Checks both APR (`.apr`) and legacy SafeTensors (`.safetensors`) checkpoint files.
329fn detect_checkpoint_step(model_path: &Path) -> usize {
330    use crate::transformer::weights::parse_checkpoint_step_from_path;
331
332    if model_path.is_file() {
333        if let Some(name) = model_path.file_name().and_then(|n| n.to_str()) {
334            if let Some(step) = parse_checkpoint_step(name) {
335                return step;
336            }
337        }
338        return parse_checkpoint_step_from_path(model_path).unwrap_or(0);
339    }
340    if !model_path.is_dir() {
341        return 0;
342    }
343    // Check for model-step-*.{apr,safetensors} files
344    let Ok(entries) = std::fs::read_dir(model_path) else { return 0 };
345    let mut max_step = 0usize;
346    for entry in entries.flatten() {
347        if let Some(name) = entry.file_name().to_str() {
348            if let Some(step) = parse_checkpoint_step(name) {
349                max_step = max_step.max(step);
350            }
351        }
352        if let Some(step) = parse_checkpoint_step_from_path(&entry.path()) {
353            max_step = max_step.max(step);
354        }
355    }
356    max_step
357}
358
359/// Apply architecture overrides to a `TransformerConfig`.
360///
361/// Only `Some` fields in the overrides replace the corresponding base config field.
362fn apply_architecture_overrides(
363    config: &mut TransformerConfig,
364    overrides: &crate::config::ArchitectureOverrides,
365) {
366    if let Some(v) = overrides.hidden_size {
367        config.hidden_size = v;
368    }
369    if let Some(v) = overrides.num_hidden_layers {
370        config.num_hidden_layers = v;
371    }
372    if let Some(v) = overrides.num_attention_heads {
373        config.num_attention_heads = v;
374    }
375    if let Some(v) = overrides.num_kv_heads {
376        config.num_kv_heads = v;
377    }
378    if let Some(v) = overrides.intermediate_size {
379        config.intermediate_size = v;
380    }
381    if let Some(v) = overrides.vocab_size {
382        config.vocab_size = v;
383    }
384    if let Some(v) = overrides.max_position_embeddings {
385        config.max_position_embeddings = v;
386    }
387    if let Some(v) = overrides.rms_norm_eps {
388        config.rms_norm_eps = v;
389    }
390    if let Some(v) = overrides.rope_theta {
391        config.rope_theta = v;
392    }
393    if let Some(v) = overrides.use_bias {
394        config.use_bias = v;
395    }
396    if let Some(v) = overrides.head_dim {
397        config.head_dim_override = Some(v);
398    }
399}
400
401/// Build TransformerConfig from TrainSpec
402///
403/// Uses config file if specified, otherwise defaults to a small model.
404/// Architecture overrides from the YAML manifest are applied on top.
405fn build_transformer_config_from_spec(spec: &TrainSpec) -> Result<TransformerConfig> {
406    // Check if config file is specified (explicit path or auto-detect from model dir)
407    let config_path_resolved = spec.model.config.clone().or_else(|| {
408        // Auto-detect config.json in model directory
409        let model_config = spec.model.path.join("config.json");
410        if model_config.exists() {
411            Some(model_config.to_string_lossy().into_owned())
412        } else {
413            None
414        }
415    });
416
417    let mut config = if let Some(config_path) = &config_path_resolved {
418        let config_file = std::path::Path::new(config_path);
419        if config_file.exists() {
420            let config_content = std::fs::read_to_string(config_file)
421                .map_err(|e| Error::ConfigError(format!("Failed to read model config: {e}")))?;
422
423            if let Ok(hf_config) = serde_json::from_str::<serde_json::Value>(&config_content) {
424                parse_hf_config(&hf_config)?
425            } else {
426                fallback_demo_config()
427            }
428        } else {
429            fallback_demo_config()
430        }
431    } else if let Some(ref overrides) = spec.model.architecture {
432        // No config file specified — try to build entirely from architecture overrides
433        if let Some(cfg) = config_from_overrides(overrides) {
434            cfg
435        } else {
436            fallback_demo_config()
437        }
438    } else {
439        fallback_demo_config()
440    };
441
442    // Apply architecture overrides from YAML manifest (handles partial overrides on top of base)
443    if let Some(ref overrides) = spec.model.architecture {
444        apply_architecture_overrides(&mut config, overrides);
445    }
446
447    Ok(config)
448}
449
450/// Build a TransformerConfig directly from architecture overrides if all required fields are present.
451/// Required: hidden_size, num_attention_heads, num_hidden_layers, vocab_size, intermediate_size.
452fn config_from_overrides(
453    overrides: &crate::config::ArchitectureOverrides,
454) -> Option<TransformerConfig> {
455    let hidden_size = overrides.hidden_size?;
456    let num_attention_heads = overrides.num_attention_heads?;
457    let num_hidden_layers = overrides.num_hidden_layers?;
458    let vocab_size = overrides.vocab_size?;
459    let intermediate_size = overrides.intermediate_size?;
460
461    Some(TransformerConfig {
462        hidden_size,
463        num_attention_heads,
464        num_kv_heads: overrides.num_kv_heads.unwrap_or(num_attention_heads),
465        intermediate_size,
466        num_hidden_layers,
467        vocab_size,
468        max_position_embeddings: overrides.max_position_embeddings.unwrap_or(2048),
469        rms_norm_eps: overrides.rms_norm_eps.unwrap_or(1e-5),
470        rope_theta: overrides.rope_theta.unwrap_or(10000.0),
471        use_bias: overrides.use_bias.unwrap_or(false),
472        head_dim_override: overrides.head_dim,
473        architecture: ModelArchitecture::Decoder,
474        hf_architecture: None,
475        hf_model_type: None,
476        tie_word_embeddings: false,
477    })
478}
479
480/// R-05 (Meyer DbC): Explicit Qwen2-0.5B demo config — NOT a generic default.
481/// This path is ONLY for testing without a model. Production callers must provide config.json.
482fn fallback_demo_config() -> TransformerConfig {
483    eprintln!("WARNING: No model config found — using Qwen2-0.5B demo config (NOT suitable for production training)");
484    TransformerConfig {
485        hidden_size: QWEN_HIDDEN_SIZE,
486        num_attention_heads: QWEN_NUM_ATTENTION_HEADS,
487        num_kv_heads: QWEN_NUM_KV_HEADS,
488        intermediate_size: QWEN_INTERMEDIATE_SIZE,
489        num_hidden_layers: QWEN_NUM_HIDDEN_LAYERS,
490        vocab_size: QWEN_VOCAB_SIZE,
491        max_position_embeddings: QWEN_MAX_POSITION_EMBEDDINGS,
492        rms_norm_eps: 1e-6,
493        rope_theta: QWEN_ROPE_THETA as f32,
494        use_bias: false,
495        head_dim_override: None,
496        architecture: ModelArchitecture::Decoder,
497        hf_architecture: None,
498        hf_model_type: None,
499        tie_word_embeddings: false,
500    }
501}
502
503/// Parse HuggingFace config.json into `TransformerConfig`.
504///
505/// C-10/C-11 (Meyer DbC): Required fields (hidden_size, num_attention_heads,
506/// num_hidden_layers, vocab_size, intermediate_size) must be present — no silent defaults.
507/// R-04: Optional fields use generic defaults with warnings for likely-wrong values.
508fn parse_hf_config(hf_config: &serde_json::Value) -> Result<TransformerConfig> {
509    let hidden_size = hf_config["hidden_size"].as_u64().ok_or_else(|| {
510        Error::ConfigError(
511            "C-11: config.json missing 'hidden_size' — cannot train without model dimensions"
512                .into(),
513        )
514    })? as usize;
515    let num_attention_heads = hf_config["num_attention_heads"].as_u64().ok_or_else(|| {
516        Error::ConfigError("C-11: config.json missing 'num_attention_heads'".into())
517    })? as usize;
518    let num_hidden_layers = hf_config["num_hidden_layers"]
519        .as_u64()
520        .ok_or_else(|| Error::ConfigError("C-11: config.json missing 'num_hidden_layers'".into()))?
521        as usize;
522    let vocab_size = hf_config["vocab_size"]
523        .as_u64()
524        .ok_or_else(|| Error::ConfigError(
525            "C-10: config.json missing 'vocab_size' — training with wrong vocab corrupts embeddings".into()
526        ))? as usize;
527    let intermediate_size = hf_config["intermediate_size"]
528        .as_u64()
529        .ok_or_else(|| Error::ConfigError("C-11: config.json missing 'intermediate_size'".into()))?
530        as usize;
531
532    // R-04 (Meyer DbC): Optional fields with generic defaults.
533    // num_kv_heads → num_attention_heads is the correct GQA→MHA fallback.
534    // max_position_embeddings → 2048 is a conservative safe minimum.
535    // rms_norm_eps → 1e-6 is the most common default.
536    // rope_theta → 10000 is the LLaMA/Mistral standard (WRONG for Qwen at 1M).
537    // use_bias → false is correct for most modern architectures.
538    let num_kv_heads =
539        hf_config["num_key_value_heads"].as_u64().unwrap_or(num_attention_heads as u64) as usize;
540
541    let max_position_embeddings = match hf_config["max_position_embeddings"].as_u64() {
542        Some(v) => v as usize,
543        None => {
544            eprintln!("Warning: config.json missing 'max_position_embeddings', defaulting to 2048");
545            2048
546        }
547    };
548
549    let rope_theta = match hf_config["rope_theta"].as_f64() {
550        Some(v) => v as f32,
551        None => {
552            eprintln!(
553                "Warning: config.json missing 'rope_theta', defaulting to 10000.0 \
554                (Qwen models use 1000000.0 — check your config)"
555            );
556            10_000.0
557        }
558    };
559
560    let rms_norm_eps = hf_config["rms_norm_eps"].as_f64().unwrap_or_else(|| {
561        eprintln!(
562            "Warning: config.json missing 'rms_norm_eps', defaulting to 1e-6 \
563            (some models use 1e-5 or 1e-12 — check your config)"
564        );
565        1e-6
566    }) as f32;
567    let use_bias = hf_config["attention_bias"].as_bool().unwrap_or(false);
568    let head_dim_override = hf_config["head_dim"].as_u64().map(|v| v as usize);
569
570    // Detect encoder architectures from HuggingFace model_type
571    let architecture = match hf_config["model_type"].as_str() {
572        Some("bert" | "roberta" | "distilbert" | "albert" | "electra" | "deberta") => {
573            ModelArchitecture::Encoder
574        }
575        _ => ModelArchitecture::Decoder,
576    };
577
578    // Preserve HuggingFace architecture metadata for checkpoint config.json (#259)
579    let hf_architecture = hf_config["architectures"]
580        .as_array()
581        .and_then(|a| a.first())
582        .and_then(|v| v.as_str())
583        .map(String::from);
584    let hf_model_type = hf_config["model_type"].as_str().map(String::from);
585    let tie_word_embeddings = hf_config["tie_word_embeddings"].as_bool().unwrap_or(false);
586
587    Ok(TransformerConfig {
588        hidden_size,
589        num_attention_heads,
590        num_kv_heads,
591        intermediate_size,
592        num_hidden_layers,
593        vocab_size,
594        max_position_embeddings,
595        rms_norm_eps,
596        rope_theta,
597        use_bias,
598        head_dim_override,
599        architecture,
600        hf_architecture,
601        hf_model_type,
602        tie_word_embeddings,
603    })
604}
605
606/// Load training data as LMBatches for transformer training
607///
608/// Supports:
609/// 1. Pre-tokenized JSON with `input_ids` arrays
610/// 2. Text JSON/JSONL with `text` or `content` fields (requires tokenizer)
611/// 3. Demo mode fallback for testing
612fn load_lm_batches(spec: &TrainSpec) -> Result<Vec<LMBatch>> {
613    let batch_size = spec.data.batch_size;
614    let seq_len = spec.data.seq_len.unwrap_or_else(|| {
615        eprintln!("Warning: seq_len not specified, defaulting to 512 for LM batch loading");
616        512
617    });
618    let tokenizer = load_tokenizer(spec)?;
619
620    if let Some(result) = try_load_lm_from_file(spec, tokenizer.as_ref(), batch_size, seq_len) {
621        return result;
622    }
623
624    eprintln!(
625        "Warning: Training data not found at '{}', using demo LM batches",
626        spec.data.train.display()
627    );
628    create_demo_lm_batches(batch_size, seq_len)
629}
630
631/// Attempt to load LM batches from the training data file or directory
632fn try_load_lm_from_file(
633    spec: &TrainSpec,
634    tokenizer: Option<&HfTokenizer>,
635    batch_size: usize,
636    seq_len: usize,
637) -> Option<Result<Vec<LMBatch>>> {
638    if !spec.data.train.exists() {
639        return None;
640    }
641
642    // Handle directory of Parquet shards (ALB-007)
643    if spec.data.train.is_dir() {
644        let tokenizer = tokenizer?;
645        return Some(load_lm_batches_from_parquet(
646            &spec.data.train,
647            tokenizer,
648            batch_size,
649            seq_len,
650            spec.data.input_column.as_deref().unwrap_or("text"),
651        ));
652    }
653
654    let ext = spec.data.train.extension()?;
655
656    if ext == "json" || ext == "jsonl" {
657        let content = std::fs::read_to_string(&spec.data.train).ok()?;
658        return Some(load_lm_batches_from_json(
659            &content,
660            tokenizer,
661            batch_size,
662            seq_len,
663            spec.data.input_column.as_deref(),
664        ));
665    }
666
667    if ext == "parquet" {
668        let tokenizer = tokenizer?;
669        return Some(load_lm_batches_from_parquet(
670            &spec.data.train,
671            tokenizer,
672            batch_size,
673            seq_len,
674            spec.data.input_column.as_deref().unwrap_or("text"),
675        ));
676    }
677
678    None
679}
680
681/// Load HfTokenizer from spec if tokenizer path is specified
682fn load_tokenizer(spec: &TrainSpec) -> Result<Option<HfTokenizer>> {
683    if let Some(ref tokenizer_path) = spec.data.tokenizer {
684        if tokenizer_path.exists() {
685            println!("  Loading tokenizer from: {}", tokenizer_path.display());
686            let tokenizer = HfTokenizer::from_file(tokenizer_path)
687                .map_err(|e| Error::ConfigError(format!("Failed to load tokenizer: {e}")))?;
688            println!("  Tokenizer vocab size: {}", tokenizer.vocab_size());
689            return Ok(Some(tokenizer));
690        }
691        eprintln!(
692            "Warning: Tokenizer not found at '{}', using default Qwen2 tokenizer",
693            tokenizer_path.display()
694        );
695    }
696
697    // No tokenizer specified - use default for transformer mode
698    println!("  Using default Qwen2 tokenizer");
699    Ok(Some(HfTokenizer::qwen2()))
700}
701
702/// Extract text strings from a JSON array using the given column names
703fn extract_texts_from_array(array: &[serde_json::Value], text_col: &str) -> Vec<String> {
704    array
705        .iter()
706        .filter_map(|e| {
707            e.get(text_col).or_else(|| e.get("content")).and_then(|v| v.as_str()).map(String::from)
708        })
709        .collect()
710}
711
712/// Try loading from a JSON array (either pre-tokenized or text)
713fn try_load_from_array(
714    array: &[serde_json::Value],
715    tokenizer: Option<&HfTokenizer>,
716    batch_size: usize,
717    seq_len: usize,
718    text_col: &str,
719    label: &str,
720) -> Option<Result<Vec<LMBatch>>> {
721    // Check for pre-tokenized
722    if array.first().and_then(|e| e.get("input_ids")).is_some() {
723        return Some(load_pretokenized_json(array, batch_size, seq_len));
724    }
725
726    // Extract text and tokenize
727    let tokenizer = tokenizer?;
728    let texts = extract_texts_from_array(array, text_col);
729    if texts.is_empty() {
730        return None;
731    }
732
733    println!("  Loaded {} text examples from {label}, tokenizing...", texts.len());
734    Some(tokenize_texts_to_batches(&texts, tokenizer, batch_size, seq_len))
735}
736
737/// Try loading from JSONL (newline-delimited JSON)
738fn try_load_from_jsonl(
739    content: &str,
740    tokenizer: Option<&HfTokenizer>,
741    batch_size: usize,
742    seq_len: usize,
743    text_col: &str,
744) -> Option<Result<Vec<LMBatch>>> {
745    let tokenizer = tokenizer?;
746    let texts: Vec<String> = content
747        .lines()
748        .filter(|l| !l.trim().is_empty())
749        .filter_map(|line| {
750            serde_json::from_str::<serde_json::Value>(line).ok().and_then(|obj| {
751                obj.get(text_col)
752                    .or_else(|| obj.get("content"))
753                    .and_then(|v| v.as_str())
754                    .map(String::from)
755            })
756        })
757        .collect();
758
759    if texts.is_empty() {
760        return None;
761    }
762
763    println!("  Loaded {} text examples from JSONL, tokenizing...", texts.len());
764    Some(tokenize_texts_to_batches(&texts, tokenizer, batch_size, seq_len))
765}
766
767/// Try to load LM batches from a parsed JSON value (object or array)
768fn try_load_from_json_value(
769    data: &serde_json::Value,
770    tokenizer: Option<&HfTokenizer>,
771    batch_size: usize,
772    seq_len: usize,
773    text_col: &str,
774) -> Option<Result<Vec<LMBatch>>> {
775    // Try {"examples": [...]} format
776    if let Some(examples) = data.get("examples").and_then(|e| e.as_array()) {
777        if let Some(result) =
778            try_load_from_array(examples, tokenizer, batch_size, seq_len, text_col, "JSON")
779        {
780            return Some(result);
781        }
782    }
783
784    // Try top-level array format
785    if let Some(array) = data.as_array() {
786        if let Some(result) =
787            try_load_from_array(array, tokenizer, batch_size, seq_len, text_col, "JSON array")
788        {
789            return Some(result);
790        }
791    }
792
793    None
794}
795
796/// Load LM batches from JSON content
797///
798/// Supports formats:
799/// - Pre-tokenized: `{"examples": [{"input_ids": [...]}]}`
800/// - Text data: `{"examples": [{"text": "..."}]}` or `[{"text": "..."}]`
801/// - JSONL: `{"text": "..."}\n{"text": "..."}`
802fn load_lm_batches_from_json(
803    content: &str,
804    tokenizer: Option<&HfTokenizer>,
805    batch_size: usize,
806    seq_len: usize,
807    input_column: Option<&str>,
808) -> Result<Vec<LMBatch>> {
809    let text_col = input_column.unwrap_or("text");
810
811    // Try parsing as single JSON object or array
812    if let Ok(data) = serde_json::from_str::<serde_json::Value>(content) {
813        if let Some(result) =
814            try_load_from_json_value(&data, tokenizer, batch_size, seq_len, text_col)
815        {
816            return result;
817        }
818    }
819
820    // Try JSONL format
821    if let Some(result) = try_load_from_jsonl(content, tokenizer, batch_size, seq_len, text_col) {
822        return result;
823    }
824
825    // Fallback to demo batches
826    eprintln!("Warning: Could not parse training data, using demo LM batches");
827    create_demo_lm_batches(batch_size, seq_len)
828}
829
830/// Load pre-tokenized sequences from JSON
831fn load_pretokenized_json(
832    examples: &[serde_json::Value],
833    batch_size: usize,
834    seq_len: usize,
835) -> Result<Vec<LMBatch>> {
836    let mut all_sequences: Vec<Vec<u32>> = Vec::new();
837
838    for example in examples {
839        if let Some(tokens) = example.get("input_ids").and_then(|t| t.as_array()) {
840            let seq: Vec<u32> =
841                tokens.iter().filter_map(|t| t.as_u64().map(|v| v as u32)).collect();
842            if !seq.is_empty() {
843                all_sequences.push(seq);
844            }
845        }
846    }
847
848    if !all_sequences.is_empty() {
849        println!("  Loaded {} pre-tokenized sequences from JSON", all_sequences.len());
850        return create_lm_batches_from_sequences(&all_sequences, batch_size, seq_len);
851    }
852
853    eprintln!("Warning: No valid sequences found in JSON");
854    create_demo_lm_batches(batch_size, seq_len)
855}
856
857/// Tokenize texts and create LM batches
858fn tokenize_texts_to_batches(
859    texts: &[String],
860    tokenizer: &HfTokenizer,
861    batch_size: usize,
862    seq_len: usize,
863) -> Result<Vec<LMBatch>> {
864    let sequences: Vec<Vec<u32>> = texts
865        .iter()
866        .map(|text| {
867            let mut tokens = tokenizer.encode_with_special(text);
868            tokens.truncate(seq_len);
869            tokens
870        })
871        .filter(|seq| seq.len() > 1) // Need at least 2 tokens for causal LM
872        .collect();
873
874    if sequences.is_empty() {
875        eprintln!("Warning: No valid sequences after tokenization");
876        return create_demo_lm_batches(batch_size, seq_len);
877    }
878
879    println!("  Tokenized {} sequences", sequences.len());
880    create_lm_batches_from_sequences(&sequences, batch_size, seq_len)
881}
882
883/// Load LM batches from Parquet file with text or pre-tokenized columns (ALB-007)
884///
885/// Supports two modes:
886/// 1. **Text column** (Utf8): reads text, tokenizes with HfTokenizer, creates LMBatch
887/// 2. **Pre-tokenized column** (List<UInt32/Int32>): reads token ID lists directly
888///
889/// Also handles directory paths containing multiple .parquet shard files.
890#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
891fn load_lm_batches_from_parquet(
892    path: &std::path::Path,
893    tokenizer: &HfTokenizer,
894    batch_size: usize,
895    seq_len: usize,
896    text_column: &str,
897) -> Result<Vec<LMBatch>> {
898    use alimentar::{ArrowDataset, Dataset};
899
900    // Handle directory of parquet shards
901    if path.is_dir() {
902        return load_lm_batches_from_parquet_dir(path, tokenizer, batch_size, seq_len, text_column);
903    }
904
905    println!("  Loading Parquet LM data: {}", path.display());
906
907    // ALB-099: Scope ArrowDataset so it drops before LMBatch construction,
908    // avoiding triple materialization (Arrow + Vec<Vec<u32>> + LMBatch).
909    let (sequences, texts) = {
910        let dataset = ArrowDataset::from_parquet(path).map_err(|e| {
911            Error::ConfigError(format!("Failed to load parquet {}: {e}", path.display()))
912        })?;
913
914        println!("  Loaded {} rows from Parquet", dataset.len());
915
916        let schema = dataset.schema();
917        let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
918
919        // Try pre-tokenized first (input_ids column with integer list type)
920        let seqs = try_extract_pretokenized(&dataset, &column_names);
921        let txts = if seqs.is_none() {
922            Some(extract_text_column(&dataset, text_column, &column_names)?)
923        } else {
924            None
925        };
926        (seqs, txts)
927        // dataset dropped here — frees Arrow RecordBatch memory
928    };
929
930    if let Some(sequences) = sequences {
931        println!("  Found pre-tokenized column, loaded {} sequences", sequences.len());
932        return create_lm_batches_from_sequences(&sequences, batch_size, seq_len);
933    }
934
935    if let Some(texts) = texts {
936        println!("  Extracted {} text rows, tokenizing...", texts.len());
937        return tokenize_texts_to_batches(&texts, tokenizer, batch_size, seq_len);
938    }
939
940    Err(Error::ConfigError("No pre-tokenized or text column found".into()))
941}
942
943/// Load LM batches from a directory of Parquet shard files (ALB-007, ALB-101)
944///
945/// Uses `StreamingParquetLoader` to process shards one at a time, keeping only
946/// one shard's worth of raw Arrow data in memory at any point. The resulting
947/// `LMBatch`es are still accumulated into a single Vec (full streaming during
948/// training is a future enhancement).
949#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
950fn load_lm_batches_from_parquet_dir(
951    dir: &std::path::Path,
952    _tokenizer: &HfTokenizer,
953    batch_size: usize,
954    seq_len: usize,
955    _text_column: &str,
956) -> Result<Vec<LMBatch>> {
957    use crate::config::train::batches::streaming::{ShardConfig, StreamingParquetLoader};
958
959    let mut loader = StreamingParquetLoader::new(dir, ShardConfig::single(), batch_size, seq_len)
960        .map_err(|e| Error::ConfigError(e))?;
961
962    let total_shards = loader.total_files();
963    println!("  Streaming {} Parquet shard(s) from {} (ALB-101)", total_shards, dir.display());
964
965    let mut all_batches = Vec::new();
966    let mut shard_idx = 0usize;
967
968    while let Some(shard_batches) = loader.next_batches().map_err(|e| Error::ConfigError(e))? {
969        shard_idx += 1;
970        let n = shard_batches.len();
971        all_batches.extend(shard_batches);
972        // shard Arrow data already dropped inside next_batches()
973        println!(
974            "    shard {}/{}: {} batches (cumulative: {})",
975            shard_idx,
976            total_shards,
977            n,
978            all_batches.len()
979        );
980    }
981
982    // entrenar#315: Shuffle all batches across shards to prevent overfitting.
983    // Without this, the model sees all sequences from shard 1 before shard 2,
984    // causing memorization of shard 1 by step 2K (v22: val_ppl regressed
985    // from 9.44 to 118 between step 2K and 3K).
986    //
987    // PyTorch DataLoader with shuffle=True provides this by default.
988    // We shuffle deterministically using the config seed for reproducibility.
989    {
990        let seed = 123u64; // TODO: propagate config seed
991        let mut rng_state = seed;
992        // Fisher-Yates shuffle with simple LCG
993        for i in (1..all_batches.len()).rev() {
994            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
995            let j = (rng_state >> 33) as usize % (i + 1);
996            all_batches.swap(i, j);
997        }
998    }
999
1000    println!("  Total: {} batches from {} shards (shuffled)", all_batches.len(), total_shards);
1001    Ok(all_batches)
1002}
1003
1004/// Try to extract pre-tokenized sequences from a Parquet dataset (ALB-007)
1005///
1006/// Looks for columns named `input_ids` or `token_ids` containing integer arrays.
1007#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1008fn try_extract_pretokenized(
1009    dataset: &alimentar::ArrowDataset,
1010    column_names: &[&str],
1011) -> Option<Vec<Vec<u32>>> {
1012    use alimentar::Dataset;
1013
1014    let token_col =
1015        column_names.iter().find(|&&n| n == "input_ids" || n == "token_ids").copied()?;
1016
1017    let schema = dataset.schema();
1018    let col_idx = schema.index_of(token_col).ok()?;
1019
1020    // ALB-099: Pre-allocate with known row count
1021    let mut all_sequences = Vec::with_capacity(dataset.len());
1022
1023    for batch in dataset.iter() {
1024        let col = batch.column(col_idx);
1025        extract_sequences_from_column(col, &mut all_sequences);
1026    }
1027
1028    if all_sequences.is_empty() {
1029        None
1030    } else {
1031        Some(all_sequences)
1032    }
1033}
1034
1035/// Extract token sequences from a single Arrow column (List or flat integer types)
1036#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1037fn extract_sequences_from_column(col: &arrow::array::ArrayRef, sequences: &mut Vec<Vec<u32>>) {
1038    use arrow::array::{Array, ListArray};
1039
1040    if let Some(list_arr) = col.as_any().downcast_ref::<ListArray>() {
1041        for i in 0..list_arr.len() {
1042            if list_arr.is_null(i) {
1043                continue;
1044            }
1045            let values = list_arr.value(i);
1046            let seq = extract_u32_from_array(&values);
1047            if !seq.is_empty() {
1048                sequences.push(seq);
1049            }
1050        }
1051    } else {
1052        // Flat integer column: treat entire column as one sequence
1053        let seq = extract_u32_from_array(col.as_ref());
1054        if !seq.is_empty() {
1055            sequences.push(seq);
1056        }
1057    }
1058}
1059
1060/// Extract u32 token IDs from an Arrow array (inner values of a ListArray)
1061#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1062fn extract_u32_from_array(array: &dyn arrow::array::Array) -> Vec<u32> {
1063    use arrow::array::{Int32Array, Int64Array, UInt32Array};
1064
1065    if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
1066        arr.values().to_vec()
1067    } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
1068        arr.values().iter().map(|&v| v as u32).collect()
1069    } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
1070        arr.values().iter().map(|&v| v as u32).collect()
1071    } else {
1072        Vec::new()
1073    }
1074}
1075
1076/// Resolve text column name from available columns (ALB-007)
1077///
1078/// Tries the specified name first, then common alternatives: text, content, code.
1079#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1080fn resolve_text_column_name(text_column: &str, column_names: &[&str]) -> Result<String> {
1081    if column_names.contains(&text_column) {
1082        return Ok(text_column.to_string());
1083    }
1084    for &fallback in &["text", "content", "code"] {
1085        if column_names.contains(&fallback) {
1086            return Ok(fallback.to_string());
1087        }
1088    }
1089    Err(Error::ConfigError(format!(
1090        "No text column found in Parquet (tried '{text_column}', 'text', 'content', 'code'). Available: {column_names:?}"
1091    )))
1092}
1093
1094/// Extract text strings from a Parquet dataset column (ALB-007)
1095#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1096fn extract_text_column(
1097    dataset: &alimentar::ArrowDataset,
1098    text_column: &str,
1099    column_names: &[&str],
1100) -> Result<Vec<String>> {
1101    use alimentar::Dataset;
1102
1103    let col_name = resolve_text_column_name(text_column, column_names)?;
1104
1105    let schema = dataset.schema();
1106    let col_idx = schema
1107        .index_of(&col_name)
1108        .map_err(|e| Error::ConfigError(format!("Column '{col_name}' not found: {e}")))?;
1109
1110    let mut texts = Vec::new();
1111    for batch in dataset.iter() {
1112        let col = batch.column(col_idx);
1113        extract_strings_from_array(col, &col_name, &mut texts)?;
1114    }
1115    Ok(texts)
1116}
1117
1118/// Extract non-empty strings from a StringArray column
1119#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1120fn extract_strings_from_array(
1121    col: &arrow::array::ArrayRef,
1122    col_name: &str,
1123    texts: &mut Vec<String>,
1124) -> Result<()> {
1125    use arrow::array::{Array, StringArray};
1126
1127    let str_arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
1128        Error::ConfigError(format!(
1129            "Column '{col_name}' is not a string type (found {:?})",
1130            col.data_type()
1131        ))
1132    })?;
1133
1134    for i in 0..str_arr.len() {
1135        if !str_arr.is_null(i) {
1136            let text = str_arr.value(i);
1137            if !text.is_empty() {
1138                texts.push(text.to_string());
1139            }
1140        }
1141    }
1142    Ok(())
1143}
1144
1145/// Fallback: Parquet loading without alimentar feature
1146#[cfg(not(all(not(target_arch = "wasm32"), feature = "parquet")))]
1147fn load_lm_batches_from_parquet(
1148    path: &std::path::Path,
1149    _tokenizer: &HfTokenizer,
1150    batch_size: usize,
1151    seq_len: usize,
1152    text_column: &str,
1153) -> Result<Vec<LMBatch>> {
1154    if !path.exists() {
1155        return Err(Error::Io(format!("Parquet path does not exist: {}", path.display())));
1156    }
1157    eprintln!(
1158        "Warning: Parquet LM loading requires the 'parquet' feature. \
1159         Build with: cargo build --features parquet"
1160    );
1161    eprintln!(
1162        "  Alternatively, convert to JSONL: alimentar export {} -o train.jsonl --text-column {}",
1163        path.display(),
1164        text_column
1165    );
1166    create_demo_lm_batches(batch_size, seq_len)
1167}
1168
1169/// Create LMBatches from tokenized sequences
1170fn create_lm_batches_from_sequences(
1171    sequences: &[Vec<u32>],
1172    batch_size: usize,
1173    _seq_len: usize,
1174) -> Result<Vec<LMBatch>> {
1175    // ALB-099: Pre-allocate with known batch count
1176    let num_batches = sequences.len().div_ceil(batch_size);
1177    let mut batches = Vec::with_capacity(num_batches);
1178    let pad_id = 0u32; // Standard pad token
1179    let eos_id = 2u32; // Standard EOS token
1180
1181    for chunk in sequences.chunks(batch_size) {
1182        let batch = LMBatch::from_sequences(chunk, pad_id, eos_id);
1183        batches.push(batch);
1184    }
1185
1186    Ok(batches)
1187}
1188
1189/// Create demo LM batches for testing
1190fn create_demo_lm_batches(batch_size: usize, seq_len: usize) -> Result<Vec<LMBatch>> {
1191    let mut batches = Vec::new();
1192
1193    // Create a few demo sequences with synthetic tokens
1194    // Simulating simple patterns like: [1, 2, 3, 4, ...] with slight variations
1195    for batch_idx in 0..4 {
1196        let mut sequences = Vec::new();
1197        for item in 0..batch_size {
1198            let offset = (batch_idx * batch_size + item) as u32;
1199            // Create a sequence with incrementing tokens
1200            let seq: Vec<u32> = (0..seq_len.min(64))
1201                .map(|i| (offset + i as u32) % 1000 + 1) // Keep in reasonable token range
1202                .collect();
1203            sequences.push(seq);
1204        }
1205
1206        let batch = LMBatch::from_sequences(&sequences, 0, 2);
1207        batches.push(batch);
1208    }
1209
1210    Ok(batches)
1211}
1212
1213/// Detect whether YAML content is in the new manifest format.
1214///
1215/// Returns true if the content contains an `entrenar:` key at the start of a line,
1216/// which is the discriminating field in the manifest schema.
1217fn is_manifest_format(yaml: &str) -> bool {
1218    yaml.lines().any(|line| line.starts_with("entrenar:") || line.starts_with("entrenar :"))
1219}
1220
1221/// Load training spec from YAML file (without running training)
1222///
1223/// Auto-detects format:
1224/// - If the YAML contains `entrenar:`, it's parsed as a `TrainingManifest` and
1225///   converted to `TrainSpec` via the bridge converter.
1226/// - Otherwise, it's parsed directly as `TrainSpec` (legacy format).
1227pub fn load_config<P: AsRef<Path>>(config_path: P) -> Result<TrainSpec> {
1228    let yaml_content = fs::read_to_string(config_path.as_ref()).map_err(|e| {
1229        Error::ConfigError(format!(
1230            "Failed to read config file {}: {}",
1231            config_path.as_ref().display(),
1232            e
1233        ))
1234    })?;
1235
1236    if is_manifest_format(&yaml_content) {
1237        // New declarative manifest format
1238        let manifest: yaml_mode::TrainingManifest = serde_yaml::from_str(&yaml_content)
1239            .map_err(|e| Error::ConfigError(format!("Failed to parse manifest YAML: {e}")))?;
1240
1241        yaml_mode::validate_manifest(&manifest)
1242            .map_err(|e| Error::ConfigError(format!("Invalid manifest: {e}")))?;
1243
1244        let bridge_result = yaml_mode::manifest_to_spec(&manifest)
1245            .map_err(|e| Error::ConfigError(format!("Manifest conversion failed: {e}")))?;
1246
1247        for warning in &bridge_result.warnings {
1248            eprintln!("Warning: {warning}");
1249        }
1250
1251        validate_config(&bridge_result.spec)
1252            .map_err(|e| Error::ConfigError(format!("Invalid config after conversion: {e}")))?;
1253
1254        Ok(bridge_result.spec)
1255    } else {
1256        // Legacy TrainSpec format
1257        let spec: TrainSpec = serde_yaml::from_str(&yaml_content)
1258            .map_err(|e| Error::ConfigError(format!("Failed to parse YAML config: {e}")))?;
1259
1260        validate_config(&spec).map_err(|e| Error::ConfigError(format!("Invalid config: {e}")))?;
1261
1262        Ok(spec)
1263    }
1264}
1265