everruns_core/atoms/
input.rs1use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13use super::{Atom, AtomContext};
14use crate::error::{AgentLoopError, Result};
15use crate::message::Message;
16use crate::message_retriever::MessageRetriever;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct InputAtomInput {
25 pub context: AtomContext,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct InputAtomResult {
32 pub message: Message,
34}
35
36pub struct InputAtom<M>
49where
50 M: MessageRetriever,
51{
52 message_retriever: M,
53}
54
55impl<M> InputAtom<M>
56where
57 M: MessageRetriever,
58{
59 pub fn new(message_retriever: M) -> Self {
61 Self { message_retriever }
62 }
63}
64
65#[async_trait]
66impl<M> Atom for InputAtom<M>
67where
68 M: MessageRetriever + Send + Sync,
69{
70 type Input = InputAtomInput;
71 type Output = InputAtomResult;
72
73 fn name(&self) -> &'static str {
74 "input"
75 }
76
77 async fn execute(&self, input: Self::Input) -> Result<Self::Output> {
78 let InputAtomInput { context } = input;
79
80 tracing::debug!(
81 session_id = %context.session_id,
82 turn_id = %context.turn_id,
83 input_message_id = %context.input_message_id,
84 exec_id = %context.exec_id,
85 "InputAtom: retrieving user message"
86 );
87
88 let message = self
90 .message_retriever
91 .get(context.session_id, context.input_message_id)
92 .await?
93 .ok_or_else(|| {
94 AgentLoopError::store(format!(
95 "User message not found: {}",
96 context.input_message_id
97 ))
98 })?;
99
100 tracing::info!(
101 session_id = %context.session_id,
102 turn_id = %context.turn_id,
103 message_id = %message.id,
104 "InputAtom: turn started with user message"
105 );
106
107 Ok(InputAtomResult { message })
108 }
109}
110
111#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::memory::InMemoryMessageRetriever;
119 use crate::message_retriever::InputMessage;
120 use crate::typed_id::{MessageId, SessionId, TurnId};
121
122 #[tokio::test]
123 async fn test_input_atom_retrieves_message() {
124 let retriever = InMemoryMessageRetriever::new();
125 let session_id = SessionId::new();
126 let turn_id = TurnId::new();
127
128 let user_message = retriever
130 .add(session_id, InputMessage::user("Hello, world!"))
131 .await
132 .unwrap();
133
134 let context = AtomContext::new(session_id, turn_id, user_message.id);
135 let atom = InputAtom::new(retriever);
136
137 let result = atom.execute(InputAtomInput { context }).await.unwrap();
138
139 assert_eq!(result.message.id, user_message.id);
140 assert_eq!(result.message.text(), Some("Hello, world!"));
141 }
142
143 #[tokio::test]
144 async fn test_input_atom_not_found() {
145 let retriever = InMemoryMessageRetriever::new();
146 let session_id = SessionId::new();
147 let turn_id = TurnId::new();
148 let missing_id = MessageId::new();
149
150 let context = AtomContext::new(session_id, turn_id, missing_id);
151 let atom = InputAtom::new(retriever);
152
153 let result = atom.execute(InputAtomInput { context }).await;
154
155 assert!(result.is_err());
156 }
157}