#![cfg(any(test, feature = "test-utils"))]
#![allow(missing_docs)]
use crate::agent::AgentContext;
use crate::bus::{
ClaimHandleImpl, ClaimedJob, Headers, Job, JobQueue, KvEntry, KvStore, Lease, LeaseImpl, Msg,
MsgStream, Pubsub, RequestReply, Revision,
};
use crate::error::{BusError, LlmError, MemoryError, ToolError};
use crate::ids::{DurableName, FactId, JobId, RunId, ThreadId};
use crate::llm::{
Capabilities, ChatChunk, ChatRequest, ChatResponse, ChunkStream, Embedding, FinishReason,
LlmClient, Message, ToolCall,
};
use crate::memory::{
Episode, EpisodicMemory, Fact, LongTermMemory, RunFilter, RunSummary, Scope, ShortTermMemory,
};
use crate::runtime::{run_steps, RunOptions};
use crate::tool::{ToolCtx, ToolInvoker};
use async_trait::async_trait;
use bytes::Bytes;
use chrono::Utc;
use futures_core::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
pub enum FakeLlmStep {
Text(String),
ToolCalls(Vec<ToolCall>),
}
pub enum FakeStreamStep {
Chunks(Vec<Result<ChatChunk, LlmError>>),
InitErr(Option<LlmError>),
}
impl FakeStreamStep {
pub fn init_err(e: LlmError) -> Self {
Self::InitErr(Some(e))
}
}
#[derive(Default)]
pub struct FakeLlmClient {
name: String,
caps: Capabilities,
script: Mutex<std::collections::VecDeque<FakeLlmStep>>,
stream_script: Mutex<std::collections::VecDeque<FakeStreamStep>>,
stream_calls: std::sync::atomic::AtomicU32,
}
impl FakeLlmClient {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
caps: Capabilities {
tool_calling: true,
streaming: false,
structured_output: false,
embeddings: false,
max_context_tokens: 32_000,
vision: false,
},
script: Mutex::new(std::collections::VecDeque::new()),
stream_script: Mutex::new(std::collections::VecDeque::new()),
stream_calls: std::sync::atomic::AtomicU32::new(0),
}
}
pub fn with_steps(self, steps: Vec<FakeLlmStep>) -> Self {
let mut new_q = std::collections::VecDeque::new();
new_q.extend(steps);
Self {
script: Mutex::new(new_q),
..self
}
}
pub fn with_stream_steps(self, steps: Vec<FakeStreamStep>) -> Self {
let mut new_q = std::collections::VecDeque::new();
new_q.extend(steps);
Self {
stream_script: Mutex::new(new_q),
..self
}
}
pub fn stream_call_count(&self) -> u32 {
self.stream_calls.load(std::sync::atomic::Ordering::SeqCst)
}
}
#[async_trait]
impl LlmClient for FakeLlmClient {
fn name(&self) -> &str {
&self.name
}
fn capabilities(&self) -> &Capabilities {
&self.caps
}
async fn complete(&self, _req: ChatRequest) -> Result<ChatResponse, LlmError> {
let mut q = self.script.lock().await;
let step = q
.pop_front()
.ok_or_else(|| LlmError::BadRequest("FakeLlmClient: script exhausted".into()))?;
let (msg, finish) = match step {
FakeLlmStep::Text(s) => (
Message {
role: crate::llm::Role::Assistant,
content: s,
tool_calls: vec![],
tool_call_id: None,
},
FinishReason::Stop,
),
FakeLlmStep::ToolCalls(calls) => (
Message {
role: crate::llm::Role::Assistant,
content: String::new(),
tool_calls: calls,
tool_call_id: None,
},
FinishReason::ToolCalls,
),
};
Ok(ChatResponse {
message: msg,
usage: Default::default(),
finish_reason: finish,
})
}
async fn stream(&self, _req: ChatRequest) -> Result<ChunkStream, LlmError> {
self.stream_calls
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let mut q = self.stream_script.lock().await;
let step = q
.pop_front()
.ok_or_else(|| LlmError::Unsupported("streaming".into()))?;
match step {
FakeStreamStep::Chunks(items) => Ok(Box::pin(tokio_stream::iter(items))),
FakeStreamStep::InitErr(mut slot) => {
let e = slot
.take()
.unwrap_or_else(|| LlmError::Server("fake init err already taken".into()));
Err(e)
}
}
}
async fn embed(&self, _texts: &[String]) -> Result<Vec<Embedding>, LlmError> {
Err(LlmError::Unsupported("embeddings".into()))
}
}
type ToolHandler =
Arc<dyn Fn(serde_json::Value) -> Result<serde_json::Value, ToolError> + Send + Sync>;
#[derive(Default)]
pub struct FakeToolInvoker {
handlers: HashMap<String, ToolHandler>,
catalogue: Vec<crate::llm::ToolDef>,
}
impl FakeToolInvoker {
pub fn new() -> Self {
Self::default()
}
pub fn with_tool<F>(mut self, name: &str, description: &str, handler: F) -> Self
where
F: Fn(serde_json::Value) -> Result<serde_json::Value, ToolError> + Send + Sync + 'static,
{
self.handlers.insert(name.to_string(), Arc::new(handler));
self.catalogue.push(crate::llm::ToolDef {
name: name.to_string(),
description: description.to_string(),
json_schema: serde_json::json!({"type": "object"}),
});
self
}
}
#[async_trait]
impl ToolInvoker for FakeToolInvoker {
async fn invoke(
&self,
name: &str,
args: serde_json::Value,
_ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
let h = self
.handlers
.get(name)
.ok_or_else(|| ToolError::UnknownTool(name.into()))?;
h(args)
}
fn catalogue(&self) -> Vec<crate::llm::ToolDef> {
self.catalogue.clone()
}
}
#[derive(Default)]
pub struct InMemoryShortTerm {
inner: Mutex<HashMap<String, Vec<Message>>>,
}
#[async_trait]
impl ShortTermMemory for InMemoryShortTerm {
async fn append(&self, thread: ThreadId, msg: Message) -> Result<(), MemoryError> {
self.inner
.lock()
.await
.entry(thread.0)
.or_default()
.push(msg);
Ok(())
}
async fn load(
&self,
thread: ThreadId,
_max_tokens: usize,
) -> Result<Vec<Message>, MemoryError> {
Ok(self
.inner
.lock()
.await
.get(&thread.0)
.cloned()
.unwrap_or_default())
}
async fn clear(&self, thread: ThreadId) -> Result<(), MemoryError> {
self.inner.lock().await.remove(&thread.0);
Ok(())
}
}
#[derive(Default)]
pub struct InMemoryLongTerm {
inner: Mutex<Vec<(FactId, Scope, Fact)>>,
counter: Mutex<u64>,
}
#[async_trait]
impl LongTermMemory for InMemoryLongTerm {
async fn remember(&self, scope: Scope, fact: Fact) -> Result<FactId, MemoryError> {
let mut c = self.counter.lock().await;
*c += 1;
let id = FactId(format!("fake-{}", *c));
self.inner.lock().await.push((id.clone(), scope, fact));
Ok(id)
}
async fn recall(&self, scope: Scope, query: &str, k: usize) -> Result<Vec<Fact>, MemoryError> {
let q = query.to_lowercase();
Ok(self
.inner
.lock()
.await
.iter()
.filter(|(_, s, _)| *s == scope)
.filter(|(_, _, f)| f.text.to_lowercase().contains(&q))
.take(k)
.map(|(_, _, f)| f.clone())
.collect())
}
async fn forget(&self, id: FactId) -> Result<(), MemoryError> {
self.inner.lock().await.retain(|(i, _, _)| i != &id);
Ok(())
}
}
#[derive(Default)]
pub struct InMemoryEpisodic {
inner: Mutex<HashMap<RunId, Vec<Episode>>>,
}
#[async_trait]
impl EpisodicMemory for InMemoryEpisodic {
async fn record(&self, run: RunId, event: Episode) -> Result<(), MemoryError> {
self.inner.lock().await.entry(run).or_default().push(event);
Ok(())
}
async fn replay(&self, run: RunId) -> Result<Vec<Episode>, MemoryError> {
Ok(self
.inner
.lock()
.await
.get(&run)
.cloned()
.unwrap_or_default())
}
async fn list_runs(&self, _filter: RunFilter) -> Result<Vec<RunSummary>, MemoryError> {
let g = self.inner.lock().await;
Ok(g.iter()
.map(|(id, eps)| RunSummary {
run_id: *id,
agent: String::new(),
started_at: Utc::now(),
finished_at: None,
episode_count: eps.len() as u32,
})
.collect())
}
}
struct EmptyStream;
impl Stream for EmptyStream {
type Item = Result<Msg, BusError>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}
pub struct NoopPubsub;
#[async_trait]
impl Pubsub for NoopPubsub {
async fn publish(
&self,
_subject: &str,
_payload: Bytes,
_headers: Headers,
) -> Result<(), BusError> {
Ok(())
}
async fn subscribe(
&self,
_subject: &str,
_durable: DurableName,
) -> Result<MsgStream, BusError> {
let s = EmptyStream;
Ok(Box::pin(s))
}
}
pub struct NoopRequestReply;
#[async_trait]
impl RequestReply for NoopRequestReply {
async fn request(
&self,
_subject: &str,
_payload: Bytes,
_timeout: Duration,
) -> Result<Bytes, BusError> {
Err(BusError::NotFound("noop bus".into()))
}
}
pub struct NoopKv;
struct NoopLease;
#[async_trait]
impl LeaseImpl for NoopLease {
async fn heartbeat(&self) -> Result<(), BusError> {
Ok(())
}
}
#[async_trait]
impl KvStore for NoopKv {
async fn get(&self, _b: &str, _k: &str) -> Result<Option<KvEntry>, BusError> {
Ok(None)
}
async fn put(&self, _b: &str, _k: &str, _v: Bytes) -> Result<Revision, BusError> {
Ok(1)
}
async fn cas(
&self,
_b: &str,
_k: &str,
_v: Bytes,
_expected: Option<Revision>,
) -> Result<Revision, BusError> {
Ok(1)
}
async fn delete(&self, _b: &str, _k: &str) -> Result<(), BusError> {
Ok(())
}
async fn lease(&self, _b: &str, _k: &str, _ttl: Duration) -> Result<Lease, BusError> {
Ok(Lease::new(Box::new(NoopLease)))
}
}
pub struct NoopJobQueue;
#[allow(dead_code)]
struct NoopClaim;
#[async_trait]
impl ClaimHandleImpl for NoopClaim {
async fn ack(self: Box<Self>) -> Result<(), BusError> {
Ok(())
}
async fn nak(self: Box<Self>, _delay: Duration) -> Result<(), BusError> {
Ok(())
}
async fn dead_letter(self: Box<Self>, _reason: &str) -> Result<(), BusError> {
Ok(())
}
}
#[async_trait]
impl JobQueue for NoopJobQueue {
async fn enqueue(&self, _queue: &str, _job: Job) -> Result<JobId, BusError> {
Ok(JobId("noop-0".into()))
}
async fn claim(
&self,
_queue: &str,
_worker_id: &str,
_ttl: Duration,
) -> Result<Option<ClaimedJob>, BusError> {
Ok(None)
}
}
pub type NoopBusHandles = (
Arc<dyn Pubsub>,
Arc<dyn RequestReply>,
Arc<dyn KvStore>,
Arc<dyn JobQueue>,
);
pub fn noop_bus() -> NoopBusHandles {
(
Arc::new(NoopPubsub),
Arc::new(NoopRequestReply),
Arc::new(NoopKv),
Arc::new(NoopJobQueue),
)
}
pub struct TestContext {
llm: Arc<FakeLlmClient>,
tools: Arc<dyn ToolInvoker>,
short_term: Arc<dyn ShortTermMemory>,
long_term: Arc<dyn LongTermMemory>,
episodic: Arc<dyn EpisodicMemory>,
pubsub: Arc<dyn Pubsub>,
request_reply: Arc<dyn RequestReply>,
kv: Arc<dyn KvStore>,
jobs: Arc<dyn JobQueue>,
cancel: CancellationToken,
run_id: RunId,
agent_name: String,
system_prompt: String,
run_options: RunOptions,
seeded_history: Vec<Message>,
pending_tools: Vec<(String, String, ToolHandler)>,
}
impl Default for TestContext {
fn default() -> Self {
let (pubsub, request_reply, kv, jobs) = noop_bus();
Self {
llm: Arc::new(FakeLlmClient::new("test")),
tools: Arc::new(FakeToolInvoker::new()),
short_term: Arc::new(InMemoryShortTerm::default()),
long_term: Arc::new(InMemoryLongTerm::default()),
episodic: Arc::new(InMemoryEpisodic::default()),
pubsub,
request_reply,
kv,
jobs,
cancel: CancellationToken::new(),
run_id: RunId::new(),
agent_name: "test-agent".into(),
system_prompt: String::new(),
run_options: RunOptions::default(),
seeded_history: Vec::new(),
pending_tools: Vec::new(),
}
}
}
impl TestContext {
pub fn with_canned_llm_responses(mut self, steps: Vec<FakeLlmStep>) -> Self {
self.llm = Arc::new(FakeLlmClient::new(self.llm.name().to_string()).with_steps(steps));
self
}
pub fn with_canned_stream_responses(mut self, steps: Vec<FakeStreamStep>) -> Self {
self.llm =
Arc::new(FakeLlmClient::new(self.llm.name().to_string()).with_stream_steps(steps));
self
}
pub fn with_tool<F>(mut self, name: &str, description: &str, f: F) -> Self
where
F: Fn(serde_json::Value) -> Result<serde_json::Value, ToolError> + Send + Sync + 'static,
{
self.pending_tools
.push((name.into(), description.into(), Arc::new(f)));
self
}
pub fn with_short_term_history(mut self, messages: Vec<Message>) -> Self {
self.seeded_history = messages;
self
}
pub fn with_run_options(mut self, opts: RunOptions) -> Self {
self.run_options = opts;
self
}
pub fn with_system_prompt(mut self, prompt: &str) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_agent_name(mut self, name: &str) -> Self {
self.agent_name = name.into();
self
}
pub fn cancel_token(&self) -> &CancellationToken {
&self.cancel
}
pub fn run_id(&self) -> RunId {
self.run_id
}
pub fn llm(&self) -> &FakeLlmClient {
&self.llm
}
pub async fn recorded_episodes(&self) -> Vec<Episode> {
self.episodic.replay(self.run_id).await.unwrap_or_default()
}
async fn materialize(&mut self) -> Result<(), crate::Error> {
if !self.pending_tools.is_empty() {
let mut inv = FakeToolInvoker::new();
for (name, desc, handler) in self.pending_tools.drain(..) {
inv = inv.with_tool(&name, &desc, move |args| handler(args));
}
self.tools = Arc::new(inv);
}
if !self.seeded_history.is_empty() {
let thread = ThreadId::new("test");
for msg in self.seeded_history.drain(..) {
self.short_term.append(thread.clone(), msg).await?;
}
}
Ok(())
}
fn build_agent_ctx(&self) -> AgentContext {
AgentContext {
llm: self.llm.clone() as Arc<dyn LlmClient>,
short_term: self.short_term.clone(),
long_term: self.long_term.clone(),
episodic: self.episodic.clone(),
pubsub: self.pubsub.clone(),
kv: self.kv.clone(),
request_reply: self.request_reply.clone(),
jobs: self.jobs.clone(),
tools: self.tools.clone(),
run_id: self.run_id,
cancel: self.cancel.clone(),
agent_name: self.agent_name.clone(),
}
}
pub async fn run<A>(&mut self, agent: &A, input: A::Input) -> Result<A::Output, A::Error>
where
A: crate::Agent<Error = crate::Error>,
{
self.materialize().await?;
let ctx = self.build_agent_ctx();
agent.run(ctx, input).await
}
pub async fn run_steps_directly(&mut self) -> Result<String, crate::Error> {
self.materialize().await?;
let ctx = self.build_agent_ctx();
let thread = ThreadId::new("test");
run_steps(&ctx, &self.system_prompt, thread, self.run_options.clone()).await
}
}