Skip to main content

opensession_core/
scoring.rs

1use crate::{extract::extract_file_metadata, EventType, Session};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5pub const DEFAULT_SCORE_PLUGIN: &str = "heuristic_v1";
6
7/// A scoring plugin maps one session to one numeric score.
8pub trait SessionScorePlugin: Send + Sync {
9    fn id(&self) -> &'static str;
10    fn score(&self, session: &Session) -> i64;
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct SessionScore {
15    pub plugin: String,
16    pub score: i64,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum SessionScoreError {
21    UnknownPlugin {
22        requested: String,
23        available: Vec<String>,
24    },
25}
26
27impl std::fmt::Display for SessionScoreError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::UnknownPlugin {
31                requested,
32                available,
33            } => {
34                write!(
35                    f,
36                    "unknown session score plugin '{requested}'. available: {}",
37                    available.join(", ")
38                )
39            }
40        }
41    }
42}
43
44impl std::error::Error for SessionScoreError {}
45
46/// Runtime registry for session score plugins.
47pub struct SessionScoreRegistry {
48    default_plugin: String,
49    plugins: HashMap<String, Arc<dyn SessionScorePlugin>>,
50}
51
52impl Default for SessionScoreRegistry {
53    fn default() -> Self {
54        let mut registry = Self::new(DEFAULT_SCORE_PLUGIN);
55        registry.register(HeuristicV1ScorePlugin);
56        registry.register(ZeroV1ScorePlugin);
57        registry
58    }
59}
60
61impl SessionScoreRegistry {
62    pub fn new(default_plugin: &str) -> Self {
63        Self {
64            default_plugin: default_plugin.to_string(),
65            plugins: HashMap::new(),
66        }
67    }
68
69    pub fn register<P>(&mut self, plugin: P)
70    where
71        P: SessionScorePlugin + 'static,
72    {
73        self.plugins
74            .insert(plugin.id().to_string(), Arc::new(plugin));
75    }
76
77    pub fn available_plugins(&self) -> Vec<String> {
78        let mut names: Vec<String> = self.plugins.keys().cloned().collect();
79        names.sort();
80        names
81    }
82
83    pub fn score_default(&self, session: &Session) -> Result<SessionScore, SessionScoreError> {
84        self.score_with(self.default_plugin.as_str(), session)
85    }
86
87    pub fn score_with(
88        &self,
89        plugin_id: &str,
90        session: &Session,
91    ) -> Result<SessionScore, SessionScoreError> {
92        let plugin =
93            self.plugins
94                .get(plugin_id)
95                .ok_or_else(|| SessionScoreError::UnknownPlugin {
96                    requested: plugin_id.to_string(),
97                    available: self.available_plugins(),
98                })?;
99        Ok(SessionScore {
100            plugin: plugin_id.to_string(),
101            score: plugin.score(session),
102        })
103    }
104}
105
106/// Default heuristic scorer.
107///
108/// Formula:
109/// - Base 100
110/// - `has_errors` => -15
111/// - shell failures (`exit_code != 0`) => -5 each (cap -30)
112/// - tool errors (`ToolResult.is_error=true`) => -4 each (cap -20)
113/// - recovery (same task lane: failure -> success) => +5 each (cap +20)
114/// - clamp to 0..100
115pub struct HeuristicV1ScorePlugin;
116
117impl SessionScorePlugin for HeuristicV1ScorePlugin {
118    fn id(&self) -> &'static str {
119        "heuristic_v1"
120    }
121
122    fn score(&self, session: &Session) -> i64 {
123        let (_, _, has_errors) = extract_file_metadata(session);
124        let shell_failures = count_shell_failures(session) as i64;
125        let tool_errors = count_tool_errors(session) as i64;
126        let recoveries = count_recoveries(session) as i64;
127
128        let mut score = 100i64;
129        if has_errors {
130            score -= 15;
131        }
132        score -= (shell_failures * 5).min(30);
133        score -= (tool_errors * 4).min(20);
134        score += (recoveries * 5).min(20);
135        score.clamp(0, 100)
136    }
137}
138
139/// A deterministic scorer useful for testing and compatibility checks.
140pub struct ZeroV1ScorePlugin;
141
142impl SessionScorePlugin for ZeroV1ScorePlugin {
143    fn id(&self) -> &'static str {
144        "zero_v1"
145    }
146
147    fn score(&self, _session: &Session) -> i64 {
148        0
149    }
150}
151
152fn count_shell_failures(session: &Session) -> usize {
153    session
154        .events
155        .iter()
156        .filter(|event| {
157            matches!(
158                &event.event_type,
159                EventType::ShellCommand {
160                    exit_code: Some(code),
161                    ..
162                } if *code != 0
163            )
164        })
165        .count()
166}
167
168fn count_tool_errors(session: &Session) -> usize {
169    session
170        .events
171        .iter()
172        .filter(|event| {
173            matches!(
174                &event.event_type,
175                EventType::ToolResult { is_error: true, .. }
176            )
177        })
178        .count()
179}
180
181fn event_task_key(task_id: &Option<String>) -> String {
182    task_id
183        .as_deref()
184        .map(str::trim)
185        .filter(|value| !value.is_empty())
186        .unwrap_or("__global__")
187        .to_string()
188}
189
190fn count_recoveries(session: &Session) -> usize {
191    let mut pending_failures: HashMap<String, usize> = HashMap::new();
192    let mut recoveries = 0usize;
193
194    for event in &session.events {
195        let key = event_task_key(&event.task_id);
196        match &event.event_type {
197            EventType::ShellCommand {
198                exit_code: Some(code),
199                ..
200            } if *code != 0 => {
201                *pending_failures.entry(key).or_default() += 1;
202            }
203            EventType::ToolResult { is_error: true, .. } => {
204                *pending_failures.entry(key).or_default() += 1;
205            }
206            EventType::ShellCommand {
207                exit_code: Some(0), ..
208            }
209            | EventType::ToolResult {
210                is_error: false, ..
211            } => {
212                let mut remove = false;
213                if let Some(pending) = pending_failures.get_mut(&key) {
214                    if *pending > 0 {
215                        *pending -= 1;
216                        recoveries += 1;
217                    }
218                    if *pending == 0 {
219                        remove = true;
220                    }
221                }
222                if remove {
223                    pending_failures.remove(&key);
224                }
225            }
226            _ => {}
227        }
228    }
229
230    recoveries
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::{testing, Session};
237
238    fn build_session(events: Vec<crate::Event>) -> Session {
239        let mut session = Session::new("score-test".to_string(), testing::agent());
240        session.events = events;
241        session.recompute_stats();
242        session
243    }
244
245    #[test]
246    fn registry_contains_builtin_plugins() {
247        let registry = SessionScoreRegistry::default();
248        let names = registry.available_plugins();
249        assert!(names.contains(&"heuristic_v1".to_string()));
250        assert!(names.contains(&"zero_v1".to_string()));
251    }
252
253    #[test]
254    fn heuristic_v1_penalizes_failures_and_rewards_recovery() {
255        let mut fail = testing::event(
256            EventType::ShellCommand {
257                command: "cargo test".to_string(),
258                exit_code: Some(101),
259            },
260            "",
261        );
262        fail.task_id = Some("t1".to_string());
263
264        let mut success = testing::event(
265            EventType::ShellCommand {
266                command: "cargo test".to_string(),
267                exit_code: Some(0),
268            },
269            "",
270        );
271        success.task_id = Some("t1".to_string());
272
273        let session = build_session(vec![fail, success]);
274        let registry = SessionScoreRegistry::default();
275        let result = registry
276            .score_with("heuristic_v1", &session)
277            .expect("heuristic scorer must exist");
278
279        // 100 -15(has_errors) -5(shell fail) +5(recovery)
280        assert_eq!(result.score, 85);
281    }
282
283    #[test]
284    fn zero_plugin_returns_zero() {
285        let session = build_session(vec![testing::event(EventType::UserMessage, "hello")]);
286        let registry = SessionScoreRegistry::default();
287        let result = registry
288            .score_with("zero_v1", &session)
289            .expect("zero scorer must exist");
290        assert_eq!(result.score, 0);
291    }
292
293    #[test]
294    fn unknown_plugin_reports_available_names() {
295        let session = build_session(vec![]);
296        let registry = SessionScoreRegistry::default();
297        let err = registry
298            .score_with("missing_plugin", &session)
299            .expect_err("must fail for unknown plugin");
300
301        match err {
302            SessionScoreError::UnknownPlugin {
303                requested,
304                available,
305            } => {
306                assert_eq!(requested, "missing_plugin");
307                assert!(available.contains(&"heuristic_v1".to_string()));
308            }
309        }
310    }
311}