use crate::agent::Agent;
use crate::harness::Harness;
use crate::llm_models::LlmProviderType;
use crate::session::Session;
use crate::tool_types::{ToolCall, ToolDefinition, ToolResult};
use crate::traits::ModelWithProvider;
use crate::typed_id::{AgentId, EventId, HarnessId, MessageId, ModelId, SessionId};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::error::Result;
use crate::message::Message;
use crate::message_filter::MessageQuery;
use crate::message_retriever::{InputMessage, MessageRetriever};
use crate::traits::{AgentStore, HarnessStore, LlmProviderStore, SessionStore, ToolExecutor};
use chrono::Utc;
#[derive(Debug, Default, Clone)]
pub struct InMemoryMessageRetriever {
messages: Arc<RwLock<HashMap<SessionId, Vec<Message>>>>,
}
impl InMemoryMessageRetriever {
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn sessions(&self) -> Vec<SessionId> {
self.messages.read().await.keys().copied().collect()
}
pub async fn clear(&self) {
self.messages.write().await.clear();
}
pub async fn clear_session(&self, session_id: SessionId) {
self.messages.write().await.remove(&session_id);
}
pub async fn seed(&self, session_id: SessionId, messages: Vec<Message>) {
self.messages.write().await.insert(session_id, messages);
}
pub async fn add(&self, session_id: SessionId, input: InputMessage) -> Result<Message> {
let message = Message {
id: MessageId::new(),
role: input.role,
content: input.content,
phase: None,
thinking: None, thinking_signature: None,
controls: input.controls,
metadata: input.metadata,
external_actor: None,
created_at: Utc::now(),
};
self.messages
.write()
.await
.entry(session_id)
.or_default()
.push(message.clone());
Ok(message)
}
pub async fn store(&self, session_id: SessionId, message: Message) -> Result<()> {
self.messages
.write()
.await
.entry(session_id)
.or_default()
.push(message);
Ok(())
}
}
#[async_trait]
impl MessageRetriever for InMemoryMessageRetriever {
async fn get(&self, session_id: SessionId, message_id: MessageId) -> Result<Option<Message>> {
Ok(self
.messages
.read()
.await
.get(&session_id)
.and_then(|messages| messages.iter().find(|m| m.id == message_id).cloned()))
}
async fn load(&self, session_id: SessionId) -> Result<Vec<Message>> {
Ok(self
.messages
.read()
.await
.get(&session_id)
.cloned()
.unwrap_or_default())
}
async fn load_filtered(&self, query: MessageQuery) -> Result<Vec<Message>> {
use crate::message_filter::MessageFilter;
let mut messages = self.load(query.session_id).await?;
for filter in &query.filters {
match filter {
MessageFilter::TimeRange { from, to } => {
messages.retain(|m| {
let after_from = from.is_none_or(|t| m.created_at >= t);
let before_to = to.is_none_or(|t| m.created_at <= t);
after_from && before_to
});
}
MessageFilter::Search(q) => {
let q_lower = q.to_lowercase();
messages.retain(|m| {
m.text()
.is_some_and(|t| t.to_lowercase().contains(&q_lower))
});
}
MessageFilter::Custom(predicate) => {
messages.retain(|m| predicate(m));
}
_ => {}
}
}
query.apply_windowing(&mut messages);
if query.has_injections() {
query.apply_injections(&mut messages);
}
Ok(messages)
}
async fn count(&self, session_id: SessionId) -> Result<usize> {
Ok(self
.messages
.read()
.await
.get(&session_id)
.map(|m| m.len())
.unwrap_or(0))
}
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryAgentStore {
agents: Arc<RwLock<HashMap<AgentId, Agent>>>,
}
impl InMemoryAgentStore {
pub fn new() -> Self {
Self {
agents: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_agent(&self, agent: Agent) {
self.agents.write().await.insert(agent.public_id, agent);
}
pub async fn agent_ids(&self) -> Vec<AgentId> {
self.agents.read().await.keys().copied().collect()
}
pub async fn clear(&self) {
self.agents.write().await.clear();
}
}
#[async_trait]
impl AgentStore for InMemoryAgentStore {
async fn get_agent(&self, agent_id: AgentId) -> Result<Option<Agent>> {
Ok(self.agents.read().await.get(&agent_id).cloned())
}
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryHarnessStore {
harnesses: Arc<RwLock<HashMap<HarnessId, Harness>>>,
}
impl InMemoryHarnessStore {
pub fn new() -> Self {
Self {
harnesses: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_harness(&self, harness: Harness) {
self.harnesses.write().await.insert(harness.id, harness);
}
}
#[async_trait]
impl HarnessStore for InMemoryHarnessStore {
async fn get_harness_chain(&self, harness_id: HarnessId) -> Result<Vec<Harness>> {
Ok(self
.harnesses
.read()
.await
.get(&harness_id)
.cloned()
.into_iter()
.collect())
}
}
#[derive(Debug, Default, Clone)]
pub struct InMemorySessionStore {
sessions: Arc<RwLock<HashMap<SessionId, Session>>>,
}
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_session(&self, session: Session) {
self.sessions.write().await.insert(session.id, session);
}
pub async fn session_ids(&self) -> Vec<SessionId> {
self.sessions.read().await.keys().copied().collect()
}
pub async fn clear(&self) {
self.sessions.write().await.clear();
}
}
#[async_trait]
impl SessionStore for InMemorySessionStore {
async fn get_session(&self, session_id: SessionId) -> Result<Option<Session>> {
Ok(self.sessions.read().await.get(&session_id).cloned())
}
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryLlmProviderStore {
models: Arc<RwLock<HashMap<ModelId, ModelWithProvider>>>,
default_model: Arc<RwLock<Option<ModelWithProvider>>>,
}
impl InMemoryLlmProviderStore {
pub fn new() -> Self {
Self {
models: Arc::new(RwLock::new(HashMap::new())),
default_model: Arc::new(RwLock::new(None)),
}
}
pub async fn from_env() -> Self {
let store = Self::new();
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
let model = ModelWithProvider {
model: "gpt-5.4".to_string(),
provider_type: LlmProviderType::Openai,
api_key: Some(api_key),
base_url: std::env::var("OPENAI_BASE_URL").ok(),
};
store.set_default_model(model).await;
} else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
let model = ModelWithProvider {
model: "claude-sonnet-4-20250514".to_string(),
provider_type: LlmProviderType::Anthropic,
api_key: Some(api_key),
base_url: std::env::var("ANTHROPIC_BASE_URL").ok(),
};
store.set_default_model(model).await;
}
store
}
pub async fn with_default(model: ModelWithProvider) -> Self {
let store = Self::new();
store.set_default_model(model).await;
store
}
pub async fn add_model(&self, model_id: ModelId, model: ModelWithProvider) {
self.models.write().await.insert(model_id, model);
}
pub async fn set_default_model(&self, model: ModelWithProvider) {
*self.default_model.write().await = Some(model);
}
pub async fn clear(&self) {
self.models.write().await.clear();
*self.default_model.write().await = None;
}
}
#[async_trait]
impl LlmProviderStore for InMemoryLlmProviderStore {
async fn get_model_with_provider(
&self,
model_id: ModelId,
) -> Result<Option<ModelWithProvider>> {
Ok(self.models.read().await.get(&model_id).cloned())
}
async fn get_default_model(&self) -> Result<Option<ModelWithProvider>> {
Ok(self.default_model.read().await.clone())
}
}
#[derive(Debug, Default)]
pub struct MockToolExecutor {
results: Arc<RwLock<HashMap<String, serde_json::Value>>>,
call_log: Arc<RwLock<Vec<ToolCall>>>,
}
impl MockToolExecutor {
pub fn new() -> Self {
Self {
results: Arc::new(RwLock::new(HashMap::new())),
call_log: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn set_result(&self, tool_name: impl Into<String>, result: serde_json::Value) {
self.results.write().await.insert(tool_name.into(), result);
}
pub async fn calls(&self) -> Vec<ToolCall> {
self.call_log.read().await.clone()
}
pub async fn clear_calls(&self) {
self.call_log.write().await.clear();
}
}
#[async_trait]
impl ToolExecutor for MockToolExecutor {
async fn execute(
&self,
tool_call: &ToolCall,
_tool_def: &ToolDefinition,
) -> Result<ToolResult> {
self.call_log.write().await.push(tool_call.clone());
let result = self
.results
.read()
.await
.get(&tool_call.name)
.cloned()
.unwrap_or_else(|| serde_json::json!({"status": "ok"}));
Ok(ToolResult {
tool_call_id: tool_call.id.clone(),
result: Some(result),
images: None,
error: None,
connection_required: None,
raw_output: None,
})
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct EchoToolExecutor;
impl EchoToolExecutor {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl ToolExecutor for EchoToolExecutor {
async fn execute(
&self,
tool_call: &ToolCall,
_tool_def: &ToolDefinition,
) -> Result<ToolResult> {
Ok(ToolResult {
tool_call_id: tool_call.id.clone(),
result: Some(serde_json::json!({
"echoed_tool": tool_call.name,
"echoed_arguments": tool_call.arguments
})),
images: None,
error: None,
connection_required: None,
raw_output: None,
})
}
}
#[derive(Debug, Clone)]
pub struct FailingToolExecutor {
error_message: String,
}
impl FailingToolExecutor {
pub fn new(error_message: impl Into<String>) -> Self {
Self {
error_message: error_message.into(),
}
}
}
impl Default for FailingToolExecutor {
fn default() -> Self {
Self::new("Tool execution failed")
}
}
#[async_trait]
impl ToolExecutor for FailingToolExecutor {
async fn execute(
&self,
tool_call: &ToolCall,
_tool_def: &ToolDefinition,
) -> Result<ToolResult> {
Ok(ToolResult {
tool_call_id: tool_call.id.clone(),
result: None,
images: None,
error: Some(self.error_message.clone()),
connection_required: None,
raw_output: None,
})
}
}
use crate::events::{Event, EventRequest};
use crate::llm_driver_registry::{
LlmCallConfig, LlmDriver, LlmMessage, LlmResponseStream, LlmStreamEvent,
};
use crate::traits::EventEmitter;
use futures::stream;
#[derive(Debug, Default)]
pub struct MockLlmProvider {
responses: Arc<RwLock<Vec<MockLlmResponse>>>,
call_index: Arc<RwLock<usize>>,
call_log: Arc<RwLock<Vec<Vec<LlmMessage>>>>,
}
#[derive(Debug, Clone)]
pub struct MockLlmResponse {
pub text: String,
pub tool_calls: Option<Vec<ToolCall>>,
}
impl MockLlmResponse {
pub fn text(text: impl Into<String>) -> Self {
Self {
text: text.into(),
tool_calls: None,
}
}
pub fn with_tools(text: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
text: text.into(),
tool_calls: Some(tool_calls),
}
}
}
impl MockLlmProvider {
pub fn new() -> Self {
Self {
responses: Arc::new(RwLock::new(Vec::new())),
call_index: Arc::new(RwLock::new(0)),
call_log: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn add_response(&self, response: MockLlmResponse) {
self.responses.write().await.push(response);
}
pub async fn set_responses(&self, responses: Vec<MockLlmResponse>) {
*self.responses.write().await = responses;
*self.call_index.write().await = 0;
}
pub async fn calls(&self) -> Vec<Vec<LlmMessage>> {
self.call_log.read().await.clone()
}
pub async fn reset(&self) {
self.responses.write().await.clear();
*self.call_index.write().await = 0;
self.call_log.write().await.clear();
}
}
#[async_trait]
impl LlmDriver for MockLlmProvider {
async fn chat_completion_stream(
&self,
messages: Vec<LlmMessage>,
_config: &LlmCallConfig,
) -> Result<LlmResponseStream> {
self.call_log.write().await.push(messages);
let mut index = self.call_index.write().await;
let responses = self.responses.read().await;
let response = responses.get(*index).cloned().unwrap_or_else(|| {
MockLlmResponse::text("Mock response (no more responses configured)")
});
*index += 1;
drop(index);
drop(responses);
let events = vec![
Ok(LlmStreamEvent::TextDelta(response.text.clone())),
if let Some(tool_calls) = response.tool_calls {
Ok(LlmStreamEvent::ToolCalls(tool_calls))
} else {
Ok(LlmStreamEvent::Done(Box::default()))
},
Ok(LlmStreamEvent::Done(Box::default())),
];
Ok(Box::pin(stream::iter(events)))
}
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryEventEmitter {
events: Arc<RwLock<Vec<Event>>>,
sequence: Arc<RwLock<i32>>,
}
impl InMemoryEventEmitter {
pub fn new() -> Self {
Self {
events: Arc::new(RwLock::new(Vec::new())),
sequence: Arc::new(RwLock::new(0)),
}
}
pub async fn events(&self) -> Vec<Event> {
self.events.read().await.clone()
}
pub async fn event_count(&self) -> usize {
self.events.read().await.len()
}
pub async fn clear(&self) {
self.events.write().await.clear();
*self.sequence.write().await = 0;
}
pub async fn events_by_type(&self, event_type: &str) -> Vec<Event> {
self.events
.read()
.await
.iter()
.filter(|e| e.event_type == event_type)
.cloned()
.collect()
}
pub async fn events_for_session(&self, session_id: Uuid) -> Vec<Event> {
self.events
.read()
.await
.iter()
.filter(|e| e.session_uuid() == session_id)
.cloned()
.collect()
}
}
#[async_trait]
impl EventEmitter for InMemoryEventEmitter {
async fn emit(&self, request: EventRequest) -> Result<Event> {
let mut sequence = self.sequence.write().await;
*sequence += 1;
let seq = *sequence;
drop(sequence);
let event = request.into_event(EventId::new(), seq);
self.events.write().await.push(event.clone());
Ok(event)
}
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[tokio::test]
async fn test_in_memory_message_retriever() {
let store = InMemoryMessageRetriever::new();
let session_id: SessionId = Uuid::now_v7().into();
store
.store(session_id, Message::user("Hello"))
.await
.unwrap();
let messages = store.load(session_id).await.unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].text(), Some("Hello"));
}
#[tokio::test]
async fn test_in_memory_message_retriever_add_and_get() {
let store = InMemoryMessageRetriever::new();
let session_id: SessionId = Uuid::now_v7().into();
let message = store
.add(session_id, InputMessage::user("Hello via add"))
.await
.unwrap();
let retrieved = store.get(session_id, message.id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().text(), Some("Hello via add"));
let missing = store.get(session_id, MessageId::new()).await.unwrap();
assert!(missing.is_none());
}
#[tokio::test]
async fn test_message_retriever_add_returns_consistent_id() {
let store = InMemoryMessageRetriever::new();
let session_id: SessionId = Uuid::now_v7().into();
let added = store
.add(session_id, InputMessage::user("Test consistency"))
.await
.unwrap();
let retrieved = store.get(session_id, added.id).await.unwrap();
assert!(
retrieved.is_some(),
"Message must be retrievable by the ID returned from add()"
);
let retrieved = retrieved.unwrap();
assert_eq!(
retrieved.id, added.id,
"Retrieved message ID must match the ID returned from add()"
);
let all_messages = store.load(session_id).await.unwrap();
let found = all_messages.iter().find(|m| m.id == added.id);
assert!(
found.is_some(),
"Message with returned ID must appear in load() results"
);
}
#[tokio::test]
async fn test_mock_tool_executor() {
let executor = MockToolExecutor::new();
executor
.set_result("get_weather", serde_json::json!({"temp": 72}))
.await;
let tool_call = ToolCall {
id: "call_1".to_string(),
name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "NYC"}),
};
let tool_def = ToolDefinition::Builtin(crate::tool_types::BuiltinTool {
name: "get_weather".to_string(),
display_name: None,
description: "Get weather".to_string(),
parameters: serde_json::json!({}),
policy: crate::tool_types::ToolPolicy::Auto,
category: None,
deferrable: crate::tool_types::DeferrablePolicy::default(),
hints: crate::tool_types::ToolHints::default(),
full_parameters: None,
});
let result = executor.execute(&tool_call, &tool_def).await.unwrap();
assert!(result.error.is_none());
assert_eq!(result.result, Some(serde_json::json!({"temp": 72})));
}
#[tokio::test]
async fn test_in_memory_event_emitter() {
use crate::events::{EventContext, EventRequest, InputMessageData};
let emitter = InMemoryEventEmitter::new();
let session_id: SessionId = Uuid::now_v7().into();
let event_context = EventContext::empty();
let event1 = emitter
.emit(EventRequest::new(
session_id,
event_context.clone(),
InputMessageData::new(Message::user("test1")),
))
.await
.unwrap();
assert_eq!(event1.sequence, Some(1));
let event2 = emitter
.emit(EventRequest::new(
session_id,
event_context,
InputMessageData::new(Message::user("test2")),
))
.await
.unwrap();
assert_eq!(event2.sequence, Some(2));
let events = emitter.events().await;
assert_eq!(events.len(), 2);
assert_eq!(emitter.event_count().await, 2);
}
#[tokio::test]
async fn test_in_memory_event_emitter_filter_by_type() {
use crate::events::{
EventContext, EventRequest, INPUT_MESSAGE, InputMessageData, REASON_STARTED,
ReasonStartedData,
};
let emitter = InMemoryEventEmitter::new();
let session_id: SessionId = Uuid::now_v7().into();
let event_context = EventContext::empty();
emitter
.emit(EventRequest::new(
session_id,
event_context.clone(),
InputMessageData::new(Message::user("test")),
))
.await
.unwrap();
emitter
.emit(EventRequest::new(
session_id,
event_context,
ReasonStartedData {
harness_id: HarnessId::from_seed(1),
agent_id: Some(AgentId::new()),
metadata: None,
},
))
.await
.unwrap();
let received_events = emitter.events_by_type(INPUT_MESSAGE).await;
assert_eq!(received_events.len(), 1);
let started_events = emitter.events_by_type(REASON_STARTED).await;
assert_eq!(started_events.len(), 1);
}
#[tokio::test]
async fn test_in_memory_event_emitter_filter_by_session() {
use crate::events::{EventContext, EventRequest, InputMessageData};
let emitter = InMemoryEventEmitter::new();
let session1: SessionId = Uuid::now_v7().into();
let session2: SessionId = Uuid::now_v7().into();
let context = EventContext::empty();
emitter
.emit(EventRequest::new(
session1,
context.clone(),
InputMessageData::new(Message::user("session1")),
))
.await
.unwrap();
emitter
.emit(EventRequest::new(
session2,
context,
InputMessageData::new(Message::user("session2")),
))
.await
.unwrap();
let session1_events = emitter.events_for_session(session1.uuid()).await;
assert_eq!(session1_events.len(), 1);
let session2_events = emitter.events_for_session(session2.uuid()).await;
assert_eq!(session2_events.len(), 1);
}
#[tokio::test]
async fn test_in_memory_event_emitter_clear() {
use crate::events::{EventContext, EventRequest, InputMessageData};
let emitter = InMemoryEventEmitter::new();
let session_id: SessionId = Uuid::now_v7().into();
let event_context = EventContext::empty();
emitter
.emit(EventRequest::new(
session_id,
event_context,
InputMessageData::new(Message::user("test")),
))
.await
.unwrap();
assert_eq!(emitter.event_count().await, 1);
emitter.clear().await;
assert_eq!(emitter.event_count().await, 0);
}
}
use crate::memory_store::{
Memory, MemoryContentPart, MemoryKind, MemoryQuery, MemoryStoreBackend, MemoryStoreEntity,
};
use crate::typed_id::{MemoryId, MemoryStoreId, OrgId};
#[derive(Debug, Default, Clone)]
pub struct InMemoryMemoryStore {
stores: Arc<RwLock<Vec<MemoryStoreEntity>>>,
memories: Arc<RwLock<Vec<Memory>>>,
}
impl InMemoryMemoryStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl MemoryStoreBackend for InMemoryMemoryStore {
async fn get_or_create_default_store(&self, org_id: OrgId) -> Result<MemoryStoreEntity> {
let mut stores = self.stores.write().await;
if let Some(store) = stores.iter().find(|s| s.org_id == org_id && s.is_default) {
return Ok(store.clone());
}
let store = MemoryStoreEntity {
id: MemoryStoreId::new(),
org_id,
name: "default".to_string(),
is_default: true,
created_at: chrono::Utc::now(),
};
stores.push(store.clone());
Ok(store)
}
async fn get_store(&self, store_id: MemoryStoreId) -> Result<Option<MemoryStoreEntity>> {
Ok(self
.stores
.read()
.await
.iter()
.find(|s| s.id == store_id)
.cloned())
}
async fn create_memory(
&self,
store_id: MemoryStoreId,
content: String,
content_parts: Vec<MemoryContentPart>,
kind: MemoryKind,
importance: u8,
tags: Vec<String>,
) -> Result<Memory> {
let now = chrono::Utc::now();
let memory = Memory {
id: MemoryId::new(),
store_id,
content,
content_parts,
kind,
importance: importance.clamp(1, 10),
tags,
active: true,
created_at: now,
updated_at: now,
};
self.memories.write().await.push(memory.clone());
Ok(memory)
}
async fn recall(&self, query: MemoryQuery) -> Result<(Vec<Memory>, usize)> {
let memories = self.memories.read().await;
let mut results: Vec<&Memory> = memories
.iter()
.filter(|m| m.active)
.filter(|m| {
if let Some(ref sid) = query.store_id {
m.store_id == *sid
} else {
true
}
})
.filter(|m| {
if let Some(ref kind) = query.kind {
m.kind == *kind
} else {
true
}
})
.filter(|m| {
if let Some(ref tags) = query.tags {
tags.iter().all(|t| m.tags.contains(t))
} else {
true
}
})
.filter(|m| {
if let Some(ref q) = query.query {
let q_lower = q.to_lowercase();
m.content.to_lowercase().contains(&q_lower)
|| m.tags.iter().any(|t| t.to_lowercase().contains(&q_lower))
} else {
true
}
})
.collect();
results.sort_by(|a, b| {
b.importance
.cmp(&a.importance)
.then_with(|| b.created_at.cmp(&a.created_at))
});
let total = results.len();
let limit = if query.limit > 0 { query.limit } else { 10 };
let results: Vec<Memory> = results.into_iter().take(limit).cloned().collect();
Ok((results, total))
}
async fn forget(&self, store_id: MemoryStoreId, memory_id: MemoryId) -> Result<bool> {
let mut memories = self.memories.write().await;
if let Some(m) = memories
.iter_mut()
.find(|m| m.id == memory_id && m.store_id == store_id && m.active)
{
m.active = false;
m.updated_at = chrono::Utc::now();
Ok(true)
} else {
Ok(false)
}
}
async fn count_active(&self, store_id: MemoryStoreId) -> Result<usize> {
Ok(self
.memories
.read()
.await
.iter()
.filter(|m| m.store_id == store_id && m.active)
.count())
}
}