Skip to main content

adk_rs/core/
context.rs

1//! Invocation and tool execution contexts.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use indexmap::IndexMap;
7use parking_lot::Mutex;
8use serde_json::Value;
9use uuid::Uuid;
10
11use crate::genai_types::Content;
12
13use crate::core::cancel::CancellationToken;
14use crate::core::run_config::RunConfig;
15use crate::core::services::{ArtifactService, CredentialService, MemoryService, SessionService};
16use crate::core::session::Session;
17use crate::core::state::StateDelta;
18
19/// Where the invocation came from. Mostly informational; the runner uses this
20/// to decide whether to auto-create a session.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum InvocationOrigin {
23    /// Direct API call (default).
24    #[default]
25    Api,
26    /// CLI / REPL.
27    Cli,
28    /// Web server / browser.
29    Web,
30}
31
32/// Per-invocation context carried through the agent run.
33#[derive(Clone)]
34pub struct InvocationContext {
35    /// App name.
36    pub app_name: String,
37    /// User id.
38    pub user_id: String,
39    /// Invocation id (unique per `runner.run` call).
40    pub invocation_id: String,
41    /// The session being mutated. Wrapped in [`Arc<Mutex>`] so the runner
42    /// and tool callbacks can both mutate state safely.
43    pub session: Arc<Mutex<Session>>,
44    /// Active session service.
45    pub session_service: Arc<dyn SessionService>,
46    /// Optional artifact service.
47    pub artifact_service: Option<Arc<dyn ArtifactService>>,
48    /// Optional memory service.
49    pub memory_service: Option<Arc<dyn MemoryService>>,
50    /// Optional credential service.
51    pub credential_service: Option<Arc<dyn CredentialService>>,
52    /// Per-invocation config.
53    pub run_config: RunConfig,
54    /// Origin of the invocation (CLI, web, API).
55    pub origin: InvocationOrigin,
56    /// User content for this invocation, if any.
57    pub user_content: Option<Content>,
58    /// Counter of LLM calls performed (capped by `RunConfig::max_llm_calls`).
59    pub llm_call_count: Arc<Mutex<u32>>,
60    /// Cooperative cancellation flag. Agents check this at safe points and
61    /// short-circuit cleanly when it flips. Flipped by
62    /// [`crate::runner::Runner::cancel`] or by the A2A `tasks/cancel`
63    /// handler — the same token plumbs through to both surfaces so cancels
64    /// initiated from either side reach the in-flight agent.
65    pub cancellation: CancellationToken,
66    /// Free-form attribute bag for plugins / agent-specific bookkeeping.
67    pub attributes: Arc<Mutex<HashMap<String, Value>>>,
68    /// Root of the agent tree for this invocation (set by the runner).
69    /// Agent transfer resolves targets from here, so an agent can reach
70    /// siblings and ancestors, not just its own subtree.
71    pub root_agent: Option<Arc<dyn crate::agents::BaseAgent>>,
72}
73
74impl std::fmt::Debug for InvocationContext {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("InvocationContext")
77            .field("app_name", &self.app_name)
78            .field("user_id", &self.user_id)
79            .field("invocation_id", &self.invocation_id)
80            .field("origin", &self.origin)
81            .finish()
82    }
83}
84
85impl InvocationContext {
86    /// Generate a fresh invocation id.
87    #[must_use]
88    pub fn new_id() -> String {
89        format!("inv-{}", Uuid::new_v4())
90    }
91
92    /// Increment the LLM call counter. Returns `Err` if the cap was already
93    /// reached.
94    pub fn check_and_inc_llm_call(&self) -> crate::error::Result<()> {
95        let mut n = self.llm_call_count.lock();
96        if let Some(cap) = self.run_config.max_llm_calls {
97            if *n >= cap {
98                return Err(crate::error::Error::config(format!(
99                    "max_llm_calls={cap} reached"
100                )));
101            }
102        }
103        *n += 1;
104        Ok(())
105    }
106
107    /// True if [`Self::cancellation`] has been flipped. Agents call this
108    /// at safe points (between iterations of the LLM↔tool loop, between
109    /// sub-agents) to halt cleanly.
110    #[must_use]
111    pub fn is_cancelled(&self) -> bool {
112        self.cancellation.is_cancelled()
113    }
114}
115
116/// Per-call context passed to a tool's `run` method.
117pub struct ToolContext {
118    /// Underlying invocation context.
119    pub invocation: Arc<InvocationContext>,
120    /// Function-call id (matches `FunctionCall::id`).
121    pub function_call_id: Option<String>,
122    /// State delta accumulator that the tool can mutate.
123    pub state_delta: StateDelta,
124    /// Artifact-version delta accumulator.
125    pub artifact_delta: IndexMap<String, u64>,
126    /// If set on return, the runner skips summarization of the response.
127    pub skip_summarization: bool,
128    /// If set on return, the runner transfers control to the named agent.
129    pub transfer_to_agent: Option<String>,
130    /// If set on return, the runner unwinds escalations.
131    pub escalate: bool,
132    /// If true, the tool returned a long-running operation handle.
133    pub long_running: bool,
134    /// Resolved credential, set by the runner before `run` when the tool
135    /// declared an `auth_config()`. Authenticated tools should read their
136    /// access token / API key from here rather than from the args.
137    pub auth_credential: Option<crate::auth::AuthCredential>,
138    /// The user's confirmation decision, set before `run` when the tool
139    /// requires confirmation and the user approved it (carries any payload
140    /// the user attached to the approval).
141    pub tool_confirmation: Option<crate::core::tool_confirmation::ToolConfirmation>,
142}
143
144impl ToolContext {
145    /// Construct.
146    pub fn new(invocation: Arc<InvocationContext>) -> Self {
147        Self {
148            invocation,
149            function_call_id: None,
150            state_delta: StateDelta::new(),
151            artifact_delta: IndexMap::new(),
152            skip_summarization: false,
153            transfer_to_agent: None,
154            escalate: false,
155            long_running: false,
156            auth_credential: None,
157            tool_confirmation: None,
158        }
159    }
160
161    /// Set the function-call id.
162    #[must_use]
163    pub fn with_function_call_id(mut self, id: impl Into<String>) -> Self {
164        self.function_call_id = Some(id.into());
165        self
166    }
167
168    /// Save an artifact via the configured artifact service. Returns the
169    /// new version. Errors if no artifact service is configured.
170    pub async fn save_artifact(
171        &mut self,
172        filename: &str,
173        part: crate::genai_types::Part,
174    ) -> crate::error::Result<u64> {
175        let svc = self
176            .invocation
177            .artifact_service
178            .as_ref()
179            .ok_or_else(|| crate::error::Error::config("no artifact service configured"))?;
180        let key = crate::core::artifact::ArtifactKey::new(
181            &self.invocation.app_name,
182            &self.invocation.user_id,
183            &self.invocation.session.lock().id,
184            filename,
185        );
186        let v = svc.save_artifact(key, part).await?;
187        self.artifact_delta.insert(filename.to_string(), v);
188        Ok(v)
189    }
190
191    /// Load an artifact via the configured artifact service.
192    pub async fn load_artifact(
193        &self,
194        filename: &str,
195        version: Option<u64>,
196    ) -> crate::error::Result<Option<crate::genai_types::Part>> {
197        let svc = self
198            .invocation
199            .artifact_service
200            .as_ref()
201            .ok_or_else(|| crate::error::Error::config("no artifact service configured"))?;
202        let key = crate::core::artifact::ArtifactKey::new(
203            &self.invocation.app_name,
204            &self.invocation.user_id,
205            &self.invocation.session.lock().id,
206            filename,
207        );
208        svc.load_artifact(key, version).await
209    }
210}