Skip to main content

awaken_runtime/runtime/agent_runtime/
mod.rs

1//! Agent runtime: top-level orchestrator for run management, routing, and control.
2
3mod active_registry;
4mod control;
5mod run_request;
6mod runner;
7
8use std::sync::Arc;
9
10use awaken_contract::contract::storage::ThreadRunStore;
11
12use crate::error::RuntimeError;
13#[cfg(feature = "a2a")]
14use crate::registry::composite::CompositeAgentSpecRegistry;
15use awaken_contract::contract::suspension::ToolCallResume;
16use futures::channel::mpsc;
17
18use crate::cancellation::CancellationToken;
19use crate::registry::{
20    AgentResolver, ExecutionResolver, LocalExecutionResolver, RegistryHandle, RegistrySet,
21    RegistrySnapshot,
22};
23
24pub use run_request::RunRequest;
25
26use active_registry::ActiveRunRegistry;
27
28pub(crate) type DecisionBatch = Vec<(String, ToolCallResume)>;
29
30// ---------------------------------------------------------------------------
31// RunHandle
32// ---------------------------------------------------------------------------
33
34/// Internal control handle for a running agent loop.
35///
36/// Stored in `ActiveRunRegistry` for the lifetime of a run.
37/// External control is exposed via `AgentRuntime::cancel()` / `send_decisions()`.
38#[derive(Clone)]
39pub(crate) struct RunHandle {
40    pub(crate) run_id: String,
41    cancellation_token: CancellationToken,
42    decision_tx: mpsc::UnboundedSender<DecisionBatch>,
43}
44
45impl RunHandle {
46    /// Cancel the running agent loop cooperatively.
47    pub(crate) fn cancel(&self) {
48        self.cancellation_token.cancel();
49    }
50
51    /// Send one or more tool call decisions to the running loop atomically.
52    pub(crate) fn send_decisions(
53        &self,
54        decisions: DecisionBatch,
55    ) -> Result<(), Box<mpsc::TrySendError<DecisionBatch>>> {
56        self.decision_tx.unbounded_send(decisions).map_err(Box::new)
57    }
58
59    /// Send a single tool call decision to the running loop.
60    pub(crate) fn send_decision(
61        &self,
62        call_id: String,
63        resume: ToolCallResume,
64    ) -> Result<(), Box<mpsc::TrySendError<DecisionBatch>>> {
65        self.send_decisions(vec![(call_id, resume)])
66    }
67}
68
69// ---------------------------------------------------------------------------
70// AgentRuntime
71// ---------------------------------------------------------------------------
72
73/// Top-level agent runtime. Manages runs across threads.
74///
75/// Provides methods for cancelling and sending decisions
76/// to active agent runs. Enforces one active run per thread.
77pub struct AgentRuntime {
78    pub(crate) resolver: Arc<dyn ExecutionResolver>,
79    pub(crate) storage: Option<Arc<dyn ThreadRunStore>>,
80    pub(crate) profile_store:
81        Option<Arc<dyn awaken_contract::contract::profile_store::ProfileStore>>,
82    pub(crate) active_runs: ActiveRunRegistry,
83    pub(crate) registry_handle: Option<RegistryHandle>,
84    #[cfg(feature = "a2a")]
85    composite_registry: Option<Arc<CompositeAgentSpecRegistry>>,
86}
87
88impl AgentRuntime {
89    pub fn new(resolver: Arc<dyn AgentResolver>) -> Self {
90        Self::new_with_execution_resolver(Arc::new(LocalExecutionResolver::new(resolver)))
91    }
92
93    pub fn new_with_execution_resolver(resolver: Arc<dyn ExecutionResolver>) -> Self {
94        Self {
95            resolver,
96            storage: None,
97            profile_store: None,
98            active_runs: ActiveRunRegistry::new(),
99            registry_handle: None,
100            #[cfg(feature = "a2a")]
101            composite_registry: None,
102        }
103    }
104
105    #[must_use]
106    pub fn with_registry_handle(mut self, handle: RegistryHandle) -> Self {
107        self.registry_handle = Some(handle);
108        self
109    }
110
111    #[must_use]
112    pub fn with_thread_run_store(mut self, store: Arc<dyn ThreadRunStore>) -> Self {
113        self.storage = Some(store);
114        self
115    }
116
117    #[must_use]
118    pub(crate) fn with_profile_store(
119        mut self,
120        store: Arc<dyn awaken_contract::contract::profile_store::ProfileStore>,
121    ) -> Self {
122        self.profile_store = Some(store);
123        self
124    }
125
126    pub fn resolver(&self) -> &dyn AgentResolver {
127        self.resolver.as_ref()
128    }
129
130    /// Return a cloned `Arc` of the agent resolver.
131    pub fn resolver_arc(&self) -> Arc<dyn AgentResolver> {
132        self.resolver.clone()
133    }
134
135    pub fn execution_resolver(&self) -> &dyn ExecutionResolver {
136        self.resolver.as_ref()
137    }
138
139    pub fn execution_resolver_arc(&self) -> Arc<dyn ExecutionResolver> {
140        self.resolver.clone()
141    }
142
143    pub fn registry_handle(&self) -> Option<RegistryHandle> {
144        self.registry_handle.clone()
145    }
146
147    pub fn registry_snapshot(&self) -> Option<RegistrySnapshot> {
148        self.registry_handle.as_ref().map(RegistryHandle::snapshot)
149    }
150
151    pub fn registry_version(&self) -> Option<u64> {
152        self.registry_handle.as_ref().map(RegistryHandle::version)
153    }
154
155    pub fn registry_set(&self) -> Option<RegistrySet> {
156        self.registry_snapshot()
157            .map(RegistrySnapshot::into_registries)
158    }
159
160    pub fn replace_registry_set(&self, registries: RegistrySet) -> Option<u64> {
161        self.registry_handle
162            .as_ref()
163            .map(|handle| handle.replace(registries))
164    }
165
166    #[cfg(feature = "a2a")]
167    #[must_use]
168    pub fn with_composite_registry(mut self, registry: Arc<CompositeAgentSpecRegistry>) -> Self {
169        self.composite_registry = Some(registry);
170        self
171    }
172
173    /// Return the composite registry, if one was configured.
174    #[cfg(feature = "a2a")]
175    pub fn composite_registry(&self) -> Option<&Arc<CompositeAgentSpecRegistry>> {
176        self.composite_registry.as_ref()
177    }
178
179    /// Initialize the runtime — discover remote agents.
180    /// Call this after `build()` to complete async initialization.
181    #[cfg(feature = "a2a")]
182    pub async fn initialize(&self) -> Result<(), RuntimeError> {
183        if let Some(composite) = &self.composite_registry {
184            composite
185                .discover()
186                .await
187                .map_err(|e| RuntimeError::ResolveFailed {
188                    message: format!("remote agent discovery failed: {e}"),
189                })?;
190        }
191        Ok(())
192    }
193
194    pub fn thread_run_store(&self) -> Option<&dyn ThreadRunStore> {
195        self.storage.as_deref()
196    }
197
198    /// Create a run handle pair (handle + internal channels).
199    ///
200    /// Returns (RunHandle for caller, CancellationToken for loop, decision_rx for loop).
201    pub(crate) fn create_run_channels(
202        &self,
203        run_id: String,
204    ) -> (
205        RunHandle,
206        CancellationToken,
207        mpsc::UnboundedReceiver<DecisionBatch>,
208    ) {
209        let token = CancellationToken::new();
210        let (tx, rx) = mpsc::unbounded();
211
212        let handle = RunHandle {
213            run_id,
214            cancellation_token: token.clone(),
215            decision_tx: tx,
216        };
217
218        (handle, token, rx)
219    }
220
221    /// Register an active run. Returns error if thread already has one.
222    ///
223    /// Uses atomic try-insert to avoid TOCTOU race between check and insert.
224    pub(crate) fn register_run(
225        &self,
226        thread_id: &str,
227        handle: RunHandle,
228    ) -> Result<(), RuntimeError> {
229        let run_id = handle.run_id.clone();
230        if !self.active_runs.register(&run_id, thread_id, handle) {
231            return Err(RuntimeError::ThreadAlreadyRunning {
232                thread_id: thread_id.to_string(),
233            });
234        }
235        Ok(())
236    }
237
238    /// Unregister an active run when it completes (by run_id).
239    pub(crate) fn unregister_run(&self, run_id: &str) {
240        self.active_runs.unregister(run_id);
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::sync::Arc;
248
249    use awaken_contract::contract::suspension::{ResumeDecisionAction, ToolCallResume};
250    use serde_json::Value;
251
252    struct StubResolver;
253    impl crate::registry::AgentResolver for StubResolver {
254        fn resolve(
255            &self,
256            agent_id: &str,
257        ) -> Result<crate::registry::ResolvedAgent, crate::error::RuntimeError> {
258            Err(crate::error::RuntimeError::AgentNotFound {
259                agent_id: agent_id.to_string(),
260            })
261        }
262    }
263
264    fn make_runtime() -> AgentRuntime {
265        AgentRuntime::new(Arc::new(StubResolver))
266    }
267
268    fn make_resume() -> ToolCallResume {
269        ToolCallResume {
270            decision_id: "d1".into(),
271            action: ResumeDecisionAction::Resume,
272            result: Value::Null,
273            reason: None,
274            updated_at: 0,
275        }
276    }
277
278    #[test]
279    fn new_creates_runtime() {
280        let rt = make_runtime();
281        assert!(rt.storage.is_none());
282        assert!(rt.profile_store.is_none());
283        assert!(rt.registry_handle().is_none());
284    }
285
286    #[test]
287    fn resolver_returns_ref() {
288        let rt = make_runtime();
289        // The stub resolver always returns AgentNotFound
290        let err = rt.resolver().resolve("any").unwrap_err();
291        assert!(
292            matches!(err, crate::error::RuntimeError::AgentNotFound { .. }),
293            "expected AgentNotFound, got {err:?}"
294        );
295    }
296
297    #[test]
298    fn resolver_arc_returns_clone() {
299        let rt = make_runtime();
300        let arc = rt.resolver_arc();
301        let err = arc.resolve("x").unwrap_err();
302        assert!(matches!(
303            err,
304            crate::error::RuntimeError::AgentNotFound { .. }
305        ));
306    }
307
308    #[test]
309    fn with_thread_run_store_sets_store() {
310        let store = Arc::new(awaken_stores::InMemoryStore::new());
311        let rt = make_runtime().with_thread_run_store(store);
312        assert!(rt.thread_run_store().is_some());
313    }
314
315    #[test]
316    fn thread_run_store_none_by_default() {
317        let rt = make_runtime();
318        assert!(rt.thread_run_store().is_none());
319    }
320
321    #[test]
322    fn create_run_channels_returns_triple() {
323        let rt = make_runtime();
324        let (handle, token, _rx) = rt.create_run_channels("run-1".into());
325        assert_eq!(handle.run_id, "run-1");
326        assert!(!token.is_cancelled());
327    }
328
329    #[test]
330    fn register_run_succeeds() {
331        let rt = make_runtime();
332        let (handle, _token, _rx) = rt.create_run_channels("run-1".into());
333        assert!(rt.register_run("thread-1", handle).is_ok());
334    }
335
336    #[test]
337    fn register_run_fails_for_same_thread() {
338        let rt = make_runtime();
339        let (h1, _, _rx1) = rt.create_run_channels("run-1".into());
340        let (h2, _, _rx2) = rt.create_run_channels("run-2".into());
341        rt.register_run("thread-1", h1).unwrap();
342        let err = rt.register_run("thread-1", h2).unwrap_err();
343        assert!(
344            matches!(err, RuntimeError::ThreadAlreadyRunning { ref thread_id } if thread_id == "thread-1"),
345            "expected ThreadAlreadyRunning, got {err:?}"
346        );
347    }
348
349    #[test]
350    fn unregister_run_allows_reregistration() {
351        let rt = make_runtime();
352        let (h1, _, _rx1) = rt.create_run_channels("run-1".into());
353        rt.register_run("thread-1", h1).unwrap();
354        rt.unregister_run("run-1");
355
356        let (h2, _, _rx2) = rt.create_run_channels("run-2".into());
357        assert!(rt.register_run("thread-1", h2).is_ok());
358    }
359
360    #[test]
361    fn run_handle_cancel() {
362        let rt = make_runtime();
363        let (handle, token, _rx) = rt.create_run_channels("run-1".into());
364        assert!(!token.is_cancelled());
365        handle.cancel();
366        assert!(token.is_cancelled());
367    }
368
369    #[test]
370    fn run_handle_send_decisions() {
371        let rt = make_runtime();
372        let (handle, _token, mut rx) = rt.create_run_channels("run-1".into());
373        let decisions = vec![("call-1".into(), make_resume())];
374        handle.send_decisions(decisions).unwrap();
375
376        // Receive the batch from the channel
377        let batch = rx.try_recv().unwrap();
378        assert_eq!(batch.len(), 1);
379        assert_eq!(batch[0].0, "call-1");
380    }
381
382    #[test]
383    fn run_handle_send_decision_single() {
384        let rt = make_runtime();
385        let (handle, _token, mut rx) = rt.create_run_channels("run-1".into());
386        handle
387            .send_decision("call-2".into(), make_resume())
388            .unwrap();
389
390        let batch = rx.try_recv().unwrap();
391        assert_eq!(batch.len(), 1);
392        assert_eq!(batch[0].0, "call-2");
393    }
394
395    #[test]
396    fn run_handle_send_decisions_closed_channel() {
397        let rt = make_runtime();
398        let (handle, _token, rx) = rt.create_run_channels("run-1".into());
399        // Drop the receiver to close the channel
400        drop(rx);
401
402        let result = handle.send_decisions(vec![("call-1".into(), make_resume())]);
403        assert!(result.is_err(), "send should fail when receiver is dropped");
404    }
405}