Skip to main content

oharness_loop/
user_simulator.rs

1//! [`UserSimulator`] trait and shipped simulator implementations
2//! (plan §12.3, §12.4).
3//!
4//! A simulator stands in for a human user during a [`ConversationLoop`]
5//! run: it produces an initial message from the task, then responds to
6//! each assistant turn with either [`UserAction::Say`] (send a follow-up)
7//! or [`UserAction::EndConversation`] (terminate the loop).
8//!
9//! Simulators receive a [`oharness_core::ConversationView`] — not raw
10//! messages — so they can call `.user_visible()` to strip internal
11//! reasoning / tool calls / tool results and respond only to the
12//! human-facing thread. Simulators that want stricter or looser views
13//! can compose their own.
14
15use async_trait::async_trait;
16use oharness_core::{CompletionRequest, ConversationView, Message, Task};
17use oharness_llm::Llm;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21#[async_trait]
22pub trait UserSimulator: Send + Sync {
23    fn name(&self) -> &str;
24
25    /// Produce the first user message from the task. Called once at the
26    /// start of a conversation loop.
27    async fn initial_message(&self, task: &Task) -> Result<String, UserError>;
28
29    /// Respond to the current conversation. Return
30    /// [`UserAction::Say`] with the next user message, or
31    /// [`UserAction::EndConversation`] to terminate the loop.
32    async fn respond(
33        &self,
34        conversation: ConversationView<'_>,
35        task: &Task,
36    ) -> Result<UserAction, UserError>;
37}
38
39/// What the simulator wants to do next.
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum UserAction {
42    Say(String),
43    EndConversation,
44}
45
46/// Simulator-side errors. Any error here is promoted to
47/// `Termination::Failed { reason: "user_simulator_error" }` by the
48/// [`ConversationLoop`] — simulators cannot silently fall back to
49/// [`UserAction::EndConversation`] since that would hide bugs.
50#[derive(Debug, thiserror::Error)]
51pub enum UserError {
52    #[error("user simulator: {0}")]
53    Other(String),
54    #[error("user simulator llm: {0}")]
55    Llm(#[from] oharness_llm::LlmError),
56}
57
58// ======================================================================
59// ScriptedUserSimulator (plan §12.4)
60// ======================================================================
61
62/// A simulator that replays a fixed sequence of user utterances. The
63/// first entry is returned from [`UserSimulator::initial_message`]; each
64/// subsequent call to [`UserSimulator::respond`] returns the next entry,
65/// [`UserAction::Say`]-wrapped, until the script is exhausted — at which
66/// point the simulator returns [`UserAction::EndConversation`].
67///
68/// Useful for tests, reproducible evaluation runs, and
69/// conversation-loop smoke paths where no LLM should be involved on the
70/// user side.
71pub struct ScriptedUserSimulator {
72    script: Vec<String>,
73    cursor: AtomicUsize,
74    name: String,
75}
76
77impl ScriptedUserSimulator {
78    pub fn new(script: impl IntoIterator<Item = impl Into<String>>) -> Self {
79        let script = script.into_iter().map(Into::into).collect::<Vec<_>>();
80        Self {
81            script,
82            cursor: AtomicUsize::new(0),
83            name: "scripted-user".to_string(),
84        }
85    }
86
87    pub fn with_name(mut self, name: impl Into<String>) -> Self {
88        self.name = name.into();
89        self
90    }
91}
92
93#[async_trait]
94impl UserSimulator for ScriptedUserSimulator {
95    fn name(&self) -> &str {
96        &self.name
97    }
98
99    async fn initial_message(&self, task: &Task) -> Result<String, UserError> {
100        let idx = self.cursor.fetch_add(1, Ordering::SeqCst);
101        self.script
102            .get(idx)
103            .cloned()
104            .ok_or_else(|| UserError::Other(format!("empty script (task={})", task.instruction)))
105    }
106
107    async fn respond(
108        &self,
109        _conversation: ConversationView<'_>,
110        _task: &Task,
111    ) -> Result<UserAction, UserError> {
112        let idx = self.cursor.fetch_add(1, Ordering::SeqCst);
113        match self.script.get(idx) {
114            Some(msg) => Ok(UserAction::Say(msg.clone())),
115            None => Ok(UserAction::EndConversation),
116        }
117    }
118}
119
120// ======================================================================
121// LlmUserSimulator (plan §12.4)
122// ======================================================================
123
124/// A simulator that drives a user LLM with a persona + template. Each
125/// `respond` call builds a [`CompletionRequest`] whose system prompt is
126/// the rendered `prompt_template` (with `{persona}` and `{task}`
127/// substituted) and whose user message is the serialized conversation
128/// so far.
129///
130/// The simulator looks for a terminal sentinel — by default `<end>` —
131/// in the LLM's response to decide whether to return
132/// [`UserAction::EndConversation`]. The sentinel is matched
133/// case-insensitively and stripped from the surrounding text.
134pub struct LlmUserSimulator {
135    llm: Arc<dyn Llm>,
136    persona: String,
137    prompt_template: String,
138    end_sentinel: String,
139    name: String,
140}
141
142impl LlmUserSimulator {
143    /// A reasonable default template — users with opinionated
144    /// simulators should supply their own.
145    pub fn default_template() -> &'static str {
146        "You are role-playing a user with this persona:\n\n{persona}\n\n\
147         The user's underlying task is:\n\n{task}\n\n\
148         Respond to the assistant's most recent turn as the user would. \
149         Keep responses short. When the task is fully resolved, include \
150         the literal token `<end>` anywhere in your reply to end the \
151         conversation. Do not prefix your reply with `USER:` or any role \
152         label."
153    }
154
155    pub fn new(
156        llm: Arc<dyn Llm>,
157        persona: impl Into<String>,
158        prompt_template: impl Into<String>,
159    ) -> Self {
160        Self {
161            llm,
162            persona: persona.into(),
163            prompt_template: prompt_template.into(),
164            end_sentinel: "<end>".to_string(),
165            name: "llm-user".to_string(),
166        }
167    }
168
169    pub fn with_end_sentinel(mut self, sentinel: impl Into<String>) -> Self {
170        self.end_sentinel = sentinel.into();
171        self
172    }
173
174    pub fn with_name(mut self, name: impl Into<String>) -> Self {
175        self.name = name.into();
176        self
177    }
178
179    fn render_system(&self, task: &Task) -> String {
180        self.prompt_template
181            .replace("{persona}", &self.persona)
182            .replace("{task}", &task.instruction)
183    }
184}
185
186#[async_trait]
187impl UserSimulator for LlmUserSimulator {
188    fn name(&self) -> &str {
189        &self.name
190    }
191
192    async fn initial_message(&self, task: &Task) -> Result<String, UserError> {
193        // Cheap: the task's instruction is already a user-shaped prompt.
194        // Callers that want a more elaborate kickoff can subclass /
195        // wrap — for the default, we just replay the instruction.
196        Ok(task.instruction.clone())
197    }
198
199    async fn respond(
200        &self,
201        conversation: ConversationView<'_>,
202        task: &Task,
203    ) -> Result<UserAction, UserError> {
204        let transcript = render_transcript(conversation);
205        let mut req = CompletionRequest::new(vec![Message::user_text(transcript)]);
206        req.system = Some(self.render_system(task));
207        let res = self.llm.complete(req).await?;
208        let text = res
209            .content
210            .iter()
211            .filter_map(|c| match c {
212                oharness_core::Content::Text { text } => Some(text.as_str()),
213                _ => None,
214            })
215            .collect::<Vec<_>>()
216            .join("\n");
217        let text_lower = text.to_ascii_lowercase();
218        let sentinel_lower = self.end_sentinel.to_ascii_lowercase();
219        if text_lower.contains(&sentinel_lower) {
220            // Sentinel present — end the conversation. Any surrounding
221            // text is dropped; users who want a trailing message should
222            // emit it on a previous turn rather than glue it to `<end>`.
223            let _ = strip_case_insensitive(&text, &self.end_sentinel);
224            Ok(UserAction::EndConversation)
225        } else {
226            Ok(UserAction::Say(text))
227        }
228    }
229}
230
231fn render_transcript(view: ConversationView<'_>) -> String {
232    let mut out = String::new();
233    for m in view.user_visible() {
234        match m {
235            Message::System { content, .. } => {
236                out.push_str("SYSTEM: ");
237                out.push_str(&content);
238                out.push('\n');
239            }
240            Message::User { content, .. } => {
241                out.push_str("USER: ");
242                out.push_str(&flatten_text(&content));
243                out.push('\n');
244            }
245            Message::Assistant { content, .. } => {
246                out.push_str("ASSISTANT: ");
247                out.push_str(&flatten_text(&content));
248                out.push('\n');
249            }
250        }
251    }
252    out
253}
254
255fn flatten_text(content: &[oharness_core::Content]) -> String {
256    content
257        .iter()
258        .filter_map(|c| match c {
259            oharness_core::Content::Text { text } => Some(text.as_str()),
260            _ => None,
261        })
262        .collect::<Vec<_>>()
263        .join("\n")
264}
265
266fn strip_case_insensitive(haystack: &str, needle: &str) -> String {
267    // Case-insensitive single-pass strip. Simple and correct for small
268    // inputs; we don't expect the sentinel to appear many times.
269    let hl = haystack.to_ascii_lowercase();
270    let nl = needle.to_ascii_lowercase();
271    let mut out = String::with_capacity(haystack.len());
272    let mut i = 0;
273    while i < haystack.len() {
274        if hl[i..].starts_with(&nl) {
275            i += needle.len();
276        } else {
277            let ch = haystack[i..].chars().next().unwrap();
278            out.push(ch);
279            i += ch.len_utf8();
280        }
281    }
282    out
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use async_trait::async_trait;
289    use oharness_core::{
290        CompletionResponse, Content, LlmCapabilities, Message, ModelId, StopReason, Task, Usage,
291    };
292    use oharness_llm::{ChunkStream, LlmError};
293    use std::sync::Mutex;
294
295    // ---------- ScriptedUserSimulator ----------
296
297    #[tokio::test]
298    async fn scripted_returns_initial_then_sequenced_responses() {
299        let sim = ScriptedUserSimulator::new(["hi", "more please", "thanks"]);
300        let task = Task::new("chat");
301        let first = sim.initial_message(&task).await.unwrap();
302        assert_eq!(first, "hi");
303
304        let empty: Vec<Message> = Vec::new();
305        let v = ConversationView::new(&empty);
306        match sim.respond(v, &task).await.unwrap() {
307            UserAction::Say(s) => assert_eq!(s, "more please"),
308            other => panic!("expected Say, got {other:?}"),
309        }
310        let v = ConversationView::new(&empty);
311        match sim.respond(v, &task).await.unwrap() {
312            UserAction::Say(s) => assert_eq!(s, "thanks"),
313            other => panic!("expected Say, got {other:?}"),
314        }
315        // Exhausted -> EndConversation.
316        let v = ConversationView::new(&empty);
317        assert_eq!(
318            sim.respond(v, &task).await.unwrap(),
319            UserAction::EndConversation
320        );
321    }
322
323    #[tokio::test]
324    async fn scripted_empty_script_errors_on_initial_message() {
325        let sim: ScriptedUserSimulator = ScriptedUserSimulator::new(std::iter::empty::<String>());
326        match sim.initial_message(&Task::new("t")).await {
327            Err(UserError::Other(msg)) => assert!(msg.contains("empty script")),
328            other => panic!("expected Err(UserError::Other), got {other:?}"),
329        }
330    }
331
332    // ---------- LlmUserSimulator ----------
333
334    struct OneShot(Mutex<Option<CompletionResponse>>);
335    #[async_trait]
336    impl Llm for OneShot {
337        fn name(&self) -> &str {
338            "one-shot-user"
339        }
340        fn capabilities(&self) -> LlmCapabilities {
341            LlmCapabilities::default()
342        }
343        async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
344            self.0
345                .lock()
346                .unwrap()
347                .take()
348                .ok_or(LlmError::Unsupported("one-shot"))
349        }
350        async fn stream(&self, _req: CompletionRequest) -> Result<ChunkStream, LlmError> {
351            Err(LlmError::Unsupported("stream"))
352        }
353    }
354
355    fn text_response(text: &str) -> CompletionResponse {
356        CompletionResponse {
357            id: "u".into(),
358            model: ModelId::new("m"),
359            content: vec![Content::text(text)],
360            stop_reason: StopReason::EndTurn,
361            usage: Usage::default(),
362        }
363    }
364
365    #[tokio::test]
366    async fn llm_user_initial_message_replays_task_instruction() {
367        let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(None)));
368        let sim = LlmUserSimulator::new(llm, "friendly user", LlmUserSimulator::default_template());
369        let task = Task::new("help me debug this bug");
370        assert_eq!(
371            sim.initial_message(&task).await.unwrap(),
372            "help me debug this bug"
373        );
374    }
375
376    #[tokio::test]
377    async fn llm_user_respond_emits_say_on_plain_response() {
378        let llm: Arc<dyn Llm> =
379            Arc::new(OneShot(Mutex::new(Some(text_response("that's helpful")))));
380        let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
381        let empty: Vec<Message> = Vec::new();
382        match sim
383            .respond(ConversationView::new(&empty), &Task::new("t"))
384            .await
385            .unwrap()
386        {
387            UserAction::Say(s) => assert_eq!(s, "that's helpful"),
388            other => panic!("expected Say, got {other:?}"),
389        }
390    }
391
392    #[tokio::test]
393    async fn llm_user_respond_emits_end_on_sentinel() {
394        let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(Some(text_response("done <end>")))));
395        let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
396        let empty: Vec<Message> = Vec::new();
397        assert_eq!(
398            sim.respond(ConversationView::new(&empty), &Task::new("t"))
399                .await
400                .unwrap(),
401            UserAction::EndConversation
402        );
403    }
404
405    #[tokio::test]
406    async fn llm_user_respond_is_case_insensitive_on_sentinel() {
407        let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(Some(text_response("<END>")))));
408        let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
409        let empty: Vec<Message> = Vec::new();
410        assert_eq!(
411            sim.respond(ConversationView::new(&empty), &Task::new("t"))
412                .await
413                .unwrap(),
414            UserAction::EndConversation
415        );
416    }
417
418    #[tokio::test]
419    async fn llm_user_respond_errors_on_llm_error() {
420        // One-shot with no response → first respond errors.
421        let llm: Arc<dyn Llm> = Arc::new(OneShot(Mutex::new(None)));
422        let sim = LlmUserSimulator::new(llm, "user", LlmUserSimulator::default_template());
423        let empty: Vec<Message> = Vec::new();
424        match sim
425            .respond(ConversationView::new(&empty), &Task::new("t"))
426            .await
427        {
428            Err(UserError::Llm(_)) => {}
429            other => panic!("expected Err(UserError::Llm), got {other:?}"),
430        }
431    }
432}