1use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use std::future::Future;
6use tokio::sync::{mpsc, oneshot};
7use tokio::task::JoinHandle;
8
9pub struct AgentHandle<T> {
13 join: JoinHandle<SageResult<T>>,
14 message_tx: mpsc::Sender<Message>,
15}
16
17impl<T> AgentHandle<T> {
18 pub async fn result(self) -> SageResult<T> {
20 self.join.await?
21 }
22
23 pub async fn send<M>(&self, msg: M) -> SageResult<()>
27 where
28 M: serde::Serialize,
29 {
30 let message = Message::new(msg)?;
31 self.message_tx
32 .send(message)
33 .await
34 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct Message {
41 pub payload: serde_json::Value,
43}
44
45impl Message {
46 pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
48 Ok(Self {
49 payload: serde_json::to_value(value)?,
50 })
51 }
52}
53
54pub struct AgentContext<T> {
58 pub llm: LlmClient,
60 result_tx: Option<oneshot::Sender<T>>,
62 message_rx: mpsc::Receiver<Message>,
64 emitted: bool,
66}
67
68impl<T> AgentContext<T> {
69 fn new(
71 llm: LlmClient,
72 result_tx: oneshot::Sender<T>,
73 message_rx: mpsc::Receiver<Message>,
74 ) -> Self {
75 Self {
76 llm,
77 result_tx: Some(result_tx),
78 message_rx,
79 emitted: false,
80 }
81 }
82
83 pub fn emit(&mut self, value: T) -> SageResult<T>
88 where
89 T: Clone,
90 {
91 if self.emitted {
92 return Ok(value);
94 }
95 self.emitted = true;
96 if let Some(tx) = self.result_tx.take() {
97 let _ = tx.send(value.clone());
99 }
100 Ok(value)
101 }
102
103 pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
105 where
106 R: serde::de::DeserializeOwned,
107 {
108 self.llm.infer(prompt).await
109 }
110
111 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
113 self.llm.infer_string(prompt).await
114 }
115
116 pub async fn receive<M>(&mut self) -> SageResult<M>
121 where
122 M: serde::de::DeserializeOwned,
123 {
124 let msg = self
125 .message_rx
126 .recv()
127 .await
128 .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
129
130 serde_json::from_value(msg.payload)
131 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
132 }
133
134 pub async fn receive_timeout<M>(
138 &mut self,
139 timeout: std::time::Duration,
140 ) -> SageResult<Option<M>>
141 where
142 M: serde::de::DeserializeOwned,
143 {
144 match tokio::time::timeout(timeout, self.message_rx.recv()).await {
145 Ok(Some(msg)) => {
146 let value = serde_json::from_value(msg.payload)
147 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
148 Ok(Some(value))
149 }
150 Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
151 Err(_) => Ok(None), }
153 }
154}
155
156pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
160where
161 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
162 F: Future<Output = SageResult<T>> + Send,
163 T: Send + 'static,
164{
165 let (result_tx, result_rx) = oneshot::channel();
166 let (message_tx, message_rx) = mpsc::channel(32);
167
168 let llm = LlmClient::from_env();
169 let ctx = AgentContext::new(llm, result_tx, message_rx);
170
171 let join = tokio::spawn(async move { agent(ctx).await });
172
173 drop(result_rx);
176
177 AgentHandle { join, message_tx }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use serde::{Deserialize, Serialize};
184
185 #[tokio::test]
186 async fn spawn_simple_agent() {
187 let handle = spawn(|mut ctx: AgentContext<i64>| async move { ctx.emit(42) });
188
189 let result = handle.result().await.expect("agent should succeed");
190 assert_eq!(result, 42);
191 }
192
193 #[tokio::test]
194 async fn spawn_agent_with_computation() {
195 let handle = spawn(|mut ctx: AgentContext<i64>| async move {
196 let sum = (1..=10).sum();
197 ctx.emit(sum)
198 });
199
200 let result = handle.result().await.expect("agent should succeed");
201 assert_eq!(result, 55);
202 }
203
204 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
205 struct TaskMessage {
206 id: u32,
207 content: String,
208 }
209
210 #[tokio::test]
211 async fn agent_receives_message() {
212 let handle = spawn(|mut ctx: AgentContext<String>| async move {
213 let msg: TaskMessage = ctx.receive().await?;
214 ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
215 });
216
217 handle
218 .send(TaskMessage {
219 id: 42,
220 content: "Hello".to_string(),
221 })
222 .await
223 .expect("send should succeed");
224
225 let result = handle.result().await.expect("agent should succeed");
226 assert_eq!(result, "Got task 42: Hello");
227 }
228
229 #[tokio::test]
230 async fn agent_receives_multiple_messages() {
231 let handle = spawn(|mut ctx: AgentContext<i32>| async move {
232 let mut sum = 0;
233 for _ in 0..3 {
234 let n: i32 = ctx.receive().await?;
235 sum += n;
236 }
237 ctx.emit(sum)
238 });
239
240 for n in [10, 20, 30] {
241 handle.send(n).await.expect("send should succeed");
242 }
243
244 let result = handle.result().await.expect("agent should succeed");
245 assert_eq!(result, 60);
246 }
247
248 #[tokio::test]
249 async fn agent_receive_timeout() {
250 let handle = spawn(|mut ctx: AgentContext<String>| async move {
251 let result: Option<i32> = ctx
252 .receive_timeout(std::time::Duration::from_millis(10))
253 .await?;
254 match result {
255 Some(n) => ctx.emit(format!("Got {n}")),
256 None => ctx.emit("Timeout".to_string()),
257 }
258 });
259
260 let result = handle.result().await.expect("agent should succeed");
262 assert_eq!(result, "Timeout");
263 }
264}