1use crate::completion::{StopReason, Usage};
10use crate::event::Event;
11use crate::ids::SpanId;
12use crate::message::Message;
13use crate::outcome::RunOutcome;
14use crate::task::Task;
15use crate::trajectory::TrajectoryHandle;
16use crate::MetadataMap;
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use time::OffsetDateTime;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AssistantTurn {
33 pub turn_index: u32,
34 pub span_id: SpanId,
35 pub message: Message,
37 pub tool_calls: Vec<ToolCall>,
38 pub usage: Usage,
39 pub stop_reason: StopReason,
40}
41
42impl AssistantTurn {
43 pub fn new(
47 turn_index: u32,
48 span_id: impl Into<SpanId>,
49 message: Message,
50 usage: Usage,
51 stop_reason: StopReason,
52 ) -> Self {
53 let tool_calls = tool_calls_from_message(&message);
54 Self {
55 turn_index,
56 span_id: span_id.into(),
57 message,
58 tool_calls,
59 usage,
60 stop_reason,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolCall {
67 pub id: String,
68 pub name: String,
69 pub input: Value,
70}
71
72fn tool_calls_from_message(message: &Message) -> Vec<ToolCall> {
73 use crate::message::Content;
74 let Message::Assistant { content, .. } = message else {
75 return Vec::new();
76 };
77 content
78 .iter()
79 .filter_map(|c| match c {
80 Content::ToolUse { id, name, input } => Some(ToolCall {
81 id: id.clone(),
82 name: name.clone(),
83 input: input.clone(),
84 }),
85 _ => None,
86 })
87 .collect()
88}
89
90pub struct TrajectoryView<'a> {
102 events: &'a [Event],
103}
104
105impl<'a> TrajectoryView<'a> {
106 pub fn new(events: &'a [Event]) -> Self {
107 Self { events }
108 }
109
110 pub fn events(&self) -> &[Event] {
111 self.events
112 }
113
114 pub fn turn_count(&self) -> u32 {
116 use crate::event::EventKind;
117 self.events
118 .iter()
119 .filter(|e| matches!(e.kind, EventKind::TurnFinished(_)))
120 .count() as u32
121 }
122
123 pub fn to_handle(&self) -> TrajectoryHandle {
128 TrajectoryHandle::in_memory(self.events.to_vec())
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct EvaluationResult {
141 pub score: f64,
144 pub passed: bool,
147 #[serde(default, skip_serializing_if = "MetadataMap::is_empty")]
148 pub details: MetadataMap,
149}
150
151#[async_trait]
158pub trait TaskEvaluator: Send + Sync {
159 async fn evaluate(&self, task: &Task, outcome: &RunOutcome) -> EvaluationResult;
160}
161
162impl EvaluationResult {
163 pub fn pass() -> Self {
164 Self {
165 score: 1.0,
166 passed: true,
167 details: MetadataMap::new(),
168 }
169 }
170
171 pub fn fail() -> Self {
172 Self {
173 score: 0.0,
174 passed: false,
175 details: MetadataMap::new(),
176 }
177 }
178
179 pub fn scored(score: f64) -> Self {
182 Self {
183 score,
184 passed: score >= 0.5,
185 details: MetadataMap::new(),
186 }
187 }
188
189 pub fn with_details(mut self, details: MetadataMap) -> Self {
190 self.details = details;
191 self
192 }
193}
194
195pub struct Episode<'a> {
204 pub index: u32,
205 pub task: &'a Task,
206 pub outcome: &'a RunOutcome,
207 pub evaluation: &'a EvaluationResult,
208 pub prior_reflections: &'a [Reflection],
209}
210
211impl<'a> Episode<'a> {
212 pub fn to_owned(&self) -> OwnedEpisode {
213 OwnedEpisode {
214 index: self.index,
215 task: self.task.clone(),
216 outcome: self.outcome.clone(),
217 evaluation: self.evaluation.clone(),
218 prior_reflections: self.prior_reflections.to_vec(),
219 }
220 }
221}
222
223#[derive(Debug, Clone)]
226pub struct OwnedEpisode {
227 pub index: u32,
228 pub task: Task,
229 pub outcome: RunOutcome,
230 pub evaluation: EvaluationResult,
231 pub prior_reflections: Vec<Reflection>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct Reflection {
236 pub text: String,
237 #[serde(default, skip_serializing_if = "MetadataMap::is_empty")]
238 pub metadata: MetadataMap,
239 #[serde(with = "time::serde::rfc3339")]
240 pub created_at: OffsetDateTime,
241}
242
243impl Reflection {
244 pub fn new(text: impl Into<String>) -> Self {
245 Self {
246 text: text.into(),
247 metadata: MetadataMap::new(),
248 created_at: OffsetDateTime::now_utc(),
249 }
250 }
251
252 pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
253 self.metadata = metadata;
254 self
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::message::Content;
262 use crate::MetadataMap;
263
264 #[test]
265 fn assistant_turn_extracts_tool_calls_from_message() {
266 let msg = Message::Assistant {
267 content: vec![
268 Content::text("Let me check."),
269 Content::ToolUse {
270 id: "tu_1".into(),
271 name: "fs_list".into(),
272 input: serde_json::json!({"path": "."}),
273 },
274 Content::ToolUse {
275 id: "tu_2".into(),
276 name: "bash".into(),
277 input: serde_json::json!({"cmd": "ls"}),
278 },
279 ],
280 stop_reason: Some(StopReason::ToolUse),
281 meta: MetadataMap::new(),
282 };
283 let turn = AssistantTurn::new(0, "span-0", msg, Usage::default(), StopReason::ToolUse);
284 assert_eq!(turn.tool_calls.len(), 2);
285 assert_eq!(turn.tool_calls[0].name, "fs_list");
286 assert_eq!(turn.tool_calls[1].id, "tu_2");
287 }
288
289 #[test]
290 fn trajectory_view_turn_count_matches_finished_events() {
291 use crate::event::{EventKind, TurnFinishedPayload, TurnPayload};
292 use crate::ids::RunId;
293 let run = RunId::new();
294 let events = vec![
295 Event::new(
296 0,
297 run,
298 "turn-0",
299 EventKind::TurnStarted(TurnPayload { turn_index: 0 }),
300 ),
301 Event::new(
302 1,
303 run,
304 "turn-0",
305 EventKind::TurnFinished(TurnFinishedPayload {
306 turn_index: 0,
307 stop_reason: StopReason::EndTurn,
308 usage: Usage::default(),
309 tool_calls: 0,
310 }),
311 ),
312 Event::new(
313 2,
314 run,
315 "turn-1",
316 EventKind::TurnStarted(TurnPayload { turn_index: 1 }),
317 ),
318 Event::new(
319 3,
320 run,
321 "turn-1",
322 EventKind::TurnFinished(TurnFinishedPayload {
323 turn_index: 1,
324 stop_reason: StopReason::EndTurn,
325 usage: Usage::default(),
326 tool_calls: 0,
327 }),
328 ),
329 ];
330 let view = TrajectoryView::new(&events);
331 assert_eq!(view.turn_count(), 2);
332 assert_eq!(view.events().len(), 4);
333 }
334
335 #[test]
336 fn evaluation_result_constructors() {
337 assert!(EvaluationResult::pass().passed);
338 assert_eq!(EvaluationResult::pass().score, 1.0);
339 assert!(!EvaluationResult::fail().passed);
340 assert!(EvaluationResult::scored(0.7).passed);
341 assert!(!EvaluationResult::scored(0.4).passed);
342 }
343
344 #[test]
345 fn reflection_round_trips_through_serde() {
346 let r = Reflection::new("next time, check the imports first");
347 let bytes = serde_json::to_vec(&r).unwrap();
348 let back: Reflection = serde_json::from_slice(&bytes).unwrap();
349 assert_eq!(back.text, r.text);
350 }
351}