swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
//! DPO (Direct Preference Optimization) LearnModel
//!
//! group_id でグループ化された Episode を比較し、DPO 学習用データを生成する。

use std::collections::HashMap;

use super::super::episode::{Episode, EpisodeContext, Outcome};
use super::super::record::Record;
use super::super::training::TrainingData;
use super::{LearnError, LearnModel};
use crate::types::GroupId;

/// DPO 学習用の比較ペア
///
/// 同じ group_id 内の成功/失敗 Episode から生成される。
#[derive(Debug, Clone)]
pub struct DpoPair {
    /// 成功した Episode
    pub chosen: Episode,
    /// 失敗した Episode
    pub rejected: Episode,
    /// 共通の group_id
    pub group_id: GroupId,
    /// 品質差(chosen.score - rejected.score)
    pub quality_gap: f64,
}

impl DpoPair {
    /// 新しい DpoPair を作成
    pub fn new(chosen: Episode, rejected: Episode, group_id: GroupId) -> Self {
        let chosen_score = chosen.outcome.score();
        let rejected_score = rejected.outcome.score();
        let quality_gap = chosen_score - rejected_score;

        Self {
            chosen,
            rejected,
            group_id,
            quality_gap,
        }
    }
}

/// DPO LearnModel の設定
#[derive(Debug, Clone)]
pub struct DpoConfig {
    /// 最小品質差(この差未満のペアは除外)
    pub min_quality_gap: f64,
    /// 最大ペア数(None なら無制限)
    pub max_pairs: Option<usize>,
    /// 同じエピソードの重複使用を許可
    pub allow_reuse: bool,
}

impl Default for DpoConfig {
    fn default() -> Self {
        Self {
            min_quality_gap: 0.1, // 10% 以上の差
            max_pairs: None,
            allow_reuse: true,
        }
    }
}

/// 汎用 DPO LearnModel
///
/// group_id でグループ化された Episode を比較し、DPO 学習用データを生成する。
///
/// ## 設計思想
///
/// DPO 学習では「同じ条件で複数回実行した結果を比較」する。
/// - group_id: 同じ条件での実行グループ(Eval -n 5 で 5 回実行など)
/// - 成功 Episode と失敗 Episode をペアにして比較
///
/// ## 使用方法
///
/// ```ignore
/// // Eval で group_id 付きの Episode を収集
/// let episodes: Vec<Episode> = ...;
///
/// // DPO ペアを生成
/// let dpo_learn = DpoLearnModel::new();
/// let pairs = dpo_learn.build_pairs(&episodes);
///
/// // TrainingData に変換
/// let training_data: Vec<TrainingData> = pairs
///     .iter()
///     .filter_map(|pair| dpo_learn.convert_pair(pair).ok())
///     .collect();
/// ```
pub struct DpoLearnModel<F>
where
    F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
{
    /// システムプロンプト
    system_prompt: String,
    /// 設定
    config: DpoConfig,
    /// Episode から (prompt, response) を抽出する関数
    extractor: F,
}

impl<F> DpoLearnModel<F>
where
    F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
{
    /// 新しい DpoLearnModel を作成
    pub fn new(extractor: F) -> Self {
        Self {
            system_prompt: String::new(),
            config: DpoConfig::default(),
            extractor,
        }
    }

    /// システムプロンプトを設定
    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.system_prompt = prompt.into();
        self
    }

    /// 設定を適用
    pub fn with_config(mut self, config: DpoConfig) -> Self {
        self.config = config;
        self
    }

    /// 最小品質差を設定
    pub fn with_min_quality_gap(mut self, gap: f64) -> Self {
        self.config.min_quality_gap = gap;
        self
    }

    /// 最大ペア数を設定
    pub fn with_max_pairs(mut self, max: usize) -> Self {
        self.config.max_pairs = Some(max);
        self
    }

    /// group_id でグループ化された Episode から DPO ペアを生成
    pub fn build_pairs(&self, episodes: &[Episode]) -> Vec<DpoPair> {
        // group_id でグループ化
        let mut by_group: HashMap<GroupId, Vec<&Episode>> = HashMap::new();
        for ep in episodes {
            if let Some(gid) = ep.group_id {
                by_group.entry(gid).or_default().push(ep);
            }
        }

        let mut pairs = Vec::new();

        for (group_id, group_episodes) in by_group {
            // 成功/失敗で分類
            let (successes, failures): (Vec<_>, Vec<_>) = group_episodes
                .into_iter()
                .partition(|ep| ep.outcome.is_success());

            if successes.is_empty() || failures.is_empty() {
                continue;
            }

            // スコアでソート(高い順)
            let mut sorted_successes: Vec<_> = successes;
            sorted_successes.sort_by(|a, b| {
                let a_score = a.outcome.score();
                let b_score = b.outcome.score();
                b_score
                    .partial_cmp(&a_score)
                    .unwrap_or(std::cmp::Ordering::Equal)
            });

            // スコアでソート(低い順)
            let mut sorted_failures: Vec<_> = failures;
            sorted_failures.sort_by(|a, b| {
                let a_score = a.outcome.score();
                let b_score = b.outcome.score();
                a_score
                    .partial_cmp(&b_score)
                    .unwrap_or(std::cmp::Ordering::Equal)
            });

            // ペア作成
            for success_ep in &sorted_successes {
                for failure_ep in &sorted_failures {
                    let chosen_score = success_ep.outcome.score();
                    let rejected_score = failure_ep.outcome.score();
                    let gap = chosen_score - rejected_score;

                    if gap < self.config.min_quality_gap {
                        continue;
                    }

                    let pair = DpoPair::new((*success_ep).clone(), (*failure_ep).clone(), group_id);
                    pairs.push(pair);

                    if !self.config.allow_reuse {
                        break;
                    }
                }

                if !self.config.allow_reuse {
                    break;
                }
            }
        }

        // 品質差でソート(大きい順)
        pairs.sort_by(|a, b| {
            b.quality_gap
                .partial_cmp(&a.quality_gap)
                .unwrap_or(std::cmp::Ordering::Equal)
        });

        // 最大数で制限
        if let Some(max) = self.config.max_pairs {
            pairs.truncate(max);
        }

        pairs
    }

    /// DPO ペアを TrainingData に変換
    pub fn convert_pair(&self, pair: &DpoPair) -> Result<TrainingData, LearnError> {
        let (chosen_prompt, chosen_response) = (self.extractor)(&pair.chosen)
            .ok_or_else(|| LearnError::MissingData("chosen prompt/response".into()))?;

        let (rejected_prompt, rejected_response) = (self.extractor)(&pair.rejected)
            .ok_or_else(|| LearnError::MissingData("rejected prompt/response".into()))?;

        // prompt が一致することを確認(正規化後)
        if chosen_prompt != rejected_prompt {
            return Err(LearnError::InvalidEpisode(format!(
                "Prompt mismatch: '{}' vs '{}'",
                chosen_prompt, rejected_prompt
            )));
        }

        let training = if self.system_prompt.is_empty() {
            TrainingData::dpo(&chosen_prompt, &chosen_response, &rejected_response)
        } else {
            TrainingData::dpo_with_system(
                &self.system_prompt,
                &chosen_prompt,
                &chosen_response,
                &rejected_response,
            )
        };

        Ok(training
            .with_episode_id(pair.chosen.id.to_string())
            .with_custom("rejected_episode_id", pair.rejected.id.to_string())
            .with_custom("quality_gap", pair.quality_gap.to_string())
            .with_custom("group_id", pair.group_id.0.to_string()))
    }

    /// 複数のペアを一括変換
    pub fn convert_pairs(&self, pairs: &[DpoPair]) -> Vec<TrainingData> {
        pairs
            .iter()
            .filter_map(|pair| self.convert_pair(pair).ok())
            .collect()
    }
}

/// LearnModel trait の実装(Record ベースの Episode 構築用)
///
/// DPO は通常、既存の Episode を比較するため、build_episodes は空を返す。
/// 実際の DPO ペア生成は build_pairs メソッドを使用。
impl<F> LearnModel for DpoLearnModel<F>
where
    F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
{
    fn name(&self) -> &str {
        "dpo"
    }

    fn objective(&self) -> &str {
        "Learn preferences from success/failure Episode pairs within the same group"
    }

    fn build_episodes(&self, _records: &[Record]) -> Vec<Episode> {
        // DPO は既存の Episode を比較するため、Record から Episode は生成しない
        vec![]
    }

    fn evaluate(&self, _context: &EpisodeContext) -> Outcome {
        // DpoLearnModel は複数 Episode を group_id でグルーピングし、
        // 成功/失敗のペアを比較して学習する。
        // 個々の Episode を evaluate() するのは設計として不適切。
        //
        // DPO のフロー:
        //   1. Eval 実行時に Episode が生成される(Outcome は Eval 側で設定)
        //   2. build_pairs() で group_id ごとにグルーピング
        //   3. 成功/失敗 Episode のペアから TrainingData を生成
        //
        // この evaluate() が呼ばれるのは実装ミス。
        panic!(
            "DpoLearnModel::evaluate() should not be called.\n\
             DPO learning compares multiple Episodes by group_id, not individual Episode evaluation.\n\
             Use build_pairs() to generate training pairs from Episodes."
        );
    }

    fn convert(&self, _episode: &Episode) -> Result<TrainingData, LearnError> {
        // 単一の Episode からは DPO TrainingData は生成できない
        // convert_pair を使用すること
        Err(LearnError::InvalidEpisode(
            "DPO requires pairs, use convert_pair instead".into(),
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::learn::episode::EpisodeBuilder;
    use crate::learn::record::ActionRecord;
    use crate::types::TaskId;

    fn create_test_episode(
        task_id: TaskId,
        group_id: GroupId,
        success: bool,
        score: f64,
    ) -> Episode {
        let outcome = if success {
            Outcome::success(score)
        } else {
            Outcome::failure("test failure")
        };

        EpisodeBuilder::default()
            .learn_model("test")
            .task_id(task_id)
            .group_id(group_id)
            .record(ActionRecord::new(1, 0, "TestAction").success(success))
            .outcome(outcome)
            .build()
    }

    fn test_extractor(ep: &Episode) -> Option<(String, String)> {
        // テスト用: 固定の prompt/response を返す
        Some((
            "test prompt".to_string(),
            format!("response for {:?}", ep.id),
        ))
    }

    #[test]
    fn test_build_pairs_basic() {
        let group_id = GroupId::new();
        let task1 = TaskId::new();
        let task2 = TaskId::new();

        let episodes = vec![
            create_test_episode(task1, group_id, true, 0.9),
            create_test_episode(task2, group_id, false, 0.0),
        ];

        let dpo = DpoLearnModel::new(test_extractor);
        let pairs = dpo.build_pairs(&episodes);

        assert_eq!(pairs.len(), 1);
        assert!(pairs[0].quality_gap > 0.0);
    }

    #[test]
    fn test_build_pairs_different_groups() {
        let group1 = GroupId::new();
        let group2 = GroupId::new();

        let episodes = vec![
            create_test_episode(TaskId::new(), group1, true, 0.9),
            create_test_episode(TaskId::new(), group2, false, 0.0),
        ];

        let dpo = DpoLearnModel::new(test_extractor);
        let pairs = dpo.build_pairs(&episodes);

        // 異なる group_id なのでペアにならない
        assert!(pairs.is_empty());
    }

    #[test]
    fn test_min_quality_gap() {
        let group_id = GroupId::new();

        let episodes = vec![
            create_test_episode(TaskId::new(), group_id, true, 0.6),
            create_test_episode(TaskId::new(), group_id, false, 0.0),
        ];

        // 0.5 以上の差を要求
        let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.5);
        let pairs = dpo.build_pairs(&episodes);

        // 0.6 - 0.0 = 0.6 なのでペアになる
        assert_eq!(pairs.len(), 1);

        // 0.7 以上の差を要求
        let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.7);
        let pairs = dpo.build_pairs(&episodes);

        // 差が足りないのでペアにならない
        assert!(pairs.is_empty());
    }
}