Skip to main content

lance_context_core/
export.rs

1//! Curate stored [`ContextRecord`]s into trainable datasets and export them as
2//! SFT / preference / RL-rollout JSONL plus a reproducible manifest.
3//!
4//! This is the downstream half of the post-training pipeline: raw logs are
5//! ingested into faithful `ContextRecord`s, then *curated* (lifecycle-correct
6//! filtering, semantic dedup, decontamination, reward thresholding), grouped
7//! into conversations/trajectories, and *exported* into the shape a trainer
8//! expects.
9//!
10//! ## Field conventions
11//!
12//! Reward / preference signals live in each record's free-form `metadata`
13//! object:
14//! - `metadata.reward` (number) — scalar reward / verifier score.
15//! - `metadata.reward_source` (string) — e.g. `"verifier"`, `"tests"`,
16//!   `"judge"`.
17//! - `metadata.group_id` (string) — links the N samples generated for one
18//!   prompt (GRPO/RLVR groups, candidate sets).
19//! - `metadata.label` (string `"chosen"`/`"rejected"`) — unpaired (KTO-style)
20//!   preference label.
21//! - `metadata.rank` (integer, 1 = best) — N-way ranking position.
22//!
23//! Messages are built from each record's `role` + `text_payload`.
24//!
25//! ## Curation lifecycle
26//!
27//! Curation is *selection only* — it never deletes or tombstones records. By
28//! default it keeps only records visible under default lifecycle (drops
29//! tombstoned/expired/retired/superseded) and additionally drops
30//! `contradicted` records.
31
32use std::collections::{BTreeMap, HashMap};
33use std::io::{BufWriter, Write};
34
35use chrono::{DateTime, Utc};
36use lance::{Error as LanceError, Result as LanceResult};
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39
40use crate::record::{ContextRecord, LifecycleQueryOptions, RecordFilters, LIFECYCLE_CONTRADICTED};
41use crate::store::ContextStore;
42
43/// Schema version stamped into every export manifest.
44pub const EXPORT_SCHEMA_VERSION: &str = "1";
45
46/// Which training shape to export.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
48#[serde(rename_all = "lowercase")]
49pub enum ExportTask {
50    /// Supervised fine-tuning (also the rejection-sampling / Best-of-N target).
51    #[default]
52    Sft,
53    /// Preference optimization (DPO/SimPO/ORPO paired, KTO unpaired, judge N-way).
54    Preference,
55    /// RL rollout (PPO/GRPO/RLVR): prompt + response group(s) with rewards.
56    Rollout,
57}
58
59impl ExportTask {
60    #[must_use]
61    pub fn as_str(self) -> &'static str {
62        match self {
63            Self::Sft => "sft",
64            Self::Preference => "preference",
65            Self::Rollout => "rollout",
66        }
67    }
68}
69
70/// Shape of a preference export (selected by the caller to match their labels).
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
72#[serde(rename_all = "lowercase")]
73pub enum PreferenceForm {
74    /// Paired `chosen`/`rejected` (DPO/SimPO/ORPO).
75    #[default]
76    Paired,
77    /// One completion with an unpaired binary label (KTO).
78    Unpaired,
79    /// N-way ranked candidate list (LLM-judge / RULER).
80    Ranked,
81}
82
83/// How records are grouped into one training example.
84#[derive(Debug, Clone, PartialEq, Eq, Default)]
85pub enum GroupBy {
86    /// One example per record.
87    None,
88    #[default]
89    SessionId,
90    RunId,
91    Tenant,
92    Source,
93    BotId,
94    /// Group by the `external_id` prefix up to the first occurrence of the
95    /// delimiter (e.g. `"doc-7#chunk-1"` -> `"doc-7"` with delimiter `"#"`).
96    ExternalIdPrefix(String),
97}
98
99impl GroupBy {
100    fn label(&self) -> String {
101        match self {
102            Self::None => "none".to_string(),
103            Self::SessionId => "session_id".to_string(),
104            Self::RunId => "run_id".to_string(),
105            Self::Tenant => "tenant".to_string(),
106            Self::Source => "source".to_string(),
107            Self::BotId => "bot_id".to_string(),
108            Self::ExternalIdPrefix(delim) => format!("external_id_prefix:{delim}"),
109        }
110    }
111
112    /// Group key for a record; `None` records become their own singleton group.
113    fn key(&self, record: &ContextRecord) -> String {
114        let value = match self {
115            Self::None => None,
116            Self::SessionId => record.session_id.clone(),
117            Self::RunId => Some(record.run_id.clone()),
118            Self::Tenant => record.tenant.clone(),
119            Self::Source => record.source.clone(),
120            Self::BotId => record.bot_id.clone(),
121            Self::ExternalIdPrefix(delim) => record.external_id.as_ref().map(|external_id| {
122                external_id
123                    .split_once(delim.as_str())
124                    .map_or(external_id.as_str(), |(prefix, _)| prefix)
125                    .to_string()
126            }),
127        };
128        value.unwrap_or_else(|| format!("__rec__{}", record.id))
129    }
130}
131
132/// Reproducible train/eval split configuration.
133///
134/// Each record is assigned to train or eval by a stable hash of its `by`
135/// grouping key plus `seed`, so no group (e.g. session) spans both sides and
136/// the same `seed` reproduces the identical partition.
137#[derive(Clone)]
138pub struct SplitConfig {
139    /// Fraction of groups assigned to the eval side, in `0.0..=1.0`.
140    pub eval_fraction: f64,
141    /// Grouping key the split is disjoint on (defaults to `session_id`).
142    pub by: GroupBy,
143    pub seed: u64,
144}
145
146impl Default for SplitConfig {
147    fn default() -> Self {
148        Self {
149            eval_fraction: 0.1,
150            by: GroupBy::SessionId,
151            seed: 0,
152        }
153    }
154}
155
156/// Curation + export configuration.
157#[derive(Clone)]
158pub struct ExportConfig {
159    pub task: ExportTask,
160    pub group_by: GroupBy,
161    pub preference_form: PreferenceForm,
162    /// Metadata/quality filters applied before lifecycle curation.
163    pub filters: Option<RecordFilters>,
164    /// Lifecycle visibility; defaults to visible-only.
165    pub lifecycle: LifecycleQueryOptions,
166    /// Collapse near-duplicates whose cosine distance is `<=` this threshold.
167    pub dedup_threshold: Option<f32>,
168    /// Holdout embeddings to decontaminate against.
169    pub decontaminate_against: Vec<Vec<f32>>,
170    /// Drop records within this cosine distance of any holdout embedding.
171    pub decontaminate_threshold: Option<f32>,
172    /// Drop records whose `metadata.reward` is present and below this value
173    /// (rejection sampling / Best-of-N / quality threshold).
174    pub min_reward: Option<f64>,
175    /// Pin the export to a dataset version (time-travel); restored afterward.
176    pub version: Option<u64>,
177    /// Optional JSON summary of the applied selectors, recorded in the manifest.
178    pub filters_summary: Option<Value>,
179    /// When set, write group-disjoint `train` / `eval` outputs instead of one.
180    pub split: Option<SplitConfig>,
181    /// When `true`, also write a `<output_path>.stats.json` dataset report.
182    pub emit_stats: bool,
183}
184
185impl Default for ExportConfig {
186    fn default() -> Self {
187        Self {
188            task: ExportTask::Sft,
189            group_by: GroupBy::SessionId,
190            preference_form: PreferenceForm::Paired,
191            filters: None,
192            lifecycle: LifecycleQueryOptions::default(),
193            dedup_threshold: None,
194            decontaminate_against: Vec::new(),
195            decontaminate_threshold: None,
196            min_reward: None,
197            version: None,
198            filters_summary: None,
199            split: None,
200            emit_stats: false,
201        }
202    }
203}
204
205/// A single chat message in a training example.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct Message {
208    pub role: String,
209    pub content: String,
210}
211
212/// Where an exported example came from, for auditability.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct Provenance {
215    pub context_uri: String,
216    pub version: u64,
217    pub record_ids: Vec<String>,
218    #[serde(skip_serializing_if = "Vec::is_empty", default)]
219    pub external_ids: Vec<String>,
220    #[serde(skip_serializing_if = "Option::is_none", default)]
221    pub tenant: Option<String>,
222    #[serde(skip_serializing_if = "Option::is_none", default)]
223    pub source: Option<String>,
224    #[serde(skip_serializing_if = "Option::is_none", default)]
225    pub bot_id: Option<String>,
226    #[serde(skip_serializing_if = "Option::is_none", default)]
227    pub session_id: Option<String>,
228    #[serde(skip_serializing_if = "Option::is_none", default)]
229    pub run_id: Option<String>,
230    pub created_at_start: DateTime<Utc>,
231    pub created_at_end: DateTime<Utc>,
232}
233
234/// SFT example: an ordered message list (doubles as the rejection-sampling /
235/// Best-of-N target when filtered by `min_reward`).
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct SftExample {
238    pub messages: Vec<Message>,
239    #[serde(skip_serializing_if = "Option::is_none", default)]
240    pub reward: Option<f64>,
241    pub provenance: Provenance,
242}
243
244/// One ranked candidate in an N-way preference example.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct RankedCandidate {
247    pub messages: Vec<Message>,
248    pub rank: i64,
249    #[serde(skip_serializing_if = "Option::is_none", default)]
250    pub reward: Option<f64>,
251}
252
253/// Preference example in one of three forms, tagged by `form`.
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(tag = "form", rename_all = "lowercase")]
256pub enum PreferenceExample {
257    Paired {
258        prompt: Vec<Message>,
259        chosen: Vec<Message>,
260        rejected: Vec<Message>,
261        provenance: Provenance,
262    },
263    Unpaired {
264        prompt: Vec<Message>,
265        completion: Vec<Message>,
266        /// `true` = chosen, `false` = rejected.
267        label: bool,
268        provenance: Provenance,
269    },
270    Ranked {
271        prompt: Vec<Message>,
272        candidates: Vec<RankedCandidate>,
273        provenance: Provenance,
274    },
275}
276
277/// One response within a rollout example.
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct RolloutResponse {
280    pub messages: Vec<Message>,
281    #[serde(skip_serializing_if = "Option::is_none", default)]
282    pub reward: Option<f64>,
283    #[serde(skip_serializing_if = "Option::is_none", default)]
284    pub reward_source: Option<String>,
285}
286
287/// RL rollout example: a prompt plus one or more rewarded responses, with an
288/// optional `group_id` linking the N samples for one prompt.
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct RolloutExample {
291    pub prompt: Vec<Message>,
292    pub responses: Vec<RolloutResponse>,
293    #[serde(skip_serializing_if = "Option::is_none", default)]
294    pub group_id: Option<String>,
295    pub provenance: Provenance,
296}
297
298/// Record counts at each curation stage.
299#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
300pub struct ExportCounts {
301    pub input_records: usize,
302    pub after_lifecycle: usize,
303    pub after_dedup: usize,
304    pub after_decontaminate: usize,
305    pub after_reward_filter: usize,
306    pub examples: usize,
307}
308
309/// Records which side of a train/eval split an output is, and the parameters
310/// needed to reproduce the partition.
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SplitManifest {
313    /// `"train"` or `"eval"`.
314    pub side: String,
315    pub eval_fraction: f64,
316    pub by: String,
317    pub seed: u64,
318    /// Path of the complementary (other side's) output.
319    pub complement_path: String,
320}
321
322/// Reproducible description of an export, written next to the JSONL output.
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ExportManifest {
325    pub context_uri: String,
326    pub version: u64,
327    pub task: String,
328    pub group_by: String,
329    pub schema_version: String,
330    #[serde(skip_serializing_if = "Option::is_none", default)]
331    pub preference_form: Option<String>,
332    #[serde(skip_serializing_if = "Option::is_none", default)]
333    pub filters: Option<Value>,
334    #[serde(skip_serializing_if = "Option::is_none", default)]
335    pub dedup_threshold: Option<f32>,
336    #[serde(skip_serializing_if = "Option::is_none", default)]
337    pub decontaminate_threshold: Option<f32>,
338    #[serde(skip_serializing_if = "Option::is_none", default)]
339    pub min_reward: Option<f64>,
340    #[serde(skip_serializing_if = "Option::is_none", default)]
341    pub split: Option<SplitManifest>,
342    #[serde(skip_serializing_if = "Option::is_none", default)]
343    pub created_at_start: Option<DateTime<Utc>>,
344    #[serde(skip_serializing_if = "Option::is_none", default)]
345    pub created_at_end: Option<DateTime<Utc>>,
346    pub source_record_ids: Vec<String>,
347    pub counts: ExportCounts,
348}
349
350/// Numeric distribution summary.
351#[derive(Debug, Clone, Default, Serialize, Deserialize)]
352pub struct Distribution {
353    pub count: usize,
354    pub min: f64,
355    pub median: f64,
356    pub p95: f64,
357    pub max: f64,
358    pub mean: f64,
359}
360
361impl Distribution {
362    fn from_sorted(mut values: Vec<f64>) -> Option<Self> {
363        if values.is_empty() {
364            return None;
365        }
366        values.sort_by(f64::total_cmp);
367        let count = values.len();
368        let sum: f64 = values.iter().sum();
369        let percentile = |p: f64| {
370            let idx = ((p * (count - 1) as f64).round() as usize).min(count - 1);
371            values[idx]
372        };
373        Some(Self {
374            count,
375            min: values[0],
376            median: percentile(0.5),
377            p95: percentile(0.95),
378            max: values[count - 1],
379            mean: sum / count as f64,
380        })
381    }
382}
383
384/// Token-length statistics and which source produced them.
385#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct TokenStats {
387    #[serde(flatten)]
388    pub distribution: Distribution,
389    /// `"tokens_used"`, `"length_proxy"`, or `"mixed"`.
390    pub source: String,
391}
392
393/// Source records excluded during curation, by reason.
394#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
395pub struct ExcludedCounts {
396    pub lifecycle: usize,
397    pub reward_threshold: usize,
398    pub dedup: usize,
399    pub decontaminate: usize,
400}
401
402/// Auditable dataset statistics for one export output, written to
403/// `<output_path>.stats.json`.
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct ExportStats {
406    pub task: String,
407    pub examples: usize,
408    pub num_groups: usize,
409    /// Record counts by `role`.
410    pub by_role: BTreeMap<String, usize>,
411    /// Record counts by `source` (`"__none__"` when absent).
412    pub by_source: BTreeMap<String, usize>,
413    /// Record counts by `tenant` (`"__none__"` when absent).
414    pub by_tenant: BTreeMap<String, usize>,
415    pub records_per_group: Distribution,
416    #[serde(skip_serializing_if = "Option::is_none", default)]
417    pub tokens: Option<TokenStats>,
418    pub excluded: ExcludedCounts,
419    #[serde(skip_serializing_if = "Option::is_none", default)]
420    pub preference_form: Option<String>,
421    /// Reward distribution over records carrying `metadata.reward`.
422    #[serde(skip_serializing_if = "Option::is_none", default)]
423    pub reward: Option<Distribution>,
424    /// Counts by `metadata.reward_source`.
425    #[serde(skip_serializing_if = "BTreeMap::is_empty", default)]
426    pub reward_sources: BTreeMap<String, usize>,
427}
428
429// ----- metadata helpers -----------------------------------------------------
430
431fn metadata_field<'a>(record: &'a ContextRecord, key: &str) -> Option<&'a Value> {
432    record.metadata.as_ref()?.get(key)
433}
434
435fn record_reward(record: &ContextRecord) -> Option<f64> {
436    metadata_field(record, "reward")?.as_f64()
437}
438
439fn record_reward_source(record: &ContextRecord) -> Option<String> {
440    Some(
441        metadata_field(record, "reward_source")?
442            .as_str()?
443            .to_string(),
444    )
445}
446
447fn record_group_id(record: &ContextRecord) -> Option<String> {
448    Some(metadata_field(record, "group_id")?.as_str()?.to_string())
449}
450
451fn record_label(record: &ContextRecord) -> Option<String> {
452    Some(metadata_field(record, "label")?.as_str()?.to_string())
453}
454
455fn record_rank(record: &ContextRecord) -> Option<i64> {
456    metadata_field(record, "rank")?.as_i64()
457}
458
459fn message_of(record: &ContextRecord) -> Message {
460    Message {
461        role: record.role.clone(),
462        content: record.text_payload.clone().unwrap_or_default(),
463    }
464}
465
466fn is_assistant(record: &ContextRecord) -> bool {
467    record.role.eq_ignore_ascii_case("assistant")
468}
469
470/// Cosine distance in `0.0..=2.0`; `0` = identical direction. Returns `None`
471/// for zero-magnitude or mismatched-length vectors.
472fn cosine_distance(a: &[f32], b: &[f32]) -> Option<f32> {
473    if a.len() != b.len() || a.is_empty() {
474        return None;
475    }
476    let mut dot = 0.0f32;
477    let mut na = 0.0f32;
478    let mut nb = 0.0f32;
479    for (x, y) in a.iter().zip(b.iter()) {
480        dot += x * y;
481        na += x * x;
482        nb += y * y;
483    }
484    if na == 0.0 || nb == 0.0 {
485        return None;
486    }
487    Some(1.0 - dot / (na.sqrt() * nb.sqrt()))
488}
489
490// ----- grouping -------------------------------------------------------------
491
492/// Records grouped for one example, already ordered by `(created_at, id)`.
493struct Group {
494    key: String,
495    records: Vec<ContextRecord>,
496}
497
498fn group_records(records: Vec<ContextRecord>, group_by: &GroupBy) -> Vec<Group> {
499    let mut order: Vec<String> = Vec::new();
500    let mut groups: HashMap<String, Vec<ContextRecord>> = HashMap::new();
501    for record in records {
502        let key = group_by.key(&record);
503        if !groups.contains_key(&key) {
504            order.push(key.clone());
505        }
506        groups.entry(key).or_default().push(record);
507    }
508
509    let mut result: Vec<Group> = order
510        .into_iter()
511        .map(|key| {
512            let mut records = groups.remove(&key).unwrap_or_default();
513            records.sort_by(|a, b| {
514                a.created_at
515                    .cmp(&b.created_at)
516                    .then_with(|| a.id.cmp(&b.id))
517            });
518            Group { key, records }
519        })
520        .collect();
521
522    // Deterministic group order: by each group's earliest (created_at, id).
523    result.sort_by(|a, b| {
524        let left = a.records.first();
525        let right = b.records.first();
526        match (left, right) {
527            (Some(l), Some(r)) => l
528                .created_at
529                .cmp(&r.created_at)
530                .then_with(|| l.id.cmp(&r.id)),
531            _ => a.key.cmp(&b.key),
532        }
533    });
534    result
535}
536
537fn provenance_for(records: &[ContextRecord], context_uri: &str, version: u64) -> Provenance {
538    let first = records.first();
539    let created_at_start = records
540        .iter()
541        .map(|r| r.created_at)
542        .min()
543        .unwrap_or_else(Utc::now);
544    let created_at_end = records
545        .iter()
546        .map(|r| r.created_at)
547        .max()
548        .unwrap_or(created_at_start);
549    Provenance {
550        context_uri: context_uri.to_string(),
551        version,
552        record_ids: records.iter().map(|r| r.id.clone()).collect(),
553        external_ids: records
554            .iter()
555            .filter_map(|r| r.external_id.clone())
556            .collect(),
557        tenant: first.and_then(|r| r.tenant.clone()),
558        source: first.and_then(|r| r.source.clone()),
559        bot_id: first.and_then(|r| r.bot_id.clone()),
560        session_id: first.and_then(|r| r.session_id.clone()),
561        run_id: first.map(|r| r.run_id.clone()),
562        created_at_start,
563        created_at_end,
564    }
565}
566
567/// Split a group into leading non-assistant prompt messages and the assistant
568/// candidate records that follow.
569fn split_prompt_candidates(records: &[ContextRecord]) -> (Vec<Message>, Vec<&ContextRecord>) {
570    let mut prompt = Vec::new();
571    let mut candidates = Vec::new();
572    for record in records {
573        if is_assistant(record) {
574            candidates.push(record);
575        } else if candidates.is_empty() {
576            prompt.push(message_of(record));
577        } else {
578            // Non-assistant turn after a candidate (e.g. tool result) — keep it
579            // attached to the conversation prompt context.
580            prompt.push(message_of(record));
581        }
582    }
583    (prompt, candidates)
584}
585
586fn write_line<T: Serialize>(writer: &mut impl Write, value: &T) -> LanceResult<()> {
587    let line = serde_json::to_string(value).map_err(|err| LanceError::io(err.to_string()))?;
588    writeln!(writer, "{line}")?;
589    Ok(())
590}
591
592// ----- curation -------------------------------------------------------------
593
594/// Greedily collapse near-duplicates: keep a record only if its embedding is
595/// not within `threshold` cosine distance of an already-kept record. Records
596/// without an embedding are always kept.
597fn dedup(records: Vec<ContextRecord>, threshold: f32) -> Vec<ContextRecord> {
598    let mut kept: Vec<ContextRecord> = Vec::new();
599    for record in records {
600        let is_dup = record.embedding.as_ref().is_some_and(|embedding| {
601            kept.iter().any(|other| {
602                other
603                    .embedding
604                    .as_ref()
605                    .and_then(|other_embedding| cosine_distance(embedding, other_embedding))
606                    .is_some_and(|distance| distance <= threshold)
607            })
608        });
609        if !is_dup {
610            kept.push(record);
611        }
612    }
613    kept
614}
615
616/// Drop records whose embedding is within `threshold` cosine distance of any
617/// holdout embedding. Records without an embedding are kept.
618fn decontaminate(
619    records: Vec<ContextRecord>,
620    holdout: &[Vec<f32>],
621    threshold: f32,
622) -> Vec<ContextRecord> {
623    records
624        .into_iter()
625        .filter(|record| match &record.embedding {
626            Some(embedding) => !holdout.iter().any(|held| {
627                cosine_distance(embedding, held).is_some_and(|distance| distance <= threshold)
628            }),
629            None => true,
630        })
631        .collect()
632}
633
634/// Apply the curation pipeline (lifecycle/contradicted, reward threshold,
635/// dedup, decontamination) and return the curated records plus stage counts.
636fn curate(
637    records: Vec<ContextRecord>,
638    config: &ExportConfig,
639) -> (Vec<ContextRecord>, ExportCounts) {
640    let mut counts = ExportCounts {
641        input_records: records.len(),
642        ..ExportCounts::default()
643    };
644
645    // `list` already dropped tombstoned/expired/retired/superseded; also drop
646    // `contradicted`, which is visible by default.
647    let lifecycle: Vec<ContextRecord> = records
648        .into_iter()
649        .filter(|record| record.lifecycle_status != LIFECYCLE_CONTRADICTED)
650        .collect();
651    counts.after_lifecycle = lifecycle.len();
652
653    let rewarded: Vec<ContextRecord> = match config.min_reward {
654        Some(min) => lifecycle
655            .into_iter()
656            .filter(|record| record_reward(record).is_none_or(|reward| reward >= min))
657            .collect(),
658        None => lifecycle,
659    };
660    counts.after_reward_filter = rewarded.len();
661
662    let deduped = match config.dedup_threshold {
663        Some(threshold) => dedup(rewarded, threshold),
664        None => rewarded,
665    };
666    counts.after_dedup = deduped.len();
667
668    let clean = match config.decontaminate_threshold {
669        Some(threshold) if !config.decontaminate_against.is_empty() => {
670            decontaminate(deduped, &config.decontaminate_against, threshold)
671        }
672        _ => deduped,
673    };
674    counts.after_decontaminate = clean.len();
675
676    (clean, counts)
677}
678
679// ----- exporters ------------------------------------------------------------
680
681fn write_sft(
682    groups: &[Group],
683    writer: &mut impl Write,
684    context_uri: &str,
685    version: u64,
686) -> LanceResult<usize> {
687    let mut written = 0;
688    for group in groups {
689        if group.records.is_empty() {
690            continue;
691        }
692        let example = SftExample {
693            messages: group.records.iter().map(message_of).collect(),
694            reward: group
695                .records
696                .iter()
697                .filter_map(record_reward)
698                .reduce(f64::max),
699            provenance: provenance_for(&group.records, context_uri, version),
700        };
701        write_line(writer, &example)?;
702        written += 1;
703    }
704    Ok(written)
705}
706
707/// Preference score used to order candidates: an explicit label dominates a
708/// numeric reward, which dominates "unscored".
709fn preference_score(record: &ContextRecord) -> f64 {
710    match record_label(record).as_deref() {
711        Some("chosen") => f64::INFINITY,
712        Some("rejected") => f64::NEG_INFINITY,
713        _ => record_reward(record).unwrap_or(0.0),
714    }
715}
716
717fn unpaired_label(record: &ContextRecord, min_reward: Option<f64>) -> Option<bool> {
718    match record_label(record).as_deref() {
719        Some("chosen") => Some(true),
720        Some("rejected") => Some(false),
721        _ => match (record_reward(record), min_reward) {
722            (Some(reward), Some(min)) => Some(reward >= min),
723            _ => None,
724        },
725    }
726}
727
728fn write_preference(
729    groups: &[Group],
730    writer: &mut impl Write,
731    form: PreferenceForm,
732    min_reward: Option<f64>,
733    context_uri: &str,
734    version: u64,
735) -> LanceResult<usize> {
736    let mut written = 0;
737    for group in groups {
738        let (prompt, candidates) = split_prompt_candidates(&group.records);
739        if candidates.is_empty() {
740            continue;
741        }
742        let provenance = provenance_for(&group.records, context_uri, version);
743
744        match form {
745            PreferenceForm::Paired => {
746                if candidates.len() < 2 {
747                    continue;
748                }
749                let mut best = candidates[0];
750                let mut worst = candidates[0];
751                for candidate in &candidates {
752                    if preference_score(candidate) > preference_score(best) {
753                        best = candidate;
754                    }
755                    if preference_score(candidate) < preference_score(worst) {
756                        worst = candidate;
757                    }
758                }
759                if best.id == worst.id {
760                    continue; // no signal to separate chosen from rejected
761                }
762                let example = PreferenceExample::Paired {
763                    prompt,
764                    chosen: vec![message_of(best)],
765                    rejected: vec![message_of(worst)],
766                    provenance,
767                };
768                write_line(writer, &example)?;
769                written += 1;
770            }
771            PreferenceForm::Unpaired => {
772                for candidate in &candidates {
773                    let Some(label) = unpaired_label(candidate, min_reward) else {
774                        continue;
775                    };
776                    let example = PreferenceExample::Unpaired {
777                        prompt: prompt.clone(),
778                        completion: vec![message_of(candidate)],
779                        label,
780                        provenance: provenance.clone(),
781                    };
782                    write_line(writer, &example)?;
783                    written += 1;
784                }
785            }
786            PreferenceForm::Ranked => {
787                let mut ordered: Vec<&ContextRecord> = candidates.clone();
788                ordered.sort_by(|a, b| {
789                    let rank_a = record_rank(a);
790                    let rank_b = record_rank(b);
791                    match (rank_a, rank_b) {
792                        (Some(ra), Some(rb)) => ra.cmp(&rb),
793                        _ => preference_score(b).total_cmp(&preference_score(a)),
794                    }
795                });
796                let candidates: Vec<RankedCandidate> = ordered
797                    .iter()
798                    .enumerate()
799                    .map(|(index, candidate)| RankedCandidate {
800                        messages: vec![message_of(candidate)],
801                        rank: record_rank(candidate).unwrap_or((index + 1) as i64),
802                        reward: record_reward(candidate),
803                    })
804                    .collect();
805                let example = PreferenceExample::Ranked {
806                    prompt,
807                    candidates,
808                    provenance,
809                };
810                write_line(writer, &example)?;
811                written += 1;
812            }
813        }
814    }
815    Ok(written)
816}
817
818fn write_rollout(
819    groups: &[Group],
820    writer: &mut impl Write,
821    context_uri: &str,
822    version: u64,
823) -> LanceResult<usize> {
824    let mut written = 0;
825    for group in groups {
826        let (prompt, candidates) = split_prompt_candidates(&group.records);
827        if candidates.is_empty() {
828            continue;
829        }
830        let responses: Vec<RolloutResponse> = candidates
831            .iter()
832            .map(|candidate| RolloutResponse {
833                messages: vec![message_of(candidate)],
834                reward: record_reward(candidate),
835                reward_source: record_reward_source(candidate),
836            })
837            .collect();
838        let group_id = candidates
839            .iter()
840            .find_map(|candidate| record_group_id(candidate));
841        let example = RolloutExample {
842            prompt,
843            responses,
844            group_id,
845            provenance: provenance_for(&group.records, context_uri, version),
846        };
847        write_line(writer, &example)?;
848        written += 1;
849    }
850    Ok(written)
851}
852
853fn summarize_groups(
854    groups: &[Group],
855) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>, Vec<String>) {
856    let mut start: Option<DateTime<Utc>> = None;
857    let mut end: Option<DateTime<Utc>> = None;
858    let mut ids = Vec::new();
859    for group in groups {
860        for record in &group.records {
861            start = Some(start.map_or(record.created_at, |s| s.min(record.created_at)));
862            end = Some(end.map_or(record.created_at, |e| e.max(record.created_at)));
863            ids.push(record.id.clone());
864        }
865    }
866    (start, end, ids)
867}
868
869/// Stable 64-bit FNV-1a hash of `seed` + `key`, used for reproducible,
870/// platform-independent split assignment.
871fn stable_hash(seed: u64, key: &str) -> u64 {
872    let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
873    let mix = |hash: &mut u64, byte: u8| {
874        *hash ^= u64::from(byte);
875        *hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
876    };
877    for byte in seed.to_le_bytes() {
878        mix(&mut hash, byte);
879    }
880    for byte in key.as_bytes() {
881        mix(&mut hash, *byte);
882    }
883    hash
884}
885
886/// Map a group key to a stable fraction in `0.0..1.0`.
887fn split_fraction(seed: u64, key: &str) -> f64 {
888    // FNV-1a alone has biased high bits; run the MurmurHash3 fmix64 finalizer
889    // for good avalanche before taking the top 53 bits.
890    let mut hash = stable_hash(seed, key);
891    hash ^= hash >> 33;
892    hash = hash.wrapping_mul(0xff51_afd7_ed55_8ccd);
893    hash ^= hash >> 33;
894    hash = hash.wrapping_mul(0xc4ce_b9fe_1a85_ec53);
895    hash ^= hash >> 33;
896    (hash >> 11) as f64 / (1u64 << 53) as f64
897}
898
899/// Derive `train` / `eval` output paths by inserting the side before the final
900/// extension (e.g. `cut.jsonl` -> `cut.train.jsonl`).
901fn split_paths(output_path: &str) -> (String, String) {
902    let slash = output_path.rfind('/').map_or(0, |i| i + 1);
903    match output_path[slash..].rfind('.') {
904        Some(rel_dot) => {
905            let dot = slash + rel_dot;
906            let (stem, ext) = output_path.split_at(dot);
907            (format!("{stem}.train{ext}"), format!("{stem}.eval{ext}"))
908        }
909        None => (
910            format!("{output_path}.train"),
911            format!("{output_path}.eval"),
912        ),
913    }
914}
915
916/// Group `records`, write the task-shaped JSONL to `output_path`, write the
917/// sibling manifest, and return it.
918#[allow(clippy::too_many_arguments)]
919fn emit_export(
920    records: Vec<ContextRecord>,
921    config: &ExportConfig,
922    context_uri: &str,
923    version: u64,
924    mut counts: ExportCounts,
925    output_path: &str,
926    split: Option<SplitManifest>,
927) -> LanceResult<ExportManifest> {
928    let groups = group_records(records, &config.group_by);
929    let (created_at_start, created_at_end, source_record_ids) = summarize_groups(&groups);
930
931    let file = std::fs::File::create(output_path)?;
932    let mut writer = BufWriter::new(file);
933    let examples = match config.task {
934        ExportTask::Sft => write_sft(&groups, &mut writer, context_uri, version)?,
935        ExportTask::Preference => write_preference(
936            &groups,
937            &mut writer,
938            config.preference_form,
939            config.min_reward,
940            context_uri,
941            version,
942        )?,
943        ExportTask::Rollout => write_rollout(&groups, &mut writer, context_uri, version)?,
944    };
945    writer.flush()?;
946    counts.examples = examples;
947
948    let manifest = ExportManifest {
949        context_uri: context_uri.to_string(),
950        version,
951        task: config.task.as_str().to_string(),
952        group_by: config.group_by.label(),
953        schema_version: EXPORT_SCHEMA_VERSION.to_string(),
954        preference_form: matches!(config.task, ExportTask::Preference).then(|| {
955            match config.preference_form {
956                PreferenceForm::Paired => "paired",
957                PreferenceForm::Unpaired => "unpaired",
958                PreferenceForm::Ranked => "ranked",
959            }
960            .to_string()
961        }),
962        filters: config.filters_summary.clone(),
963        dedup_threshold: config.dedup_threshold,
964        decontaminate_threshold: config.decontaminate_threshold,
965        min_reward: config.min_reward,
966        split,
967        created_at_start,
968        created_at_end,
969        source_record_ids,
970        counts,
971    };
972
973    let manifest_json =
974        serde_json::to_string_pretty(&manifest).map_err(|err| LanceError::io(err.to_string()))?;
975    std::fs::write(format!("{output_path}.manifest.json"), manifest_json)?;
976
977    if config.emit_stats {
978        let stats = compute_stats(&groups, &counts, examples, config);
979        let stats_json =
980            serde_json::to_string_pretty(&stats).map_err(|err| LanceError::io(err.to_string()))?;
981        std::fs::write(format!("{output_path}.stats.json"), stats_json)?;
982    }
983
984    Ok(manifest)
985}
986
987/// Compute the dataset statistics report for one export output.
988fn compute_stats(
989    groups: &[Group],
990    counts: &ExportCounts,
991    examples: usize,
992    config: &ExportConfig,
993) -> ExportStats {
994    let mut by_role: BTreeMap<String, usize> = BTreeMap::new();
995    let mut by_source: BTreeMap<String, usize> = BTreeMap::new();
996    let mut by_tenant: BTreeMap<String, usize> = BTreeMap::new();
997    let mut token_values: Vec<f64> = Vec::new();
998    let mut used_tokens_used = false;
999    let mut used_fallback = false;
1000    let mut reward_values: Vec<f64> = Vec::new();
1001    let mut reward_sources: BTreeMap<String, usize> = BTreeMap::new();
1002    let mut records_per_group: Vec<f64> = Vec::new();
1003
1004    for group in groups {
1005        records_per_group.push(group.records.len() as f64);
1006        for record in &group.records {
1007            *by_role.entry(record.role.clone()).or_insert(0) += 1;
1008            *by_source
1009                .entry(
1010                    record
1011                        .source
1012                        .clone()
1013                        .unwrap_or_else(|| "__none__".to_string()),
1014                )
1015                .or_insert(0) += 1;
1016            *by_tenant
1017                .entry(
1018                    record
1019                        .tenant
1020                        .clone()
1021                        .unwrap_or_else(|| "__none__".to_string()),
1022                )
1023                .or_insert(0) += 1;
1024
1025            match record.state_metadata.as_ref().and_then(|m| m.tokens_used) {
1026                Some(tokens) if tokens >= 0 => {
1027                    token_values.push(f64::from(tokens));
1028                    used_tokens_used = true;
1029                }
1030                _ => {
1031                    // Fallback proxy: whitespace-delimited word count of content.
1032                    let proxy = record
1033                        .text_payload
1034                        .as_deref()
1035                        .map_or(0, |text| text.split_whitespace().count());
1036                    token_values.push(proxy as f64);
1037                    used_fallback = true;
1038                }
1039            }
1040
1041            if let Some(reward) = record_reward(record) {
1042                reward_values.push(reward);
1043            }
1044            if let Some(source) = record_reward_source(record) {
1045                *reward_sources.entry(source).or_insert(0) += 1;
1046            }
1047        }
1048    }
1049
1050    let tokens = Distribution::from_sorted(token_values).map(|distribution| TokenStats {
1051        distribution,
1052        source: match (used_tokens_used, used_fallback) {
1053            (true, true) => "mixed",
1054            (true, false) => "tokens_used",
1055            _ => "length_proxy",
1056        }
1057        .to_string(),
1058    });
1059
1060    let excluded = ExcludedCounts {
1061        lifecycle: counts.input_records.saturating_sub(counts.after_lifecycle),
1062        reward_threshold: counts
1063            .after_lifecycle
1064            .saturating_sub(counts.after_reward_filter),
1065        dedup: counts
1066            .after_reward_filter
1067            .saturating_sub(counts.after_dedup),
1068        decontaminate: counts
1069            .after_dedup
1070            .saturating_sub(counts.after_decontaminate),
1071    };
1072
1073    ExportStats {
1074        task: config.task.as_str().to_string(),
1075        examples,
1076        num_groups: groups.len(),
1077        by_role,
1078        by_source,
1079        by_tenant,
1080        records_per_group: Distribution::from_sorted(records_per_group).unwrap_or_default(),
1081        tokens,
1082        excluded,
1083        preference_form: matches!(config.task, ExportTask::Preference).then(|| {
1084            match config.preference_form {
1085                PreferenceForm::Paired => "paired",
1086                PreferenceForm::Unpaired => "unpaired",
1087                PreferenceForm::Ranked => "ranked",
1088            }
1089            .to_string()
1090        }),
1091        reward: Distribution::from_sorted(reward_values),
1092        reward_sources,
1093    }
1094}
1095
1096impl ContextStore {
1097    /// Curate stored records and export them as task-shaped JSONL plus a
1098    /// sibling `<output_path>.manifest.json`, returning the manifest.
1099    ///
1100    /// If `config.version` is set the export is pinned to that dataset version
1101    /// (time-travel) and the store is restored to its current version
1102    /// afterward. Output is written line-by-line, so the full JSONL is never
1103    /// held in memory at once.
1104    pub async fn export_training(
1105        &mut self,
1106        config: &ExportConfig,
1107        output_path: &str,
1108    ) -> LanceResult<ExportManifest> {
1109        let restore = match config.version {
1110            Some(target) => {
1111                let original = self.version();
1112                self.checkout(target).await?;
1113                Some(original)
1114            }
1115            None => None,
1116        };
1117
1118        let result = self.export_inner(config, output_path).await;
1119
1120        if let Some(original) = restore {
1121            self.checkout(original).await?;
1122        }
1123        result
1124    }
1125
1126    async fn export_inner(
1127        &self,
1128        config: &ExportConfig,
1129        output_path: &str,
1130    ) -> LanceResult<ExportManifest> {
1131        let context_uri = self.uri().to_string();
1132        let version = self.version();
1133
1134        let records = self
1135            .list_filtered_with_options(
1136                None,
1137                None,
1138                config.filters.as_ref(),
1139                config.lifecycle.clone(),
1140            )
1141            .await?;
1142        let (curated, counts) = curate(records, config);
1143
1144        let Some(split) = &config.split else {
1145            return emit_export(
1146                curated,
1147                config,
1148                &context_uri,
1149                version,
1150                counts,
1151                output_path,
1152                None,
1153            );
1154        };
1155
1156        // Group-disjoint, reproducible partition: a record's side is decided by
1157        // a stable hash of its `split.by` key, so no group spans both sides.
1158        let (train_path, eval_path) = split_paths(output_path);
1159        let mut train_records = Vec::new();
1160        let mut eval_records = Vec::new();
1161        for record in curated {
1162            let key = split.by.key(&record);
1163            if split_fraction(split.seed, &key) < split.eval_fraction {
1164                eval_records.push(record);
1165            } else {
1166                train_records.push(record);
1167            }
1168        }
1169
1170        let train_manifest = emit_export(
1171            train_records,
1172            config,
1173            &context_uri,
1174            version,
1175            counts,
1176            &train_path,
1177            Some(SplitManifest {
1178                side: "train".to_string(),
1179                eval_fraction: split.eval_fraction,
1180                by: split.by.label(),
1181                seed: split.seed,
1182                complement_path: eval_path.clone(),
1183            }),
1184        )?;
1185        emit_export(
1186            eval_records,
1187            config,
1188            &context_uri,
1189            version,
1190            counts,
1191            &eval_path,
1192            Some(SplitManifest {
1193                side: "eval".to_string(),
1194                eval_fraction: split.eval_fraction,
1195                by: split.by.label(),
1196                seed: split.seed,
1197                complement_path: train_path.clone(),
1198            }),
1199        )?;
1200        Ok(train_manifest)
1201    }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206    use super::*;
1207    use crate::record::LIFECYCLE_ACTIVE;
1208    use crate::store::{ContextStore, ContextStoreOptions};
1209    use chrono::TimeZone;
1210    use serde_json::json;
1211    use tempfile::TempDir;
1212
1213    const DIM: i32 = 4;
1214
1215    async fn open_store(dir: &TempDir) -> ContextStore {
1216        let uri = dir.path().join("ctx.lance").to_string_lossy().to_string();
1217        ContextStore::open_with_options(
1218            &uri,
1219            ContextStoreOptions {
1220                embedding_dim: Some(DIM),
1221                ..Default::default()
1222            },
1223        )
1224        .await
1225        .unwrap()
1226    }
1227
1228    fn rec(id: &str, role: &str, text: &str, secs: i64) -> ContextRecord {
1229        ContextRecord {
1230            id: id.to_string(),
1231            external_id: None,
1232            run_id: "run".to_string(),
1233            bot_id: None,
1234            session_id: Some("s1".to_string()),
1235            tenant: None,
1236            source: None,
1237            created_at: Utc.timestamp_opt(1_700_000_000 + secs, 0).unwrap(),
1238            role: role.to_string(),
1239            state_metadata: None,
1240            metadata: None,
1241            relationships: Vec::new(),
1242            expires_at: None,
1243            retention_policy: None,
1244            lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
1245            retired_at: None,
1246            retired_reason: None,
1247            supersedes_id: None,
1248            superseded_by_id: None,
1249            content_type: "text/plain".to_string(),
1250            text_payload: Some(text.to_string()),
1251            binary_payload: None,
1252            embedding: None,
1253        }
1254    }
1255
1256    fn emb(lead: &[f32]) -> Vec<f32> {
1257        let mut v = vec![0.0f32; DIM as usize];
1258        for (i, x) in lead.iter().enumerate() {
1259            v[i] = *x;
1260        }
1261        v
1262    }
1263
1264    fn read_lines(path: &str) -> Vec<Value> {
1265        std::fs::read_to_string(path)
1266            .unwrap()
1267            .lines()
1268            .filter(|l| !l.trim().is_empty())
1269            .map(|l| serde_json::from_str(l).unwrap())
1270            .collect()
1271    }
1272
1273    fn read_manifest(path: &str) -> ExportManifest {
1274        let raw = std::fs::read_to_string(format!("{path}.manifest.json")).unwrap();
1275        serde_json::from_str(&raw).unwrap()
1276    }
1277
1278    fn out_path(dir: &TempDir) -> String {
1279        dir.path().join("out.jsonl").to_string_lossy().to_string()
1280    }
1281
1282    #[test]
1283    fn sft_groups_session_into_ordered_conversation() {
1284        let dir = TempDir::new().unwrap();
1285        let runtime = tokio::runtime::Runtime::new().unwrap();
1286        runtime.block_on(async {
1287            let mut store = open_store(&dir).await;
1288            // Added out of order; export must order by (created_at, id).
1289            store
1290                .add(&[
1291                    rec("r2", "assistant", "hi there", 2),
1292                    rec("r1", "user", "hello", 1),
1293                    rec("r3", "user", "bye", 3),
1294                ])
1295                .await
1296                .unwrap();
1297
1298            let out = out_path(&dir);
1299            let manifest = store
1300                .export_training(
1301                    &ExportConfig {
1302                        task: ExportTask::Sft,
1303                        group_by: GroupBy::SessionId,
1304                        ..Default::default()
1305                    },
1306                    &out,
1307                )
1308                .await
1309                .unwrap();
1310
1311            let lines = read_lines(&out);
1312            assert_eq!(lines.len(), 1);
1313            let messages = lines[0]["messages"].as_array().unwrap();
1314            let contents: Vec<&str> = messages
1315                .iter()
1316                .map(|m| m["content"].as_str().unwrap())
1317                .collect();
1318            assert_eq!(contents, ["hello", "hi there", "bye"]);
1319            assert_eq!(manifest.task, "sft");
1320            assert_eq!(manifest.counts.examples, 1);
1321            assert_eq!(lines[0]["provenance"]["version"], json!(manifest.version));
1322
1323            // The sibling manifest file is written and matches the return value.
1324            let manifest_file = read_manifest(&out);
1325            assert_eq!(manifest_file.task, "sft");
1326            assert_eq!(manifest_file.counts.examples, 1);
1327            assert_eq!(manifest_file.schema_version, EXPORT_SCHEMA_VERSION);
1328            assert_eq!(manifest_file.source_record_ids.len(), 3);
1329        });
1330    }
1331
1332    #[test]
1333    fn sft_rejection_sampling_filters_low_reward() {
1334        let dir = TempDir::new().unwrap();
1335        let runtime = tokio::runtime::Runtime::new().unwrap();
1336        runtime.block_on(async {
1337            let mut store = open_store(&dir).await;
1338            let mut good = rec("g", "assistant", "good", 1);
1339            good.session_id = Some("a".to_string());
1340            good.metadata = Some(json!({"reward": 0.9}));
1341            let mut bad = rec("b", "assistant", "bad", 2);
1342            bad.session_id = Some("b".to_string());
1343            bad.metadata = Some(json!({"reward": 0.1}));
1344            store.add(&[good, bad]).await.unwrap();
1345
1346            let out = out_path(&dir);
1347            let manifest = store
1348                .export_training(
1349                    &ExportConfig {
1350                        task: ExportTask::Sft,
1351                        group_by: GroupBy::SessionId,
1352                        min_reward: Some(0.5),
1353                        ..Default::default()
1354                    },
1355                    &out,
1356                )
1357                .await
1358                .unwrap();
1359
1360            let lines = read_lines(&out);
1361            assert_eq!(lines.len(), 1, "only the high-reward record survives");
1362            assert_eq!(lines[0]["messages"][0]["content"], "good");
1363            assert_eq!(manifest.counts.after_reward_filter, 1);
1364            assert_eq!(manifest.min_reward, Some(0.5));
1365        });
1366    }
1367
1368    #[test]
1369    fn preference_paired_uses_reward_for_chosen_rejected() {
1370        let dir = TempDir::new().unwrap();
1371        let runtime = tokio::runtime::Runtime::new().unwrap();
1372        runtime.block_on(async {
1373            let mut store = open_store(&dir).await;
1374            let prompt = rec("p", "user", "question", 1);
1375            let mut hi = rec("hi", "assistant", "great answer", 2);
1376            hi.metadata = Some(json!({"reward": 0.9}));
1377            let mut lo = rec("lo", "assistant", "poor answer", 3);
1378            lo.metadata = Some(json!({"reward": 0.2}));
1379            store.add(&[prompt, hi, lo]).await.unwrap();
1380
1381            let out = out_path(&dir);
1382            store
1383                .export_training(
1384                    &ExportConfig {
1385                        task: ExportTask::Preference,
1386                        preference_form: PreferenceForm::Paired,
1387                        group_by: GroupBy::SessionId,
1388                        ..Default::default()
1389                    },
1390                    &out,
1391                )
1392                .await
1393                .unwrap();
1394
1395            let lines = read_lines(&out);
1396            assert_eq!(lines.len(), 1);
1397            assert_eq!(lines[0]["form"], "paired");
1398            assert_eq!(lines[0]["prompt"][0]["content"], "question");
1399            assert_eq!(lines[0]["chosen"][0]["content"], "great answer");
1400            assert_eq!(lines[0]["rejected"][0]["content"], "poor answer");
1401        });
1402    }
1403
1404    #[test]
1405    fn preference_unpaired_uses_kto_labels() {
1406        let dir = TempDir::new().unwrap();
1407        let runtime = tokio::runtime::Runtime::new().unwrap();
1408        runtime.block_on(async {
1409            let mut store = open_store(&dir).await;
1410            let prompt = rec("p", "user", "q", 1);
1411            let mut a = rec("a", "assistant", "yes", 2);
1412            a.metadata = Some(json!({"label": "chosen"}));
1413            let mut b = rec("b", "assistant", "no", 3);
1414            b.metadata = Some(json!({"label": "rejected"}));
1415            store.add(&[prompt, a, b]).await.unwrap();
1416
1417            let out = out_path(&dir);
1418            store
1419                .export_training(
1420                    &ExportConfig {
1421                        task: ExportTask::Preference,
1422                        preference_form: PreferenceForm::Unpaired,
1423                        group_by: GroupBy::SessionId,
1424                        ..Default::default()
1425                    },
1426                    &out,
1427                )
1428                .await
1429                .unwrap();
1430
1431            let lines = read_lines(&out);
1432            assert_eq!(lines.len(), 2);
1433            assert!(lines.iter().all(|l| l["form"] == "unpaired"));
1434            let chosen = lines
1435                .iter()
1436                .find(|l| l["completion"][0]["content"] == "yes")
1437                .unwrap();
1438            assert_eq!(chosen["label"], json!(true));
1439            let rejected = lines
1440                .iter()
1441                .find(|l| l["completion"][0]["content"] == "no")
1442                .unwrap();
1443            assert_eq!(rejected["label"], json!(false));
1444        });
1445    }
1446
1447    #[test]
1448    fn preference_ranked_orders_by_rank() {
1449        let dir = TempDir::new().unwrap();
1450        let runtime = tokio::runtime::Runtime::new().unwrap();
1451        runtime.block_on(async {
1452            let mut store = open_store(&dir).await;
1453            let prompt = rec("p", "user", "q", 1);
1454            let mut second = rec("c2", "assistant", "second", 2);
1455            second.metadata = Some(json!({"rank": 2}));
1456            let mut first = rec("c1", "assistant", "first", 3);
1457            first.metadata = Some(json!({"rank": 1}));
1458            store.add(&[prompt, second, first]).await.unwrap();
1459
1460            let out = out_path(&dir);
1461            store
1462                .export_training(
1463                    &ExportConfig {
1464                        task: ExportTask::Preference,
1465                        preference_form: PreferenceForm::Ranked,
1466                        group_by: GroupBy::SessionId,
1467                        ..Default::default()
1468                    },
1469                    &out,
1470                )
1471                .await
1472                .unwrap();
1473
1474            let lines = read_lines(&out);
1475            assert_eq!(lines.len(), 1);
1476            assert_eq!(lines[0]["form"], "ranked");
1477            let cands = lines[0]["candidates"].as_array().unwrap();
1478            assert_eq!(cands[0]["messages"][0]["content"], "first");
1479            assert_eq!(cands[0]["rank"], json!(1));
1480            assert_eq!(cands[1]["messages"][0]["content"], "second");
1481        });
1482    }
1483
1484    #[test]
1485    fn rollout_groups_responses_with_rewards() {
1486        let dir = TempDir::new().unwrap();
1487        let runtime = tokio::runtime::Runtime::new().unwrap();
1488        runtime.block_on(async {
1489            let mut store = open_store(&dir).await;
1490            let prompt = rec("p", "user", "solve x", 1);
1491            let mut r1 = rec("r1", "assistant", "ans1", 2);
1492            r1.metadata =
1493                Some(json!({"reward": 1.0, "reward_source": "verifier", "group_id": "g1"}));
1494            let mut r2 = rec("r2", "assistant", "ans2", 3);
1495            r2.metadata =
1496                Some(json!({"reward": 0.0, "reward_source": "verifier", "group_id": "g1"}));
1497            store.add(&[prompt, r1, r2]).await.unwrap();
1498
1499            let out = out_path(&dir);
1500            store
1501                .export_training(
1502                    &ExportConfig {
1503                        task: ExportTask::Rollout,
1504                        group_by: GroupBy::SessionId,
1505                        ..Default::default()
1506                    },
1507                    &out,
1508                )
1509                .await
1510                .unwrap();
1511
1512            let lines = read_lines(&out);
1513            assert_eq!(lines.len(), 1);
1514            assert_eq!(lines[0]["group_id"], "g1");
1515            assert_eq!(lines[0]["prompt"][0]["content"], "solve x");
1516            let responses = lines[0]["responses"].as_array().unwrap();
1517            assert_eq!(responses.len(), 2);
1518            assert_eq!(responses[0]["reward"], json!(1.0));
1519            assert_eq!(responses[0]["reward_source"], "verifier");
1520        });
1521    }
1522
1523    #[test]
1524    fn dedup_collapses_near_duplicates() {
1525        let dir = TempDir::new().unwrap();
1526        let runtime = tokio::runtime::Runtime::new().unwrap();
1527        runtime.block_on(async {
1528            let mut store = open_store(&dir).await;
1529            let mut a = rec("a", "user", "dup one", 1);
1530            a.session_id = Some("a".to_string());
1531            a.embedding = Some(emb(&[1.0, 0.0]));
1532            let mut b = rec("b", "user", "dup two", 2);
1533            b.session_id = Some("b".to_string());
1534            b.embedding = Some(emb(&[1.0, 0.0])); // identical direction -> distance 0
1535            store.add(&[a, b]).await.unwrap();
1536
1537            let out = out_path(&dir);
1538            let manifest = store
1539                .export_training(
1540                    &ExportConfig {
1541                        task: ExportTask::Sft,
1542                        group_by: GroupBy::SessionId,
1543                        dedup_threshold: Some(0.01),
1544                        ..Default::default()
1545                    },
1546                    &out,
1547                )
1548                .await
1549                .unwrap();
1550
1551            assert_eq!(
1552                manifest.counts.after_dedup, 1,
1553                "one near-duplicate collapsed"
1554            );
1555            assert_eq!(read_lines(&out).len(), 1);
1556        });
1557    }
1558
1559    #[test]
1560    fn decontaminate_drops_holdout_matches() {
1561        let dir = TempDir::new().unwrap();
1562        let runtime = tokio::runtime::Runtime::new().unwrap();
1563        runtime.block_on(async {
1564            let mut store = open_store(&dir).await;
1565            let mut keep = rec("k", "user", "keep", 1);
1566            keep.session_id = Some("k".to_string());
1567            keep.embedding = Some(emb(&[0.0, 1.0]));
1568            let mut leak = rec("l", "user", "leak", 2);
1569            leak.session_id = Some("l".to_string());
1570            leak.embedding = Some(emb(&[1.0, 0.0]));
1571            store.add(&[keep, leak]).await.unwrap();
1572
1573            let out = out_path(&dir);
1574            let manifest = store
1575                .export_training(
1576                    &ExportConfig {
1577                        task: ExportTask::Sft,
1578                        group_by: GroupBy::SessionId,
1579                        decontaminate_against: vec![emb(&[1.0, 0.0])], // matches "leak"
1580                        decontaminate_threshold: Some(0.01),
1581                        ..Default::default()
1582                    },
1583                    &out,
1584                )
1585                .await
1586                .unwrap();
1587
1588            assert_eq!(manifest.counts.after_decontaminate, 1);
1589            let lines = read_lines(&out);
1590            assert_eq!(lines.len(), 1);
1591            assert_eq!(lines[0]["messages"][0]["content"], "keep");
1592        });
1593    }
1594
1595    #[test]
1596    fn curation_drops_contradicted_records() {
1597        let dir = TempDir::new().unwrap();
1598        let runtime = tokio::runtime::Runtime::new().unwrap();
1599        runtime.block_on(async {
1600            let mut store = open_store(&dir).await;
1601            let good = rec("g", "user", "valid", 1);
1602            let mut bad = rec("c", "user", "contradicted", 2);
1603            bad.session_id = Some("other".to_string());
1604            bad.lifecycle_status = LIFECYCLE_CONTRADICTED.to_string();
1605            store.add(&[good, bad]).await.unwrap();
1606
1607            let out = out_path(&dir);
1608            let manifest = store
1609                .export_training(&ExportConfig::default(), &out)
1610                .await
1611                .unwrap();
1612
1613            assert_eq!(manifest.counts.after_lifecycle, 1);
1614            let lines = read_lines(&out);
1615            assert!(lines
1616                .iter()
1617                .all(|l| l["messages"][0]["content"] != "contradicted"));
1618        });
1619    }
1620
1621    #[test]
1622    fn version_pinning_exports_old_state_and_restores() {
1623        let dir = TempDir::new().unwrap();
1624        let runtime = tokio::runtime::Runtime::new().unwrap();
1625        runtime.block_on(async {
1626            let mut store = open_store(&dir).await;
1627            store.add(&[rec("r1", "user", "first", 1)]).await.unwrap();
1628            store.compact(None).await.unwrap();
1629            let pinned = store.version();
1630
1631            store.add(&[rec("r2", "user", "second", 2)]).await.unwrap();
1632            store.compact(None).await.unwrap();
1633            let latest = store.version();
1634
1635            let out = out_path(&dir);
1636            let manifest = store
1637                .export_training(
1638                    &ExportConfig {
1639                        task: ExportTask::Sft,
1640                        group_by: GroupBy::None,
1641                        version: Some(pinned),
1642                        ..Default::default()
1643                    },
1644                    &out,
1645                )
1646                .await
1647                .unwrap();
1648
1649            assert_eq!(manifest.version, pinned);
1650            assert_eq!(
1651                store.version(),
1652                latest,
1653                "store restored after pinned export"
1654            );
1655        });
1656    }
1657
1658    #[test]
1659    fn export_is_reproducible() {
1660        let dir = TempDir::new().unwrap();
1661        let runtime = tokio::runtime::Runtime::new().unwrap();
1662        runtime.block_on(async {
1663            let mut store = open_store(&dir).await;
1664            store
1665                .add(&[
1666                    rec("r1", "user", "a", 1),
1667                    rec("r2", "assistant", "b", 2),
1668                    rec("r3", "user", "c", 3),
1669                ])
1670                .await
1671                .unwrap();
1672
1673            let config = ExportConfig {
1674                task: ExportTask::Sft,
1675                group_by: GroupBy::SessionId,
1676                ..Default::default()
1677            };
1678            let first = out_path(&dir);
1679            let second = dir.path().join("out2.jsonl").to_string_lossy().to_string();
1680            store.export_training(&config, &first).await.unwrap();
1681            store.export_training(&config, &second).await.unwrap();
1682
1683            assert_eq!(
1684                std::fs::read_to_string(&first).unwrap(),
1685                std::fs::read_to_string(&second).unwrap(),
1686                "same version + config produces identical output"
1687            );
1688        });
1689    }
1690
1691    #[test]
1692    fn external_id_prefix_grouping() {
1693        let dir = TempDir::new().unwrap();
1694        let runtime = tokio::runtime::Runtime::new().unwrap();
1695        runtime.block_on(async {
1696            let mut store = open_store(&dir).await;
1697            let mut a = rec("a", "user", "doc7 turn1", 1);
1698            a.external_id = Some("doc-7#chunk-1".to_string());
1699            let mut b = rec("b", "assistant", "doc7 turn2", 2);
1700            b.external_id = Some("doc-7#chunk-2".to_string());
1701            let mut c = rec("c", "user", "doc8 turn1", 3);
1702            c.external_id = Some("doc-8#chunk-1".to_string());
1703            store.add(&[a, b, c]).await.unwrap();
1704
1705            let out = out_path(&dir);
1706            store
1707                .export_training(
1708                    &ExportConfig {
1709                        task: ExportTask::Sft,
1710                        group_by: GroupBy::ExternalIdPrefix("#".to_string()),
1711                        ..Default::default()
1712                    },
1713                    &out,
1714                )
1715                .await
1716                .unwrap();
1717
1718            let lines = read_lines(&out);
1719            // doc-7 (2 turns) + doc-8 (1 turn) => 2 examples.
1720            assert_eq!(lines.len(), 2);
1721            assert_eq!(lines[0]["messages"].as_array().unwrap().len(), 2);
1722        });
1723    }
1724
1725    fn session_ids(path: &str) -> std::collections::HashSet<String> {
1726        read_lines(path)
1727            .iter()
1728            .filter_map(|line| {
1729                line["provenance"]["session_id"]
1730                    .as_str()
1731                    .map(str::to_string)
1732            })
1733            .collect()
1734    }
1735
1736    #[test]
1737    fn split_is_deterministic_and_group_disjoint() {
1738        let dir = TempDir::new().unwrap();
1739        let runtime = tokio::runtime::Runtime::new().unwrap();
1740        runtime.block_on(async {
1741            let mut store = open_store(&dir).await;
1742            let mut records = Vec::new();
1743            for s in 0..10 {
1744                let mut user = rec(&format!("u{s}"), "user", "q", s * 2);
1745                user.session_id = Some(format!("s{s}"));
1746                let mut asst = rec(&format!("a{s}"), "assistant", "r", s * 2 + 1);
1747                asst.session_id = Some(format!("s{s}"));
1748                records.push(user);
1749                records.push(asst);
1750            }
1751            store.add(&records).await.unwrap();
1752
1753            let config = ExportConfig {
1754                task: ExportTask::Sft,
1755                group_by: GroupBy::SessionId,
1756                split: Some(SplitConfig {
1757                    eval_fraction: 0.5,
1758                    by: GroupBy::SessionId,
1759                    seed: 42,
1760                }),
1761                ..Default::default()
1762            };
1763
1764            let base1 = dir.path().join("a.jsonl").to_string_lossy().to_string();
1765            let base2 = dir.path().join("b.jsonl").to_string_lossy().to_string();
1766            store.export_training(&config, &base1).await.unwrap();
1767            store.export_training(&config, &base2).await.unwrap();
1768
1769            let (train1, eval1) = split_paths(&base1);
1770            let (train2, eval2) = split_paths(&base2);
1771
1772            // Determinism: identical partition across runs with the same seed.
1773            assert_eq!(
1774                std::fs::read_to_string(&train1).unwrap(),
1775                std::fs::read_to_string(&train2).unwrap()
1776            );
1777            assert_eq!(
1778                std::fs::read_to_string(&eval1).unwrap(),
1779                std::fs::read_to_string(&eval2).unwrap()
1780            );
1781
1782            // Group-disjoint: no session appears on both sides.
1783            let train_sessions = session_ids(&train1);
1784            let eval_sessions = session_ids(&eval1);
1785            assert!(!train_sessions.is_empty() && !eval_sessions.is_empty());
1786            assert!(train_sessions.is_disjoint(&eval_sessions));
1787            assert_eq!(train_sessions.len() + eval_sessions.len(), 10);
1788        });
1789    }
1790
1791    #[test]
1792    fn split_fraction_is_approximately_respected() {
1793        let dir = TempDir::new().unwrap();
1794        let runtime = tokio::runtime::Runtime::new().unwrap();
1795        runtime.block_on(async {
1796            let mut store = open_store(&dir).await;
1797            let mut records = Vec::new();
1798            for s in 0..200 {
1799                let mut r = rec(&format!("r{s}"), "user", "q", s);
1800                r.session_id = Some(format!("s{s}"));
1801                records.push(r);
1802            }
1803            store.add(&records).await.unwrap();
1804
1805            let config = ExportConfig {
1806                task: ExportTask::Sft,
1807                group_by: GroupBy::SessionId,
1808                split: Some(SplitConfig {
1809                    eval_fraction: 0.25,
1810                    by: GroupBy::SessionId,
1811                    seed: 7,
1812                }),
1813                ..Default::default()
1814            };
1815            let base = dir.path().join("c.jsonl").to_string_lossy().to_string();
1816            store.export_training(&config, &base).await.unwrap();
1817
1818            let (train, eval) = split_paths(&base);
1819            let eval_count = read_lines(&eval).len();
1820            let train_count = read_lines(&train).len();
1821            assert_eq!(train_count + eval_count, 200);
1822            let fraction = eval_count as f64 / 200.0;
1823            assert!(
1824                (fraction - 0.25).abs() < 0.1,
1825                "eval fraction {fraction} too far from 0.25"
1826            );
1827        });
1828    }
1829
1830    #[test]
1831    fn split_manifests_record_params_and_complement() {
1832        let dir = TempDir::new().unwrap();
1833        let runtime = tokio::runtime::Runtime::new().unwrap();
1834        runtime.block_on(async {
1835            let mut store = open_store(&dir).await;
1836            let mut a = rec("a", "user", "x", 1);
1837            a.session_id = Some("s1".to_string());
1838            let mut b = rec("b", "user", "y", 2);
1839            b.session_id = Some("s2".to_string());
1840            store.add(&[a, b]).await.unwrap();
1841
1842            let config = ExportConfig {
1843                task: ExportTask::Sft,
1844                group_by: GroupBy::SessionId,
1845                split: Some(SplitConfig {
1846                    eval_fraction: 0.5,
1847                    by: GroupBy::SessionId,
1848                    seed: 99,
1849                }),
1850                ..Default::default()
1851            };
1852            let base = dir.path().join("d.jsonl").to_string_lossy().to_string();
1853            store.export_training(&config, &base).await.unwrap();
1854
1855            let (train, eval) = split_paths(&base);
1856            assert!(std::fs::metadata(&train).is_ok());
1857            assert!(std::fs::metadata(&eval).is_ok());
1858
1859            let train_manifest = read_manifest(&train);
1860            let split = train_manifest.split.unwrap();
1861            assert_eq!(split.side, "train");
1862            assert_eq!(split.seed, 99);
1863            assert_eq!(split.eval_fraction, 0.5);
1864            assert_eq!(split.by, "session_id");
1865            assert_eq!(split.complement_path, eval);
1866
1867            let eval_manifest = read_manifest(&eval);
1868            assert_eq!(eval_manifest.split.unwrap().side, "eval");
1869        });
1870    }
1871
1872    fn read_stats(path: &str) -> ExportStats {
1873        let raw = std::fs::read_to_string(format!("{path}.stats.json")).unwrap();
1874        serde_json::from_str(&raw).unwrap()
1875    }
1876
1877    #[test]
1878    fn stats_report_counts_roles_tokens_and_exclusions() {
1879        let dir = TempDir::new().unwrap();
1880        let runtime = tokio::runtime::Runtime::new().unwrap();
1881        runtime.block_on(async {
1882            let mut store = open_store(&dir).await;
1883            let mut user = rec("u", "user", "hello there", 1);
1884            user.source = Some("memory".to_string());
1885            user.tenant = Some("acme".to_string());
1886            user.state_metadata = Some(crate::record::StateMetadata {
1887                tokens_used: Some(5),
1888                ..Default::default()
1889            });
1890            let mut asst = rec("a", "assistant", "hi", 2);
1891            asst.source = Some("memory".to_string());
1892            asst.tenant = Some("acme".to_string());
1893            asst.state_metadata = Some(crate::record::StateMetadata {
1894                tokens_used: Some(11),
1895                ..Default::default()
1896            });
1897            // a contradicted record that curation excludes by lifecycle
1898            let mut dropped = rec("d", "user", "nope", 3);
1899            dropped.session_id = Some("other".to_string());
1900            dropped.lifecycle_status = LIFECYCLE_CONTRADICTED.to_string();
1901            store.add(&[user, asst, dropped]).await.unwrap();
1902
1903            let out = out_path(&dir);
1904            store
1905                .export_training(
1906                    &ExportConfig {
1907                        task: ExportTask::Sft,
1908                        group_by: GroupBy::SessionId,
1909                        emit_stats: true,
1910                        ..Default::default()
1911                    },
1912                    &out,
1913                )
1914                .await
1915                .unwrap();
1916
1917            let stats = read_stats(&out);
1918            assert_eq!(stats.task, "sft");
1919            assert_eq!(stats.examples, 1);
1920            assert_eq!(stats.num_groups, 1);
1921            assert_eq!(stats.by_role.get("user"), Some(&1));
1922            assert_eq!(stats.by_role.get("assistant"), Some(&1));
1923            assert_eq!(stats.by_source.get("memory"), Some(&2));
1924            assert_eq!(stats.by_tenant.get("acme"), Some(&2));
1925            assert_eq!(stats.excluded.lifecycle, 1, "contradicted record excluded");
1926
1927            let tokens = stats.tokens.unwrap();
1928            assert_eq!(tokens.source, "tokens_used");
1929            assert_eq!(tokens.distribution.count, 2);
1930            assert_eq!(tokens.distribution.min, 5.0);
1931            assert_eq!(tokens.distribution.max, 11.0);
1932            assert_eq!(tokens.distribution.mean, 8.0);
1933        });
1934    }
1935
1936    #[test]
1937    fn stats_token_fallback_uses_length_proxy() {
1938        let dir = TempDir::new().unwrap();
1939        let runtime = tokio::runtime::Runtime::new().unwrap();
1940        runtime.block_on(async {
1941            let mut store = open_store(&dir).await;
1942            // no state_metadata -> fallback to whitespace word count
1943            store
1944                .add(&[rec("u", "user", "one two three four", 1)])
1945                .await
1946                .unwrap();
1947
1948            let out = out_path(&dir);
1949            store
1950                .export_training(
1951                    &ExportConfig {
1952                        emit_stats: true,
1953                        ..Default::default()
1954                    },
1955                    &out,
1956                )
1957                .await
1958                .unwrap();
1959
1960            let tokens = read_stats(&out).tokens.unwrap();
1961            assert_eq!(tokens.source, "length_proxy");
1962            assert_eq!(tokens.distribution.max, 4.0);
1963        });
1964    }
1965
1966    #[test]
1967    fn stats_report_reward_distribution_for_rollout() {
1968        let dir = TempDir::new().unwrap();
1969        let runtime = tokio::runtime::Runtime::new().unwrap();
1970        runtime.block_on(async {
1971            let mut store = open_store(&dir).await;
1972            let prompt = rec("p", "user", "solve", 1);
1973            let mut r1 = rec("r1", "assistant", "a1", 2);
1974            r1.metadata = Some(json!({"reward": 1.0, "reward_source": "verifier"}));
1975            let mut r2 = rec("r2", "assistant", "a2", 3);
1976            r2.metadata = Some(json!({"reward": 0.0, "reward_source": "verifier"}));
1977            store.add(&[prompt, r1, r2]).await.unwrap();
1978
1979            let out = out_path(&dir);
1980            store
1981                .export_training(
1982                    &ExportConfig {
1983                        task: ExportTask::Rollout,
1984                        group_by: GroupBy::SessionId,
1985                        emit_stats: true,
1986                        ..Default::default()
1987                    },
1988                    &out,
1989                )
1990                .await
1991                .unwrap();
1992
1993            let stats = read_stats(&out);
1994            let reward = stats.reward.unwrap();
1995            assert_eq!(reward.count, 2);
1996            assert_eq!(reward.min, 0.0);
1997            assert_eq!(reward.max, 1.0);
1998            assert_eq!(stats.reward_sources.get("verifier"), Some(&2));
1999        });
2000    }
2001
2002    #[test]
2003    fn stats_not_written_without_flag() {
2004        let dir = TempDir::new().unwrap();
2005        let runtime = tokio::runtime::Runtime::new().unwrap();
2006        runtime.block_on(async {
2007            let mut store = open_store(&dir).await;
2008            store.add(&[rec("u", "user", "hi", 1)]).await.unwrap();
2009            let out = out_path(&dir);
2010            store
2011                .export_training(&ExportConfig::default(), &out)
2012                .await
2013                .unwrap();
2014            assert!(std::fs::metadata(format!("{out}.stats.json")).is_err());
2015        });
2016    }
2017}