use crate::identity::{AdkIdentity, AppName, ExecutionIdentity, InvocationId, SessionId, UserId};
use crate::{AdkError, Agent, Result, types::Content};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
#[async_trait]
pub trait ReadonlyContext: Send + Sync {
fn invocation_id(&self) -> &str;
fn agent_name(&self) -> &str;
fn user_id(&self) -> &str;
fn app_name(&self) -> &str;
fn session_id(&self) -> &str;
fn branch(&self) -> &str;
fn user_content(&self) -> &Content;
fn try_app_name(&self) -> Result<AppName> {
Ok(AppName::try_from(self.app_name())?)
}
fn try_user_id(&self) -> Result<UserId> {
Ok(UserId::try_from(self.user_id())?)
}
fn try_session_id(&self) -> Result<SessionId> {
Ok(SessionId::try_from(self.session_id())?)
}
fn try_invocation_id(&self) -> Result<InvocationId> {
Ok(InvocationId::try_from(self.invocation_id())?)
}
fn try_identity(&self) -> Result<AdkIdentity> {
Ok(AdkIdentity {
app_name: self.try_app_name()?,
user_id: self.try_user_id()?,
session_id: self.try_session_id()?,
})
}
fn try_execution_identity(&self) -> Result<ExecutionIdentity> {
Ok(ExecutionIdentity {
adk: self.try_identity()?,
invocation_id: self.try_invocation_id()?,
branch: self.branch().to_string(),
agent_name: self.agent_name().to_string(),
})
}
}
pub const MAX_STATE_KEY_LEN: usize = 256;
pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
if key.is_empty() {
return Err("state key must not be empty");
}
if key.len() > MAX_STATE_KEY_LEN {
return Err("state key exceeds maximum length of 256 bytes");
}
if key.contains('/') || key.contains('\\') || key.contains("..") {
return Err("state key must not contain path separators or '..'");
}
if key.contains('\0') {
return Err("state key must not contain null bytes");
}
Ok(())
}
pub trait State: Send + Sync {
fn get(&self, key: &str) -> Option<Value>;
fn set(&mut self, key: String, value: Value);
fn all(&self) -> HashMap<String, Value>;
}
pub trait ReadonlyState: Send + Sync {
fn get(&self, key: &str) -> Option<Value>;
fn all(&self) -> HashMap<String, Value>;
}
pub trait Session: Send + Sync {
fn id(&self) -> &str;
fn app_name(&self) -> &str;
fn user_id(&self) -> &str;
fn state(&self) -> &dyn State;
fn conversation_history(&self) -> Vec<Content>;
fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
self.conversation_history()
}
fn append_to_history(&self, _content: Content) {
}
fn try_app_name(&self) -> Result<AppName> {
Ok(AppName::try_from(self.app_name())?)
}
fn try_user_id(&self) -> Result<UserId> {
Ok(UserId::try_from(self.user_id())?)
}
fn try_session_id(&self) -> Result<SessionId> {
Ok(SessionId::try_from(self.id())?)
}
fn try_identity(&self) -> Result<AdkIdentity> {
Ok(AdkIdentity {
app_name: self.try_app_name()?,
user_id: self.try_user_id()?,
session_id: self.try_session_id()?,
})
}
}
#[derive(Debug, Clone)]
pub struct ToolOutcome {
pub tool_name: String,
pub tool_args: serde_json::Value,
pub success: bool,
pub duration: std::time::Duration,
pub error_message: Option<String>,
pub attempt: u32,
}
#[async_trait]
pub trait CallbackContext: ReadonlyContext {
fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
fn tool_outcome(&self) -> Option<ToolOutcome> {
None }
fn tool_name(&self) -> Option<&str> {
None
}
fn tool_input(&self) -> Option<&serde_json::Value> {
None
}
fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
None
}
}
pub struct ToolCallbackContext {
pub inner: Arc<dyn CallbackContext>,
pub tool_name: String,
pub tool_input: serde_json::Value,
}
impl ToolCallbackContext {
pub fn new(
inner: Arc<dyn CallbackContext>,
tool_name: String,
tool_input: serde_json::Value,
) -> Self {
Self { inner, tool_name, tool_input }
}
}
#[async_trait]
impl ReadonlyContext for ToolCallbackContext {
fn invocation_id(&self) -> &str {
self.inner.invocation_id()
}
fn agent_name(&self) -> &str {
self.inner.agent_name()
}
fn user_id(&self) -> &str {
self.inner.user_id()
}
fn app_name(&self) -> &str {
self.inner.app_name()
}
fn session_id(&self) -> &str {
self.inner.session_id()
}
fn branch(&self) -> &str {
self.inner.branch()
}
fn user_content(&self) -> &Content {
self.inner.user_content()
}
}
#[async_trait]
impl CallbackContext for ToolCallbackContext {
fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
self.inner.artifacts()
}
fn tool_outcome(&self) -> Option<ToolOutcome> {
self.inner.tool_outcome()
}
fn tool_name(&self) -> Option<&str> {
Some(&self.tool_name)
}
fn tool_input(&self) -> Option<&serde_json::Value> {
Some(&self.tool_input)
}
fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
self.inner.shared_state()
}
}
#[async_trait]
pub trait InvocationContext: CallbackContext {
fn agent(&self) -> Arc<dyn Agent>;
fn memory(&self) -> Option<Arc<dyn Memory>>;
fn session(&self) -> &dyn Session;
fn run_config(&self) -> &RunConfig;
fn end_invocation(&self);
fn ended(&self) -> bool;
fn user_scopes(&self) -> Vec<String> {
vec![]
}
fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
HashMap::new()
}
async fn get_secret(&self, _name: &str) -> Result<Option<String>> {
Ok(None)
}
}
#[async_trait]
pub trait Artifacts: Send + Sync {
async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
async fn load(&self, name: &str) -> Result<crate::Part>;
async fn list(&self) -> Result<Vec<String>>;
}
#[async_trait]
pub trait Memory: Send + Sync {
async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
async fn health_check(&self) -> Result<()> {
Ok(())
}
async fn add(&self, entry: MemoryEntry) -> Result<()> {
let _ = entry;
Err(AdkError::memory("add not implemented"))
}
async fn delete(&self, query: &str) -> Result<u64> {
let _ = query;
Err(AdkError::memory("delete not implemented"))
}
async fn search_in_project(&self, query: &str, project_id: &str) -> Result<Vec<MemoryEntry>> {
let _ = project_id;
self.search(query).await
}
async fn add_to_project(&self, entry: MemoryEntry, project_id: &str) -> Result<()> {
let _ = project_id;
self.add(entry).await
}
}
#[async_trait]
pub trait SecretService: Send + Sync {
async fn get_secret(&self, name: &str) -> Result<String>;
}
#[derive(Debug, Clone)]
pub struct MemoryEntry {
pub content: Content,
pub author: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum StreamingMode {
None,
#[default]
SSE,
Bidi,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IncludeContents {
None,
#[default]
Default,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolConfirmationDecision {
Approve,
Deny,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ToolConfirmationPolicy {
#[default]
Never,
Always,
PerTool(BTreeSet<String>),
}
impl ToolConfirmationPolicy {
pub fn requires_confirmation(&self, tool_name: &str) -> bool {
match self {
Self::Never => false,
Self::Always => true,
Self::PerTool(tools) => tools.contains(tool_name),
}
}
pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
let tool_name = tool_name.into();
match &mut self {
Self::Never => {
let mut tools = BTreeSet::new();
tools.insert(tool_name);
Self::PerTool(tools)
}
Self::Always => Self::Always,
Self::PerTool(tools) => {
tools.insert(tool_name);
self
}
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfirmationRequest {
pub tool_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call_id: Option<String>,
pub args: Value,
}
#[derive(Debug, Clone)]
pub struct RunConfig {
pub streaming_mode: StreamingMode,
pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
pub cached_content: Option<String>,
pub transfer_targets: Vec<String>,
pub parent_agent: Option<String>,
pub auto_cache: bool,
}
impl Default for RunConfig {
fn default() -> Self {
Self {
streaming_mode: StreamingMode::SSE,
tool_confirmation_decisions: HashMap::new(),
cached_content: None,
transfer_targets: Vec::new(),
parent_agent: None,
auto_cache: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_run_config_default() {
let config = RunConfig::default();
assert_eq!(config.streaming_mode, StreamingMode::SSE);
assert!(config.tool_confirmation_decisions.is_empty());
}
#[test]
fn test_streaming_mode() {
assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
assert_ne!(StreamingMode::SSE, StreamingMode::None);
assert_ne!(StreamingMode::None, StreamingMode::Bidi);
}
#[test]
fn test_tool_confirmation_policy() {
let policy = ToolConfirmationPolicy::default();
assert!(!policy.requires_confirmation("search"));
let policy = policy.with_tool("search");
assert!(policy.requires_confirmation("search"));
assert!(!policy.requires_confirmation("write_file"));
assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
}
#[test]
fn test_validate_state_key_valid() {
assert!(validate_state_key("user_name").is_ok());
assert!(validate_state_key("app:config").is_ok());
assert!(validate_state_key("temp:data").is_ok());
assert!(validate_state_key("a").is_ok());
}
#[test]
fn test_validate_state_key_empty() {
assert_eq!(validate_state_key(""), Err("state key must not be empty"));
}
#[test]
fn test_validate_state_key_too_long() {
let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
assert!(validate_state_key(&long_key).is_err());
}
#[test]
fn test_validate_state_key_path_traversal() {
assert!(validate_state_key("../etc/passwd").is_err());
assert!(validate_state_key("foo/bar").is_err());
assert!(validate_state_key("foo\\bar").is_err());
assert!(validate_state_key("..").is_err());
}
#[test]
fn test_validate_state_key_null_byte() {
assert!(validate_state_key("foo\0bar").is_err());
}
}