Skip to main content

nous_middleware/
middleware.rs

1//! `NousMiddleware` — the Arcan middleware that runs Nous evaluators.
2//!
3//! Implements `arcan_core::runtime::Middleware` to run inline evaluators
4//! at each hook point in the agent lifecycle. Scores are logged via tracing
5//! and accumulated for downstream consumers (Vigil, Autonomic, Lago).
6
7use std::sync::{Arc, Mutex};
8
9use arcan_core::error::CoreError;
10use arcan_core::protocol::{ModelTurn, ToolResult};
11use arcan_core::runtime::{Middleware, ProviderRequest, RunOutput, ToolContext};
12use nous_core::{EvalContext, EvalHook, EvalScore, EvaluatorRegistry};
13use tracing::{debug, warn};
14
15/// Accumulated eval state for a middleware instance.
16#[derive(Debug, Default)]
17struct EvalAccumulator {
18    scores: Vec<EvalScore>,
19    tool_call_count: u32,
20    tool_error_count: u32,
21}
22
23/// Callback type for score notifications.
24type ScoreCallback = Arc<dyn Fn(&EvalScore) + Send + Sync>;
25
26/// Arcan middleware that runs Nous evaluators at each hook point.
27///
28/// Created with a populated `EvaluatorRegistry` and runs the appropriate
29/// evaluators at each lifecycle hook. Scores are accumulated and can be
30/// retrieved after a run completes.
31pub struct NousMiddleware {
32    registry: EvaluatorRegistry,
33    accumulator: Mutex<EvalAccumulator>,
34    /// Callback invoked for each score produced (for Vigil/Lago integration).
35    on_score: Option<ScoreCallback>,
36}
37
38impl NousMiddleware {
39    /// Create a new middleware with the given evaluator registry.
40    pub fn new(registry: EvaluatorRegistry) -> Self {
41        Self {
42            registry,
43            accumulator: Mutex::new(EvalAccumulator::default()),
44            on_score: None,
45        }
46    }
47
48    /// Create a middleware with a score callback.
49    pub fn with_on_score(registry: EvaluatorRegistry, on_score: ScoreCallback) -> Self {
50        Self {
51            registry,
52            accumulator: Mutex::new(EvalAccumulator::default()),
53            on_score: Some(on_score),
54        }
55    }
56
57    /// Create a middleware with the default heuristic evaluators.
58    pub fn with_defaults() -> Result<Self, nous_core::NousError> {
59        let registry = nous_heuristics::default_registry()?;
60        Ok(Self::new(registry))
61    }
62
63    /// Number of registered evaluators.
64    pub fn registry_len(&self) -> usize {
65        self.registry.len()
66    }
67
68    /// Get all accumulated scores from evaluations in this middleware instance.
69    pub fn scores(&self) -> Vec<EvalScore> {
70        self.accumulator
71            .lock()
72            .expect("accumulator lock poisoned")
73            .scores
74            .clone()
75    }
76
77    /// Run evaluators for a given hook and context, accumulating scores.
78    fn run_evaluators(&self, hook: EvalHook, ctx: &EvalContext) {
79        for evaluator in self.registry.evaluators_for(hook) {
80            match evaluator.evaluate(ctx) {
81                Ok(scores) => {
82                    for score in &scores {
83                        debug!(
84                            evaluator = score.evaluator,
85                            value = score.value,
86                            label = score.label.as_str(),
87                            layer = %score.layer,
88                            hook = hook.as_str(),
89                            "nous eval score"
90                        );
91                        if let Some(ref cb) = self.on_score {
92                            cb(score);
93                        }
94                    }
95                    if let Ok(mut acc) = self.accumulator.lock() {
96                        acc.scores.extend(scores);
97                    }
98                }
99                Err(e) => {
100                    warn!(
101                        evaluator = evaluator.name(),
102                        error = %e,
103                        hook = hook.as_str(),
104                        "nous evaluator failed"
105                    );
106                }
107            }
108        }
109    }
110
111    /// Build an `EvalContext` from a `ProviderRequest`.
112    fn ctx_from_request(&self, request: &ProviderRequest) -> EvalContext {
113        let mut ctx = EvalContext::new(&request.session_id);
114        ctx.run_id = Some(request.run_id.clone());
115        ctx.iteration = Some(request.iteration);
116        ctx
117    }
118
119    /// Build an `EvalContext` from a `ProviderRequest` + `ModelTurn`.
120    fn ctx_from_response(&self, request: &ProviderRequest, response: &ModelTurn) -> EvalContext {
121        let mut ctx = self.ctx_from_request(request);
122        if let Some(ref usage) = response.usage {
123            ctx.input_tokens = Some(usage.input_tokens);
124            ctx.output_tokens = Some(usage.output_tokens);
125        }
126        ctx
127    }
128}
129
130impl Middleware for NousMiddleware {
131    fn before_model_call(&self, request: &ProviderRequest) -> Result<(), CoreError> {
132        let ctx = self.ctx_from_request(request);
133        self.run_evaluators(EvalHook::BeforeModelCall, &ctx);
134        Ok(())
135    }
136
137    fn after_model_call(
138        &self,
139        request: &ProviderRequest,
140        response: &ModelTurn,
141    ) -> Result<(), CoreError> {
142        let ctx = self.ctx_from_response(request, response);
143        self.run_evaluators(EvalHook::AfterModelCall, &ctx);
144        Ok(())
145    }
146
147    fn pre_tool_call(
148        &self,
149        context: &ToolContext,
150        call: &arcan_core::protocol::ToolCall,
151    ) -> Result<(), CoreError> {
152        let mut ctx = EvalContext::new(&context.session_id);
153        ctx.run_id = Some(context.run_id.clone());
154        ctx.iteration = Some(context.iteration);
155        ctx.tool_name = Some(call.tool_name.clone());
156        self.run_evaluators(EvalHook::PreToolCall, &ctx);
157        Ok(())
158    }
159
160    fn post_tool_call(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
161        // Update tool counters.
162        if let Ok(mut acc) = self.accumulator.lock() {
163            acc.tool_call_count += 1;
164            if result.is_error {
165                acc.tool_error_count += 1;
166            }
167        }
168
169        let mut ctx = EvalContext::new(&context.session_id);
170        ctx.run_id = Some(context.run_id.clone());
171        ctx.iteration = Some(context.iteration);
172        ctx.tool_name = Some(result.tool_name.clone());
173        ctx.tool_errored = Some(result.is_error);
174
175        self.run_evaluators(EvalHook::PostToolCall, &ctx);
176        Ok(())
177    }
178
179    fn on_run_finished(&self, output: &RunOutput) -> Result<(), CoreError> {
180        let acc = self.accumulator.lock().expect("accumulator lock poisoned");
181
182        let mut ctx = EvalContext::new(&output.session_id);
183        ctx.run_id = Some(output.run_id.clone());
184        ctx.tool_call_count = Some(acc.tool_call_count);
185        ctx.tool_error_count = Some(acc.tool_error_count);
186        ctx.input_tokens = Some(output.total_usage.input_tokens);
187        ctx.output_tokens = Some(output.total_usage.output_tokens);
188
189        // Find max_iterations from run events (look for RunStarted).
190        if let Some(arcan_core::protocol::AgentEvent::RunStarted { max_iterations, .. }) =
191            output.events.first()
192        {
193            ctx.max_iterations = Some(*max_iterations);
194        }
195
196        // Determine current iteration from events count.
197        let iteration_count = output
198            .events
199            .iter()
200            .filter(|e| matches!(e, arcan_core::protocol::AgentEvent::IterationStarted { .. }))
201            .count() as u32;
202        ctx.iteration = Some(iteration_count);
203
204        // Release lock before running evaluators (they may try to accumulate).
205        drop(acc);
206
207        self.run_evaluators(EvalHook::OnRunFinished, &ctx);
208        Ok(())
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use arcan_core::protocol::{AgentEvent, RunStopReason, TokenUsage};
216    use arcan_core::state::AppState;
217
218    #[test]
219    fn middleware_with_defaults_creates() {
220        let mw = NousMiddleware::with_defaults().unwrap();
221        assert!(!mw.registry.is_empty());
222    }
223
224    #[test]
225    fn middleware_accumulates_scores_on_after_model() {
226        let mw = NousMiddleware::with_defaults().unwrap();
227
228        let request = ProviderRequest {
229            run_id: "run-1".into(),
230            session_id: "sess-1".into(),
231            iteration: 1,
232            messages: vec![],
233            tools: vec![],
234            state: AppState::default(),
235        };
236        let response = ModelTurn {
237            directives: vec![],
238            stop_reason: arcan_core::protocol::ModelStopReason::EndTurn,
239            usage: Some(TokenUsage {
240                input_tokens: 1000,
241                output_tokens: 200,
242                cache_read_tokens: 0,
243                cache_creation_tokens: 0,
244            }),
245        };
246
247        let result = mw.after_model_call(&request, &response);
248        assert!(result.is_ok());
249
250        let scores = mw.scores();
251        // TokenEfficiency + BudgetAdherence should produce scores
252        // (BudgetAdherence won't fire — no budget data in the response context)
253        assert!(
254            !scores.is_empty(),
255            "should have at least one score from token_efficiency"
256        );
257    }
258
259    #[test]
260    fn middleware_tracks_tool_calls() {
261        let mw = NousMiddleware::with_defaults().unwrap();
262
263        let context = ToolContext {
264            run_id: "run-1".into(),
265            session_id: "sess-1".into(),
266            iteration: 1,
267        };
268        let result = ToolResult {
269            call_id: "c1".into(),
270            tool_name: "read_file".into(),
271            output: serde_json::json!({"content": "hello"}),
272            content: None,
273            is_error: false,
274            state_patch: None,
275        };
276
277        mw.post_tool_call(&context, &result).unwrap();
278
279        let acc = mw.accumulator.lock().unwrap();
280        assert_eq!(acc.tool_call_count, 1);
281        assert_eq!(acc.tool_error_count, 0);
282    }
283
284    #[test]
285    fn middleware_tracks_tool_errors() {
286        let mw = NousMiddleware::with_defaults().unwrap();
287
288        let context = ToolContext {
289            run_id: "run-1".into(),
290            session_id: "sess-1".into(),
291            iteration: 1,
292        };
293        let result = ToolResult {
294            call_id: "c1".into(),
295            tool_name: "write_file".into(),
296            output: serde_json::json!({"error": "permission denied"}),
297            content: None,
298            is_error: true,
299            state_patch: None,
300        };
301
302        mw.post_tool_call(&context, &result).unwrap();
303
304        let acc = mw.accumulator.lock().unwrap();
305        assert_eq!(acc.tool_call_count, 1);
306        assert_eq!(acc.tool_error_count, 1);
307    }
308
309    #[test]
310    fn middleware_on_run_finished_fires_evaluators() {
311        let mw = NousMiddleware::with_defaults().unwrap();
312
313        // Simulate some tool calls first.
314        {
315            let mut acc = mw.accumulator.lock().unwrap();
316            acc.tool_call_count = 5;
317            acc.tool_error_count = 1;
318        }
319
320        let output = RunOutput {
321            run_id: "run-1".into(),
322            session_id: "sess-1".into(),
323            branch_id: "main".into(),
324            events: vec![
325                AgentEvent::RunStarted {
326                    run_id: "run-1".into(),
327                    session_id: "sess-1".into(),
328                    provider: "mock".into(),
329                    max_iterations: 24,
330                },
331                AgentEvent::IterationStarted {
332                    run_id: "run-1".into(),
333                    session_id: "sess-1".into(),
334                    iteration: 1,
335                },
336                AgentEvent::IterationStarted {
337                    run_id: "run-1".into(),
338                    session_id: "sess-1".into(),
339                    iteration: 2,
340                },
341                AgentEvent::RunFinished {
342                    run_id: "run-1".into(),
343                    session_id: "sess-1".into(),
344                    reason: RunStopReason::Completed,
345                    total_iterations: 2,
346                    final_answer: Some("done".into()),
347                    usage: Some(TokenUsage {
348                        input_tokens: 500,
349                        output_tokens: 200,
350                        cache_read_tokens: 0,
351                        cache_creation_tokens: 0,
352                    }),
353                },
354            ],
355            messages: vec![],
356            state: AppState::default(),
357            reason: RunStopReason::Completed,
358            final_answer: Some("done".into()),
359            total_usage: TokenUsage {
360                input_tokens: 500,
361                output_tokens: 200,
362                cache_read_tokens: 0,
363                cache_creation_tokens: 0,
364            },
365        };
366
367        let result = mw.on_run_finished(&output);
368        assert!(result.is_ok());
369
370        let scores = mw.scores();
371        // Should have tool_correctness + step_efficiency scores from on_run_finished.
372        let run_finished_scores: Vec<_> = scores
373            .iter()
374            .filter(|s| s.evaluator == "tool_correctness" || s.evaluator == "step_efficiency")
375            .collect();
376        assert!(
377            run_finished_scores.len() >= 2,
378            "expected tool_correctness and step_efficiency scores, got {:?}",
379            run_finished_scores
380                .iter()
381                .map(|s| &s.evaluator)
382                .collect::<Vec<_>>()
383        );
384    }
385
386    #[test]
387    fn on_score_callback_fires() {
388        let score_count = Arc::new(Mutex::new(0u32));
389        let counter = score_count.clone();
390
391        let registry = nous_heuristics::default_registry().unwrap();
392        let mw = NousMiddleware::with_on_score(
393            registry,
394            Arc::new(move |_score| {
395                *counter.lock().unwrap() += 1;
396            }),
397        );
398
399        let request = ProviderRequest {
400            run_id: "run-1".into(),
401            session_id: "sess-1".into(),
402            iteration: 1,
403            messages: vec![],
404            tools: vec![],
405            state: AppState::default(),
406        };
407        let response = ModelTurn {
408            directives: vec![],
409            stop_reason: arcan_core::protocol::ModelStopReason::EndTurn,
410            usage: Some(TokenUsage {
411                input_tokens: 1000,
412                output_tokens: 200,
413                cache_read_tokens: 0,
414                cache_creation_tokens: 0,
415            }),
416        };
417
418        mw.after_model_call(&request, &response).unwrap();
419
420        let count = *score_count.lock().unwrap();
421        assert!(count > 0, "callback should have fired at least once");
422    }
423}