1use crate::agent::Agent;
9use crate::credential_provider::CredentialProvider;
10use crate::harness::Harness;
11use crate::provider::DriverId;
12use crate::session::Session;
13use crate::tool_types::{ToolCall, ToolDefinition, ToolResult};
14use crate::traits::ResolvedModel;
15use crate::typed_id::{AgentId, EventId, HarnessId, MessageId, ModelId, SessionId};
16use async_trait::async_trait;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20use uuid::Uuid;
21
22use crate::error::Result;
23use crate::message::Message;
24use crate::message_filter::MessageQuery;
25use crate::message_retriever::{InputMessage, MessageRetriever};
26use crate::traits::{AgentStore, HarnessStore, ProviderStore, SessionStore, ToolExecutor};
27use chrono::Utc;
28
29#[derive(Debug, Default, Clone)]
41pub struct InMemoryMessageRetriever {
42 messages: Arc<RwLock<HashMap<SessionId, Vec<Message>>>>,
43}
44
45impl InMemoryMessageRetriever {
46 pub fn new() -> Self {
48 Self {
49 messages: Arc::new(RwLock::new(HashMap::new())),
50 }
51 }
52
53 pub async fn sessions(&self) -> Vec<SessionId> {
55 self.messages.read().await.keys().copied().collect()
56 }
57
58 pub async fn clear(&self) {
60 self.messages.write().await.clear();
61 }
62
63 pub async fn clear_session(&self, session_id: SessionId) {
65 self.messages.write().await.remove(&session_id);
66 }
67
68 pub async fn seed(&self, session_id: SessionId, messages: Vec<Message>) {
70 self.messages.write().await.insert(session_id, messages);
71 }
72
73 pub async fn add(&self, session_id: SessionId, input: InputMessage) -> Result<Message> {
78 let message = Message {
79 id: MessageId::new(),
80 role: input.role,
81 content: input.content,
82 phase: None,
83 thinking: None, thinking_signature: None,
85 controls: input.controls,
86 metadata: input.metadata,
87 external_actor: None,
88 created_at: Utc::now(),
89 };
90
91 self.messages
92 .write()
93 .await
94 .entry(session_id)
95 .or_default()
96 .push(message.clone());
97
98 Ok(message)
99 }
100
101 pub async fn store(&self, session_id: SessionId, message: Message) -> Result<()> {
106 self.messages
107 .write()
108 .await
109 .entry(session_id)
110 .or_default()
111 .push(message);
112 Ok(())
113 }
114}
115
116#[async_trait]
117impl MessageRetriever for InMemoryMessageRetriever {
118 async fn get(&self, session_id: SessionId, message_id: MessageId) -> Result<Option<Message>> {
119 Ok(self
120 .messages
121 .read()
122 .await
123 .get(&session_id)
124 .and_then(|messages| messages.iter().find(|m| m.id == message_id).cloned()))
125 }
126
127 async fn load(&self, session_id: SessionId) -> Result<Vec<Message>> {
128 Ok(self
129 .messages
130 .read()
131 .await
132 .get(&session_id)
133 .cloned()
134 .unwrap_or_default())
135 }
136
137 async fn load_filtered(&self, query: MessageQuery) -> Result<Vec<Message>> {
138 use crate::message_filter::MessageFilter;
139
140 let mut messages = self.load(query.session_id).await?;
141
142 for filter in &query.filters {
144 match filter {
145 MessageFilter::TimeRange { from, to } => {
146 messages.retain(|m| {
147 let after_from = from.is_none_or(|t| m.created_at >= t);
148 let before_to = to.is_none_or(|t| m.created_at <= t);
149 after_from && before_to
150 });
151 }
152 MessageFilter::Search(q) => {
153 let q_lower = q.to_lowercase();
154 messages.retain(|m| {
155 m.text()
156 .is_some_and(|t| t.to_lowercase().contains(&q_lower))
157 });
158 }
159 MessageFilter::Custom(predicate) => {
160 messages.retain(|m| predicate(m));
161 }
162 _ => {}
164 }
165 }
166
167 query.apply_windowing(&mut messages);
168
169 if query.has_injections() {
171 query.apply_injections(&mut messages);
172 }
173
174 Ok(messages)
175 }
176
177 async fn count(&self, session_id: SessionId) -> Result<usize> {
178 Ok(self
179 .messages
180 .read()
181 .await
182 .get(&session_id)
183 .map(|m| m.len())
184 .unwrap_or(0))
185 }
186}
187
188#[derive(Debug, Default, Clone)]
197pub struct InMemoryAgentStore {
198 agents: Arc<RwLock<HashMap<AgentId, Agent>>>,
199}
200
201impl InMemoryAgentStore {
202 pub fn new() -> Self {
204 Self {
205 agents: Arc::new(RwLock::new(HashMap::new())),
206 }
207 }
208
209 pub async fn add_agent(&self, agent: Agent) {
211 self.agents.write().await.insert(agent.public_id, agent);
212 }
213
214 pub async fn agent_ids(&self) -> Vec<AgentId> {
216 self.agents.read().await.keys().copied().collect()
217 }
218
219 pub async fn clear(&self) {
221 self.agents.write().await.clear();
222 }
223}
224
225#[async_trait]
226impl AgentStore for InMemoryAgentStore {
227 async fn get_agent(&self, agent_id: AgentId) -> Result<Option<Agent>> {
228 Ok(self.agents.read().await.get(&agent_id).cloned())
229 }
230}
231
232#[derive(Debug, Default, Clone)]
241pub struct InMemoryHarnessStore {
242 harnesses: Arc<RwLock<HashMap<HarnessId, Harness>>>,
243}
244
245impl InMemoryHarnessStore {
246 pub fn new() -> Self {
248 Self {
249 harnesses: Arc::new(RwLock::new(HashMap::new())),
250 }
251 }
252
253 pub async fn add_harness(&self, harness: Harness) {
255 self.harnesses.write().await.insert(harness.id, harness);
256 }
257}
258
259#[async_trait]
260impl HarnessStore for InMemoryHarnessStore {
261 async fn get_harness_chain(&self, harness_id: HarnessId) -> Result<Vec<Harness>> {
262 Ok(self
263 .harnesses
264 .read()
265 .await
266 .get(&harness_id)
267 .cloned()
268 .into_iter()
269 .collect())
270 }
271}
272
273#[derive(Debug, Default, Clone)]
282pub struct InMemorySessionStore {
283 sessions: Arc<RwLock<HashMap<SessionId, Session>>>,
284}
285
286impl InMemorySessionStore {
287 pub fn new() -> Self {
289 Self {
290 sessions: Arc::new(RwLock::new(HashMap::new())),
291 }
292 }
293
294 pub async fn add_session(&self, session: Session) {
296 self.sessions.write().await.insert(session.id, session);
297 }
298
299 pub async fn session_ids(&self) -> Vec<SessionId> {
301 self.sessions.read().await.keys().copied().collect()
302 }
303
304 pub async fn clear(&self) {
306 self.sessions.write().await.clear();
307 }
308}
309
310#[async_trait]
311impl SessionStore for InMemorySessionStore {
312 async fn get_session(&self, session_id: SessionId) -> Result<Option<Session>> {
313 Ok(self.sessions.read().await.get(&session_id).cloned())
314 }
315}
316
317#[derive(Debug, Default, Clone)]
336pub struct InMemoryProviderStore {
337 models: Arc<RwLock<HashMap<ModelId, ResolvedModel>>>,
338 default_model: Arc<RwLock<Option<ResolvedModel>>>,
339}
340
341impl InMemoryProviderStore {
342 pub fn new() -> Self {
344 Self {
345 models: Arc::new(RwLock::new(HashMap::new())),
346 default_model: Arc::new(RwLock::new(None)),
347 }
348 }
349
350 pub async fn from_credential_provider(provider: &dyn CredentialProvider) -> Self {
358 let store = Self::new();
359
360 if let Some(creds) = provider
362 .resolve(&DriverId::OpenAI)
363 .filter(|c| c.api_key.is_some())
364 {
365 store
366 .set_default_model(ResolvedModel {
367 model: "gpt-5.4".to_string(),
368 provider_type: DriverId::OpenAI,
369 api_key: creds.api_key,
370 base_url: creds.base_url,
371 provider_metadata: None,
372 })
373 .await;
374 } else if let Some(creds) = provider
375 .resolve(&DriverId::Anthropic)
376 .filter(|c| c.api_key.is_some())
377 {
378 store
379 .set_default_model(ResolvedModel {
380 model: "claude-sonnet-4-20250514".to_string(),
381 provider_type: DriverId::Anthropic,
382 api_key: creds.api_key,
383 base_url: creds.base_url,
384 provider_metadata: None,
385 })
386 .await;
387 }
388
389 store
390 }
391
392 pub async fn with_default(model: ResolvedModel) -> Self {
394 let store = Self::new();
395 store.set_default_model(model).await;
396 store
397 }
398
399 pub async fn add_model(&self, model_id: ModelId, model: ResolvedModel) {
401 self.models.write().await.insert(model_id, model);
402 }
403
404 pub async fn set_default_model(&self, model: ResolvedModel) {
406 *self.default_model.write().await = Some(model);
407 }
408
409 pub async fn clear(&self) {
411 self.models.write().await.clear();
412 *self.default_model.write().await = None;
413 }
414}
415
416#[async_trait]
417impl ProviderStore for InMemoryProviderStore {
418 async fn get_resolved_model(&self, model_id: ModelId) -> Result<Option<ResolvedModel>> {
419 Ok(self.models.read().await.get(&model_id).cloned())
420 }
421
422 async fn get_default_model(&self) -> Result<Option<ResolvedModel>> {
423 Ok(self.default_model.read().await.clone())
424 }
425}
426
427#[derive(Debug, Default)]
435pub struct MockToolExecutor {
436 results: Arc<RwLock<HashMap<String, serde_json::Value>>>,
437 call_log: Arc<RwLock<Vec<ToolCall>>>,
438}
439
440impl MockToolExecutor {
441 pub fn new() -> Self {
443 Self {
444 results: Arc::new(RwLock::new(HashMap::new())),
445 call_log: Arc::new(RwLock::new(Vec::new())),
446 }
447 }
448
449 pub async fn set_result(&self, tool_name: impl Into<String>, result: serde_json::Value) {
451 self.results.write().await.insert(tool_name.into(), result);
452 }
453
454 pub async fn calls(&self) -> Vec<ToolCall> {
456 self.call_log.read().await.clone()
457 }
458
459 pub async fn clear_calls(&self) {
461 self.call_log.write().await.clear();
462 }
463}
464
465#[async_trait]
466impl ToolExecutor for MockToolExecutor {
467 async fn execute(
468 &self,
469 tool_call: &ToolCall,
470 _tool_def: &ToolDefinition,
471 ) -> Result<ToolResult> {
472 self.call_log.write().await.push(tool_call.clone());
474
475 let result = self
477 .results
478 .read()
479 .await
480 .get(&tool_call.name)
481 .cloned()
482 .unwrap_or_else(|| serde_json::json!({"status": "ok"}));
483
484 Ok(ToolResult {
485 tool_call_id: tool_call.id.clone(),
486 result: Some(result),
487 images: None,
488 error: None,
489 connection_required: None,
490 raw_output: None,
491 })
492 }
493}
494
495#[derive(Debug, Default, Clone, Copy)]
503pub struct EchoToolExecutor;
504
505impl EchoToolExecutor {
506 pub fn new() -> Self {
507 Self
508 }
509}
510
511#[async_trait]
512impl ToolExecutor for EchoToolExecutor {
513 async fn execute(
514 &self,
515 tool_call: &ToolCall,
516 _tool_def: &ToolDefinition,
517 ) -> Result<ToolResult> {
518 Ok(ToolResult {
519 tool_call_id: tool_call.id.clone(),
520 result: Some(serde_json::json!({
521 "echoed_tool": tool_call.name,
522 "echoed_arguments": tool_call.arguments
523 })),
524 images: None,
525 error: None,
526 connection_required: None,
527 raw_output: None,
528 })
529 }
530}
531
532#[derive(Debug, Clone)]
540pub struct FailingToolExecutor {
541 error_message: String,
542}
543
544impl FailingToolExecutor {
545 pub fn new(error_message: impl Into<String>) -> Self {
546 Self {
547 error_message: error_message.into(),
548 }
549 }
550}
551
552impl Default for FailingToolExecutor {
553 fn default() -> Self {
554 Self::new("Tool execution failed")
555 }
556}
557
558#[async_trait]
559impl ToolExecutor for FailingToolExecutor {
560 async fn execute(
561 &self,
562 tool_call: &ToolCall,
563 _tool_def: &ToolDefinition,
564 ) -> Result<ToolResult> {
565 Ok(ToolResult {
566 tool_call_id: tool_call.id.clone(),
567 result: None,
568 images: None,
569 error: Some(self.error_message.clone()),
570 connection_required: None,
571 raw_output: None,
572 })
573 }
574}
575
576use crate::driver_registry::{
581 ChatDriver, LlmCallConfig, LlmMessage, LlmResponseStream, LlmStreamEvent,
582};
583use crate::events::{Event, EventRequest};
584use crate::traits::EventEmitter;
585use futures::stream;
586
587#[derive(Debug, Default)]
591pub struct MockProvider {
592 responses: Arc<RwLock<Vec<MockLlmResponse>>>,
593 call_index: Arc<RwLock<usize>>,
594 call_log: Arc<RwLock<Vec<Vec<LlmMessage>>>>,
595}
596
597#[derive(Debug, Clone)]
599pub struct MockLlmResponse {
600 pub text: String,
601 pub tool_calls: Option<Vec<ToolCall>>,
602}
603
604impl MockLlmResponse {
605 pub fn text(text: impl Into<String>) -> Self {
607 Self {
608 text: text.into(),
609 tool_calls: None,
610 }
611 }
612
613 pub fn with_tools(text: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
615 Self {
616 text: text.into(),
617 tool_calls: Some(tool_calls),
618 }
619 }
620}
621
622impl MockProvider {
623 pub fn new() -> Self {
625 Self {
626 responses: Arc::new(RwLock::new(Vec::new())),
627 call_index: Arc::new(RwLock::new(0)),
628 call_log: Arc::new(RwLock::new(Vec::new())),
629 }
630 }
631
632 pub async fn add_response(&self, response: MockLlmResponse) {
634 self.responses.write().await.push(response);
635 }
636
637 pub async fn set_responses(&self, responses: Vec<MockLlmResponse>) {
639 *self.responses.write().await = responses;
640 *self.call_index.write().await = 0;
641 }
642
643 pub async fn calls(&self) -> Vec<Vec<LlmMessage>> {
645 self.call_log.read().await.clone()
646 }
647
648 pub async fn reset(&self) {
650 self.responses.write().await.clear();
651 *self.call_index.write().await = 0;
652 self.call_log.write().await.clear();
653 }
654}
655
656#[async_trait]
657impl ChatDriver for MockProvider {
658 async fn chat_completion_stream(
659 &self,
660 messages: Vec<LlmMessage>,
661 _config: &LlmCallConfig,
662 ) -> Result<LlmResponseStream> {
663 self.call_log.write().await.push(messages);
665
666 let mut index = self.call_index.write().await;
668 let responses = self.responses.read().await;
669
670 let response = responses.get(*index).cloned().unwrap_or_else(|| {
671 MockLlmResponse::text("Mock response (no more responses configured)")
672 });
673
674 *index += 1;
675 drop(index);
676 drop(responses);
677
678 let events = vec![
680 Ok(LlmStreamEvent::TextDelta(response.text.clone())),
681 if let Some(tool_calls) = response.tool_calls {
682 Ok(LlmStreamEvent::ToolCalls(tool_calls))
683 } else {
684 Ok(LlmStreamEvent::Done(Box::default()))
685 },
686 Ok(LlmStreamEvent::Done(Box::default())),
687 ];
688
689 Ok(Box::pin(stream::iter(events)))
690 }
691}
692
693#[derive(Debug, Default, Clone)]
716pub struct InMemoryEventEmitter {
717 events: Arc<RwLock<Vec<Event>>>,
718 sequence: Arc<RwLock<i32>>,
719}
720
721impl InMemoryEventEmitter {
722 pub fn new() -> Self {
724 Self {
725 events: Arc::new(RwLock::new(Vec::new())),
726 sequence: Arc::new(RwLock::new(0)),
727 }
728 }
729
730 pub async fn events(&self) -> Vec<Event> {
732 self.events.read().await.clone()
733 }
734
735 pub async fn event_count(&self) -> usize {
737 self.events.read().await.len()
738 }
739
740 pub async fn clear(&self) {
742 self.events.write().await.clear();
743 *self.sequence.write().await = 0;
744 }
745
746 pub async fn events_by_type(&self, event_type: &str) -> Vec<Event> {
748 self.events
749 .read()
750 .await
751 .iter()
752 .filter(|e| e.event_type == event_type)
753 .cloned()
754 .collect()
755 }
756
757 pub async fn events_for_session(&self, session_id: Uuid) -> Vec<Event> {
759 self.events
760 .read()
761 .await
762 .iter()
763 .filter(|e| e.session_uuid() == session_id)
764 .cloned()
765 .collect()
766 }
767}
768
769#[async_trait]
770impl EventEmitter for InMemoryEventEmitter {
771 async fn emit(&self, request: EventRequest) -> Result<Event> {
772 let mut sequence = self.sequence.write().await;
773 *sequence += 1;
774 let seq = *sequence;
775 drop(sequence);
776
777 let event = request.into_event(EventId::new(), seq);
779 self.events.write().await.push(event.clone());
780 Ok(event)
781 }
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787 use uuid::Uuid;
788
789 #[tokio::test]
790 async fn test_in_memory_message_retriever() {
791 let store = InMemoryMessageRetriever::new();
792 let session_id: SessionId = Uuid::now_v7().into();
793
794 store
795 .store(session_id, Message::user("Hello"))
796 .await
797 .unwrap();
798
799 let messages = store.load(session_id).await.unwrap();
800 assert_eq!(messages.len(), 1);
801 assert_eq!(messages[0].text(), Some("Hello"));
802 }
803
804 #[tokio::test]
805 async fn test_in_memory_message_retriever_add_and_get() {
806 let store = InMemoryMessageRetriever::new();
807 let session_id: SessionId = Uuid::now_v7().into();
808
809 let message = store
811 .add(session_id, InputMessage::user("Hello via add"))
812 .await
813 .unwrap();
814
815 let retrieved = store.get(session_id, message.id).await.unwrap();
817 assert!(retrieved.is_some());
818 assert_eq!(retrieved.unwrap().text(), Some("Hello via add"));
819
820 let missing = store.get(session_id, MessageId::new()).await.unwrap();
822 assert!(missing.is_none());
823 }
824
825 #[tokio::test]
830 async fn test_message_retriever_add_returns_consistent_id() {
831 let store = InMemoryMessageRetriever::new();
832 let session_id: SessionId = Uuid::now_v7().into();
833
834 let added = store
836 .add(session_id, InputMessage::user("Test consistency"))
837 .await
838 .unwrap();
839
840 let retrieved = store.get(session_id, added.id).await.unwrap();
842 assert!(
843 retrieved.is_some(),
844 "Message must be retrievable by the ID returned from add()"
845 );
846
847 let retrieved = retrieved.unwrap();
849 assert_eq!(
850 retrieved.id, added.id,
851 "Retrieved message ID must match the ID returned from add()"
852 );
853
854 let all_messages = store.load(session_id).await.unwrap();
856 let found = all_messages.iter().find(|m| m.id == added.id);
857 assert!(
858 found.is_some(),
859 "Message with returned ID must appear in load() results"
860 );
861 }
862
863 #[tokio::test]
864 async fn test_mock_tool_executor() {
865 let executor = MockToolExecutor::new();
866 executor
867 .set_result("get_weather", serde_json::json!({"temp": 72}))
868 .await;
869
870 let tool_call = ToolCall {
871 id: "call_1".to_string(),
872 name: "get_weather".to_string(),
873 arguments: serde_json::json!({"city": "NYC"}),
874 };
875
876 let tool_def = ToolDefinition::Builtin(crate::tool_types::BuiltinTool {
877 name: "get_weather".to_string(),
878 display_name: None,
879 description: "Get weather".to_string(),
880 parameters: serde_json::json!({}),
881 policy: crate::tool_types::ToolPolicy::Auto,
882 category: None,
883 deferrable: crate::tool_types::DeferrablePolicy::default(),
884 hints: crate::tool_types::ToolHints::default(),
885 full_parameters: None,
886 });
887
888 let result = executor.execute(&tool_call, &tool_def).await.unwrap();
889
890 assert!(result.error.is_none());
891 assert_eq!(result.result, Some(serde_json::json!({"temp": 72})));
892 }
893
894 #[tokio::test]
895 async fn test_in_memory_event_emitter() {
896 use crate::events::{EventContext, EventRequest, InputMessageData};
897
898 let emitter = InMemoryEventEmitter::new();
899 let session_id: SessionId = Uuid::now_v7().into();
900 let event_context = EventContext::empty();
901
902 let event1 = emitter
904 .emit(EventRequest::new(
905 session_id,
906 event_context.clone(),
907 InputMessageData::new(Message::user("test1")),
908 ))
909 .await
910 .unwrap();
911 assert_eq!(event1.sequence, Some(1));
912
913 let event2 = emitter
915 .emit(EventRequest::new(
916 session_id,
917 event_context,
918 InputMessageData::new(Message::user("test2")),
919 ))
920 .await
921 .unwrap();
922 assert_eq!(event2.sequence, Some(2));
923
924 let events = emitter.events().await;
926 assert_eq!(events.len(), 2);
927 assert_eq!(emitter.event_count().await, 2);
928 }
929
930 #[tokio::test]
931 async fn test_in_memory_event_emitter_filter_by_type() {
932 use crate::events::{
933 EventContext, EventRequest, INPUT_MESSAGE, InputMessageData, REASON_STARTED,
934 ReasonStartedData,
935 };
936
937 let emitter = InMemoryEventEmitter::new();
938 let session_id: SessionId = Uuid::now_v7().into();
939 let event_context = EventContext::empty();
940
941 emitter
943 .emit(EventRequest::new(
944 session_id,
945 event_context.clone(),
946 InputMessageData::new(Message::user("test")),
947 ))
948 .await
949 .unwrap();
950
951 emitter
952 .emit(EventRequest::new(
953 session_id,
954 event_context,
955 ReasonStartedData {
956 harness_id: HarnessId::from_seed(1),
957 agent_id: Some(AgentId::new()),
958 metadata: None,
959 },
960 ))
961 .await
962 .unwrap();
963
964 let received_events = emitter.events_by_type(INPUT_MESSAGE).await;
966 assert_eq!(received_events.len(), 1);
967
968 let started_events = emitter.events_by_type(REASON_STARTED).await;
969 assert_eq!(started_events.len(), 1);
970 }
971
972 #[tokio::test]
973 async fn test_in_memory_event_emitter_filter_by_session() {
974 use crate::events::{EventContext, EventRequest, InputMessageData};
975
976 let emitter = InMemoryEventEmitter::new();
977 let session1: SessionId = Uuid::now_v7().into();
978 let session2: SessionId = Uuid::now_v7().into();
979
980 let context = EventContext::empty();
982
983 emitter
984 .emit(EventRequest::new(
985 session1,
986 context.clone(),
987 InputMessageData::new(Message::user("session1")),
988 ))
989 .await
990 .unwrap();
991 emitter
992 .emit(EventRequest::new(
993 session2,
994 context,
995 InputMessageData::new(Message::user("session2")),
996 ))
997 .await
998 .unwrap();
999
1000 let session1_events = emitter.events_for_session(session1.uuid()).await;
1002 assert_eq!(session1_events.len(), 1);
1003
1004 let session2_events = emitter.events_for_session(session2.uuid()).await;
1005 assert_eq!(session2_events.len(), 1);
1006 }
1007
1008 #[tokio::test]
1009 async fn test_in_memory_event_emitter_clear() {
1010 use crate::events::{EventContext, EventRequest, InputMessageData};
1011
1012 let emitter = InMemoryEventEmitter::new();
1013 let session_id: SessionId = Uuid::now_v7().into();
1014 let event_context = EventContext::empty();
1015
1016 emitter
1017 .emit(EventRequest::new(
1018 session_id,
1019 event_context,
1020 InputMessageData::new(Message::user("test")),
1021 ))
1022 .await
1023 .unwrap();
1024
1025 assert_eq!(emitter.event_count().await, 1);
1026
1027 emitter.clear().await;
1028
1029 assert_eq!(emitter.event_count().await, 0);
1030 }
1031}