#![cfg(any(test, feature = "test-support"))]
use std::sync::Arc;
use async_trait::async_trait;
use parking_lot::Mutex;
use solo_core::{LlmClient, Message, Result, Role};
pub struct StubLlmClient {
name: String,
is_real_llm_override: bool,
state: Mutex<StubState>,
}
#[derive(Default)]
struct StubState {
canned: std::collections::VecDeque<String>,
prompts: Vec<Vec<Message>>,
call_count: usize,
}
impl StubLlmClient {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
is_real_llm_override: false,
state: Mutex::new(StubState::default()),
}
}
pub fn pretend_real_llm(mut self, yes: bool) -> Self {
self.is_real_llm_override = yes;
self
}
pub fn default_stub() -> Self {
Self::new("stub-llm")
}
pub fn with_canned(name: impl Into<String>, response: impl Into<String>) -> Self {
let s = Self::new(name);
s.push_canned(response);
s
}
pub fn push_canned(&self, response: impl Into<String>) {
self.state.lock().canned.push_back(response.into());
}
pub fn call_count(&self) -> usize {
self.state.lock().call_count
}
pub fn prompts(&self) -> Vec<Vec<Message>> {
self.state.lock().prompts.clone()
}
pub fn default_response() -> &'static str {
r#"{"content":"(stub abstraction)","triples":[],"confidence":0.5}"#
}
}
#[async_trait]
impl LlmClient for StubLlmClient {
fn name(&self) -> &str {
&self.name
}
async fn complete(&self, messages: &[Message]) -> Result<Message> {
let response = {
let mut state = self.state.lock();
state.call_count += 1;
state.prompts.push(messages.to_vec());
state
.canned
.pop_front()
.unwrap_or_else(|| Self::default_response().to_string())
};
Ok(Message {
role: Role::Assistant,
content: response,
})
}
fn is_real_llm(&self) -> bool {
self.is_real_llm_override
}
}
pub fn arc_stub() -> Arc<dyn LlmClient> {
Arc::new(StubLlmClient::default_stub())
}
#[cfg(test)]
mod tests {
use super::*;
fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
}
#[test]
fn default_response_is_parseable_json() {
let s = StubLlmClient::default_stub();
let resp = rt()
.block_on(s.complete(&[Message::user("hello")]))
.unwrap();
assert_eq!(resp.role, Role::Assistant);
let v: serde_json::Value =
serde_json::from_str(&resp.content).expect("parses as JSON");
assert_eq!(v["content"], "(stub abstraction)");
assert_eq!(v["confidence"], 0.5);
}
#[test]
fn canned_responses_drain_in_fifo_order() {
let s = StubLlmClient::default_stub();
s.push_canned("first");
s.push_canned("second");
let r1 = rt().block_on(s.complete(&[])).unwrap();
let r2 = rt().block_on(s.complete(&[])).unwrap();
let r3 = rt().block_on(s.complete(&[])).unwrap();
assert_eq!(r1.content, "first");
assert_eq!(r2.content, "second");
assert_eq!(r3.content, StubLlmClient::default_response());
assert_eq!(s.call_count(), 3);
}
#[test]
fn prompts_records_every_call() {
let s = StubLlmClient::default_stub();
let _ = rt().block_on(s.complete(&[Message::user("alpha")])).unwrap();
let _ = rt()
.block_on(s.complete(&[Message::system("S"), Message::user("beta")]))
.unwrap();
let prompts = s.prompts();
assert_eq!(prompts.len(), 2);
assert_eq!(prompts[0].len(), 1);
assert_eq!(prompts[0][0].content, "alpha");
assert_eq!(prompts[1].len(), 2);
assert_eq!(prompts[1][1].content, "beta");
}
#[test]
fn name_is_returned_unchanged() {
let s = StubLlmClient::new("my-test-backend");
assert_eq!(s.name(), "my-test-backend");
let s2 = StubLlmClient::default_stub();
assert_eq!(s2.name(), "stub-llm");
}
#[test]
fn with_canned_constructor_queues_first_response() {
let s = StubLlmClient::with_canned("named", r#"{"x":1}"#);
let resp = rt().block_on(s.complete(&[])).unwrap();
assert_eq!(resp.content, r#"{"x":1}"#);
}
#[test]
fn is_real_llm_defaults_to_false_for_stub() {
let s = StubLlmClient::default_stub();
assert!(
!s.is_real_llm(),
"stub must report `is_real_llm` = false by default"
);
}
#[test]
fn pretend_real_llm_flips_is_real_llm_to_true() {
let s = StubLlmClient::default_stub().pretend_real_llm(true);
assert!(s.is_real_llm());
let s2 = StubLlmClient::default_stub().pretend_real_llm(false);
assert!(!s2.is_real_llm());
}
}