Skip to main content

sage_runtime/
agent.rs

1//! Agent spawning and lifecycle management.
2
3use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use crate::session::{ProtocolViolation, SenderHandle, SessionId, SharedSessionRegistry};
6use std::future::Future;
7use tokio::sync::{mpsc, oneshot};
8
9#[cfg(not(target_arch = "wasm32"))]
10use tokio::task::JoinHandle;
11
12// ---------------------------------------------------------------------------
13// AgentHandle — platform-specific inner field
14// ---------------------------------------------------------------------------
15
16/// Handle to a spawned agent.
17///
18/// This is returned by `spawn()` and can be awaited to get the agent's result.
19#[cfg(not(target_arch = "wasm32"))]
20pub struct AgentHandle<T> {
21    join: JoinHandle<SageResult<T>>,
22    message_tx: mpsc::Sender<Message>,
23}
24
25/// Handle to a spawned agent (WASM variant).
26///
27/// Uses a oneshot channel instead of `JoinHandle` since `spawn_local`
28/// does not return a handle.
29#[cfg(target_arch = "wasm32")]
30pub struct AgentHandle<T> {
31    result_rx: oneshot::Receiver<SageResult<T>>,
32    message_tx: mpsc::Sender<Message>,
33}
34
35// ---------------------------------------------------------------------------
36// AgentHandle::result() — platform-specific
37// ---------------------------------------------------------------------------
38
39#[cfg(not(target_arch = "wasm32"))]
40impl<T> AgentHandle<T> {
41    /// Wait for the agent to complete and return its result.
42    pub async fn result(self) -> SageResult<T> {
43        self.join.await?
44    }
45}
46
47#[cfg(target_arch = "wasm32")]
48impl<T> AgentHandle<T> {
49    /// Wait for the agent to complete and return its result.
50    pub async fn result(self) -> SageResult<T> {
51        self.result_rx
52            .await
53            .map_err(|_| SageError::Agent("Agent task dropped".to_string()))?
54    }
55}
56
57// ---------------------------------------------------------------------------
58// AgentHandle — shared methods (both platforms)
59// ---------------------------------------------------------------------------
60
61impl<T> AgentHandle<T> {
62    /// Send a message to the agent.
63    ///
64    /// The message will be serialized to JSON and placed in the agent's mailbox.
65    pub async fn send<M>(&self, msg: M) -> SageResult<()>
66    where
67        M: serde::Serialize,
68    {
69        let message = Message::new(msg)?;
70        self.message_tx
71            .send(message)
72            .await
73            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
74    }
75
76    /// Send a pre-built message to the agent.
77    ///
78    /// This is used by generated code when the message needs additional metadata
79    /// (like type_name for protocol tracking).
80    pub async fn send_message(&self, message: Message) -> SageResult<()> {
81        self.message_tx
82            .send(message)
83            .await
84            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
85    }
86}
87
88/// A message that can be sent to an agent.
89#[derive(Debug, Clone)]
90pub struct Message {
91    /// The message payload as a JSON value.
92    pub payload: serde_json::Value,
93    /// Phase 3: Session ID for protocol tracking.
94    pub session_id: Option<SessionId>,
95    /// Phase 3: Handle for replying to this message.
96    pub sender: Option<SenderHandle>,
97    /// Phase 3: Type name for protocol validation.
98    pub type_name: Option<String>,
99}
100
101impl Message {
102    /// Create a new message from a serializable value.
103    pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
104        Ok(Self {
105            payload: serde_json::to_value(value)?,
106            session_id: None,
107            sender: None,
108            type_name: None,
109        })
110    }
111
112    /// Create a new message with session context.
113    pub fn with_session<T: serde::Serialize>(
114        value: T,
115        session_id: SessionId,
116        sender: SenderHandle,
117        type_name: impl Into<String>,
118    ) -> SageResult<Self> {
119        Ok(Self {
120            payload: serde_json::to_value(value)?,
121            session_id: Some(session_id),
122            sender: Some(sender),
123            type_name: Some(type_name.into()),
124        })
125    }
126
127    /// Set the type name for this message.
128    #[must_use]
129    pub fn with_type_name(mut self, type_name: impl Into<String>) -> Self {
130        self.type_name = Some(type_name.into());
131        self
132    }
133}
134
135/// Context provided to agent handlers.
136///
137/// This gives agents access to LLM inference and the ability to emit results.
138pub struct AgentContext<T> {
139    /// LLM client for inference calls.
140    pub llm: LlmClient,
141    /// Channel to send the result to the awaiter.
142    result_tx: Option<oneshot::Sender<T>>,
143    /// Channel to receive messages from other agents.
144    message_rx: mpsc::Receiver<Message>,
145    /// Whether emit has been called (prevents double-emit).
146    emitted: bool,
147    /// Phase 3: The current message being handled (for reply()).
148    current_message: Option<Message>,
149    /// Phase 3: Session registry for protocol tracking.
150    session_registry: SharedSessionRegistry,
151    /// Phase 3: The role this agent plays in protocols.
152    agent_role: Option<String>,
153}
154
155impl<T> AgentContext<T> {
156    /// Create a new agent context.
157    fn new(
158        llm: LlmClient,
159        result_tx: oneshot::Sender<T>,
160        message_rx: mpsc::Receiver<Message>,
161        session_registry: SharedSessionRegistry,
162    ) -> Self {
163        Self {
164            llm,
165            result_tx: Some(result_tx),
166            message_rx,
167            emitted: false,
168            current_message: None,
169            session_registry,
170            agent_role: None,
171        }
172    }
173
174    /// Set the role this agent plays in protocols.
175    pub fn set_role(&mut self, role: impl Into<String>) {
176        self.agent_role = Some(role.into());
177    }
178
179    /// Get the session registry.
180    #[must_use]
181    pub fn session_registry(&self) -> &SharedSessionRegistry {
182        &self.session_registry
183    }
184
185    /// Emit a value to the awaiter.
186    ///
187    /// This should be called once at the end of the agent's execution.
188    /// Calling emit multiple times is a no-op after the first call.
189    pub fn emit(&mut self, value: T) -> SageResult<T>
190    where
191        T: Clone,
192    {
193        if self.emitted {
194            // Already emitted, just return the value
195            return Ok(value);
196        }
197        self.emitted = true;
198        if let Some(tx) = self.result_tx.take() {
199            // Ignore send errors - the receiver may have been dropped
200            let _ = tx.send(value.clone());
201        }
202        Ok(value)
203    }
204
205    /// Call the LLM with a prompt and parse the response.
206    pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
207    where
208        R: serde::de::DeserializeOwned,
209    {
210        self.llm.infer(prompt).await
211    }
212
213    /// Call the LLM with a prompt and return the raw string response.
214    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
215        self.llm.infer_string(prompt).await
216    }
217
218    /// Receive a message from the agent's mailbox.
219    ///
220    /// This blocks until a message is available. The message is deserialized
221    /// into the specified type.
222    pub async fn receive<M>(&mut self) -> SageResult<M>
223    where
224        M: serde::de::DeserializeOwned,
225    {
226        let msg = self
227            .message_rx
228            .recv()
229            .await
230            .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
231
232        // Phase 3: Store current message for reply()
233        self.current_message = Some(msg.clone());
234
235        serde_json::from_value(msg.payload)
236            .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
237    }
238
239    /// Receive a message with a timeout.
240    ///
241    /// Returns `None` if the timeout expires before a message arrives.
242    #[cfg(not(target_arch = "wasm32"))]
243    pub async fn receive_timeout<M>(
244        &mut self,
245        timeout: std::time::Duration,
246    ) -> SageResult<Option<M>>
247    where
248        M: serde::de::DeserializeOwned,
249    {
250        match tokio::time::timeout(timeout, self.message_rx.recv()).await {
251            Ok(Some(msg)) => {
252                // Phase 3: Store current message for reply()
253                self.current_message = Some(msg.clone());
254
255                let value = serde_json::from_value(msg.payload)
256                    .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
257                Ok(Some(value))
258            }
259            Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
260            Err(_) => Ok(None), // Timeout
261        }
262    }
263
264    /// Receive a message with a timeout (WASM variant).
265    ///
266    /// Uses browser `setTimeout` for the timeout mechanism.
267    #[cfg(target_arch = "wasm32")]
268    pub async fn receive_timeout<M>(
269        &mut self,
270        timeout: std::time::Duration,
271    ) -> SageResult<Option<M>>
272    where
273        M: serde::de::DeserializeOwned,
274    {
275        use futures::future::{select, Either};
276        use std::pin::pin;
277
278        let recv_fut = pin!(self.message_rx.recv());
279        let sleep_fut = pin!(sage_runtime_web::sleep(timeout));
280
281        match select(recv_fut, sleep_fut).await {
282            Either::Left((Some(msg), _)) => {
283                self.current_message = Some(msg.clone());
284                let value = serde_json::from_value(msg.payload)
285                    .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
286                Ok(Some(value))
287            }
288            Either::Left((None, _)) => {
289                Err(SageError::Agent("Message channel closed".to_string()))
290            }
291            Either::Right((_, _)) => Ok(None), // Timeout
292        }
293    }
294
295    /// Receive the raw message from the agent's mailbox.
296    ///
297    /// This blocks until a message is available. Returns the full Message
298    /// including session context.
299    pub async fn receive_raw(&mut self) -> SageResult<Message> {
300        let msg = self
301            .message_rx
302            .recv()
303            .await
304            .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
305
306        // Store current message for reply()
307        self.current_message = Some(msg.clone());
308
309        Ok(msg)
310    }
311
312    /// Set the current message context (for use in message handlers).
313    ///
314    /// This is called by generated code when entering a message handler.
315    pub fn set_current_message(&mut self, msg: Message) {
316        self.current_message = Some(msg);
317    }
318
319    /// Clear the current message context (for use after message handlers).
320    pub fn clear_current_message(&mut self) {
321        self.current_message = None;
322    }
323
324    /// Phase 3: Reply to the current message.
325    ///
326    /// This sends a response back to the sender of the current message.
327    /// Can only be called inside a message handler.
328    ///
329    /// # Errors
330    ///
331    /// Returns an error if called outside a message handler or if
332    /// the current message has no sender handle.
333    pub async fn reply<M: serde::Serialize>(&mut self, msg: M) -> SageResult<()> {
334        let current = self
335            .current_message
336            .as_ref()
337            .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
338
339        let sender = current
340            .sender
341            .as_ref()
342            .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
343
344        sender.send(msg).await
345    }
346
347    /// Phase 3: Reply to the current message with protocol state validation.
348    pub async fn reply_with_protocol<M: serde::Serialize>(
349        &mut self,
350        msg: M,
351        msg_type: &str,
352        role: &str,
353    ) -> SageResult<()> {
354        let current = self
355            .current_message
356            .as_ref()
357            .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
358
359        // If message has a session, validate protocol state
360        if let Some(session_id) = current.session_id {
361            let mut registry = self.session_registry.write().await;
362            if let Some(session) = registry.get_mut(&session_id) {
363                // Validate that we can send this message type from our role
364                if !session.state.can_send(msg_type, role) {
365                    return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
366                        protocol: session.protocol.clone(),
367                        expected: "valid reply".to_string(),
368                        received: msg_type.to_string(),
369                        state: session.state.state_name().to_string(),
370                    }));
371                }
372                // Transition the state machine
373                session.state.transition(msg_type)?;
374            }
375        }
376
377        let sender = current
378            .sender
379            .as_ref()
380            .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
381
382        sender.send(msg).await
383    }
384
385    /// Phase 3: Validate incoming message against protocol state.
386    pub async fn validate_protocol_receive(
387        &mut self,
388        msg_type: &str,
389        role: &str,
390    ) -> SageResult<()> {
391        let current = match &self.current_message {
392            Some(msg) => msg,
393            None => return Ok(()), // No current message, nothing to validate
394        };
395
396        // If message has a session, validate protocol state
397        if let Some(session_id) = current.session_id {
398            let mut registry = self.session_registry.write().await;
399            if let Some(session) = registry.get_mut(&session_id) {
400                // Validate that we can receive this message type in our role
401                if !session.state.can_receive(msg_type, role) {
402                    return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
403                        protocol: session.protocol.clone(),
404                        expected: "valid message for current state".to_string(),
405                        received: msg_type.to_string(),
406                        state: session.state.state_name().to_string(),
407                    }));
408                }
409                // Transition the state machine
410                session.state.transition(msg_type)?;
411
412                // If protocol is complete, remove the session
413                if session.state.is_terminal() {
414                    drop(registry);
415                    self.session_registry.write().await.remove(&session_id);
416                }
417            }
418        }
419
420        Ok(())
421    }
422
423    /// Phase 3: Start a new protocol session.
424    pub async fn start_session(
425        &self,
426        protocol: String,
427        role: String,
428        state: Box<dyn crate::session::ProtocolStateMachine>,
429        partner: SenderHandle,
430    ) -> SessionId {
431        let mut registry = self.session_registry.write().await;
432        let session_id = registry.next_id();
433        registry.start_session(session_id, protocol, role, state, partner);
434        session_id
435    }
436
437    /// Get the current message being handled (if any).
438    #[must_use]
439    pub fn current_message(&self) -> Option<&Message> {
440        self.current_message.as_ref()
441    }
442}
443
444// ---------------------------------------------------------------------------
445// spawn — native (tokio::spawn, requires Send)
446// ---------------------------------------------------------------------------
447
448/// Spawn an agent and return a handle to it.
449///
450/// The agent will run asynchronously in a separate task.
451#[cfg(not(target_arch = "wasm32"))]
452pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
453where
454    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
455    F: Future<Output = SageResult<T>> + Send,
456    T: Send + 'static,
457{
458    spawn_with_llm_config(agent, crate::llm::LlmConfig::from_env())
459}
460
461/// Spawn an agent with a custom LLM configuration.
462///
463/// This is used by effect handlers to configure per-agent LLM settings.
464#[cfg(not(target_arch = "wasm32"))]
465pub fn spawn_with_llm_config<A, T, F>(agent: A, llm_config: crate::llm::LlmConfig) -> AgentHandle<T>
466where
467    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
468    F: Future<Output = SageResult<T>> + Send,
469    T: Send + 'static,
470{
471    let (result_tx, result_rx) = oneshot::channel();
472    let (message_tx, message_rx) = mpsc::channel(32);
473
474    let llm = LlmClient::new(llm_config);
475    let session_registry = crate::session::shared_registry();
476    let ctx = AgentContext::new(llm, result_tx, message_rx, session_registry);
477
478    let join = tokio::spawn(async move { agent(ctx).await });
479
480    // We need to handle the result_rx somewhere, but for now we just let
481    // the result come from the JoinHandle
482    drop(result_rx);
483
484    AgentHandle { join, message_tx }
485}
486
487// ---------------------------------------------------------------------------
488// spawn — WASM (spawn_local, no Send bounds)
489// ---------------------------------------------------------------------------
490
491/// Spawn an agent and return a handle to it.
492///
493/// On WASM, agents run on the browser's single-threaded event loop
494/// via `spawn_local`. No `Send` bounds are required.
495#[cfg(target_arch = "wasm32")]
496pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
497where
498    A: FnOnce(AgentContext<T>) -> F + 'static,
499    F: Future<Output = SageResult<T>> + 'static,
500    T: 'static,
501{
502    spawn_with_llm_config(agent, crate::llm::LlmConfig::from_env())
503}
504
505/// Spawn an agent with a custom LLM configuration (WASM variant).
506#[cfg(target_arch = "wasm32")]
507pub fn spawn_with_llm_config<A, T, F>(agent: A, llm_config: crate::llm::LlmConfig) -> AgentHandle<T>
508where
509    A: FnOnce(AgentContext<T>) -> F + 'static,
510    F: Future<Output = SageResult<T>> + 'static,
511    T: 'static,
512{
513    let (task_result_tx, task_result_rx) = oneshot::channel();
514    let (emit_tx, _emit_rx) = oneshot::channel();
515    let (message_tx, message_rx) = mpsc::channel(32);
516
517    let llm = LlmClient::new(llm_config);
518    let session_registry = crate::session::shared_registry();
519    let ctx = AgentContext::new(llm, emit_tx, message_rx, session_registry);
520
521    wasm_bindgen_futures::spawn_local(async move {
522        let result = agent(ctx).await;
523        let _ = task_result_tx.send(result);
524    });
525
526    AgentHandle {
527        result_rx: task_result_rx,
528        message_tx,
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use serde::{Deserialize, Serialize};
536
537    #[tokio::test]
538    async fn spawn_simple_agent() {
539        let handle = spawn(|mut ctx: AgentContext<i64>| async move { ctx.emit(42) });
540
541        let result = handle.result().await.expect("agent should succeed");
542        assert_eq!(result, 42);
543    }
544
545    #[tokio::test]
546    async fn spawn_agent_with_computation() {
547        let handle = spawn(|mut ctx: AgentContext<i64>| async move {
548            let sum = (1..=10).sum();
549            ctx.emit(sum)
550        });
551
552        let result = handle.result().await.expect("agent should succeed");
553        assert_eq!(result, 55);
554    }
555
556    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
557    struct TaskMessage {
558        id: u32,
559        content: String,
560    }
561
562    #[tokio::test]
563    async fn agent_receives_message() {
564        let handle = spawn(|mut ctx: AgentContext<String>| async move {
565            let msg: TaskMessage = ctx.receive().await?;
566            ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
567        });
568
569        handle
570            .send(TaskMessage {
571                id: 42,
572                content: "Hello".to_string(),
573            })
574            .await
575            .expect("send should succeed");
576
577        let result = handle.result().await.expect("agent should succeed");
578        assert_eq!(result, "Got task 42: Hello");
579    }
580
581    #[tokio::test]
582    async fn agent_receives_multiple_messages() {
583        let handle = spawn(|mut ctx: AgentContext<i32>| async move {
584            let mut sum = 0;
585            for _ in 0..3 {
586                let n: i32 = ctx.receive().await?;
587                sum += n;
588            }
589            ctx.emit(sum)
590        });
591
592        for n in [10, 20, 30] {
593            handle.send(n).await.expect("send should succeed");
594        }
595
596        let result = handle.result().await.expect("agent should succeed");
597        assert_eq!(result, 60);
598    }
599
600    #[tokio::test]
601    async fn agent_receive_timeout() {
602        let handle = spawn(|mut ctx: AgentContext<String>| async move {
603            let result: Option<i32> = ctx
604                .receive_timeout(std::time::Duration::from_millis(10))
605                .await?;
606            match result {
607                Some(n) => ctx.emit(format!("Got {n}")),
608                None => ctx.emit("Timeout".to_string()),
609            }
610        });
611
612        // Don't send anything, let it timeout
613        let result = handle.result().await.expect("agent should succeed");
614        assert_eq!(result, "Timeout");
615    }
616}