aster_bench/
bench_session.rs1use aster::conversation::Conversation;
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9#[derive(Debug, Deserialize, Serialize, Clone)]
10pub struct BenchAgentError {
11 pub message: String,
12 pub level: String, pub timestamp: DateTime<Utc>,
14}
15
16#[async_trait]
18pub trait BenchBaseSession: Send + Sync {
19 async fn headless(&mut self, message: String) -> anyhow::Result<()>;
20 fn message_history(&self) -> Conversation;
21 fn get_total_token_usage(&self) -> anyhow::Result<Option<i32>>;
22 fn get_session_id(&self) -> anyhow::Result<String>;
23}
24pub struct BenchAgent {
26 session: Box<dyn BenchBaseSession>,
27 errors: Arc<Mutex<Vec<BenchAgentError>>>,
28}
29
30impl BenchAgent {
31 pub fn new(session: Box<dyn BenchBaseSession>) -> Self {
32 let errors = Arc::new(Mutex::new(Vec::new()));
33 Self { session, errors }
34 }
35
36 pub(crate) async fn prompt(&mut self, p: String) -> anyhow::Result<Conversation> {
37 {
39 let mut errors = self.errors.lock().await;
40 errors.clear();
41 }
42 self.session.headless(p).await?;
43 Ok(self.session.message_history())
44 }
45
46 pub async fn get_errors(&self) -> Vec<BenchAgentError> {
47 let errors = self.errors.lock().await;
48 errors.clone()
49 }
50
51 pub(crate) async fn get_token_usage(&self) -> Option<i32> {
52 self.session.get_total_token_usage().ok().flatten()
53 }
54
55 pub(crate) fn get_session_id(&self) -> anyhow::Result<String> {
56 self.session.get_session_id()
57 }
58}