1use async_trait::async_trait;
16use oharness_core::{CompletionRequest, ConversationView, Message, Task};
17use oharness_llm::Llm;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21#[async_trait]
22pub trait UserSimulator: Send + Sync {
23 fn name(&self) -> &str;
24
25 async fn initial_message(&self, task: &Task) -> Result<String, UserError>;
28
29 async fn respond(
33 &self,
34 conversation: ConversationView<'_>,
35 task: &Task,
36 ) -> Result<UserAction, UserError>;
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum UserAction {
42 Say(String),
43 EndConversation,
44}
45
46#[derive(Debug, thiserror::Error)]
51pub enum UserError {
52 #[error("user simulator: {0}")]
53 Other(String),
54 #[error("user simulator llm: {0}")]
55 Llm(#[from] oharness_llm::LlmError),
56}
57
58pub struct ScriptedUserSimulator {
72 script: Vec<String>,
73 cursor: AtomicUsize,
74 name: String,
75}
76
77impl ScriptedUserSimulator {
78 pub fn new(script: impl IntoIterator<Item = impl Into<String>>) -> Self {
79 let script = script.into_iter().map(Into::into).collect::<Vec<_>>();
80 Self {
81 script,
82 cursor: AtomicUsize::new(0),
83 name: "scripted-user".to_string(),
84 }
85 }
86
87 pub fn with_name(mut self, name: impl Into<String>) -> Self {
88 self.name = name.into();
89 self
90 }
91}
92
93#[async_trait]
94impl UserSimulator for ScriptedUserSimulator {
95 fn name(&self) -> &str {
96 &self.name
97 }
98
99 async fn initial_message(&self, task: &Task) -> Result<String, UserError> {
100 let idx = self.cursor.fetch_add(1, Ordering::SeqCst);
101 self.script
102 .get(idx)
103 .cloned()
104 .ok_or_else(|| UserError::Other(format!("empty script (task={})", task.instruction)))
105 }
106
107 async fn respond(
108 &self,
109 _conversation: ConversationView<'_>,
110 _task: &Task,
111 ) -> Result<UserAction, UserError> {
112 let idx = self.cursor.fetch_add(1, Ordering::SeqCst);
113 match self.script.get(idx) {
114 Some(msg) => Ok(UserAction::Say(msg.clone())),
115 None => Ok(UserAction::EndConversation),
116 }
117 }
118}
119
120pub struct LlmUserSimulator {
135 llm: Arc<dyn Llm>,
136 persona: String,
137 prompt_template: String,
138 end_sentinel: String,
139 name: String,
140}
141
142impl LlmUserSimulator {
143 pub fn default_template() -> &'static str {
146 "You are role-playing a user with this persona:\n\n{persona}\n\n\
147 The user's underlying task is:\n\n{task}\n\n\
148 Respond to the assistant's most recent turn as the user would. \
149 Keep responses short. When the task is fully resolved, include \
150 the literal token `<end>` anywhere in your reply to end the \
151 conversation. Do not prefix your reply with `USER:` or any role \
152 label."
153 }
154
155 pub fn new(
156 llm: Arc<dyn Llm>,
157 persona: impl Into<String>,
158 prompt_template: impl Into<String>,
159 ) -> Self {
160 Self {
161 llm,
162 persona: persona.into(),
163 prompt_template: prompt_template.into(),
164 end_sentinel: "<end>".to_string(),
165 name: "llm-user".to_string(),
166 }
167 }
168
169 pub fn with_end_sentinel(mut self, sentinel: impl Into<String>) -> Self {
170 self.end_sentinel = sentinel.into();
171 self
172 }
173
174 pub fn with_name(mut self, name: impl Into<String>) -> Self {
175 self.name = name.into();
176 self
177 }
178
179 fn render_system(&self, task: &Task) -> String {
180 self.prompt_template
181 .replace("{persona}", &self.persona)
182 .replace("{task}", &task.instruction)
183 }
184}
185
186#[async_trait]
187impl UserSimulator for LlmUserSimulator {
188 fn name(&self) -> &str {
189 &self.name
190 }
191
192 async fn initial_message(&self, task: &Task) -> Result<String, UserError> {
193 Ok(task.instruction.clone())
197 }
198
199 async fn respond(
200 &self,
201 conversation: ConversationView<'_>,
202 task: &Task,
203 ) -> Result<UserAction, UserError> {
204 let transcript = render_transcript(conversation);
205 let mut req = CompletionRequest::new(vec![Message::user_text(transcript)]);
206 req.system = Some(self.render_system(task));
207 let res = self.llm.complete(req).await?;
208 let text = res
209 .content
210 .iter()
211 .filter_map(|c| match c {
212 oharness_core::Content::Text { text } => Some(text.as_str()),
213 _ => None,
214 })
215 .collect::<Vec<_>>()
216 .join("\n");
217 let text_lower = text.to_ascii_lowercase();
218 let sentinel_lower = self.end_sentinel.to_ascii_lowercase();
219 if text_lower.contains(&sentinel_lower) {
220 let _ = strip_case_insensitive(&text, &self.end_sentinel);
224 Ok(UserAction::EndConversation)
225 } else {
226 Ok(UserAction::Say(text))
227 }
228 }
229}
230
231fn render_transcript(view: ConversationView<'_>) -> String {
232 let mut out = String::new();
233 for m in view.user_visible() {
234 match m {
235 Message::System { content, .. } => {
236 out.push_str("SYSTEM: ");
237 out.push_str(&content);
238 out.push('\n');
239 }
240 Message::User { content, .. } => {
241 out.push_str("USER: ");
242 out.push_str(&flatten_text(&content));
243 out.push('\n');
244 }
245 Message::Assistant { content, .. } => {
246 out.push_str("ASSISTANT: ");
247 out.push_str(&flatten_text(&content));
248 out.push('\n');
249 }
250 }
251 }
252 out
253}
254
255fn flatten_text(content: &[oharness_core::Content]) -> String {
256 content
257 .iter()
258 .filter_map(|c| match c {
259 oharness_core::Content::Text { text } => Some(text.as_str()),
260 _ => None,
261 })
262 .collect::<Vec<_>>()
263 .join("\n")
264}
265
266fn strip_case_insensitive(haystack: &str, needle: &str) -> String {
267 let hl = haystack.to_ascii_lowercase();
270 let nl = needle.to_ascii_lowercase();
271 let mut out = String::with_capacity(haystack.len());
272 let mut i = 0;
273 while i < haystack.len() {
274 if hl[i..].starts_with(&nl) {
275 i += needle.len();
276 } else {
277 let ch = haystack[i..].chars().next().unwrap();
278 out.push(ch);
279 i += ch.len_utf8();
280 }
281 }
282 out
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use async_trait::async_trait;
289 use oharness_core::{
290 CompletionResponse, Content, LlmCapabilities, Message, ModelId, StopReason, Task, Usage,
291 };
292 use oharness_llm::{ChunkStream, LlmError};
293 use std::sync::Mutex;
294
295 #[tokio::test]
298 async fn scripted_returns_initial_then_sequenced_responses() {
299 let sim = ScriptedUserSimulator::new(["hi", "more please", "thanks"]);
300 let task = Task::new("chat");
301 let first = sim.initial_message(&task).await.unwrap();
302 assert_eq!(first, "hi");
303
304 let empty: Vec<Message> = Vec::new();
305 let v = ConversationView::new(&empty);
306 match sim.respond(v, &task).await.unwrap() {
307 UserAction::Say(s) => assert_eq!(s, "more please"),
308 other => panic!("expected Say, got {other:?}"),
309 }
310 let v = ConversationView::new(&empty);
311 match sim.respond(v, &task).await.unwrap() {
312 UserAction::Say(s) => assert_eq!(s, "thanks"),
313 other => panic!("expected Say, got {other:?}"),
314 }
315 let v = ConversationView::new(&empty);
317 assert_eq!(
318 sim.respond(v, &task).await.unwrap(),
319 UserAction::EndConversation
320 );
321 }
322
323 #[tokio::test]
324 async fn scripted_empty_script_errors_on_initial_message() {
325 let sim: ScriptedUserSimulator = ScriptedUserSimulator::new(std::iter::empty::<String>());
326 match sim.initial_message(&Task::new("t")).await {
327 Err(UserError::Other(msg)) => assert!(msg.contains("empty script")),
328 other => panic!("expected Err(UserError::Other), got {other:?}"),
329 }
330 }
331
332 struct OneShot(Mutex<Option<CompletionResponse>>);
335 #[async_trait]
336 impl Llm for OneShot {
337 fn name(&self) -> &str {
338 "one-shot-user"
339 }
340 fn capabilities(&self) -> LlmCapabilities {
341 LlmCapabilities::default()
342 }
343 async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
344 self.0
345 .lock()
346 .unwrap()
347 .take()
348 .ok_or(LlmError::Unsupported("one-shot"))
349 }
350 async fn stream(&self, _req: CompletionRequest) -> Result<ChunkStream, LlmError> {
351 Err(LlmError::Unsupported("stream"))
352 }
353 }
354
355 fn text_response(text: &str) -> CompletionResponse {
356 CompletionResponse {
357 id: "u".into(),
358 model: ModelId::new("m"),
359 content: vec![Content::text(text)],
360 stop_reason: StopReason::EndTurn,
361 usage: Usage::default(),
362 }
363 }
364
365 #[tokio::test]
366 async fn llm_user_initial_message_replays_task_instruction() {
367 let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(None)));
368 let sim = LlmUserSimulator::new(llm, "friendly user", LlmUserSimulator::default_template());
369 let task = Task::new("help me debug this bug");
370 assert_eq!(
371 sim.initial_message(&task).await.unwrap(),
372 "help me debug this bug"
373 );
374 }
375
376 #[tokio::test]
377 async fn llm_user_respond_emits_say_on_plain_response() {
378 let llm: Arc<dyn Llm> =
379 Arc::new(OneShot(Mutex::new(Some(text_response("that's helpful")))));
380 let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
381 let empty: Vec<Message> = Vec::new();
382 match sim
383 .respond(ConversationView::new(&empty), &Task::new("t"))
384 .await
385 .unwrap()
386 {
387 UserAction::Say(s) => assert_eq!(s, "that's helpful"),
388 other => panic!("expected Say, got {other:?}"),
389 }
390 }
391
392 #[tokio::test]
393 async fn llm_user_respond_emits_end_on_sentinel() {
394 let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(Some(text_response("done <end>")))));
395 let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
396 let empty: Vec<Message> = Vec::new();
397 assert_eq!(
398 sim.respond(ConversationView::new(&empty), &Task::new("t"))
399 .await
400 .unwrap(),
401 UserAction::EndConversation
402 );
403 }
404
405 #[tokio::test]
406 async fn llm_user_respond_is_case_insensitive_on_sentinel() {
407 let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(Some(text_response("<END>")))));
408 let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
409 let empty: Vec<Message> = Vec::new();
410 assert_eq!(
411 sim.respond(ConversationView::new(&empty), &Task::new("t"))
412 .await
413 .unwrap(),
414 UserAction::EndConversation
415 );
416 }
417
418 #[tokio::test]
419 async fn llm_user_respond_errors_on_llm_error() {
420 let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(None)));
422 let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
423 let empty: Vec<Message> = Vec::new();
424 match sim
425 .respond(ConversationView::new(&empty), &Task::new("t"))
426 .await
427 {
428 Err(UserError::Llm(_)) => {}
429 other => panic!("expected Err(UserError::Llm), got {other:?}"),
430 }
431 }
432}