1use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use crate::session::{ProtocolViolation, SenderHandle, SessionId, SharedSessionRegistry};
6use std::future::Future;
7use tokio::sync::{mpsc, oneshot};
8
9#[cfg(not(target_arch = "wasm32"))]
10use tokio::task::JoinHandle;
11
12#[cfg(not(target_arch = "wasm32"))]
20pub struct AgentHandle<T> {
21 join: JoinHandle<SageResult<T>>,
22 message_tx: mpsc::Sender<Message>,
23}
24
25#[cfg(target_arch = "wasm32")]
30pub struct AgentHandle<T> {
31 result_rx: oneshot::Receiver<SageResult<T>>,
32 message_tx: mpsc::Sender<Message>,
33}
34
35#[cfg(not(target_arch = "wasm32"))]
40impl<T> AgentHandle<T> {
41 pub async fn result(self) -> SageResult<T> {
43 self.join.await?
44 }
45}
46
47#[cfg(target_arch = "wasm32")]
48impl<T> AgentHandle<T> {
49 pub async fn result(self) -> SageResult<T> {
51 self.result_rx
52 .await
53 .map_err(|_| SageError::Agent("Agent task dropped".to_string()))?
54 }
55}
56
57impl<T> AgentHandle<T> {
62 pub async fn send<M>(&self, msg: M) -> SageResult<()>
66 where
67 M: serde::Serialize,
68 {
69 let message = Message::new(msg)?;
70 self.message_tx
71 .send(message)
72 .await
73 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
74 }
75
76 pub async fn send_message(&self, message: Message) -> SageResult<()> {
81 self.message_tx
82 .send(message)
83 .await
84 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct Message {
91 pub payload: serde_json::Value,
93 pub session_id: Option<SessionId>,
95 pub sender: Option<SenderHandle>,
97 pub type_name: Option<String>,
99}
100
101impl Message {
102 pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
104 Ok(Self {
105 payload: serde_json::to_value(value)?,
106 session_id: None,
107 sender: None,
108 type_name: None,
109 })
110 }
111
112 pub fn with_session<T: serde::Serialize>(
114 value: T,
115 session_id: SessionId,
116 sender: SenderHandle,
117 type_name: impl Into<String>,
118 ) -> SageResult<Self> {
119 Ok(Self {
120 payload: serde_json::to_value(value)?,
121 session_id: Some(session_id),
122 sender: Some(sender),
123 type_name: Some(type_name.into()),
124 })
125 }
126
127 #[must_use]
129 pub fn with_type_name(mut self, type_name: impl Into<String>) -> Self {
130 self.type_name = Some(type_name.into());
131 self
132 }
133}
134
135pub struct AgentContext<T> {
139 pub llm: LlmClient,
141 result_tx: Option<oneshot::Sender<T>>,
143 message_rx: mpsc::Receiver<Message>,
145 emitted: bool,
147 current_message: Option<Message>,
149 session_registry: SharedSessionRegistry,
151 agent_role: Option<String>,
153}
154
155impl<T> AgentContext<T> {
156 fn new(
158 llm: LlmClient,
159 result_tx: oneshot::Sender<T>,
160 message_rx: mpsc::Receiver<Message>,
161 session_registry: SharedSessionRegistry,
162 ) -> Self {
163 Self {
164 llm,
165 result_tx: Some(result_tx),
166 message_rx,
167 emitted: false,
168 current_message: None,
169 session_registry,
170 agent_role: None,
171 }
172 }
173
174 pub fn set_role(&mut self, role: impl Into<String>) {
176 self.agent_role = Some(role.into());
177 }
178
179 #[must_use]
181 pub fn session_registry(&self) -> &SharedSessionRegistry {
182 &self.session_registry
183 }
184
185 pub fn emit(&mut self, value: T) -> SageResult<T>
190 where
191 T: Clone,
192 {
193 if self.emitted {
194 return Ok(value);
196 }
197 self.emitted = true;
198 if let Some(tx) = self.result_tx.take() {
199 let _ = tx.send(value.clone());
201 }
202 Ok(value)
203 }
204
205 pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
207 where
208 R: serde::de::DeserializeOwned,
209 {
210 self.llm.infer(prompt).await
211 }
212
213 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
215 self.llm.infer_string(prompt).await
216 }
217
218 pub async fn receive<M>(&mut self) -> SageResult<M>
223 where
224 M: serde::de::DeserializeOwned,
225 {
226 let msg = self
227 .message_rx
228 .recv()
229 .await
230 .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
231
232 self.current_message = Some(msg.clone());
234
235 serde_json::from_value(msg.payload)
236 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
237 }
238
239 #[cfg(not(target_arch = "wasm32"))]
243 pub async fn receive_timeout<M>(
244 &mut self,
245 timeout: std::time::Duration,
246 ) -> SageResult<Option<M>>
247 where
248 M: serde::de::DeserializeOwned,
249 {
250 match tokio::time::timeout(timeout, self.message_rx.recv()).await {
251 Ok(Some(msg)) => {
252 self.current_message = Some(msg.clone());
254
255 let value = serde_json::from_value(msg.payload)
256 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
257 Ok(Some(value))
258 }
259 Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
260 Err(_) => Ok(None), }
262 }
263
264 #[cfg(target_arch = "wasm32")]
268 pub async fn receive_timeout<M>(
269 &mut self,
270 timeout: std::time::Duration,
271 ) -> SageResult<Option<M>>
272 where
273 M: serde::de::DeserializeOwned,
274 {
275 use futures::future::{select, Either};
276 use std::pin::pin;
277
278 let recv_fut = pin!(self.message_rx.recv());
279 let sleep_fut = pin!(sage_runtime_web::sleep(timeout));
280
281 match select(recv_fut, sleep_fut).await {
282 Either::Left((Some(msg), _)) => {
283 self.current_message = Some(msg.clone());
284 let value = serde_json::from_value(msg.payload)
285 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
286 Ok(Some(value))
287 }
288 Either::Left((None, _)) => {
289 Err(SageError::Agent("Message channel closed".to_string()))
290 }
291 Either::Right((_, _)) => Ok(None), }
293 }
294
295 pub async fn receive_raw(&mut self) -> SageResult<Message> {
300 let msg = self
301 .message_rx
302 .recv()
303 .await
304 .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
305
306 self.current_message = Some(msg.clone());
308
309 Ok(msg)
310 }
311
312 pub fn set_current_message(&mut self, msg: Message) {
316 self.current_message = Some(msg);
317 }
318
319 pub fn clear_current_message(&mut self) {
321 self.current_message = None;
322 }
323
324 pub async fn reply<M: serde::Serialize>(&mut self, msg: M) -> SageResult<()> {
334 let current = self
335 .current_message
336 .as_ref()
337 .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
338
339 let sender = current
340 .sender
341 .as_ref()
342 .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
343
344 sender.send(msg).await
345 }
346
347 pub async fn reply_with_protocol<M: serde::Serialize>(
349 &mut self,
350 msg: M,
351 msg_type: &str,
352 role: &str,
353 ) -> SageResult<()> {
354 let current = self
355 .current_message
356 .as_ref()
357 .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
358
359 if let Some(session_id) = current.session_id {
361 let mut registry = self.session_registry.write().await;
362 if let Some(session) = registry.get_mut(&session_id) {
363 if !session.state.can_send(msg_type, role) {
365 return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
366 protocol: session.protocol.clone(),
367 expected: "valid reply".to_string(),
368 received: msg_type.to_string(),
369 state: session.state.state_name().to_string(),
370 }));
371 }
372 session.state.transition(msg_type)?;
374 }
375 }
376
377 let sender = current
378 .sender
379 .as_ref()
380 .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
381
382 sender.send(msg).await
383 }
384
385 pub async fn validate_protocol_receive(
387 &mut self,
388 msg_type: &str,
389 role: &str,
390 ) -> SageResult<()> {
391 let current = match &self.current_message {
392 Some(msg) => msg,
393 None => return Ok(()), };
395
396 if let Some(session_id) = current.session_id {
398 let mut registry = self.session_registry.write().await;
399 if let Some(session) = registry.get_mut(&session_id) {
400 if !session.state.can_receive(msg_type, role) {
402 return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
403 protocol: session.protocol.clone(),
404 expected: "valid message for current state".to_string(),
405 received: msg_type.to_string(),
406 state: session.state.state_name().to_string(),
407 }));
408 }
409 session.state.transition(msg_type)?;
411
412 if session.state.is_terminal() {
414 drop(registry);
415 self.session_registry.write().await.remove(&session_id);
416 }
417 }
418 }
419
420 Ok(())
421 }
422
423 pub async fn start_session(
425 &self,
426 protocol: String,
427 role: String,
428 state: Box<dyn crate::session::ProtocolStateMachine>,
429 partner: SenderHandle,
430 ) -> SessionId {
431 let mut registry = self.session_registry.write().await;
432 let session_id = registry.next_id();
433 registry.start_session(session_id, protocol, role, state, partner);
434 session_id
435 }
436
437 #[must_use]
439 pub fn current_message(&self) -> Option<&Message> {
440 self.current_message.as_ref()
441 }
442}
443
444#[cfg(not(target_arch = "wasm32"))]
452pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
453where
454 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
455 F: Future<Output = SageResult<T>> + Send,
456 T: Send + 'static,
457{
458 spawn_with_llm_config(agent, crate::llm::LlmConfig::from_env())
459}
460
461#[cfg(not(target_arch = "wasm32"))]
465pub fn spawn_with_llm_config<A, T, F>(agent: A, llm_config: crate::llm::LlmConfig) -> AgentHandle<T>
466where
467 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
468 F: Future<Output = SageResult<T>> + Send,
469 T: Send + 'static,
470{
471 let (result_tx, result_rx) = oneshot::channel();
472 let (message_tx, message_rx) = mpsc::channel(32);
473
474 let llm = LlmClient::new(llm_config);
475 let session_registry = crate::session::shared_registry();
476 let ctx = AgentContext::new(llm, result_tx, message_rx, session_registry);
477
478 let join = tokio::spawn(async move { agent(ctx).await });
479
480 drop(result_rx);
483
484 AgentHandle { join, message_tx }
485}
486
487#[cfg(target_arch = "wasm32")]
496pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
497where
498 A: FnOnce(AgentContext<T>) -> F + 'static,
499 F: Future<Output = SageResult<T>> + 'static,
500 T: 'static,
501{
502 spawn_with_llm_config(agent, crate::llm::LlmConfig::from_env())
503}
504
505#[cfg(target_arch = "wasm32")]
507pub fn spawn_with_llm_config<A, T, F>(agent: A, llm_config: crate::llm::LlmConfig) -> AgentHandle<T>
508where
509 A: FnOnce(AgentContext<T>) -> F + 'static,
510 F: Future<Output = SageResult<T>> + 'static,
511 T: 'static,
512{
513 let (task_result_tx, task_result_rx) = oneshot::channel();
514 let (emit_tx, _emit_rx) = oneshot::channel();
515 let (message_tx, message_rx) = mpsc::channel(32);
516
517 let llm = LlmClient::new(llm_config);
518 let session_registry = crate::session::shared_registry();
519 let ctx = AgentContext::new(llm, emit_tx, message_rx, session_registry);
520
521 wasm_bindgen_futures::spawn_local(async move {
522 let result = agent(ctx).await;
523 let _ = task_result_tx.send(result);
524 });
525
526 AgentHandle {
527 result_rx: task_result_rx,
528 message_tx,
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use serde::{Deserialize, Serialize};
536
537 #[tokio::test]
538 async fn spawn_simple_agent() {
539 let handle = spawn(|mut ctx: AgentContext<i64>| async move { ctx.emit(42) });
540
541 let result = handle.result().await.expect("agent should succeed");
542 assert_eq!(result, 42);
543 }
544
545 #[tokio::test]
546 async fn spawn_agent_with_computation() {
547 let handle = spawn(|mut ctx: AgentContext<i64>| async move {
548 let sum = (1..=10).sum();
549 ctx.emit(sum)
550 });
551
552 let result = handle.result().await.expect("agent should succeed");
553 assert_eq!(result, 55);
554 }
555
556 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
557 struct TaskMessage {
558 id: u32,
559 content: String,
560 }
561
562 #[tokio::test]
563 async fn agent_receives_message() {
564 let handle = spawn(|mut ctx: AgentContext<String>| async move {
565 let msg: TaskMessage = ctx.receive().await?;
566 ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
567 });
568
569 handle
570 .send(TaskMessage {
571 id: 42,
572 content: "Hello".to_string(),
573 })
574 .await
575 .expect("send should succeed");
576
577 let result = handle.result().await.expect("agent should succeed");
578 assert_eq!(result, "Got task 42: Hello");
579 }
580
581 #[tokio::test]
582 async fn agent_receives_multiple_messages() {
583 let handle = spawn(|mut ctx: AgentContext<i32>| async move {
584 let mut sum = 0;
585 for _ in 0..3 {
586 let n: i32 = ctx.receive().await?;
587 sum += n;
588 }
589 ctx.emit(sum)
590 });
591
592 for n in [10, 20, 30] {
593 handle.send(n).await.expect("send should succeed");
594 }
595
596 let result = handle.result().await.expect("agent should succeed");
597 assert_eq!(result, 60);
598 }
599
600 #[tokio::test]
601 async fn agent_receive_timeout() {
602 let handle = spawn(|mut ctx: AgentContext<String>| async move {
603 let result: Option<i32> = ctx
604 .receive_timeout(std::time::Duration::from_millis(10))
605 .await?;
606 match result {
607 Some(n) => ctx.emit(format!("Got {n}")),
608 None => ctx.emit("Timeout".to_string()),
609 }
610 });
611
612 let result = handle.result().await.expect("agent should succeed");
614 assert_eq!(result, "Timeout");
615 }
616}