#[cfg(feature = "runtime")]
use crate::agent::Artifact;
use crate::errors::AgentError;
#[cfg(feature = "runtime")]
use crate::errors::AgentResult;
#[cfg(feature = "runtime")]
use crate::models::utils;
#[cfg(feature = "runtime")]
use crate::models::{Content, Role};
#[cfg(feature = "runtime")]
use crate::runtime::core::event_bus::TaskEventBus;
#[cfg(feature = "runtime")]
use crate::runtime::core::status_mapper;
#[cfg(feature = "runtime")]
use crate::runtime::task_manager::{TaskEvent, TaskManager};
#[cfg(feature = "runtime")]
use a2a_types::{TaskArtifactUpdateEvent, TaskState as A2ATaskState, TaskStatus};
#[cfg(feature = "runtime")]
use chrono::Utc;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "runtime")]
use std::sync::Arc;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct SessionState {
#[serde(default)]
data: HashMap<String, serde_json::Value>,
}
impl SessionState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn save<T>(&mut self, key: &str, value: &T) -> Result<(), AgentError>
where
T: Serialize,
{
let serialized =
serde_json::to_value(value).map_err(|e| AgentError::ContextError(e.to_string()))?;
self.data.insert(key.to_string(), serialized);
Ok(())
}
pub fn load<T>(&self, key: &str) -> Result<Option<T>, AgentError>
where
T: DeserializeOwned,
{
match self.data.get(key) {
Some(value) => {
let deserialized = serde_json::from_value(value.clone())
.map_err(|e| AgentError::ContextError(e.to_string()))?;
Ok(Some(deserialized))
}
None => Ok(None),
}
}
pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
self.data.remove(key)
}
#[must_use]
pub fn contains(&self, key: &str) -> bool {
self.data.contains_key(key)
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TaskState {
#[serde(default)]
data: HashMap<String, serde_json::Value>,
#[serde(default)]
slot: Option<serde_json::Value>,
}
impl TaskState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn save<T>(&mut self, key: &str, value: &T) -> Result<(), AgentError>
where
T: Serialize,
{
let serialized =
serde_json::to_value(value).map_err(|e| AgentError::ContextError(e.to_string()))?;
self.data.insert(key.to_string(), serialized);
Ok(())
}
pub fn load<T>(&self, key: &str) -> Result<Option<T>, AgentError>
where
T: DeserializeOwned,
{
match self.data.get(key) {
Some(value) => {
let deserialized = serde_json::from_value(value.clone())
.map_err(|e| AgentError::ContextError(e.to_string()))?;
Ok(Some(deserialized))
}
None => Ok(None),
}
}
pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
self.data.remove(key)
}
#[must_use]
pub fn current_slot(&self) -> Option<crate::agent::SkillSlot> {
self.slot
.clone()
.map(crate::agent::SkillSlot::from_value_unchecked)
}
pub fn slot<T>(&self) -> Result<Option<T>, AgentError>
where
T: DeserializeOwned,
{
match &self.slot {
Some(value) => {
let slot: T = serde_json::from_value(value.clone())
.map_err(|e| AgentError::SkillSlot(e.to_string()))?;
Ok(Some(slot))
}
None => Ok(None),
}
}
pub fn set_slot<T>(&mut self, slot: T) -> Result<(), AgentError>
where
T: Serialize,
{
let serialized =
serde_json::to_value(slot).map_err(|e| AgentError::SkillSlot(e.to_string()))?;
self.slot = Some(serialized);
Ok(())
}
pub fn clear_slot(&mut self) {
self.slot = None;
}
}
#[derive(Debug, Default)]
pub struct State {
task: TaskState,
session: SessionState,
}
impl State {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_states(task: TaskState, session: SessionState) -> Self {
Self { task, session }
}
pub const fn task(&mut self) -> &mut TaskState {
&mut self.task
}
pub const fn session(&mut self) -> &mut SessionState {
&mut self.session
}
#[must_use]
pub const fn task_ref(&self) -> &TaskState {
&self.task
}
#[must_use]
pub const fn session_ref(&self) -> &SessionState {
&self.session
}
#[must_use]
pub fn into_parts(self) -> (TaskState, SessionState) {
(self.task, self.session)
}
pub fn slot<T>(&self) -> Result<Option<T>, AgentError>
where
T: DeserializeOwned,
{
match &self.task.slot {
Some(value) => {
let slot: T = serde_json::from_value(value.clone())
.map_err(|e| AgentError::SkillSlot(e.to_string()))?;
Ok(Some(slot))
}
None => Ok(None),
}
}
pub fn set_slot<T>(&mut self, slot: T) -> Result<(), AgentError>
where
T: Serialize,
{
let serialized =
serde_json::to_value(slot).map_err(|e| AgentError::SkillSlot(e.to_string()))?;
self.task.slot = Some(serialized);
Ok(())
}
pub fn clear_slot(&mut self) {
self.task.slot = None;
}
pub fn current_slot(&self) -> Option<crate::agent::SkillSlot> {
self.task
.slot
.clone()
.map(crate::agent::SkillSlot::from_value_unchecked)
}
#[cfg(feature = "runtime")]
pub(crate) fn set_pending_slot(&mut self, slot: crate::agent::SkillSlot) {
self.task.slot = Some(slot.into_value());
}
#[cfg(feature = "runtime")]
pub(crate) fn clear_pending_slot(&mut self) {
self.task.slot = None;
}
}
#[cfg(feature = "runtime")]
pub struct ProgressSender {
auth: Option<AuthContext>,
task_manager: Option<Arc<dyn TaskManager>>,
task_id: String,
context_id: String,
event_bus: Option<Arc<TaskEventBus>>,
}
#[cfg(feature = "runtime")]
impl ProgressSender {
pub(crate) fn new(
auth: AuthContext,
task_manager: Arc<dyn TaskManager>,
event_bus: Arc<TaskEventBus>,
context_id: impl Into<String>,
task_id: impl Into<String>,
) -> Self {
Self {
auth: Some(auth),
task_manager: Some(task_manager),
context_id: context_id.into(),
task_id: task_id.into(),
event_bus: Some(event_bus),
}
}
#[must_use]
pub fn noop() -> Self {
Self {
auth: None,
task_manager: None,
context_id: String::new(),
task_id: String::new(),
event_bus: None,
}
}
pub async fn send_update(&self, message: impl Into<Content>) -> AgentResult<()> {
let (Some(auth), Some(task_manager), Some(event_bus)) =
(&self.auth, &self.task_manager, &self.event_bus)
else {
return Ok(());
};
let now = Utc::now();
let status = TaskStatus {
state: A2ATaskState::Working as i32,
timestamp: Some(::pbjson_types::Timestamp {
seconds: now.timestamp(),
nanos: now.timestamp_subsec_nanos().cast_signed(),
}),
message: Some(utils::create_a2a_message(
Some(&self.context_id),
Some(&self.task_id),
Role::Assistant,
message.into(),
)),
};
let event = status_mapper::create_status_update_event(
&self.task_id,
&self.context_id,
status,
false,
);
let task_event = TaskEvent::StatusUpdate(event);
task_manager.add_task_event(auth, &task_event).await?;
event_bus.publish(&task_event);
Ok(())
}
pub async fn send_partial_artifact(&self, artifact: Artifact) -> AgentResult<()> {
let (Some(auth), Some(task_manager), Some(event_bus)) =
(&self.auth, &self.task_manager, &self.event_bus)
else {
return Ok(());
};
let a2a_artifact = utils::artifact_to_a2a(&artifact);
let event = TaskArtifactUpdateEvent {
task_id: self.task_id.clone(),
context_id: self.context_id.clone(),
artifact: Some(a2a_artifact),
append: false,
last_chunk: false,
metadata: None,
};
let task_event = TaskEvent::ArtifactUpdate(event);
task_manager.add_task_event(auth, &task_event).await?;
event_bus.publish(&task_event);
Ok(())
}
}
#[cfg(not(feature = "runtime"))]
pub struct ProgressSender {
_private: (),
}
#[cfg(not(feature = "runtime"))]
impl ProgressSender {
#[must_use]
pub const fn noop() -> Self {
Self { _private: () }
}
#[allow(clippy::unused_async)] pub async fn send_update(
&self,
_message: impl Into<crate::models::Content>,
) -> Result<(), AgentError> {
Ok(())
}
#[allow(clippy::unused_async)] pub async fn send_partial_artifact(
&self,
_artifact: crate::agent::Artifact,
) -> Result<(), AgentError> {
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct AuthContext {
pub app_name: String,
pub user_name: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_state_save_load_roundtrip() {
let mut session = SessionState::new();
session.save("key", &42u32).expect("save");
let value: Option<u32> = session.load("key").expect("load");
assert_eq!(value, Some(42));
let missing: Option<u32> = session.load("missing").expect("load");
assert!(missing.is_none());
}
#[test]
fn session_state_remove_and_contains() {
let mut session = SessionState::new();
assert!(!session.contains("key"));
assert!(session.is_empty());
session.save("key", &"value").expect("save");
assert!(session.contains("key"));
assert_eq!(session.len(), 1);
session.remove("key");
assert!(!session.contains("key"));
assert!(session.is_empty());
}
#[test]
fn task_state_save_load_roundtrip() {
let mut task = TaskState::new();
task.save("partial", &vec![1, 2, 3]).expect("save");
let value: Option<Vec<i32>> = task.load("partial").expect("load");
assert_eq!(value, Some(vec![1, 2, 3]));
}
#[test]
fn task_state_slot_roundtrip() {
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
enum MySlot {
NeedEmail,
NeedPhone { name: String },
}
let mut task = TaskState::new();
task.set_slot(MySlot::NeedEmail).expect("set slot");
let slot: Option<MySlot> = task.slot().expect("get slot");
assert_eq!(slot, Some(MySlot::NeedEmail));
task.set_slot(MySlot::NeedPhone {
name: "Alice".into(),
})
.expect("set slot");
let slot: Option<MySlot> = task.slot().expect("get slot");
assert_eq!(
slot,
Some(MySlot::NeedPhone {
name: "Alice".into()
})
);
task.clear_slot();
let slot: Option<MySlot> = task.slot().expect("get slot");
assert!(slot.is_none());
}
#[test]
fn state_provides_scoped_access() {
let mut state = State::new();
state.task().save("task_key", &"task_value").expect("save");
let task_val: Option<String> = state.task().load("task_key").expect("load");
assert_eq!(task_val, Some("task_value".to_string()));
state
.session()
.save("session_key", &"session_value")
.expect("save");
let session_val: Option<String> = state.session().load("session_key").expect("load");
assert_eq!(session_val, Some("session_value".to_string()));
let task_missing: Option<String> = state.task().load("session_key").expect("load");
assert!(task_missing.is_none());
}
#[test]
fn state_with_existing_states() {
let mut task = TaskState::new();
task.save("key", &1).expect("save");
let mut session = SessionState::new();
session.save("key", &2).expect("save");
let state = State::with_states(task, session);
let task_val: Option<i32> = state.task_ref().load("key").expect("load");
let session_val: Option<i32> = state.session_ref().load("key").expect("load");
assert_eq!(task_val, Some(1));
assert_eq!(session_val, Some(2));
}
#[test]
fn state_into_parts() {
let mut state = State::new();
state.task().save("t", &1).expect("save");
state.session().save("s", &2).expect("save");
let (task, session) = state.into_parts();
let t: Option<i32> = task.load("t").expect("load");
let s: Option<i32> = session.load("s").expect("load");
assert_eq!(t, Some(1));
assert_eq!(s, Some(2));
}
}