1#[cfg(feature = "hub-publish")]
7fn resolve_model_path(model_path: &Path) -> Result<PathBuf> {
8 use crate::config::schema::is_hf_repo_id;
9 use crate::hf_pipeline::{FetchOptions, HfModelFetcher};
10
11 let path_str = model_path.to_string_lossy();
12 if !is_hf_repo_id(&path_str) {
13 return Ok(model_path.to_path_buf());
14 }
15
16 println!("Downloading {path_str} from HuggingFace Hub...");
17 let fetcher = HfModelFetcher::new()
18 .map_err(|e| Error::ConfigError(format!("HF fetcher initialization: {e}")))?;
19
20 let artifact = fetcher
21 .download_model(&path_str, FetchOptions::new())
22 .map_err(|e| Error::ConfigError(format!("Model download failed: {e}")))?;
23
24 println!(" Cached at: {}", artifact.path.display());
25 Ok(artifact.path)
26}
27
28#[cfg(not(feature = "hub-publish"))]
29fn resolve_model_path(model_path: &Path) -> Result<PathBuf> {
30 use crate::config::schema::is_hf_repo_id;
31
32 let path_str = model_path.to_string_lossy();
33 if is_hf_repo_id(&path_str) {
34 return Err(Error::ConfigError(format!(
35 "HF model ID '{path_str}' requires the 'hub-publish' feature. \
36 Rebuild with: cargo install entrenar --features hub-publish"
37 )));
38 }
39 Ok(model_path.to_path_buf())
40}
41
42fn load_transformer_model(
48 model_path: &Path,
49 config: &TransformerConfig,
50 output_dir: &Path,
51) -> Result<(Option<Transformer>, usize)> {
52 if output_dir.is_dir() {
54 if let Some(result) = try_load_apr_delta(output_dir, config, model_path) {
57 return Ok(result);
58 }
59 if let Some(result) = try_load_apr(output_dir, config) {
61 return Ok(result);
62 }
63 if let Some(result) = try_load_safetensors_dir(output_dir, config) {
65 return Ok(result);
66 }
67 }
68
69 if !model_path.exists() {
70 println!(" Model path not found, using random initialization");
71 return Ok((None, 0));
72 }
73
74 println!("Loading model weights from {}...", model_path.display());
75
76 if let Some((model, _source_step)) = try_load_apr(model_path, config) {
84 return Ok((model, 0));
85 }
86
87 if let Some((model, _source_step)) = try_load_safetensors_dir(model_path, config) {
89 return Ok((model, 0));
90 }
91
92 eprintln!("Warning: No loadable checkpoint found, using random initialization");
93 Ok((None, 0))
94}
95
96fn try_load_safetensors_dir(
99 dir: &Path,
100 config: &TransformerConfig,
101) -> Option<(Option<Transformer>, usize)> {
102 let checkpoint_step = detect_checkpoint_step(dir);
103
104 match load_safetensors_weights(dir, Architecture::Auto) {
105 Ok(weights) => {
106 println!(" Found {} weight tensors (SafeTensors)", weights.len());
107 if let Some(transformer) = Transformer::from_params(config, &weights) {
108 let embed = &transformer.embed_tokens.weight;
109 let embed_data = embed.data();
110 let embed_slice = embed_data.as_slice().unwrap_or(&[]);
111 let (emin, emax, emean) = if embed_slice.is_empty() {
112 (0.0, 0.0, 0.0)
113 } else {
114 let min = embed_slice.iter().copied().fold(f32::INFINITY, f32::min);
115 let max = embed_slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
116 let mean = embed_slice.iter().sum::<f32>() / embed_slice.len() as f32;
117 (min, max, mean)
118 };
119 println!("✓ Loaded pre-trained weights successfully (SafeTensors)");
120 println!(" embed_tokens stats: min={emin:.4e} max={emax:.4e} mean={emean:.4e}");
121 return Some((Some(transformer), checkpoint_step));
122 }
123 None
124 }
125 Err(_) => None,
126 }
127}
128
129pub fn try_load_apr_for_inference(
136 model_path: &Path,
137 config: &TransformerConfig,
138) -> Option<(Option<Transformer>, usize)> {
139 try_load_apr(model_path, config)
140}
141
142fn try_load_apr_delta(
150 output_dir: &Path,
151 config: &TransformerConfig,
152 base_model_path: &Path,
153) -> Option<(Option<Transformer>, usize)> {
154 use aprender::serialization::apr::AprReader;
155
156 let apr_path = find_latest_apr_checkpoint(output_dir)?;
157 let reader = AprReader::open(&apr_path).ok()?;
158
159 let format = reader.get_metadata("format").and_then(|v| v.as_str().map(String::from))?;
161 if format != "entrenar-delta-checkpoint" {
162 return None; }
164
165 let checkpoint_step = reader
166 .get_metadata("checkpoint_step")
167 .and_then(|v| v.as_str())
168 .and_then(|s| s.parse::<usize>().ok())
169 .unwrap_or(0);
170
171 println!(
172 " Delta checkpoint at step {checkpoint_step} (loading base from {})",
173 base_model_path.display()
174 );
175
176 let (base_model, _) = try_load_apr(base_model_path, config)
178 .or_else(|| try_load_safetensors_dir(base_model_path, config))?;
179 let mut transformer = base_model?;
180
181 let mut overlaid = 0usize;
183 for desc in &reader.tensors {
184 let name = &desc.name;
185 if name.starts_with("__training__") || name.starts_with("lora.") {
186 continue; }
188 if let Ok(data) = reader.read_tensor_as_f32(name) {
189 let tensor = crate::Tensor::from_vec(data, false);
190 if transformer.set_named_parameter(name, tensor) {
191 overlaid += 1;
192 }
193 }
194 }
195 println!(" ✓ Delta: {overlaid} tensors overlaid on base model");
196
197 Some((Some(transformer), checkpoint_step))
198}
199
200fn try_load_apr(
201 model_path: &Path,
202 config: &TransformerConfig,
203) -> Option<(Option<Transformer>, usize)> {
204 use aprender::serialization::apr::AprReader;
205 use std::collections::HashMap;
206
207 let apr_path =
209 if model_path.is_file() && model_path.extension().and_then(|e| e.to_str()) == Some("apr") {
210 model_path.to_path_buf()
211 } else if model_path.is_dir() {
212 find_latest_apr_checkpoint(model_path)?
213 } else {
214 return None;
215 };
216
217 let reader = match AprReader::open(&apr_path) {
218 Ok(r) => r,
219 Err(_) => return None,
220 };
221
222 let checkpoint_step = reader
224 .get_metadata("checkpoint_step")
225 .and_then(|v| v.as_str())
226 .and_then(|s| s.parse::<usize>().ok())
227 .unwrap_or_else(|| {
228 apr_path
229 .file_name()
230 .and_then(|n| n.to_str())
231 .and_then(parse_checkpoint_step)
232 .unwrap_or(0)
233 });
234
235 let mut weights = HashMap::new();
238 let is_gguf_names = reader.tensors.iter().any(|t| t.name == "token_embd.weight");
239 if is_gguf_names {
240 eprintln!("[PMAT-489] Detected GGUF tensor names in APR file, mapping to HF convention");
241 }
242 for desc in &reader.tensors {
243 let tensor_name = &desc.name;
244 if tensor_name.starts_with("__training__") {
245 continue;
246 }
247 match reader.read_tensor_as_f32(tensor_name) {
248 Ok(data) => {
249 let mapped_name = if is_gguf_names {
250 use crate::transformer::weights::{Architecture, mapping::map_weight_name};
251 map_weight_name(tensor_name, Architecture::Gguf)
252 } else {
253 tensor_name.clone()
254 };
255 weights.insert(mapped_name, crate::Tensor::from_vec(data, false));
256 }
257 Err(e) => {
258 eprintln!("Warning: Failed to read APR tensor '{tensor_name}': {e}");
259 return None;
260 }
261 }
262 }
263
264 println!(" Found {} weight tensors (APR)", weights.len());
265
266 let transformer = Transformer::from_params(config, &weights)?;
267
268 let embed = &transformer.embed_tokens.weight;
269 let embed_data = embed.data();
270 let embed_slice = embed_data.as_slice().unwrap_or(&[]);
271 let (emin, emax, emean) = if embed_slice.is_empty() {
272 (0.0, 0.0, 0.0)
273 } else {
274 let min = embed_slice.iter().copied().fold(f32::INFINITY, f32::min);
275 let max = embed_slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
276 let mean = embed_slice.iter().sum::<f32>() / embed_slice.len() as f32;
277 (min, max, mean)
278 };
279 println!("✓ Loaded pre-trained weights successfully (APR)");
280 println!(" embed_tokens stats: min={emin:.4e} max={emax:.4e} mean={emean:.4e}");
284
285 Some((Some(transformer), checkpoint_step))
286}
287
288fn find_latest_apr_checkpoint(dir: &Path) -> Option<std::path::PathBuf> {
292 let mut best_step = 0usize;
293 let mut best_path = None;
294
295 if let Ok(entries) = std::fs::read_dir(dir) {
296 for entry in entries.flatten() {
297 let path = entry.path();
298 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
299 if let Some(step) = parse_checkpoint_step(name) {
300 if step >= best_step {
301 best_step = step;
302 best_path = Some(path);
303 }
304 }
305 }
306 }
307 }
308
309 if best_path.is_some() {
310 return best_path;
311 }
312
313 let best = dir.join("model-best.apr");
315 if best.exists() {
316 return Some(best);
317 }
318 let model = dir.join("model.apr");
319 if model.exists() {
320 return Some(model);
321 }
322
323 None
324}
325
326fn detect_checkpoint_step(model_path: &Path) -> usize {
330 use crate::transformer::weights::parse_checkpoint_step_from_path;
331
332 if model_path.is_file() {
333 if let Some(name) = model_path.file_name().and_then(|n| n.to_str()) {
334 if let Some(step) = parse_checkpoint_step(name) {
335 return step;
336 }
337 }
338 return parse_checkpoint_step_from_path(model_path).unwrap_or(0);
339 }
340 if !model_path.is_dir() {
341 return 0;
342 }
343 let Ok(entries) = std::fs::read_dir(model_path) else { return 0 };
345 let mut max_step = 0usize;
346 for entry in entries.flatten() {
347 if let Some(name) = entry.file_name().to_str() {
348 if let Some(step) = parse_checkpoint_step(name) {
349 max_step = max_step.max(step);
350 }
351 }
352 if let Some(step) = parse_checkpoint_step_from_path(&entry.path()) {
353 max_step = max_step.max(step);
354 }
355 }
356 max_step
357}
358
359fn apply_architecture_overrides(
363 config: &mut TransformerConfig,
364 overrides: &crate::config::ArchitectureOverrides,
365) {
366 if let Some(v) = overrides.hidden_size {
367 config.hidden_size = v;
368 }
369 if let Some(v) = overrides.num_hidden_layers {
370 config.num_hidden_layers = v;
371 }
372 if let Some(v) = overrides.num_attention_heads {
373 config.num_attention_heads = v;
374 }
375 if let Some(v) = overrides.num_kv_heads {
376 config.num_kv_heads = v;
377 }
378 if let Some(v) = overrides.intermediate_size {
379 config.intermediate_size = v;
380 }
381 if let Some(v) = overrides.vocab_size {
382 config.vocab_size = v;
383 }
384 if let Some(v) = overrides.max_position_embeddings {
385 config.max_position_embeddings = v;
386 }
387 if let Some(v) = overrides.rms_norm_eps {
388 config.rms_norm_eps = v;
389 }
390 if let Some(v) = overrides.rope_theta {
391 config.rope_theta = v;
392 }
393 if let Some(v) = overrides.use_bias {
394 config.use_bias = v;
395 }
396 if let Some(v) = overrides.head_dim {
397 config.head_dim_override = Some(v);
398 }
399}
400
401fn build_transformer_config_from_spec(spec: &TrainSpec) -> Result<TransformerConfig> {
406 let config_path_resolved = spec.model.config.clone().or_else(|| {
408 let model_config = spec.model.path.join("config.json");
410 if model_config.exists() {
411 Some(model_config.to_string_lossy().into_owned())
412 } else {
413 None
414 }
415 });
416
417 let mut config = if let Some(config_path) = &config_path_resolved {
418 let config_file = std::path::Path::new(config_path);
419 if config_file.exists() {
420 let config_content = std::fs::read_to_string(config_file)
421 .map_err(|e| Error::ConfigError(format!("Failed to read model config: {e}")))?;
422
423 if let Ok(hf_config) = serde_json::from_str::<serde_json::Value>(&config_content) {
424 parse_hf_config(&hf_config)?
425 } else {
426 fallback_demo_config()
427 }
428 } else {
429 fallback_demo_config()
430 }
431 } else if let Some(ref overrides) = spec.model.architecture {
432 if let Some(cfg) = config_from_overrides(overrides) {
434 cfg
435 } else {
436 fallback_demo_config()
437 }
438 } else {
439 fallback_demo_config()
440 };
441
442 if let Some(ref overrides) = spec.model.architecture {
444 apply_architecture_overrides(&mut config, overrides);
445 }
446
447 Ok(config)
448}
449
450fn config_from_overrides(
453 overrides: &crate::config::ArchitectureOverrides,
454) -> Option<TransformerConfig> {
455 let hidden_size = overrides.hidden_size?;
456 let num_attention_heads = overrides.num_attention_heads?;
457 let num_hidden_layers = overrides.num_hidden_layers?;
458 let vocab_size = overrides.vocab_size?;
459 let intermediate_size = overrides.intermediate_size?;
460
461 Some(TransformerConfig {
462 hidden_size,
463 num_attention_heads,
464 num_kv_heads: overrides.num_kv_heads.unwrap_or(num_attention_heads),
465 intermediate_size,
466 num_hidden_layers,
467 vocab_size,
468 max_position_embeddings: overrides.max_position_embeddings.unwrap_or(2048),
469 rms_norm_eps: overrides.rms_norm_eps.unwrap_or(1e-5),
470 rope_theta: overrides.rope_theta.unwrap_or(10000.0),
471 use_bias: overrides.use_bias.unwrap_or(false),
472 head_dim_override: overrides.head_dim,
473 architecture: ModelArchitecture::Decoder,
474 hf_architecture: None,
475 hf_model_type: None,
476 tie_word_embeddings: false,
477 })
478}
479
480fn fallback_demo_config() -> TransformerConfig {
483 eprintln!("WARNING: No model config found — using Qwen2-0.5B demo config (NOT suitable for production training)");
484 TransformerConfig {
485 hidden_size: QWEN_HIDDEN_SIZE,
486 num_attention_heads: QWEN_NUM_ATTENTION_HEADS,
487 num_kv_heads: QWEN_NUM_KV_HEADS,
488 intermediate_size: QWEN_INTERMEDIATE_SIZE,
489 num_hidden_layers: QWEN_NUM_HIDDEN_LAYERS,
490 vocab_size: QWEN_VOCAB_SIZE,
491 max_position_embeddings: QWEN_MAX_POSITION_EMBEDDINGS,
492 rms_norm_eps: 1e-6,
493 rope_theta: QWEN_ROPE_THETA as f32,
494 use_bias: false,
495 head_dim_override: None,
496 architecture: ModelArchitecture::Decoder,
497 hf_architecture: None,
498 hf_model_type: None,
499 tie_word_embeddings: false,
500 }
501}
502
503fn parse_hf_config(hf_config: &serde_json::Value) -> Result<TransformerConfig> {
509 let hidden_size = hf_config["hidden_size"].as_u64().ok_or_else(|| {
510 Error::ConfigError(
511 "C-11: config.json missing 'hidden_size' — cannot train without model dimensions"
512 .into(),
513 )
514 })? as usize;
515 let num_attention_heads = hf_config["num_attention_heads"].as_u64().ok_or_else(|| {
516 Error::ConfigError("C-11: config.json missing 'num_attention_heads'".into())
517 })? as usize;
518 let num_hidden_layers = hf_config["num_hidden_layers"]
519 .as_u64()
520 .ok_or_else(|| Error::ConfigError("C-11: config.json missing 'num_hidden_layers'".into()))?
521 as usize;
522 let vocab_size = hf_config["vocab_size"]
523 .as_u64()
524 .ok_or_else(|| Error::ConfigError(
525 "C-10: config.json missing 'vocab_size' — training with wrong vocab corrupts embeddings".into()
526 ))? as usize;
527 let intermediate_size = hf_config["intermediate_size"]
528 .as_u64()
529 .ok_or_else(|| Error::ConfigError("C-11: config.json missing 'intermediate_size'".into()))?
530 as usize;
531
532 let num_kv_heads =
539 hf_config["num_key_value_heads"].as_u64().unwrap_or(num_attention_heads as u64) as usize;
540
541 let max_position_embeddings = match hf_config["max_position_embeddings"].as_u64() {
542 Some(v) => v as usize,
543 None => {
544 eprintln!("Warning: config.json missing 'max_position_embeddings', defaulting to 2048");
545 2048
546 }
547 };
548
549 let rope_theta = match hf_config["rope_theta"].as_f64() {
550 Some(v) => v as f32,
551 None => {
552 eprintln!(
553 "Warning: config.json missing 'rope_theta', defaulting to 10000.0 \
554 (Qwen models use 1000000.0 — check your config)"
555 );
556 10_000.0
557 }
558 };
559
560 let rms_norm_eps = hf_config["rms_norm_eps"].as_f64().unwrap_or_else(|| {
561 eprintln!(
562 "Warning: config.json missing 'rms_norm_eps', defaulting to 1e-6 \
563 (some models use 1e-5 or 1e-12 — check your config)"
564 );
565 1e-6
566 }) as f32;
567 let use_bias = hf_config["attention_bias"].as_bool().unwrap_or(false);
568 let head_dim_override = hf_config["head_dim"].as_u64().map(|v| v as usize);
569
570 let architecture = match hf_config["model_type"].as_str() {
572 Some("bert" | "roberta" | "distilbert" | "albert" | "electra" | "deberta") => {
573 ModelArchitecture::Encoder
574 }
575 _ => ModelArchitecture::Decoder,
576 };
577
578 let hf_architecture = hf_config["architectures"]
580 .as_array()
581 .and_then(|a| a.first())
582 .and_then(|v| v.as_str())
583 .map(String::from);
584 let hf_model_type = hf_config["model_type"].as_str().map(String::from);
585 let tie_word_embeddings = hf_config["tie_word_embeddings"].as_bool().unwrap_or(false);
586
587 Ok(TransformerConfig {
588 hidden_size,
589 num_attention_heads,
590 num_kv_heads,
591 intermediate_size,
592 num_hidden_layers,
593 vocab_size,
594 max_position_embeddings,
595 rms_norm_eps,
596 rope_theta,
597 use_bias,
598 head_dim_override,
599 architecture,
600 hf_architecture,
601 hf_model_type,
602 tie_word_embeddings,
603 })
604}
605
606fn load_lm_batches(spec: &TrainSpec) -> Result<Vec<LMBatch>> {
613 let batch_size = spec.data.batch_size;
614 let seq_len = spec.data.seq_len.unwrap_or_else(|| {
615 eprintln!("Warning: seq_len not specified, defaulting to 512 for LM batch loading");
616 512
617 });
618 let tokenizer = load_tokenizer(spec)?;
619
620 if let Some(result) = try_load_lm_from_file(spec, tokenizer.as_ref(), batch_size, seq_len) {
621 return result;
622 }
623
624 eprintln!(
625 "Warning: Training data not found at '{}', using demo LM batches",
626 spec.data.train.display()
627 );
628 create_demo_lm_batches(batch_size, seq_len)
629}
630
631fn try_load_lm_from_file(
633 spec: &TrainSpec,
634 tokenizer: Option<&HfTokenizer>,
635 batch_size: usize,
636 seq_len: usize,
637) -> Option<Result<Vec<LMBatch>>> {
638 if !spec.data.train.exists() {
639 return None;
640 }
641
642 if spec.data.train.is_dir() {
644 let tokenizer = tokenizer?;
645 return Some(load_lm_batches_from_parquet(
646 &spec.data.train,
647 tokenizer,
648 batch_size,
649 seq_len,
650 spec.data.input_column.as_deref().unwrap_or("text"),
651 ));
652 }
653
654 let ext = spec.data.train.extension()?;
655
656 if ext == "json" || ext == "jsonl" {
657 let content = std::fs::read_to_string(&spec.data.train).ok()?;
658 return Some(load_lm_batches_from_json(
659 &content,
660 tokenizer,
661 batch_size,
662 seq_len,
663 spec.data.input_column.as_deref(),
664 ));
665 }
666
667 if ext == "parquet" {
668 let tokenizer = tokenizer?;
669 return Some(load_lm_batches_from_parquet(
670 &spec.data.train,
671 tokenizer,
672 batch_size,
673 seq_len,
674 spec.data.input_column.as_deref().unwrap_or("text"),
675 ));
676 }
677
678 None
679}
680
681fn load_tokenizer(spec: &TrainSpec) -> Result<Option<HfTokenizer>> {
683 if let Some(ref tokenizer_path) = spec.data.tokenizer {
684 if tokenizer_path.exists() {
685 println!(" Loading tokenizer from: {}", tokenizer_path.display());
686 let tokenizer = HfTokenizer::from_file(tokenizer_path)
687 .map_err(|e| Error::ConfigError(format!("Failed to load tokenizer: {e}")))?;
688 println!(" Tokenizer vocab size: {}", tokenizer.vocab_size());
689 return Ok(Some(tokenizer));
690 }
691 eprintln!(
692 "Warning: Tokenizer not found at '{}', using default Qwen2 tokenizer",
693 tokenizer_path.display()
694 );
695 }
696
697 println!(" Using default Qwen2 tokenizer");
699 Ok(Some(HfTokenizer::qwen2()))
700}
701
702fn extract_texts_from_array(array: &[serde_json::Value], text_col: &str) -> Vec<String> {
704 array
705 .iter()
706 .filter_map(|e| {
707 e.get(text_col).or_else(|| e.get("content")).and_then(|v| v.as_str()).map(String::from)
708 })
709 .collect()
710}
711
712fn try_load_from_array(
714 array: &[serde_json::Value],
715 tokenizer: Option<&HfTokenizer>,
716 batch_size: usize,
717 seq_len: usize,
718 text_col: &str,
719 label: &str,
720) -> Option<Result<Vec<LMBatch>>> {
721 if array.first().and_then(|e| e.get("input_ids")).is_some() {
723 return Some(load_pretokenized_json(array, batch_size, seq_len));
724 }
725
726 let tokenizer = tokenizer?;
728 let texts = extract_texts_from_array(array, text_col);
729 if texts.is_empty() {
730 return None;
731 }
732
733 println!(" Loaded {} text examples from {label}, tokenizing...", texts.len());
734 Some(tokenize_texts_to_batches(&texts, tokenizer, batch_size, seq_len))
735}
736
737fn try_load_from_jsonl(
739 content: &str,
740 tokenizer: Option<&HfTokenizer>,
741 batch_size: usize,
742 seq_len: usize,
743 text_col: &str,
744) -> Option<Result<Vec<LMBatch>>> {
745 let tokenizer = tokenizer?;
746 let texts: Vec<String> = content
747 .lines()
748 .filter(|l| !l.trim().is_empty())
749 .filter_map(|line| {
750 serde_json::from_str::<serde_json::Value>(line).ok().and_then(|obj| {
751 obj.get(text_col)
752 .or_else(|| obj.get("content"))
753 .and_then(|v| v.as_str())
754 .map(String::from)
755 })
756 })
757 .collect();
758
759 if texts.is_empty() {
760 return None;
761 }
762
763 println!(" Loaded {} text examples from JSONL, tokenizing...", texts.len());
764 Some(tokenize_texts_to_batches(&texts, tokenizer, batch_size, seq_len))
765}
766
767fn try_load_from_json_value(
769 data: &serde_json::Value,
770 tokenizer: Option<&HfTokenizer>,
771 batch_size: usize,
772 seq_len: usize,
773 text_col: &str,
774) -> Option<Result<Vec<LMBatch>>> {
775 if let Some(examples) = data.get("examples").and_then(|e| e.as_array()) {
777 if let Some(result) =
778 try_load_from_array(examples, tokenizer, batch_size, seq_len, text_col, "JSON")
779 {
780 return Some(result);
781 }
782 }
783
784 if let Some(array) = data.as_array() {
786 if let Some(result) =
787 try_load_from_array(array, tokenizer, batch_size, seq_len, text_col, "JSON array")
788 {
789 return Some(result);
790 }
791 }
792
793 None
794}
795
796fn load_lm_batches_from_json(
803 content: &str,
804 tokenizer: Option<&HfTokenizer>,
805 batch_size: usize,
806 seq_len: usize,
807 input_column: Option<&str>,
808) -> Result<Vec<LMBatch>> {
809 let text_col = input_column.unwrap_or("text");
810
811 if let Ok(data) = serde_json::from_str::<serde_json::Value>(content) {
813 if let Some(result) =
814 try_load_from_json_value(&data, tokenizer, batch_size, seq_len, text_col)
815 {
816 return result;
817 }
818 }
819
820 if let Some(result) = try_load_from_jsonl(content, tokenizer, batch_size, seq_len, text_col) {
822 return result;
823 }
824
825 eprintln!("Warning: Could not parse training data, using demo LM batches");
827 create_demo_lm_batches(batch_size, seq_len)
828}
829
830fn load_pretokenized_json(
832 examples: &[serde_json::Value],
833 batch_size: usize,
834 seq_len: usize,
835) -> Result<Vec<LMBatch>> {
836 let mut all_sequences: Vec<Vec<u32>> = Vec::new();
837
838 for example in examples {
839 if let Some(tokens) = example.get("input_ids").and_then(|t| t.as_array()) {
840 let seq: Vec<u32> =
841 tokens.iter().filter_map(|t| t.as_u64().map(|v| v as u32)).collect();
842 if !seq.is_empty() {
843 all_sequences.push(seq);
844 }
845 }
846 }
847
848 if !all_sequences.is_empty() {
849 println!(" Loaded {} pre-tokenized sequences from JSON", all_sequences.len());
850 return create_lm_batches_from_sequences(&all_sequences, batch_size, seq_len);
851 }
852
853 eprintln!("Warning: No valid sequences found in JSON");
854 create_demo_lm_batches(batch_size, seq_len)
855}
856
857fn tokenize_texts_to_batches(
859 texts: &[String],
860 tokenizer: &HfTokenizer,
861 batch_size: usize,
862 seq_len: usize,
863) -> Result<Vec<LMBatch>> {
864 let sequences: Vec<Vec<u32>> = texts
865 .iter()
866 .map(|text| {
867 let mut tokens = tokenizer.encode_with_special(text);
868 tokens.truncate(seq_len);
869 tokens
870 })
871 .filter(|seq| seq.len() > 1) .collect();
873
874 if sequences.is_empty() {
875 eprintln!("Warning: No valid sequences after tokenization");
876 return create_demo_lm_batches(batch_size, seq_len);
877 }
878
879 println!(" Tokenized {} sequences", sequences.len());
880 create_lm_batches_from_sequences(&sequences, batch_size, seq_len)
881}
882
883#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
891fn load_lm_batches_from_parquet(
892 path: &std::path::Path,
893 tokenizer: &HfTokenizer,
894 batch_size: usize,
895 seq_len: usize,
896 text_column: &str,
897) -> Result<Vec<LMBatch>> {
898 use alimentar::{ArrowDataset, Dataset};
899
900 if path.is_dir() {
902 return load_lm_batches_from_parquet_dir(path, tokenizer, batch_size, seq_len, text_column);
903 }
904
905 println!(" Loading Parquet LM data: {}", path.display());
906
907 let (sequences, texts) = {
910 let dataset = ArrowDataset::from_parquet(path).map_err(|e| {
911 Error::ConfigError(format!("Failed to load parquet {}: {e}", path.display()))
912 })?;
913
914 println!(" Loaded {} rows from Parquet", dataset.len());
915
916 let schema = dataset.schema();
917 let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
918
919 let seqs = try_extract_pretokenized(&dataset, &column_names);
921 let txts = if seqs.is_none() {
922 Some(extract_text_column(&dataset, text_column, &column_names)?)
923 } else {
924 None
925 };
926 (seqs, txts)
927 };
929
930 if let Some(sequences) = sequences {
931 println!(" Found pre-tokenized column, loaded {} sequences", sequences.len());
932 return create_lm_batches_from_sequences(&sequences, batch_size, seq_len);
933 }
934
935 if let Some(texts) = texts {
936 println!(" Extracted {} text rows, tokenizing...", texts.len());
937 return tokenize_texts_to_batches(&texts, tokenizer, batch_size, seq_len);
938 }
939
940 Err(Error::ConfigError("No pre-tokenized or text column found".into()))
941}
942
943#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
950fn load_lm_batches_from_parquet_dir(
951 dir: &std::path::Path,
952 _tokenizer: &HfTokenizer,
953 batch_size: usize,
954 seq_len: usize,
955 _text_column: &str,
956) -> Result<Vec<LMBatch>> {
957 use crate::config::train::batches::streaming::{ShardConfig, StreamingParquetLoader};
958
959 let mut loader = StreamingParquetLoader::new(dir, ShardConfig::single(), batch_size, seq_len)
960 .map_err(|e| Error::ConfigError(e))?;
961
962 let total_shards = loader.total_files();
963 println!(" Streaming {} Parquet shard(s) from {} (ALB-101)", total_shards, dir.display());
964
965 let mut all_batches = Vec::new();
966 let mut shard_idx = 0usize;
967
968 while let Some(shard_batches) = loader.next_batches().map_err(|e| Error::ConfigError(e))? {
969 shard_idx += 1;
970 let n = shard_batches.len();
971 all_batches.extend(shard_batches);
972 println!(
974 " shard {}/{}: {} batches (cumulative: {})",
975 shard_idx,
976 total_shards,
977 n,
978 all_batches.len()
979 );
980 }
981
982 {
990 let seed = 123u64; let mut rng_state = seed;
992 for i in (1..all_batches.len()).rev() {
994 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
995 let j = (rng_state >> 33) as usize % (i + 1);
996 all_batches.swap(i, j);
997 }
998 }
999
1000 println!(" Total: {} batches from {} shards (shuffled)", all_batches.len(), total_shards);
1001 Ok(all_batches)
1002}
1003
1004#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1008fn try_extract_pretokenized(
1009 dataset: &alimentar::ArrowDataset,
1010 column_names: &[&str],
1011) -> Option<Vec<Vec<u32>>> {
1012 use alimentar::Dataset;
1013
1014 let token_col =
1015 column_names.iter().find(|&&n| n == "input_ids" || n == "token_ids").copied()?;
1016
1017 let schema = dataset.schema();
1018 let col_idx = schema.index_of(token_col).ok()?;
1019
1020 let mut all_sequences = Vec::with_capacity(dataset.len());
1022
1023 for batch in dataset.iter() {
1024 let col = batch.column(col_idx);
1025 extract_sequences_from_column(col, &mut all_sequences);
1026 }
1027
1028 if all_sequences.is_empty() {
1029 None
1030 } else {
1031 Some(all_sequences)
1032 }
1033}
1034
1035#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1037fn extract_sequences_from_column(col: &arrow::array::ArrayRef, sequences: &mut Vec<Vec<u32>>) {
1038 use arrow::array::{Array, ListArray};
1039
1040 if let Some(list_arr) = col.as_any().downcast_ref::<ListArray>() {
1041 for i in 0..list_arr.len() {
1042 if list_arr.is_null(i) {
1043 continue;
1044 }
1045 let values = list_arr.value(i);
1046 let seq = extract_u32_from_array(&values);
1047 if !seq.is_empty() {
1048 sequences.push(seq);
1049 }
1050 }
1051 } else {
1052 let seq = extract_u32_from_array(col.as_ref());
1054 if !seq.is_empty() {
1055 sequences.push(seq);
1056 }
1057 }
1058}
1059
1060#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1062fn extract_u32_from_array(array: &dyn arrow::array::Array) -> Vec<u32> {
1063 use arrow::array::{Int32Array, Int64Array, UInt32Array};
1064
1065 if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
1066 arr.values().to_vec()
1067 } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
1068 arr.values().iter().map(|&v| v as u32).collect()
1069 } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
1070 arr.values().iter().map(|&v| v as u32).collect()
1071 } else {
1072 Vec::new()
1073 }
1074}
1075
1076#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1080fn resolve_text_column_name(text_column: &str, column_names: &[&str]) -> Result<String> {
1081 if column_names.contains(&text_column) {
1082 return Ok(text_column.to_string());
1083 }
1084 for &fallback in &["text", "content", "code"] {
1085 if column_names.contains(&fallback) {
1086 return Ok(fallback.to_string());
1087 }
1088 }
1089 Err(Error::ConfigError(format!(
1090 "No text column found in Parquet (tried '{text_column}', 'text', 'content', 'code'). Available: {column_names:?}"
1091 )))
1092}
1093
1094#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1096fn extract_text_column(
1097 dataset: &alimentar::ArrowDataset,
1098 text_column: &str,
1099 column_names: &[&str],
1100) -> Result<Vec<String>> {
1101 use alimentar::Dataset;
1102
1103 let col_name = resolve_text_column_name(text_column, column_names)?;
1104
1105 let schema = dataset.schema();
1106 let col_idx = schema
1107 .index_of(&col_name)
1108 .map_err(|e| Error::ConfigError(format!("Column '{col_name}' not found: {e}")))?;
1109
1110 let mut texts = Vec::new();
1111 for batch in dataset.iter() {
1112 let col = batch.column(col_idx);
1113 extract_strings_from_array(col, &col_name, &mut texts)?;
1114 }
1115 Ok(texts)
1116}
1117
1118#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
1120fn extract_strings_from_array(
1121 col: &arrow::array::ArrayRef,
1122 col_name: &str,
1123 texts: &mut Vec<String>,
1124) -> Result<()> {
1125 use arrow::array::{Array, StringArray};
1126
1127 let str_arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
1128 Error::ConfigError(format!(
1129 "Column '{col_name}' is not a string type (found {:?})",
1130 col.data_type()
1131 ))
1132 })?;
1133
1134 for i in 0..str_arr.len() {
1135 if !str_arr.is_null(i) {
1136 let text = str_arr.value(i);
1137 if !text.is_empty() {
1138 texts.push(text.to_string());
1139 }
1140 }
1141 }
1142 Ok(())
1143}
1144
1145#[cfg(not(all(not(target_arch = "wasm32"), feature = "parquet")))]
1147fn load_lm_batches_from_parquet(
1148 path: &std::path::Path,
1149 _tokenizer: &HfTokenizer,
1150 batch_size: usize,
1151 seq_len: usize,
1152 text_column: &str,
1153) -> Result<Vec<LMBatch>> {
1154 if !path.exists() {
1155 return Err(Error::Io(format!("Parquet path does not exist: {}", path.display())));
1156 }
1157 eprintln!(
1158 "Warning: Parquet LM loading requires the 'parquet' feature. \
1159 Build with: cargo build --features parquet"
1160 );
1161 eprintln!(
1162 " Alternatively, convert to JSONL: alimentar export {} -o train.jsonl --text-column {}",
1163 path.display(),
1164 text_column
1165 );
1166 create_demo_lm_batches(batch_size, seq_len)
1167}
1168
1169fn create_lm_batches_from_sequences(
1171 sequences: &[Vec<u32>],
1172 batch_size: usize,
1173 _seq_len: usize,
1174) -> Result<Vec<LMBatch>> {
1175 let num_batches = sequences.len().div_ceil(batch_size);
1177 let mut batches = Vec::with_capacity(num_batches);
1178 let pad_id = 0u32; let eos_id = 2u32; for chunk in sequences.chunks(batch_size) {
1182 let batch = LMBatch::from_sequences(chunk, pad_id, eos_id);
1183 batches.push(batch);
1184 }
1185
1186 Ok(batches)
1187}
1188
1189fn create_demo_lm_batches(batch_size: usize, seq_len: usize) -> Result<Vec<LMBatch>> {
1191 let mut batches = Vec::new();
1192
1193 for batch_idx in 0..4 {
1196 let mut sequences = Vec::new();
1197 for item in 0..batch_size {
1198 let offset = (batch_idx * batch_size + item) as u32;
1199 let seq: Vec<u32> = (0..seq_len.min(64))
1201 .map(|i| (offset + i as u32) % 1000 + 1) .collect();
1203 sequences.push(seq);
1204 }
1205
1206 let batch = LMBatch::from_sequences(&sequences, 0, 2);
1207 batches.push(batch);
1208 }
1209
1210 Ok(batches)
1211}
1212
1213fn is_manifest_format(yaml: &str) -> bool {
1218 yaml.lines().any(|line| line.starts_with("entrenar:") || line.starts_with("entrenar :"))
1219}
1220
1221pub fn load_config<P: AsRef<Path>>(config_path: P) -> Result<TrainSpec> {
1228 let yaml_content = fs::read_to_string(config_path.as_ref()).map_err(|e| {
1229 Error::ConfigError(format!(
1230 "Failed to read config file {}: {}",
1231 config_path.as_ref().display(),
1232 e
1233 ))
1234 })?;
1235
1236 if is_manifest_format(&yaml_content) {
1237 let manifest: yaml_mode::TrainingManifest = serde_yaml::from_str(&yaml_content)
1239 .map_err(|e| Error::ConfigError(format!("Failed to parse manifest YAML: {e}")))?;
1240
1241 yaml_mode::validate_manifest(&manifest)
1242 .map_err(|e| Error::ConfigError(format!("Invalid manifest: {e}")))?;
1243
1244 let bridge_result = yaml_mode::manifest_to_spec(&manifest)
1245 .map_err(|e| Error::ConfigError(format!("Manifest conversion failed: {e}")))?;
1246
1247 for warning in &bridge_result.warnings {
1248 eprintln!("Warning: {warning}");
1249 }
1250
1251 validate_config(&bridge_result.spec)
1252 .map_err(|e| Error::ConfigError(format!("Invalid config after conversion: {e}")))?;
1253
1254 Ok(bridge_result.spec)
1255 } else {
1256 let spec: TrainSpec = serde_yaml::from_str(&yaml_content)
1258 .map_err(|e| Error::ConfigError(format!("Failed to parse YAML config: {e}")))?;
1259
1260 validate_config(&spec).map_err(|e| Error::ConfigError(format!("Invalid config: {e}")))?;
1261
1262 Ok(spec)
1263 }
1264}
1265