1use std::collections::{BTreeMap, HashMap};
33use std::io::{BufWriter, Write};
34
35use chrono::{DateTime, Utc};
36use lance::{Error as LanceError, Result as LanceResult};
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39
40use crate::record::{ContextRecord, LifecycleQueryOptions, RecordFilters, LIFECYCLE_CONTRADICTED};
41use crate::store::ContextStore;
42
43pub const EXPORT_SCHEMA_VERSION: &str = "1";
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
48#[serde(rename_all = "lowercase")]
49pub enum ExportTask {
50 #[default]
52 Sft,
53 Preference,
55 Rollout,
57}
58
59impl ExportTask {
60 #[must_use]
61 pub fn as_str(self) -> &'static str {
62 match self {
63 Self::Sft => "sft",
64 Self::Preference => "preference",
65 Self::Rollout => "rollout",
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
72#[serde(rename_all = "lowercase")]
73pub enum PreferenceForm {
74 #[default]
76 Paired,
77 Unpaired,
79 Ranked,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq, Default)]
85pub enum GroupBy {
86 None,
88 #[default]
89 SessionId,
90 RunId,
91 Tenant,
92 Source,
93 BotId,
94 ExternalIdPrefix(String),
97}
98
99impl GroupBy {
100 fn label(&self) -> String {
101 match self {
102 Self::None => "none".to_string(),
103 Self::SessionId => "session_id".to_string(),
104 Self::RunId => "run_id".to_string(),
105 Self::Tenant => "tenant".to_string(),
106 Self::Source => "source".to_string(),
107 Self::BotId => "bot_id".to_string(),
108 Self::ExternalIdPrefix(delim) => format!("external_id_prefix:{delim}"),
109 }
110 }
111
112 fn key(&self, record: &ContextRecord) -> String {
114 let value = match self {
115 Self::None => None,
116 Self::SessionId => record.session_id.clone(),
117 Self::RunId => Some(record.run_id.clone()),
118 Self::Tenant => record.tenant.clone(),
119 Self::Source => record.source.clone(),
120 Self::BotId => record.bot_id.clone(),
121 Self::ExternalIdPrefix(delim) => record.external_id.as_ref().map(|external_id| {
122 external_id
123 .split_once(delim.as_str())
124 .map_or(external_id.as_str(), |(prefix, _)| prefix)
125 .to_string()
126 }),
127 };
128 value.unwrap_or_else(|| format!("__rec__{}", record.id))
129 }
130}
131
132#[derive(Clone)]
138pub struct SplitConfig {
139 pub eval_fraction: f64,
141 pub by: GroupBy,
143 pub seed: u64,
144}
145
146impl Default for SplitConfig {
147 fn default() -> Self {
148 Self {
149 eval_fraction: 0.1,
150 by: GroupBy::SessionId,
151 seed: 0,
152 }
153 }
154}
155
156#[derive(Clone)]
158pub struct ExportConfig {
159 pub task: ExportTask,
160 pub group_by: GroupBy,
161 pub preference_form: PreferenceForm,
162 pub filters: Option<RecordFilters>,
164 pub lifecycle: LifecycleQueryOptions,
166 pub dedup_threshold: Option<f32>,
168 pub decontaminate_against: Vec<Vec<f32>>,
170 pub decontaminate_threshold: Option<f32>,
172 pub min_reward: Option<f64>,
175 pub version: Option<u64>,
177 pub filters_summary: Option<Value>,
179 pub split: Option<SplitConfig>,
181 pub emit_stats: bool,
183}
184
185impl Default for ExportConfig {
186 fn default() -> Self {
187 Self {
188 task: ExportTask::Sft,
189 group_by: GroupBy::SessionId,
190 preference_form: PreferenceForm::Paired,
191 filters: None,
192 lifecycle: LifecycleQueryOptions::default(),
193 dedup_threshold: None,
194 decontaminate_against: Vec::new(),
195 decontaminate_threshold: None,
196 min_reward: None,
197 version: None,
198 filters_summary: None,
199 split: None,
200 emit_stats: false,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct Message {
208 pub role: String,
209 pub content: String,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct Provenance {
215 pub context_uri: String,
216 pub version: u64,
217 pub record_ids: Vec<String>,
218 #[serde(skip_serializing_if = "Vec::is_empty", default)]
219 pub external_ids: Vec<String>,
220 #[serde(skip_serializing_if = "Option::is_none", default)]
221 pub tenant: Option<String>,
222 #[serde(skip_serializing_if = "Option::is_none", default)]
223 pub source: Option<String>,
224 #[serde(skip_serializing_if = "Option::is_none", default)]
225 pub bot_id: Option<String>,
226 #[serde(skip_serializing_if = "Option::is_none", default)]
227 pub session_id: Option<String>,
228 #[serde(skip_serializing_if = "Option::is_none", default)]
229 pub run_id: Option<String>,
230 pub created_at_start: DateTime<Utc>,
231 pub created_at_end: DateTime<Utc>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct SftExample {
238 pub messages: Vec<Message>,
239 #[serde(skip_serializing_if = "Option::is_none", default)]
240 pub reward: Option<f64>,
241 pub provenance: Provenance,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct RankedCandidate {
247 pub messages: Vec<Message>,
248 pub rank: i64,
249 #[serde(skip_serializing_if = "Option::is_none", default)]
250 pub reward: Option<f64>,
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(tag = "form", rename_all = "lowercase")]
256pub enum PreferenceExample {
257 Paired {
258 prompt: Vec<Message>,
259 chosen: Vec<Message>,
260 rejected: Vec<Message>,
261 provenance: Provenance,
262 },
263 Unpaired {
264 prompt: Vec<Message>,
265 completion: Vec<Message>,
266 label: bool,
268 provenance: Provenance,
269 },
270 Ranked {
271 prompt: Vec<Message>,
272 candidates: Vec<RankedCandidate>,
273 provenance: Provenance,
274 },
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct RolloutResponse {
280 pub messages: Vec<Message>,
281 #[serde(skip_serializing_if = "Option::is_none", default)]
282 pub reward: Option<f64>,
283 #[serde(skip_serializing_if = "Option::is_none", default)]
284 pub reward_source: Option<String>,
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct RolloutExample {
291 pub prompt: Vec<Message>,
292 pub responses: Vec<RolloutResponse>,
293 #[serde(skip_serializing_if = "Option::is_none", default)]
294 pub group_id: Option<String>,
295 pub provenance: Provenance,
296}
297
298#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
300pub struct ExportCounts {
301 pub input_records: usize,
302 pub after_lifecycle: usize,
303 pub after_dedup: usize,
304 pub after_decontaminate: usize,
305 pub after_reward_filter: usize,
306 pub examples: usize,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SplitManifest {
313 pub side: String,
315 pub eval_fraction: f64,
316 pub by: String,
317 pub seed: u64,
318 pub complement_path: String,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ExportManifest {
325 pub context_uri: String,
326 pub version: u64,
327 pub task: String,
328 pub group_by: String,
329 pub schema_version: String,
330 #[serde(skip_serializing_if = "Option::is_none", default)]
331 pub preference_form: Option<String>,
332 #[serde(skip_serializing_if = "Option::is_none", default)]
333 pub filters: Option<Value>,
334 #[serde(skip_serializing_if = "Option::is_none", default)]
335 pub dedup_threshold: Option<f32>,
336 #[serde(skip_serializing_if = "Option::is_none", default)]
337 pub decontaminate_threshold: Option<f32>,
338 #[serde(skip_serializing_if = "Option::is_none", default)]
339 pub min_reward: Option<f64>,
340 #[serde(skip_serializing_if = "Option::is_none", default)]
341 pub split: Option<SplitManifest>,
342 #[serde(skip_serializing_if = "Option::is_none", default)]
343 pub created_at_start: Option<DateTime<Utc>>,
344 #[serde(skip_serializing_if = "Option::is_none", default)]
345 pub created_at_end: Option<DateTime<Utc>>,
346 pub source_record_ids: Vec<String>,
347 pub counts: ExportCounts,
348}
349
350#[derive(Debug, Clone, Default, Serialize, Deserialize)]
352pub struct Distribution {
353 pub count: usize,
354 pub min: f64,
355 pub median: f64,
356 pub p95: f64,
357 pub max: f64,
358 pub mean: f64,
359}
360
361impl Distribution {
362 fn from_sorted(mut values: Vec<f64>) -> Option<Self> {
363 if values.is_empty() {
364 return None;
365 }
366 values.sort_by(f64::total_cmp);
367 let count = values.len();
368 let sum: f64 = values.iter().sum();
369 let percentile = |p: f64| {
370 let idx = ((p * (count - 1) as f64).round() as usize).min(count - 1);
371 values[idx]
372 };
373 Some(Self {
374 count,
375 min: values[0],
376 median: percentile(0.5),
377 p95: percentile(0.95),
378 max: values[count - 1],
379 mean: sum / count as f64,
380 })
381 }
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct TokenStats {
387 #[serde(flatten)]
388 pub distribution: Distribution,
389 pub source: String,
391}
392
393#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
395pub struct ExcludedCounts {
396 pub lifecycle: usize,
397 pub reward_threshold: usize,
398 pub dedup: usize,
399 pub decontaminate: usize,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct ExportStats {
406 pub task: String,
407 pub examples: usize,
408 pub num_groups: usize,
409 pub by_role: BTreeMap<String, usize>,
411 pub by_source: BTreeMap<String, usize>,
413 pub by_tenant: BTreeMap<String, usize>,
415 pub records_per_group: Distribution,
416 #[serde(skip_serializing_if = "Option::is_none", default)]
417 pub tokens: Option<TokenStats>,
418 pub excluded: ExcludedCounts,
419 #[serde(skip_serializing_if = "Option::is_none", default)]
420 pub preference_form: Option<String>,
421 #[serde(skip_serializing_if = "Option::is_none", default)]
423 pub reward: Option<Distribution>,
424 #[serde(skip_serializing_if = "BTreeMap::is_empty", default)]
426 pub reward_sources: BTreeMap<String, usize>,
427}
428
429fn metadata_field<'a>(record: &'a ContextRecord, key: &str) -> Option<&'a Value> {
432 record.metadata.as_ref()?.get(key)
433}
434
435fn record_reward(record: &ContextRecord) -> Option<f64> {
436 metadata_field(record, "reward")?.as_f64()
437}
438
439fn record_reward_source(record: &ContextRecord) -> Option<String> {
440 Some(
441 metadata_field(record, "reward_source")?
442 .as_str()?
443 .to_string(),
444 )
445}
446
447fn record_group_id(record: &ContextRecord) -> Option<String> {
448 Some(metadata_field(record, "group_id")?.as_str()?.to_string())
449}
450
451fn record_label(record: &ContextRecord) -> Option<String> {
452 Some(metadata_field(record, "label")?.as_str()?.to_string())
453}
454
455fn record_rank(record: &ContextRecord) -> Option<i64> {
456 metadata_field(record, "rank")?.as_i64()
457}
458
459fn message_of(record: &ContextRecord) -> Message {
460 Message {
461 role: record.role.clone(),
462 content: record.text_payload.clone().unwrap_or_default(),
463 }
464}
465
466fn is_assistant(record: &ContextRecord) -> bool {
467 record.role.eq_ignore_ascii_case("assistant")
468}
469
470fn cosine_distance(a: &[f32], b: &[f32]) -> Option<f32> {
473 if a.len() != b.len() || a.is_empty() {
474 return None;
475 }
476 let mut dot = 0.0f32;
477 let mut na = 0.0f32;
478 let mut nb = 0.0f32;
479 for (x, y) in a.iter().zip(b.iter()) {
480 dot += x * y;
481 na += x * x;
482 nb += y * y;
483 }
484 if na == 0.0 || nb == 0.0 {
485 return None;
486 }
487 Some(1.0 - dot / (na.sqrt() * nb.sqrt()))
488}
489
490struct Group {
494 key: String,
495 records: Vec<ContextRecord>,
496}
497
498fn group_records(records: Vec<ContextRecord>, group_by: &GroupBy) -> Vec<Group> {
499 let mut order: Vec<String> = Vec::new();
500 let mut groups: HashMap<String, Vec<ContextRecord>> = HashMap::new();
501 for record in records {
502 let key = group_by.key(&record);
503 if !groups.contains_key(&key) {
504 order.push(key.clone());
505 }
506 groups.entry(key).or_default().push(record);
507 }
508
509 let mut result: Vec<Group> = order
510 .into_iter()
511 .map(|key| {
512 let mut records = groups.remove(&key).unwrap_or_default();
513 records.sort_by(|a, b| {
514 a.created_at
515 .cmp(&b.created_at)
516 .then_with(|| a.id.cmp(&b.id))
517 });
518 Group { key, records }
519 })
520 .collect();
521
522 result.sort_by(|a, b| {
524 let left = a.records.first();
525 let right = b.records.first();
526 match (left, right) {
527 (Some(l), Some(r)) => l
528 .created_at
529 .cmp(&r.created_at)
530 .then_with(|| l.id.cmp(&r.id)),
531 _ => a.key.cmp(&b.key),
532 }
533 });
534 result
535}
536
537fn provenance_for(records: &[ContextRecord], context_uri: &str, version: u64) -> Provenance {
538 let first = records.first();
539 let created_at_start = records
540 .iter()
541 .map(|r| r.created_at)
542 .min()
543 .unwrap_or_else(Utc::now);
544 let created_at_end = records
545 .iter()
546 .map(|r| r.created_at)
547 .max()
548 .unwrap_or(created_at_start);
549 Provenance {
550 context_uri: context_uri.to_string(),
551 version,
552 record_ids: records.iter().map(|r| r.id.clone()).collect(),
553 external_ids: records
554 .iter()
555 .filter_map(|r| r.external_id.clone())
556 .collect(),
557 tenant: first.and_then(|r| r.tenant.clone()),
558 source: first.and_then(|r| r.source.clone()),
559 bot_id: first.and_then(|r| r.bot_id.clone()),
560 session_id: first.and_then(|r| r.session_id.clone()),
561 run_id: first.map(|r| r.run_id.clone()),
562 created_at_start,
563 created_at_end,
564 }
565}
566
567fn split_prompt_candidates(records: &[ContextRecord]) -> (Vec<Message>, Vec<&ContextRecord>) {
570 let mut prompt = Vec::new();
571 let mut candidates = Vec::new();
572 for record in records {
573 if is_assistant(record) {
574 candidates.push(record);
575 } else if candidates.is_empty() {
576 prompt.push(message_of(record));
577 } else {
578 prompt.push(message_of(record));
581 }
582 }
583 (prompt, candidates)
584}
585
586fn write_line<T: Serialize>(writer: &mut impl Write, value: &T) -> LanceResult<()> {
587 let line = serde_json::to_string(value).map_err(|err| LanceError::io(err.to_string()))?;
588 writeln!(writer, "{line}")?;
589 Ok(())
590}
591
592fn dedup(records: Vec<ContextRecord>, threshold: f32) -> Vec<ContextRecord> {
598 let mut kept: Vec<ContextRecord> = Vec::new();
599 for record in records {
600 let is_dup = record.embedding.as_ref().is_some_and(|embedding| {
601 kept.iter().any(|other| {
602 other
603 .embedding
604 .as_ref()
605 .and_then(|other_embedding| cosine_distance(embedding, other_embedding))
606 .is_some_and(|distance| distance <= threshold)
607 })
608 });
609 if !is_dup {
610 kept.push(record);
611 }
612 }
613 kept
614}
615
616fn decontaminate(
619 records: Vec<ContextRecord>,
620 holdout: &[Vec<f32>],
621 threshold: f32,
622) -> Vec<ContextRecord> {
623 records
624 .into_iter()
625 .filter(|record| match &record.embedding {
626 Some(embedding) => !holdout.iter().any(|held| {
627 cosine_distance(embedding, held).is_some_and(|distance| distance <= threshold)
628 }),
629 None => true,
630 })
631 .collect()
632}
633
634fn curate(
637 records: Vec<ContextRecord>,
638 config: &ExportConfig,
639) -> (Vec<ContextRecord>, ExportCounts) {
640 let mut counts = ExportCounts {
641 input_records: records.len(),
642 ..ExportCounts::default()
643 };
644
645 let lifecycle: Vec<ContextRecord> = records
648 .into_iter()
649 .filter(|record| record.lifecycle_status != LIFECYCLE_CONTRADICTED)
650 .collect();
651 counts.after_lifecycle = lifecycle.len();
652
653 let rewarded: Vec<ContextRecord> = match config.min_reward {
654 Some(min) => lifecycle
655 .into_iter()
656 .filter(|record| record_reward(record).is_none_or(|reward| reward >= min))
657 .collect(),
658 None => lifecycle,
659 };
660 counts.after_reward_filter = rewarded.len();
661
662 let deduped = match config.dedup_threshold {
663 Some(threshold) => dedup(rewarded, threshold),
664 None => rewarded,
665 };
666 counts.after_dedup = deduped.len();
667
668 let clean = match config.decontaminate_threshold {
669 Some(threshold) if !config.decontaminate_against.is_empty() => {
670 decontaminate(deduped, &config.decontaminate_against, threshold)
671 }
672 _ => deduped,
673 };
674 counts.after_decontaminate = clean.len();
675
676 (clean, counts)
677}
678
679fn write_sft(
682 groups: &[Group],
683 writer: &mut impl Write,
684 context_uri: &str,
685 version: u64,
686) -> LanceResult<usize> {
687 let mut written = 0;
688 for group in groups {
689 if group.records.is_empty() {
690 continue;
691 }
692 let example = SftExample {
693 messages: group.records.iter().map(message_of).collect(),
694 reward: group
695 .records
696 .iter()
697 .filter_map(record_reward)
698 .reduce(f64::max),
699 provenance: provenance_for(&group.records, context_uri, version),
700 };
701 write_line(writer, &example)?;
702 written += 1;
703 }
704 Ok(written)
705}
706
707fn preference_score(record: &ContextRecord) -> f64 {
710 match record_label(record).as_deref() {
711 Some("chosen") => f64::INFINITY,
712 Some("rejected") => f64::NEG_INFINITY,
713 _ => record_reward(record).unwrap_or(0.0),
714 }
715}
716
717fn unpaired_label(record: &ContextRecord, min_reward: Option<f64>) -> Option<bool> {
718 match record_label(record).as_deref() {
719 Some("chosen") => Some(true),
720 Some("rejected") => Some(false),
721 _ => match (record_reward(record), min_reward) {
722 (Some(reward), Some(min)) => Some(reward >= min),
723 _ => None,
724 },
725 }
726}
727
728fn write_preference(
729 groups: &[Group],
730 writer: &mut impl Write,
731 form: PreferenceForm,
732 min_reward: Option<f64>,
733 context_uri: &str,
734 version: u64,
735) -> LanceResult<usize> {
736 let mut written = 0;
737 for group in groups {
738 let (prompt, candidates) = split_prompt_candidates(&group.records);
739 if candidates.is_empty() {
740 continue;
741 }
742 let provenance = provenance_for(&group.records, context_uri, version);
743
744 match form {
745 PreferenceForm::Paired => {
746 if candidates.len() < 2 {
747 continue;
748 }
749 let mut best = candidates[0];
750 let mut worst = candidates[0];
751 for candidate in &candidates {
752 if preference_score(candidate) > preference_score(best) {
753 best = candidate;
754 }
755 if preference_score(candidate) < preference_score(worst) {
756 worst = candidate;
757 }
758 }
759 if best.id == worst.id {
760 continue; }
762 let example = PreferenceExample::Paired {
763 prompt,
764 chosen: vec![message_of(best)],
765 rejected: vec![message_of(worst)],
766 provenance,
767 };
768 write_line(writer, &example)?;
769 written += 1;
770 }
771 PreferenceForm::Unpaired => {
772 for candidate in &candidates {
773 let Some(label) = unpaired_label(candidate, min_reward) else {
774 continue;
775 };
776 let example = PreferenceExample::Unpaired {
777 prompt: prompt.clone(),
778 completion: vec![message_of(candidate)],
779 label,
780 provenance: provenance.clone(),
781 };
782 write_line(writer, &example)?;
783 written += 1;
784 }
785 }
786 PreferenceForm::Ranked => {
787 let mut ordered: Vec<&ContextRecord> = candidates.clone();
788 ordered.sort_by(|a, b| {
789 let rank_a = record_rank(a);
790 let rank_b = record_rank(b);
791 match (rank_a, rank_b) {
792 (Some(ra), Some(rb)) => ra.cmp(&rb),
793 _ => preference_score(b).total_cmp(&preference_score(a)),
794 }
795 });
796 let candidates: Vec<RankedCandidate> = ordered
797 .iter()
798 .enumerate()
799 .map(|(index, candidate)| RankedCandidate {
800 messages: vec![message_of(candidate)],
801 rank: record_rank(candidate).unwrap_or((index + 1) as i64),
802 reward: record_reward(candidate),
803 })
804 .collect();
805 let example = PreferenceExample::Ranked {
806 prompt,
807 candidates,
808 provenance,
809 };
810 write_line(writer, &example)?;
811 written += 1;
812 }
813 }
814 }
815 Ok(written)
816}
817
818fn write_rollout(
819 groups: &[Group],
820 writer: &mut impl Write,
821 context_uri: &str,
822 version: u64,
823) -> LanceResult<usize> {
824 let mut written = 0;
825 for group in groups {
826 let (prompt, candidates) = split_prompt_candidates(&group.records);
827 if candidates.is_empty() {
828 continue;
829 }
830 let responses: Vec<RolloutResponse> = candidates
831 .iter()
832 .map(|candidate| RolloutResponse {
833 messages: vec![message_of(candidate)],
834 reward: record_reward(candidate),
835 reward_source: record_reward_source(candidate),
836 })
837 .collect();
838 let group_id = candidates
839 .iter()
840 .find_map(|candidate| record_group_id(candidate));
841 let example = RolloutExample {
842 prompt,
843 responses,
844 group_id,
845 provenance: provenance_for(&group.records, context_uri, version),
846 };
847 write_line(writer, &example)?;
848 written += 1;
849 }
850 Ok(written)
851}
852
853fn summarize_groups(
854 groups: &[Group],
855) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>, Vec<String>) {
856 let mut start: Option<DateTime<Utc>> = None;
857 let mut end: Option<DateTime<Utc>> = None;
858 let mut ids = Vec::new();
859 for group in groups {
860 for record in &group.records {
861 start = Some(start.map_or(record.created_at, |s| s.min(record.created_at)));
862 end = Some(end.map_or(record.created_at, |e| e.max(record.created_at)));
863 ids.push(record.id.clone());
864 }
865 }
866 (start, end, ids)
867}
868
869fn stable_hash(seed: u64, key: &str) -> u64 {
872 let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
873 let mix = |hash: &mut u64, byte: u8| {
874 *hash ^= u64::from(byte);
875 *hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
876 };
877 for byte in seed.to_le_bytes() {
878 mix(&mut hash, byte);
879 }
880 for byte in key.as_bytes() {
881 mix(&mut hash, *byte);
882 }
883 hash
884}
885
886fn split_fraction(seed: u64, key: &str) -> f64 {
888 let mut hash = stable_hash(seed, key);
891 hash ^= hash >> 33;
892 hash = hash.wrapping_mul(0xff51_afd7_ed55_8ccd);
893 hash ^= hash >> 33;
894 hash = hash.wrapping_mul(0xc4ce_b9fe_1a85_ec53);
895 hash ^= hash >> 33;
896 (hash >> 11) as f64 / (1u64 << 53) as f64
897}
898
899fn split_paths(output_path: &str) -> (String, String) {
902 let slash = output_path.rfind('/').map_or(0, |i| i + 1);
903 match output_path[slash..].rfind('.') {
904 Some(rel_dot) => {
905 let dot = slash + rel_dot;
906 let (stem, ext) = output_path.split_at(dot);
907 (format!("{stem}.train{ext}"), format!("{stem}.eval{ext}"))
908 }
909 None => (
910 format!("{output_path}.train"),
911 format!("{output_path}.eval"),
912 ),
913 }
914}
915
916#[allow(clippy::too_many_arguments)]
919fn emit_export(
920 records: Vec<ContextRecord>,
921 config: &ExportConfig,
922 context_uri: &str,
923 version: u64,
924 mut counts: ExportCounts,
925 output_path: &str,
926 split: Option<SplitManifest>,
927) -> LanceResult<ExportManifest> {
928 let groups = group_records(records, &config.group_by);
929 let (created_at_start, created_at_end, source_record_ids) = summarize_groups(&groups);
930
931 let file = std::fs::File::create(output_path)?;
932 let mut writer = BufWriter::new(file);
933 let examples = match config.task {
934 ExportTask::Sft => write_sft(&groups, &mut writer, context_uri, version)?,
935 ExportTask::Preference => write_preference(
936 &groups,
937 &mut writer,
938 config.preference_form,
939 config.min_reward,
940 context_uri,
941 version,
942 )?,
943 ExportTask::Rollout => write_rollout(&groups, &mut writer, context_uri, version)?,
944 };
945 writer.flush()?;
946 counts.examples = examples;
947
948 let manifest = ExportManifest {
949 context_uri: context_uri.to_string(),
950 version,
951 task: config.task.as_str().to_string(),
952 group_by: config.group_by.label(),
953 schema_version: EXPORT_SCHEMA_VERSION.to_string(),
954 preference_form: matches!(config.task, ExportTask::Preference).then(|| {
955 match config.preference_form {
956 PreferenceForm::Paired => "paired",
957 PreferenceForm::Unpaired => "unpaired",
958 PreferenceForm::Ranked => "ranked",
959 }
960 .to_string()
961 }),
962 filters: config.filters_summary.clone(),
963 dedup_threshold: config.dedup_threshold,
964 decontaminate_threshold: config.decontaminate_threshold,
965 min_reward: config.min_reward,
966 split,
967 created_at_start,
968 created_at_end,
969 source_record_ids,
970 counts,
971 };
972
973 let manifest_json =
974 serde_json::to_string_pretty(&manifest).map_err(|err| LanceError::io(err.to_string()))?;
975 std::fs::write(format!("{output_path}.manifest.json"), manifest_json)?;
976
977 if config.emit_stats {
978 let stats = compute_stats(&groups, &counts, examples, config);
979 let stats_json =
980 serde_json::to_string_pretty(&stats).map_err(|err| LanceError::io(err.to_string()))?;
981 std::fs::write(format!("{output_path}.stats.json"), stats_json)?;
982 }
983
984 Ok(manifest)
985}
986
987fn compute_stats(
989 groups: &[Group],
990 counts: &ExportCounts,
991 examples: usize,
992 config: &ExportConfig,
993) -> ExportStats {
994 let mut by_role: BTreeMap<String, usize> = BTreeMap::new();
995 let mut by_source: BTreeMap<String, usize> = BTreeMap::new();
996 let mut by_tenant: BTreeMap<String, usize> = BTreeMap::new();
997 let mut token_values: Vec<f64> = Vec::new();
998 let mut used_tokens_used = false;
999 let mut used_fallback = false;
1000 let mut reward_values: Vec<f64> = Vec::new();
1001 let mut reward_sources: BTreeMap<String, usize> = BTreeMap::new();
1002 let mut records_per_group: Vec<f64> = Vec::new();
1003
1004 for group in groups {
1005 records_per_group.push(group.records.len() as f64);
1006 for record in &group.records {
1007 *by_role.entry(record.role.clone()).or_insert(0) += 1;
1008 *by_source
1009 .entry(
1010 record
1011 .source
1012 .clone()
1013 .unwrap_or_else(|| "__none__".to_string()),
1014 )
1015 .or_insert(0) += 1;
1016 *by_tenant
1017 .entry(
1018 record
1019 .tenant
1020 .clone()
1021 .unwrap_or_else(|| "__none__".to_string()),
1022 )
1023 .or_insert(0) += 1;
1024
1025 match record.state_metadata.as_ref().and_then(|m| m.tokens_used) {
1026 Some(tokens) if tokens >= 0 => {
1027 token_values.push(f64::from(tokens));
1028 used_tokens_used = true;
1029 }
1030 _ => {
1031 let proxy = record
1033 .text_payload
1034 .as_deref()
1035 .map_or(0, |text| text.split_whitespace().count());
1036 token_values.push(proxy as f64);
1037 used_fallback = true;
1038 }
1039 }
1040
1041 if let Some(reward) = record_reward(record) {
1042 reward_values.push(reward);
1043 }
1044 if let Some(source) = record_reward_source(record) {
1045 *reward_sources.entry(source).or_insert(0) += 1;
1046 }
1047 }
1048 }
1049
1050 let tokens = Distribution::from_sorted(token_values).map(|distribution| TokenStats {
1051 distribution,
1052 source: match (used_tokens_used, used_fallback) {
1053 (true, true) => "mixed",
1054 (true, false) => "tokens_used",
1055 _ => "length_proxy",
1056 }
1057 .to_string(),
1058 });
1059
1060 let excluded = ExcludedCounts {
1061 lifecycle: counts.input_records.saturating_sub(counts.after_lifecycle),
1062 reward_threshold: counts
1063 .after_lifecycle
1064 .saturating_sub(counts.after_reward_filter),
1065 dedup: counts
1066 .after_reward_filter
1067 .saturating_sub(counts.after_dedup),
1068 decontaminate: counts
1069 .after_dedup
1070 .saturating_sub(counts.after_decontaminate),
1071 };
1072
1073 ExportStats {
1074 task: config.task.as_str().to_string(),
1075 examples,
1076 num_groups: groups.len(),
1077 by_role,
1078 by_source,
1079 by_tenant,
1080 records_per_group: Distribution::from_sorted(records_per_group).unwrap_or_default(),
1081 tokens,
1082 excluded,
1083 preference_form: matches!(config.task, ExportTask::Preference).then(|| {
1084 match config.preference_form {
1085 PreferenceForm::Paired => "paired",
1086 PreferenceForm::Unpaired => "unpaired",
1087 PreferenceForm::Ranked => "ranked",
1088 }
1089 .to_string()
1090 }),
1091 reward: Distribution::from_sorted(reward_values),
1092 reward_sources,
1093 }
1094}
1095
1096impl ContextStore {
1097 pub async fn export_training(
1105 &mut self,
1106 config: &ExportConfig,
1107 output_path: &str,
1108 ) -> LanceResult<ExportManifest> {
1109 let restore = match config.version {
1110 Some(target) => {
1111 let original = self.version();
1112 self.checkout(target).await?;
1113 Some(original)
1114 }
1115 None => None,
1116 };
1117
1118 let result = self.export_inner(config, output_path).await;
1119
1120 if let Some(original) = restore {
1121 self.checkout(original).await?;
1122 }
1123 result
1124 }
1125
1126 async fn export_inner(
1127 &self,
1128 config: &ExportConfig,
1129 output_path: &str,
1130 ) -> LanceResult<ExportManifest> {
1131 let context_uri = self.uri().to_string();
1132 let version = self.version();
1133
1134 let records = self
1135 .list_filtered_with_options(
1136 None,
1137 None,
1138 config.filters.as_ref(),
1139 config.lifecycle.clone(),
1140 )
1141 .await?;
1142 let (curated, counts) = curate(records, config);
1143
1144 let Some(split) = &config.split else {
1145 return emit_export(
1146 curated,
1147 config,
1148 &context_uri,
1149 version,
1150 counts,
1151 output_path,
1152 None,
1153 );
1154 };
1155
1156 let (train_path, eval_path) = split_paths(output_path);
1159 let mut train_records = Vec::new();
1160 let mut eval_records = Vec::new();
1161 for record in curated {
1162 let key = split.by.key(&record);
1163 if split_fraction(split.seed, &key) < split.eval_fraction {
1164 eval_records.push(record);
1165 } else {
1166 train_records.push(record);
1167 }
1168 }
1169
1170 let train_manifest = emit_export(
1171 train_records,
1172 config,
1173 &context_uri,
1174 version,
1175 counts,
1176 &train_path,
1177 Some(SplitManifest {
1178 side: "train".to_string(),
1179 eval_fraction: split.eval_fraction,
1180 by: split.by.label(),
1181 seed: split.seed,
1182 complement_path: eval_path.clone(),
1183 }),
1184 )?;
1185 emit_export(
1186 eval_records,
1187 config,
1188 &context_uri,
1189 version,
1190 counts,
1191 &eval_path,
1192 Some(SplitManifest {
1193 side: "eval".to_string(),
1194 eval_fraction: split.eval_fraction,
1195 by: split.by.label(),
1196 seed: split.seed,
1197 complement_path: train_path.clone(),
1198 }),
1199 )?;
1200 Ok(train_manifest)
1201 }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206 use super::*;
1207 use crate::record::LIFECYCLE_ACTIVE;
1208 use crate::store::{ContextStore, ContextStoreOptions};
1209 use chrono::TimeZone;
1210 use serde_json::json;
1211 use tempfile::TempDir;
1212
1213 const DIM: i32 = 4;
1214
1215 async fn open_store(dir: &TempDir) -> ContextStore {
1216 let uri = dir.path().join("ctx.lance").to_string_lossy().to_string();
1217 ContextStore::open_with_options(
1218 &uri,
1219 ContextStoreOptions {
1220 embedding_dim: Some(DIM),
1221 ..Default::default()
1222 },
1223 )
1224 .await
1225 .unwrap()
1226 }
1227
1228 fn rec(id: &str, role: &str, text: &str, secs: i64) -> ContextRecord {
1229 ContextRecord {
1230 id: id.to_string(),
1231 external_id: None,
1232 run_id: "run".to_string(),
1233 bot_id: None,
1234 session_id: Some("s1".to_string()),
1235 tenant: None,
1236 source: None,
1237 created_at: Utc.timestamp_opt(1_700_000_000 + secs, 0).unwrap(),
1238 role: role.to_string(),
1239 state_metadata: None,
1240 metadata: None,
1241 relationships: Vec::new(),
1242 expires_at: None,
1243 retention_policy: None,
1244 lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
1245 retired_at: None,
1246 retired_reason: None,
1247 supersedes_id: None,
1248 superseded_by_id: None,
1249 content_type: "text/plain".to_string(),
1250 text_payload: Some(text.to_string()),
1251 binary_payload: None,
1252 embedding: None,
1253 }
1254 }
1255
1256 fn emb(lead: &[f32]) -> Vec<f32> {
1257 let mut v = vec![0.0f32; DIM as usize];
1258 for (i, x) in lead.iter().enumerate() {
1259 v[i] = *x;
1260 }
1261 v
1262 }
1263
1264 fn read_lines(path: &str) -> Vec<Value> {
1265 std::fs::read_to_string(path)
1266 .unwrap()
1267 .lines()
1268 .filter(|l| !l.trim().is_empty())
1269 .map(|l| serde_json::from_str(l).unwrap())
1270 .collect()
1271 }
1272
1273 fn read_manifest(path: &str) -> ExportManifest {
1274 let raw = std::fs::read_to_string(format!("{path}.manifest.json")).unwrap();
1275 serde_json::from_str(&raw).unwrap()
1276 }
1277
1278 fn out_path(dir: &TempDir) -> String {
1279 dir.path().join("out.jsonl").to_string_lossy().to_string()
1280 }
1281
1282 #[test]
1283 fn sft_groups_session_into_ordered_conversation() {
1284 let dir = TempDir::new().unwrap();
1285 let runtime = tokio::runtime::Runtime::new().unwrap();
1286 runtime.block_on(async {
1287 let mut store = open_store(&dir).await;
1288 store
1290 .add(&[
1291 rec("r2", "assistant", "hi there", 2),
1292 rec("r1", "user", "hello", 1),
1293 rec("r3", "user", "bye", 3),
1294 ])
1295 .await
1296 .unwrap();
1297
1298 let out = out_path(&dir);
1299 let manifest = store
1300 .export_training(
1301 &ExportConfig {
1302 task: ExportTask::Sft,
1303 group_by: GroupBy::SessionId,
1304 ..Default::default()
1305 },
1306 &out,
1307 )
1308 .await
1309 .unwrap();
1310
1311 let lines = read_lines(&out);
1312 assert_eq!(lines.len(), 1);
1313 let messages = lines[0]["messages"].as_array().unwrap();
1314 let contents: Vec<&str> = messages
1315 .iter()
1316 .map(|m| m["content"].as_str().unwrap())
1317 .collect();
1318 assert_eq!(contents, ["hello", "hi there", "bye"]);
1319 assert_eq!(manifest.task, "sft");
1320 assert_eq!(manifest.counts.examples, 1);
1321 assert_eq!(lines[0]["provenance"]["version"], json!(manifest.version));
1322
1323 let manifest_file = read_manifest(&out);
1325 assert_eq!(manifest_file.task, "sft");
1326 assert_eq!(manifest_file.counts.examples, 1);
1327 assert_eq!(manifest_file.schema_version, EXPORT_SCHEMA_VERSION);
1328 assert_eq!(manifest_file.source_record_ids.len(), 3);
1329 });
1330 }
1331
1332 #[test]
1333 fn sft_rejection_sampling_filters_low_reward() {
1334 let dir = TempDir::new().unwrap();
1335 let runtime = tokio::runtime::Runtime::new().unwrap();
1336 runtime.block_on(async {
1337 let mut store = open_store(&dir).await;
1338 let mut good = rec("g", "assistant", "good", 1);
1339 good.session_id = Some("a".to_string());
1340 good.metadata = Some(json!({"reward": 0.9}));
1341 let mut bad = rec("b", "assistant", "bad", 2);
1342 bad.session_id = Some("b".to_string());
1343 bad.metadata = Some(json!({"reward": 0.1}));
1344 store.add(&[good, bad]).await.unwrap();
1345
1346 let out = out_path(&dir);
1347 let manifest = store
1348 .export_training(
1349 &ExportConfig {
1350 task: ExportTask::Sft,
1351 group_by: GroupBy::SessionId,
1352 min_reward: Some(0.5),
1353 ..Default::default()
1354 },
1355 &out,
1356 )
1357 .await
1358 .unwrap();
1359
1360 let lines = read_lines(&out);
1361 assert_eq!(lines.len(), 1, "only the high-reward record survives");
1362 assert_eq!(lines[0]["messages"][0]["content"], "good");
1363 assert_eq!(manifest.counts.after_reward_filter, 1);
1364 assert_eq!(manifest.min_reward, Some(0.5));
1365 });
1366 }
1367
1368 #[test]
1369 fn preference_paired_uses_reward_for_chosen_rejected() {
1370 let dir = TempDir::new().unwrap();
1371 let runtime = tokio::runtime::Runtime::new().unwrap();
1372 runtime.block_on(async {
1373 let mut store = open_store(&dir).await;
1374 let prompt = rec("p", "user", "question", 1);
1375 let mut hi = rec("hi", "assistant", "great answer", 2);
1376 hi.metadata = Some(json!({"reward": 0.9}));
1377 let mut lo = rec("lo", "assistant", "poor answer", 3);
1378 lo.metadata = Some(json!({"reward": 0.2}));
1379 store.add(&[prompt, hi, lo]).await.unwrap();
1380
1381 let out = out_path(&dir);
1382 store
1383 .export_training(
1384 &ExportConfig {
1385 task: ExportTask::Preference,
1386 preference_form: PreferenceForm::Paired,
1387 group_by: GroupBy::SessionId,
1388 ..Default::default()
1389 },
1390 &out,
1391 )
1392 .await
1393 .unwrap();
1394
1395 let lines = read_lines(&out);
1396 assert_eq!(lines.len(), 1);
1397 assert_eq!(lines[0]["form"], "paired");
1398 assert_eq!(lines[0]["prompt"][0]["content"], "question");
1399 assert_eq!(lines[0]["chosen"][0]["content"], "great answer");
1400 assert_eq!(lines[0]["rejected"][0]["content"], "poor answer");
1401 });
1402 }
1403
1404 #[test]
1405 fn preference_unpaired_uses_kto_labels() {
1406 let dir = TempDir::new().unwrap();
1407 let runtime = tokio::runtime::Runtime::new().unwrap();
1408 runtime.block_on(async {
1409 let mut store = open_store(&dir).await;
1410 let prompt = rec("p", "user", "q", 1);
1411 let mut a = rec("a", "assistant", "yes", 2);
1412 a.metadata = Some(json!({"label": "chosen"}));
1413 let mut b = rec("b", "assistant", "no", 3);
1414 b.metadata = Some(json!({"label": "rejected"}));
1415 store.add(&[prompt, a, b]).await.unwrap();
1416
1417 let out = out_path(&dir);
1418 store
1419 .export_training(
1420 &ExportConfig {
1421 task: ExportTask::Preference,
1422 preference_form: PreferenceForm::Unpaired,
1423 group_by: GroupBy::SessionId,
1424 ..Default::default()
1425 },
1426 &out,
1427 )
1428 .await
1429 .unwrap();
1430
1431 let lines = read_lines(&out);
1432 assert_eq!(lines.len(), 2);
1433 assert!(lines.iter().all(|l| l["form"] == "unpaired"));
1434 let chosen = lines
1435 .iter()
1436 .find(|l| l["completion"][0]["content"] == "yes")
1437 .unwrap();
1438 assert_eq!(chosen["label"], json!(true));
1439 let rejected = lines
1440 .iter()
1441 .find(|l| l["completion"][0]["content"] == "no")
1442 .unwrap();
1443 assert_eq!(rejected["label"], json!(false));
1444 });
1445 }
1446
1447 #[test]
1448 fn preference_ranked_orders_by_rank() {
1449 let dir = TempDir::new().unwrap();
1450 let runtime = tokio::runtime::Runtime::new().unwrap();
1451 runtime.block_on(async {
1452 let mut store = open_store(&dir).await;
1453 let prompt = rec("p", "user", "q", 1);
1454 let mut second = rec("c2", "assistant", "second", 2);
1455 second.metadata = Some(json!({"rank": 2}));
1456 let mut first = rec("c1", "assistant", "first", 3);
1457 first.metadata = Some(json!({"rank": 1}));
1458 store.add(&[prompt, second, first]).await.unwrap();
1459
1460 let out = out_path(&dir);
1461 store
1462 .export_training(
1463 &ExportConfig {
1464 task: ExportTask::Preference,
1465 preference_form: PreferenceForm::Ranked,
1466 group_by: GroupBy::SessionId,
1467 ..Default::default()
1468 },
1469 &out,
1470 )
1471 .await
1472 .unwrap();
1473
1474 let lines = read_lines(&out);
1475 assert_eq!(lines.len(), 1);
1476 assert_eq!(lines[0]["form"], "ranked");
1477 let cands = lines[0]["candidates"].as_array().unwrap();
1478 assert_eq!(cands[0]["messages"][0]["content"], "first");
1479 assert_eq!(cands[0]["rank"], json!(1));
1480 assert_eq!(cands[1]["messages"][0]["content"], "second");
1481 });
1482 }
1483
1484 #[test]
1485 fn rollout_groups_responses_with_rewards() {
1486 let dir = TempDir::new().unwrap();
1487 let runtime = tokio::runtime::Runtime::new().unwrap();
1488 runtime.block_on(async {
1489 let mut store = open_store(&dir).await;
1490 let prompt = rec("p", "user", "solve x", 1);
1491 let mut r1 = rec("r1", "assistant", "ans1", 2);
1492 r1.metadata =
1493 Some(json!({"reward": 1.0, "reward_source": "verifier", "group_id": "g1"}));
1494 let mut r2 = rec("r2", "assistant", "ans2", 3);
1495 r2.metadata =
1496 Some(json!({"reward": 0.0, "reward_source": "verifier", "group_id": "g1"}));
1497 store.add(&[prompt, r1, r2]).await.unwrap();
1498
1499 let out = out_path(&dir);
1500 store
1501 .export_training(
1502 &ExportConfig {
1503 task: ExportTask::Rollout,
1504 group_by: GroupBy::SessionId,
1505 ..Default::default()
1506 },
1507 &out,
1508 )
1509 .await
1510 .unwrap();
1511
1512 let lines = read_lines(&out);
1513 assert_eq!(lines.len(), 1);
1514 assert_eq!(lines[0]["group_id"], "g1");
1515 assert_eq!(lines[0]["prompt"][0]["content"], "solve x");
1516 let responses = lines[0]["responses"].as_array().unwrap();
1517 assert_eq!(responses.len(), 2);
1518 assert_eq!(responses[0]["reward"], json!(1.0));
1519 assert_eq!(responses[0]["reward_source"], "verifier");
1520 });
1521 }
1522
1523 #[test]
1524 fn dedup_collapses_near_duplicates() {
1525 let dir = TempDir::new().unwrap();
1526 let runtime = tokio::runtime::Runtime::new().unwrap();
1527 runtime.block_on(async {
1528 let mut store = open_store(&dir).await;
1529 let mut a = rec("a", "user", "dup one", 1);
1530 a.session_id = Some("a".to_string());
1531 a.embedding = Some(emb(&[1.0, 0.0]));
1532 let mut b = rec("b", "user", "dup two", 2);
1533 b.session_id = Some("b".to_string());
1534 b.embedding = Some(emb(&[1.0, 0.0])); store.add(&[a, b]).await.unwrap();
1536
1537 let out = out_path(&dir);
1538 let manifest = store
1539 .export_training(
1540 &ExportConfig {
1541 task: ExportTask::Sft,
1542 group_by: GroupBy::SessionId,
1543 dedup_threshold: Some(0.01),
1544 ..Default::default()
1545 },
1546 &out,
1547 )
1548 .await
1549 .unwrap();
1550
1551 assert_eq!(
1552 manifest.counts.after_dedup, 1,
1553 "one near-duplicate collapsed"
1554 );
1555 assert_eq!(read_lines(&out).len(), 1);
1556 });
1557 }
1558
1559 #[test]
1560 fn decontaminate_drops_holdout_matches() {
1561 let dir = TempDir::new().unwrap();
1562 let runtime = tokio::runtime::Runtime::new().unwrap();
1563 runtime.block_on(async {
1564 let mut store = open_store(&dir).await;
1565 let mut keep = rec("k", "user", "keep", 1);
1566 keep.session_id = Some("k".to_string());
1567 keep.embedding = Some(emb(&[0.0, 1.0]));
1568 let mut leak = rec("l", "user", "leak", 2);
1569 leak.session_id = Some("l".to_string());
1570 leak.embedding = Some(emb(&[1.0, 0.0]));
1571 store.add(&[keep, leak]).await.unwrap();
1572
1573 let out = out_path(&dir);
1574 let manifest = store
1575 .export_training(
1576 &ExportConfig {
1577 task: ExportTask::Sft,
1578 group_by: GroupBy::SessionId,
1579 decontaminate_against: vec![emb(&[1.0, 0.0])], decontaminate_threshold: Some(0.01),
1581 ..Default::default()
1582 },
1583 &out,
1584 )
1585 .await
1586 .unwrap();
1587
1588 assert_eq!(manifest.counts.after_decontaminate, 1);
1589 let lines = read_lines(&out);
1590 assert_eq!(lines.len(), 1);
1591 assert_eq!(lines[0]["messages"][0]["content"], "keep");
1592 });
1593 }
1594
1595 #[test]
1596 fn curation_drops_contradicted_records() {
1597 let dir = TempDir::new().unwrap();
1598 let runtime = tokio::runtime::Runtime::new().unwrap();
1599 runtime.block_on(async {
1600 let mut store = open_store(&dir).await;
1601 let good = rec("g", "user", "valid", 1);
1602 let mut bad = rec("c", "user", "contradicted", 2);
1603 bad.session_id = Some("other".to_string());
1604 bad.lifecycle_status = LIFECYCLE_CONTRADICTED.to_string();
1605 store.add(&[good, bad]).await.unwrap();
1606
1607 let out = out_path(&dir);
1608 let manifest = store
1609 .export_training(&ExportConfig::default(), &out)
1610 .await
1611 .unwrap();
1612
1613 assert_eq!(manifest.counts.after_lifecycle, 1);
1614 let lines = read_lines(&out);
1615 assert!(lines
1616 .iter()
1617 .all(|l| l["messages"][0]["content"] != "contradicted"));
1618 });
1619 }
1620
1621 #[test]
1622 fn version_pinning_exports_old_state_and_restores() {
1623 let dir = TempDir::new().unwrap();
1624 let runtime = tokio::runtime::Runtime::new().unwrap();
1625 runtime.block_on(async {
1626 let mut store = open_store(&dir).await;
1627 store.add(&[rec("r1", "user", "first", 1)]).await.unwrap();
1628 store.compact(None).await.unwrap();
1629 let pinned = store.version();
1630
1631 store.add(&[rec("r2", "user", "second", 2)]).await.unwrap();
1632 store.compact(None).await.unwrap();
1633 let latest = store.version();
1634
1635 let out = out_path(&dir);
1636 let manifest = store
1637 .export_training(
1638 &ExportConfig {
1639 task: ExportTask::Sft,
1640 group_by: GroupBy::None,
1641 version: Some(pinned),
1642 ..Default::default()
1643 },
1644 &out,
1645 )
1646 .await
1647 .unwrap();
1648
1649 assert_eq!(manifest.version, pinned);
1650 assert_eq!(
1651 store.version(),
1652 latest,
1653 "store restored after pinned export"
1654 );
1655 });
1656 }
1657
1658 #[test]
1659 fn export_is_reproducible() {
1660 let dir = TempDir::new().unwrap();
1661 let runtime = tokio::runtime::Runtime::new().unwrap();
1662 runtime.block_on(async {
1663 let mut store = open_store(&dir).await;
1664 store
1665 .add(&[
1666 rec("r1", "user", "a", 1),
1667 rec("r2", "assistant", "b", 2),
1668 rec("r3", "user", "c", 3),
1669 ])
1670 .await
1671 .unwrap();
1672
1673 let config = ExportConfig {
1674 task: ExportTask::Sft,
1675 group_by: GroupBy::SessionId,
1676 ..Default::default()
1677 };
1678 let first = out_path(&dir);
1679 let second = dir.path().join("out2.jsonl").to_string_lossy().to_string();
1680 store.export_training(&config, &first).await.unwrap();
1681 store.export_training(&config, &second).await.unwrap();
1682
1683 assert_eq!(
1684 std::fs::read_to_string(&first).unwrap(),
1685 std::fs::read_to_string(&second).unwrap(),
1686 "same version + config produces identical output"
1687 );
1688 });
1689 }
1690
1691 #[test]
1692 fn external_id_prefix_grouping() {
1693 let dir = TempDir::new().unwrap();
1694 let runtime = tokio::runtime::Runtime::new().unwrap();
1695 runtime.block_on(async {
1696 let mut store = open_store(&dir).await;
1697 let mut a = rec("a", "user", "doc7 turn1", 1);
1698 a.external_id = Some("doc-7#chunk-1".to_string());
1699 let mut b = rec("b", "assistant", "doc7 turn2", 2);
1700 b.external_id = Some("doc-7#chunk-2".to_string());
1701 let mut c = rec("c", "user", "doc8 turn1", 3);
1702 c.external_id = Some("doc-8#chunk-1".to_string());
1703 store.add(&[a, b, c]).await.unwrap();
1704
1705 let out = out_path(&dir);
1706 store
1707 .export_training(
1708 &ExportConfig {
1709 task: ExportTask::Sft,
1710 group_by: GroupBy::ExternalIdPrefix("#".to_string()),
1711 ..Default::default()
1712 },
1713 &out,
1714 )
1715 .await
1716 .unwrap();
1717
1718 let lines = read_lines(&out);
1719 assert_eq!(lines.len(), 2);
1721 assert_eq!(lines[0]["messages"].as_array().unwrap().len(), 2);
1722 });
1723 }
1724
1725 fn session_ids(path: &str) -> std::collections::HashSet<String> {
1726 read_lines(path)
1727 .iter()
1728 .filter_map(|line| {
1729 line["provenance"]["session_id"]
1730 .as_str()
1731 .map(str::to_string)
1732 })
1733 .collect()
1734 }
1735
1736 #[test]
1737 fn split_is_deterministic_and_group_disjoint() {
1738 let dir = TempDir::new().unwrap();
1739 let runtime = tokio::runtime::Runtime::new().unwrap();
1740 runtime.block_on(async {
1741 let mut store = open_store(&dir).await;
1742 let mut records = Vec::new();
1743 for s in 0..10 {
1744 let mut user = rec(&format!("u{s}"), "user", "q", s * 2);
1745 user.session_id = Some(format!("s{s}"));
1746 let mut asst = rec(&format!("a{s}"), "assistant", "r", s * 2 + 1);
1747 asst.session_id = Some(format!("s{s}"));
1748 records.push(user);
1749 records.push(asst);
1750 }
1751 store.add(&records).await.unwrap();
1752
1753 let config = ExportConfig {
1754 task: ExportTask::Sft,
1755 group_by: GroupBy::SessionId,
1756 split: Some(SplitConfig {
1757 eval_fraction: 0.5,
1758 by: GroupBy::SessionId,
1759 seed: 42,
1760 }),
1761 ..Default::default()
1762 };
1763
1764 let base1 = dir.path().join("a.jsonl").to_string_lossy().to_string();
1765 let base2 = dir.path().join("b.jsonl").to_string_lossy().to_string();
1766 store.export_training(&config, &base1).await.unwrap();
1767 store.export_training(&config, &base2).await.unwrap();
1768
1769 let (train1, eval1) = split_paths(&base1);
1770 let (train2, eval2) = split_paths(&base2);
1771
1772 assert_eq!(
1774 std::fs::read_to_string(&train1).unwrap(),
1775 std::fs::read_to_string(&train2).unwrap()
1776 );
1777 assert_eq!(
1778 std::fs::read_to_string(&eval1).unwrap(),
1779 std::fs::read_to_string(&eval2).unwrap()
1780 );
1781
1782 let train_sessions = session_ids(&train1);
1784 let eval_sessions = session_ids(&eval1);
1785 assert!(!train_sessions.is_empty() && !eval_sessions.is_empty());
1786 assert!(train_sessions.is_disjoint(&eval_sessions));
1787 assert_eq!(train_sessions.len() + eval_sessions.len(), 10);
1788 });
1789 }
1790
1791 #[test]
1792 fn split_fraction_is_approximately_respected() {
1793 let dir = TempDir::new().unwrap();
1794 let runtime = tokio::runtime::Runtime::new().unwrap();
1795 runtime.block_on(async {
1796 let mut store = open_store(&dir).await;
1797 let mut records = Vec::new();
1798 for s in 0..200 {
1799 let mut r = rec(&format!("r{s}"), "user", "q", s);
1800 r.session_id = Some(format!("s{s}"));
1801 records.push(r);
1802 }
1803 store.add(&records).await.unwrap();
1804
1805 let config = ExportConfig {
1806 task: ExportTask::Sft,
1807 group_by: GroupBy::SessionId,
1808 split: Some(SplitConfig {
1809 eval_fraction: 0.25,
1810 by: GroupBy::SessionId,
1811 seed: 7,
1812 }),
1813 ..Default::default()
1814 };
1815 let base = dir.path().join("c.jsonl").to_string_lossy().to_string();
1816 store.export_training(&config, &base).await.unwrap();
1817
1818 let (train, eval) = split_paths(&base);
1819 let eval_count = read_lines(&eval).len();
1820 let train_count = read_lines(&train).len();
1821 assert_eq!(train_count + eval_count, 200);
1822 let fraction = eval_count as f64 / 200.0;
1823 assert!(
1824 (fraction - 0.25).abs() < 0.1,
1825 "eval fraction {fraction} too far from 0.25"
1826 );
1827 });
1828 }
1829
1830 #[test]
1831 fn split_manifests_record_params_and_complement() {
1832 let dir = TempDir::new().unwrap();
1833 let runtime = tokio::runtime::Runtime::new().unwrap();
1834 runtime.block_on(async {
1835 let mut store = open_store(&dir).await;
1836 let mut a = rec("a", "user", "x", 1);
1837 a.session_id = Some("s1".to_string());
1838 let mut b = rec("b", "user", "y", 2);
1839 b.session_id = Some("s2".to_string());
1840 store.add(&[a, b]).await.unwrap();
1841
1842 let config = ExportConfig {
1843 task: ExportTask::Sft,
1844 group_by: GroupBy::SessionId,
1845 split: Some(SplitConfig {
1846 eval_fraction: 0.5,
1847 by: GroupBy::SessionId,
1848 seed: 99,
1849 }),
1850 ..Default::default()
1851 };
1852 let base = dir.path().join("d.jsonl").to_string_lossy().to_string();
1853 store.export_training(&config, &base).await.unwrap();
1854
1855 let (train, eval) = split_paths(&base);
1856 assert!(std::fs::metadata(&train).is_ok());
1857 assert!(std::fs::metadata(&eval).is_ok());
1858
1859 let train_manifest = read_manifest(&train);
1860 let split = train_manifest.split.unwrap();
1861 assert_eq!(split.side, "train");
1862 assert_eq!(split.seed, 99);
1863 assert_eq!(split.eval_fraction, 0.5);
1864 assert_eq!(split.by, "session_id");
1865 assert_eq!(split.complement_path, eval);
1866
1867 let eval_manifest = read_manifest(&eval);
1868 assert_eq!(eval_manifest.split.unwrap().side, "eval");
1869 });
1870 }
1871
1872 fn read_stats(path: &str) -> ExportStats {
1873 let raw = std::fs::read_to_string(format!("{path}.stats.json")).unwrap();
1874 serde_json::from_str(&raw).unwrap()
1875 }
1876
1877 #[test]
1878 fn stats_report_counts_roles_tokens_and_exclusions() {
1879 let dir = TempDir::new().unwrap();
1880 let runtime = tokio::runtime::Runtime::new().unwrap();
1881 runtime.block_on(async {
1882 let mut store = open_store(&dir).await;
1883 let mut user = rec("u", "user", "hello there", 1);
1884 user.source = Some("memory".to_string());
1885 user.tenant = Some("acme".to_string());
1886 user.state_metadata = Some(crate::record::StateMetadata {
1887 tokens_used: Some(5),
1888 ..Default::default()
1889 });
1890 let mut asst = rec("a", "assistant", "hi", 2);
1891 asst.source = Some("memory".to_string());
1892 asst.tenant = Some("acme".to_string());
1893 asst.state_metadata = Some(crate::record::StateMetadata {
1894 tokens_used: Some(11),
1895 ..Default::default()
1896 });
1897 let mut dropped = rec("d", "user", "nope", 3);
1899 dropped.session_id = Some("other".to_string());
1900 dropped.lifecycle_status = LIFECYCLE_CONTRADICTED.to_string();
1901 store.add(&[user, asst, dropped]).await.unwrap();
1902
1903 let out = out_path(&dir);
1904 store
1905 .export_training(
1906 &ExportConfig {
1907 task: ExportTask::Sft,
1908 group_by: GroupBy::SessionId,
1909 emit_stats: true,
1910 ..Default::default()
1911 },
1912 &out,
1913 )
1914 .await
1915 .unwrap();
1916
1917 let stats = read_stats(&out);
1918 assert_eq!(stats.task, "sft");
1919 assert_eq!(stats.examples, 1);
1920 assert_eq!(stats.num_groups, 1);
1921 assert_eq!(stats.by_role.get("user"), Some(&1));
1922 assert_eq!(stats.by_role.get("assistant"), Some(&1));
1923 assert_eq!(stats.by_source.get("memory"), Some(&2));
1924 assert_eq!(stats.by_tenant.get("acme"), Some(&2));
1925 assert_eq!(stats.excluded.lifecycle, 1, "contradicted record excluded");
1926
1927 let tokens = stats.tokens.unwrap();
1928 assert_eq!(tokens.source, "tokens_used");
1929 assert_eq!(tokens.distribution.count, 2);
1930 assert_eq!(tokens.distribution.min, 5.0);
1931 assert_eq!(tokens.distribution.max, 11.0);
1932 assert_eq!(tokens.distribution.mean, 8.0);
1933 });
1934 }
1935
1936 #[test]
1937 fn stats_token_fallback_uses_length_proxy() {
1938 let dir = TempDir::new().unwrap();
1939 let runtime = tokio::runtime::Runtime::new().unwrap();
1940 runtime.block_on(async {
1941 let mut store = open_store(&dir).await;
1942 store
1944 .add(&[rec("u", "user", "one two three four", 1)])
1945 .await
1946 .unwrap();
1947
1948 let out = out_path(&dir);
1949 store
1950 .export_training(
1951 &ExportConfig {
1952 emit_stats: true,
1953 ..Default::default()
1954 },
1955 &out,
1956 )
1957 .await
1958 .unwrap();
1959
1960 let tokens = read_stats(&out).tokens.unwrap();
1961 assert_eq!(tokens.source, "length_proxy");
1962 assert_eq!(tokens.distribution.max, 4.0);
1963 });
1964 }
1965
1966 #[test]
1967 fn stats_report_reward_distribution_for_rollout() {
1968 let dir = TempDir::new().unwrap();
1969 let runtime = tokio::runtime::Runtime::new().unwrap();
1970 runtime.block_on(async {
1971 let mut store = open_store(&dir).await;
1972 let prompt = rec("p", "user", "solve", 1);
1973 let mut r1 = rec("r1", "assistant", "a1", 2);
1974 r1.metadata = Some(json!({"reward": 1.0, "reward_source": "verifier"}));
1975 let mut r2 = rec("r2", "assistant", "a2", 3);
1976 r2.metadata = Some(json!({"reward": 0.0, "reward_source": "verifier"}));
1977 store.add(&[prompt, r1, r2]).await.unwrap();
1978
1979 let out = out_path(&dir);
1980 store
1981 .export_training(
1982 &ExportConfig {
1983 task: ExportTask::Rollout,
1984 group_by: GroupBy::SessionId,
1985 emit_stats: true,
1986 ..Default::default()
1987 },
1988 &out,
1989 )
1990 .await
1991 .unwrap();
1992
1993 let stats = read_stats(&out);
1994 let reward = stats.reward.unwrap();
1995 assert_eq!(reward.count, 2);
1996 assert_eq!(reward.min, 0.0);
1997 assert_eq!(reward.max, 1.0);
1998 assert_eq!(stats.reward_sources.get("verifier"), Some(&2));
1999 });
2000 }
2001
2002 #[test]
2003 fn stats_not_written_without_flag() {
2004 let dir = TempDir::new().unwrap();
2005 let runtime = tokio::runtime::Runtime::new().unwrap();
2006 runtime.block_on(async {
2007 let mut store = open_store(&dir).await;
2008 store.add(&[rec("u", "user", "hi", 1)]).await.unwrap();
2009 let out = out_path(&dir);
2010 store
2011 .export_training(&ExportConfig::default(), &out)
2012 .await
2013 .unwrap();
2014 assert!(std::fs::metadata(format!("{out}.stats.json")).is_err());
2015 });
2016 }
2017}