Skip to main content

aster_bench/
bench_session.rs

1use aster::conversation::Conversation;
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9#[derive(Debug, Deserialize, Serialize, Clone)]
10pub struct BenchAgentError {
11    pub message: String,
12    pub level: String, // ERROR, WARN, etc.
13    pub timestamp: DateTime<Utc>,
14}
15
16// avoid tying benchmarking to current session-impl.
17#[async_trait]
18pub trait BenchBaseSession: Send + Sync {
19    async fn headless(&mut self, message: String) -> anyhow::Result<()>;
20    fn message_history(&self) -> Conversation;
21    fn get_total_token_usage(&self) -> anyhow::Result<Option<i32>>;
22    fn get_session_id(&self) -> anyhow::Result<String>;
23}
24// struct for managing agent-session-access. to be passed to evals for benchmarking
25pub struct BenchAgent {
26    session: Box<dyn BenchBaseSession>,
27    errors: Arc<Mutex<Vec<BenchAgentError>>>,
28}
29
30impl BenchAgent {
31    pub fn new(session: Box<dyn BenchBaseSession>) -> Self {
32        let errors = Arc::new(Mutex::new(Vec::new()));
33        Self { session, errors }
34    }
35
36    pub(crate) async fn prompt(&mut self, p: String) -> anyhow::Result<Conversation> {
37        // Clear previous errors
38        {
39            let mut errors = self.errors.lock().await;
40            errors.clear();
41        }
42        self.session.headless(p).await?;
43        Ok(self.session.message_history())
44    }
45
46    pub async fn get_errors(&self) -> Vec<BenchAgentError> {
47        let errors = self.errors.lock().await;
48        errors.clone()
49    }
50
51    pub(crate) async fn get_token_usage(&self) -> Option<i32> {
52        self.session.get_total_token_usage().ok().flatten()
53    }
54
55    pub(crate) fn get_session_id(&self) -> anyhow::Result<String> {
56        self.session.get_session_id()
57    }
58}