1use std::sync::{Arc, Mutex};
8
9use arcan_core::error::CoreError;
10use arcan_core::protocol::{ModelTurn, ToolResult};
11use arcan_core::runtime::{Middleware, ProviderRequest, RunOutput, ToolContext};
12use nous_core::{EvalContext, EvalHook, EvalScore, EvaluatorRegistry};
13use tracing::{debug, warn};
14
15#[derive(Debug, Default)]
17struct EvalAccumulator {
18 scores: Vec<EvalScore>,
19 tool_call_count: u32,
20 tool_error_count: u32,
21}
22
23type ScoreCallback = Arc<dyn Fn(&EvalScore) + Send + Sync>;
25
26pub struct NousMiddleware {
32 registry: EvaluatorRegistry,
33 accumulator: Mutex<EvalAccumulator>,
34 on_score: Option<ScoreCallback>,
36}
37
38impl NousMiddleware {
39 pub fn new(registry: EvaluatorRegistry) -> Self {
41 Self {
42 registry,
43 accumulator: Mutex::new(EvalAccumulator::default()),
44 on_score: None,
45 }
46 }
47
48 pub fn with_on_score(registry: EvaluatorRegistry, on_score: ScoreCallback) -> Self {
50 Self {
51 registry,
52 accumulator: Mutex::new(EvalAccumulator::default()),
53 on_score: Some(on_score),
54 }
55 }
56
57 pub fn with_defaults() -> Result<Self, nous_core::NousError> {
59 let registry = nous_heuristics::default_registry()?;
60 Ok(Self::new(registry))
61 }
62
63 pub fn registry_len(&self) -> usize {
65 self.registry.len()
66 }
67
68 pub fn scores(&self) -> Vec<EvalScore> {
70 self.accumulator
71 .lock()
72 .expect("accumulator lock poisoned")
73 .scores
74 .clone()
75 }
76
77 fn run_evaluators(&self, hook: EvalHook, ctx: &EvalContext) {
79 for evaluator in self.registry.evaluators_for(hook) {
80 match evaluator.evaluate(ctx) {
81 Ok(scores) => {
82 for score in &scores {
83 debug!(
84 evaluator = score.evaluator,
85 value = score.value,
86 label = score.label.as_str(),
87 layer = %score.layer,
88 hook = hook.as_str(),
89 "nous eval score"
90 );
91 if let Some(ref cb) = self.on_score {
92 cb(score);
93 }
94 }
95 if let Ok(mut acc) = self.accumulator.lock() {
96 acc.scores.extend(scores);
97 }
98 }
99 Err(e) => {
100 warn!(
101 evaluator = evaluator.name(),
102 error = %e,
103 hook = hook.as_str(),
104 "nous evaluator failed"
105 );
106 }
107 }
108 }
109 }
110
111 fn ctx_from_request(&self, request: &ProviderRequest) -> EvalContext {
113 let mut ctx = EvalContext::new(&request.session_id);
114 ctx.run_id = Some(request.run_id.clone());
115 ctx.iteration = Some(request.iteration);
116 ctx
117 }
118
119 fn ctx_from_response(&self, request: &ProviderRequest, response: &ModelTurn) -> EvalContext {
121 let mut ctx = self.ctx_from_request(request);
122 if let Some(ref usage) = response.usage {
123 ctx.input_tokens = Some(usage.input_tokens);
124 ctx.output_tokens = Some(usage.output_tokens);
125 }
126 ctx
127 }
128}
129
130impl Middleware for NousMiddleware {
131 fn before_model_call(&self, request: &ProviderRequest) -> Result<(), CoreError> {
132 let ctx = self.ctx_from_request(request);
133 self.run_evaluators(EvalHook::BeforeModelCall, &ctx);
134 Ok(())
135 }
136
137 fn after_model_call(
138 &self,
139 request: &ProviderRequest,
140 response: &ModelTurn,
141 ) -> Result<(), CoreError> {
142 let ctx = self.ctx_from_response(request, response);
143 self.run_evaluators(EvalHook::AfterModelCall, &ctx);
144 Ok(())
145 }
146
147 fn pre_tool_call(
148 &self,
149 context: &ToolContext,
150 call: &arcan_core::protocol::ToolCall,
151 ) -> Result<(), CoreError> {
152 let mut ctx = EvalContext::new(&context.session_id);
153 ctx.run_id = Some(context.run_id.clone());
154 ctx.iteration = Some(context.iteration);
155 ctx.tool_name = Some(call.tool_name.clone());
156 self.run_evaluators(EvalHook::PreToolCall, &ctx);
157 Ok(())
158 }
159
160 fn post_tool_call(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
161 if let Ok(mut acc) = self.accumulator.lock() {
163 acc.tool_call_count += 1;
164 if result.is_error {
165 acc.tool_error_count += 1;
166 }
167 }
168
169 let mut ctx = EvalContext::new(&context.session_id);
170 ctx.run_id = Some(context.run_id.clone());
171 ctx.iteration = Some(context.iteration);
172 ctx.tool_name = Some(result.tool_name.clone());
173 ctx.tool_errored = Some(result.is_error);
174
175 self.run_evaluators(EvalHook::PostToolCall, &ctx);
176 Ok(())
177 }
178
179 fn on_run_finished(&self, output: &RunOutput) -> Result<(), CoreError> {
180 let acc = self.accumulator.lock().expect("accumulator lock poisoned");
181
182 let mut ctx = EvalContext::new(&output.session_id);
183 ctx.run_id = Some(output.run_id.clone());
184 ctx.tool_call_count = Some(acc.tool_call_count);
185 ctx.tool_error_count = Some(acc.tool_error_count);
186 ctx.input_tokens = Some(output.total_usage.input_tokens);
187 ctx.output_tokens = Some(output.total_usage.output_tokens);
188
189 if let Some(arcan_core::protocol::AgentEvent::RunStarted { max_iterations, .. }) =
191 output.events.first()
192 {
193 ctx.max_iterations = Some(*max_iterations);
194 }
195
196 let iteration_count = output
198 .events
199 .iter()
200 .filter(|e| matches!(e, arcan_core::protocol::AgentEvent::IterationStarted { .. }))
201 .count() as u32;
202 ctx.iteration = Some(iteration_count);
203
204 drop(acc);
206
207 self.run_evaluators(EvalHook::OnRunFinished, &ctx);
208 Ok(())
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use arcan_core::protocol::{AgentEvent, RunStopReason, TokenUsage};
216 use arcan_core::state::AppState;
217
218 #[test]
219 fn middleware_with_defaults_creates() {
220 let mw = NousMiddleware::with_defaults().unwrap();
221 assert!(!mw.registry.is_empty());
222 }
223
224 #[test]
225 fn middleware_accumulates_scores_on_after_model() {
226 let mw = NousMiddleware::with_defaults().unwrap();
227
228 let request = ProviderRequest {
229 run_id: "run-1".into(),
230 session_id: "sess-1".into(),
231 iteration: 1,
232 messages: vec![],
233 tools: vec![],
234 state: AppState::default(),
235 };
236 let response = ModelTurn {
237 directives: vec![],
238 stop_reason: arcan_core::protocol::ModelStopReason::EndTurn,
239 usage: Some(TokenUsage {
240 input_tokens: 1000,
241 output_tokens: 200,
242 cache_read_tokens: 0,
243 cache_creation_tokens: 0,
244 }),
245 };
246
247 let result = mw.after_model_call(&request, &response);
248 assert!(result.is_ok());
249
250 let scores = mw.scores();
251 assert!(
254 !scores.is_empty(),
255 "should have at least one score from token_efficiency"
256 );
257 }
258
259 #[test]
260 fn middleware_tracks_tool_calls() {
261 let mw = NousMiddleware::with_defaults().unwrap();
262
263 let context = ToolContext {
264 run_id: "run-1".into(),
265 session_id: "sess-1".into(),
266 iteration: 1,
267 };
268 let result = ToolResult {
269 call_id: "c1".into(),
270 tool_name: "read_file".into(),
271 output: serde_json::json!({"content": "hello"}),
272 content: None,
273 is_error: false,
274 state_patch: None,
275 };
276
277 mw.post_tool_call(&context, &result).unwrap();
278
279 let acc = mw.accumulator.lock().unwrap();
280 assert_eq!(acc.tool_call_count, 1);
281 assert_eq!(acc.tool_error_count, 0);
282 }
283
284 #[test]
285 fn middleware_tracks_tool_errors() {
286 let mw = NousMiddleware::with_defaults().unwrap();
287
288 let context = ToolContext {
289 run_id: "run-1".into(),
290 session_id: "sess-1".into(),
291 iteration: 1,
292 };
293 let result = ToolResult {
294 call_id: "c1".into(),
295 tool_name: "write_file".into(),
296 output: serde_json::json!({"error": "permission denied"}),
297 content: None,
298 is_error: true,
299 state_patch: None,
300 };
301
302 mw.post_tool_call(&context, &result).unwrap();
303
304 let acc = mw.accumulator.lock().unwrap();
305 assert_eq!(acc.tool_call_count, 1);
306 assert_eq!(acc.tool_error_count, 1);
307 }
308
309 #[test]
310 fn middleware_on_run_finished_fires_evaluators() {
311 let mw = NousMiddleware::with_defaults().unwrap();
312
313 {
315 let mut acc = mw.accumulator.lock().unwrap();
316 acc.tool_call_count = 5;
317 acc.tool_error_count = 1;
318 }
319
320 let output = RunOutput {
321 run_id: "run-1".into(),
322 session_id: "sess-1".into(),
323 branch_id: "main".into(),
324 events: vec![
325 AgentEvent::RunStarted {
326 run_id: "run-1".into(),
327 session_id: "sess-1".into(),
328 provider: "mock".into(),
329 max_iterations: 24,
330 },
331 AgentEvent::IterationStarted {
332 run_id: "run-1".into(),
333 session_id: "sess-1".into(),
334 iteration: 1,
335 },
336 AgentEvent::IterationStarted {
337 run_id: "run-1".into(),
338 session_id: "sess-1".into(),
339 iteration: 2,
340 },
341 AgentEvent::RunFinished {
342 run_id: "run-1".into(),
343 session_id: "sess-1".into(),
344 reason: RunStopReason::Completed,
345 total_iterations: 2,
346 final_answer: Some("done".into()),
347 usage: Some(TokenUsage {
348 input_tokens: 500,
349 output_tokens: 200,
350 cache_read_tokens: 0,
351 cache_creation_tokens: 0,
352 }),
353 },
354 ],
355 messages: vec![],
356 state: AppState::default(),
357 reason: RunStopReason::Completed,
358 final_answer: Some("done".into()),
359 total_usage: TokenUsage {
360 input_tokens: 500,
361 output_tokens: 200,
362 cache_read_tokens: 0,
363 cache_creation_tokens: 0,
364 },
365 };
366
367 let result = mw.on_run_finished(&output);
368 assert!(result.is_ok());
369
370 let scores = mw.scores();
371 let run_finished_scores: Vec<_> = scores
373 .iter()
374 .filter(|s| s.evaluator == "tool_correctness" || s.evaluator == "step_efficiency")
375 .collect();
376 assert!(
377 run_finished_scores.len() >= 2,
378 "expected tool_correctness and step_efficiency scores, got {:?}",
379 run_finished_scores
380 .iter()
381 .map(|s| &s.evaluator)
382 .collect::<Vec<_>>()
383 );
384 }
385
386 #[test]
387 fn on_score_callback_fires() {
388 let score_count = Arc::new(Mutex::new(0u32));
389 let counter = score_count.clone();
390
391 let registry = nous_heuristics::default_registry().unwrap();
392 let mw = NousMiddleware::with_on_score(
393 registry,
394 Arc::new(move |_score| {
395 *counter.lock().unwrap() += 1;
396 }),
397 );
398
399 let request = ProviderRequest {
400 run_id: "run-1".into(),
401 session_id: "sess-1".into(),
402 iteration: 1,
403 messages: vec![],
404 tools: vec![],
405 state: AppState::default(),
406 };
407 let response = ModelTurn {
408 directives: vec![],
409 stop_reason: arcan_core::protocol::ModelStopReason::EndTurn,
410 usage: Some(TokenUsage {
411 input_tokens: 1000,
412 output_tokens: 200,
413 cache_read_tokens: 0,
414 cache_creation_tokens: 0,
415 }),
416 };
417
418 mw.after_model_call(&request, &response).unwrap();
419
420 let count = *score_count.lock().unwrap();
421 assert!(count > 0, "callback should have fired at least once");
422 }
423}