opensession_core/
scoring.rs1use crate::{extract::extract_file_metadata, EventType, Session};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5pub const DEFAULT_SCORE_PLUGIN: &str = "heuristic_v1";
6
7pub trait SessionScorePlugin: Send + Sync {
9 fn id(&self) -> &'static str;
10 fn score(&self, session: &Session) -> i64;
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct SessionScore {
15 pub plugin: String,
16 pub score: i64,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum SessionScoreError {
21 UnknownPlugin {
22 requested: String,
23 available: Vec<String>,
24 },
25}
26
27impl std::fmt::Display for SessionScoreError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 Self::UnknownPlugin {
31 requested,
32 available,
33 } => {
34 write!(
35 f,
36 "unknown session score plugin '{requested}'. available: {}",
37 available.join(", ")
38 )
39 }
40 }
41 }
42}
43
44impl std::error::Error for SessionScoreError {}
45
46pub struct SessionScoreRegistry {
48 default_plugin: String,
49 plugins: HashMap<String, Arc<dyn SessionScorePlugin>>,
50}
51
52impl Default for SessionScoreRegistry {
53 fn default() -> Self {
54 let mut registry = Self::new(DEFAULT_SCORE_PLUGIN);
55 registry.register(HeuristicV1ScorePlugin);
56 registry.register(ZeroV1ScorePlugin);
57 registry
58 }
59}
60
61impl SessionScoreRegistry {
62 pub fn new(default_plugin: &str) -> Self {
63 Self {
64 default_plugin: default_plugin.to_string(),
65 plugins: HashMap::new(),
66 }
67 }
68
69 pub fn register<P>(&mut self, plugin: P)
70 where
71 P: SessionScorePlugin + 'static,
72 {
73 self.plugins
74 .insert(plugin.id().to_string(), Arc::new(plugin));
75 }
76
77 pub fn available_plugins(&self) -> Vec<String> {
78 let mut names: Vec<String> = self.plugins.keys().cloned().collect();
79 names.sort();
80 names
81 }
82
83 pub fn score_default(&self, session: &Session) -> Result<SessionScore, SessionScoreError> {
84 self.score_with(self.default_plugin.as_str(), session)
85 }
86
87 pub fn score_with(
88 &self,
89 plugin_id: &str,
90 session: &Session,
91 ) -> Result<SessionScore, SessionScoreError> {
92 let plugin =
93 self.plugins
94 .get(plugin_id)
95 .ok_or_else(|| SessionScoreError::UnknownPlugin {
96 requested: plugin_id.to_string(),
97 available: self.available_plugins(),
98 })?;
99 Ok(SessionScore {
100 plugin: plugin_id.to_string(),
101 score: plugin.score(session),
102 })
103 }
104}
105
106pub struct HeuristicV1ScorePlugin;
116
117impl SessionScorePlugin for HeuristicV1ScorePlugin {
118 fn id(&self) -> &'static str {
119 "heuristic_v1"
120 }
121
122 fn score(&self, session: &Session) -> i64 {
123 let (_, _, has_errors) = extract_file_metadata(session);
124 let shell_failures = count_shell_failures(session) as i64;
125 let tool_errors = count_tool_errors(session) as i64;
126 let recoveries = count_recoveries(session) as i64;
127
128 let mut score = 100i64;
129 if has_errors {
130 score -= 15;
131 }
132 score -= (shell_failures * 5).min(30);
133 score -= (tool_errors * 4).min(20);
134 score += (recoveries * 5).min(20);
135 score.clamp(0, 100)
136 }
137}
138
139pub struct ZeroV1ScorePlugin;
141
142impl SessionScorePlugin for ZeroV1ScorePlugin {
143 fn id(&self) -> &'static str {
144 "zero_v1"
145 }
146
147 fn score(&self, _session: &Session) -> i64 {
148 0
149 }
150}
151
152fn count_shell_failures(session: &Session) -> usize {
153 session
154 .events
155 .iter()
156 .filter(|event| {
157 matches!(
158 &event.event_type,
159 EventType::ShellCommand {
160 exit_code: Some(code),
161 ..
162 } if *code != 0
163 )
164 })
165 .count()
166}
167
168fn count_tool_errors(session: &Session) -> usize {
169 session
170 .events
171 .iter()
172 .filter(|event| {
173 matches!(
174 &event.event_type,
175 EventType::ToolResult { is_error: true, .. }
176 )
177 })
178 .count()
179}
180
181fn event_task_key(task_id: &Option<String>) -> String {
182 task_id
183 .as_deref()
184 .map(str::trim)
185 .filter(|value| !value.is_empty())
186 .unwrap_or("__global__")
187 .to_string()
188}
189
190fn count_recoveries(session: &Session) -> usize {
191 let mut pending_failures: HashMap<String, usize> = HashMap::new();
192 let mut recoveries = 0usize;
193
194 for event in &session.events {
195 let key = event_task_key(&event.task_id);
196 match &event.event_type {
197 EventType::ShellCommand {
198 exit_code: Some(code),
199 ..
200 } if *code != 0 => {
201 *pending_failures.entry(key).or_default() += 1;
202 }
203 EventType::ToolResult { is_error: true, .. } => {
204 *pending_failures.entry(key).or_default() += 1;
205 }
206 EventType::ShellCommand {
207 exit_code: Some(0), ..
208 }
209 | EventType::ToolResult {
210 is_error: false, ..
211 } => {
212 let mut remove = false;
213 if let Some(pending) = pending_failures.get_mut(&key) {
214 if *pending > 0 {
215 *pending -= 1;
216 recoveries += 1;
217 }
218 if *pending == 0 {
219 remove = true;
220 }
221 }
222 if remove {
223 pending_failures.remove(&key);
224 }
225 }
226 _ => {}
227 }
228 }
229
230 recoveries
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::{testing, Session};
237
238 fn build_session(events: Vec<crate::Event>) -> Session {
239 let mut session = Session::new("score-test".to_string(), testing::agent());
240 session.events = events;
241 session.recompute_stats();
242 session
243 }
244
245 #[test]
246 fn registry_contains_builtin_plugins() {
247 let registry = SessionScoreRegistry::default();
248 let names = registry.available_plugins();
249 assert!(names.contains(&"heuristic_v1".to_string()));
250 assert!(names.contains(&"zero_v1".to_string()));
251 }
252
253 #[test]
254 fn heuristic_v1_penalizes_failures_and_rewards_recovery() {
255 let mut fail = testing::event(
256 EventType::ShellCommand {
257 command: "cargo test".to_string(),
258 exit_code: Some(101),
259 },
260 "",
261 );
262 fail.task_id = Some("t1".to_string());
263
264 let mut success = testing::event(
265 EventType::ShellCommand {
266 command: "cargo test".to_string(),
267 exit_code: Some(0),
268 },
269 "",
270 );
271 success.task_id = Some("t1".to_string());
272
273 let session = build_session(vec![fail, success]);
274 let registry = SessionScoreRegistry::default();
275 let result = registry
276 .score_with("heuristic_v1", &session)
277 .expect("heuristic scorer must exist");
278
279 assert_eq!(result.score, 85);
281 }
282
283 #[test]
284 fn zero_plugin_returns_zero() {
285 let session = build_session(vec![testing::event(EventType::UserMessage, "hello")]);
286 let registry = SessionScoreRegistry::default();
287 let result = registry
288 .score_with("zero_v1", &session)
289 .expect("zero scorer must exist");
290 assert_eq!(result.score, 0);
291 }
292
293 #[test]
294 fn unknown_plugin_reports_available_names() {
295 let session = build_session(vec![]);
296 let registry = SessionScoreRegistry::default();
297 let err = registry
298 .score_with("missing_plugin", &session)
299 .expect_err("must fail for unknown plugin");
300
301 match err {
302 SessionScoreError::UnknownPlugin {
303 requested,
304 available,
305 } => {
306 assert_eq!(requested, "missing_plugin");
307 assert!(available.contains(&"heuristic_v1".to_string()));
308 }
309 }
310 }
311}