use crate::runtime::activity::ActivityManager;
use crate::runtime::run::delta::RunDelta;
use crate::runtime::run::RunIdentity;
use crate::runtime::state::SerializedStateAction;
use crate::runtime::suspended_calls_from_state;
use crate::runtime::tool_call::ToolCallContext;
use crate::runtime::tool_call::{CallerContext, SuspendedCall};
use crate::thread::Message;
use crate::RunPolicy;
use serde_json::Value;
use std::sync::{Arc, Mutex};
use tirea_state::{
apply_patches_with_registry, get_at_path, parse_path, DeltaTracked, DocCell, LatticeRegistry,
Op, State, TireaResult, TrackedPatch,
};
pub struct RunContext {
thread_base: Value,
messages: DeltaTracked<Arc<Message>>,
thread_patches: DeltaTracked<TrackedPatch>,
serialized_state_actions: DeltaTracked<SerializedStateAction>,
run_policy: RunPolicy,
run_identity: RunIdentity,
doc: DocCell,
version: Option<u64>,
version_timestamp: Option<u64>,
lattice_registry: Arc<LatticeRegistry>,
}
impl RunContext {
pub fn new(
thread_id: impl Into<String>,
state: Value,
messages: Vec<Arc<Message>>,
run_policy: RunPolicy,
) -> Self {
let thread_id = thread_id.into();
Self::with_registry_and_identity(
state,
messages,
run_policy,
RunIdentity::for_thread(thread_id),
Arc::new(LatticeRegistry::new()),
)
}
pub fn with_registry(
thread_id: impl Into<String>,
state: Value,
messages: Vec<Arc<Message>>,
run_policy: RunPolicy,
lattice_registry: Arc<LatticeRegistry>,
) -> Self {
let thread_id = thread_id.into();
Self::with_registry_and_identity(
state,
messages,
run_policy,
RunIdentity::for_thread(thread_id),
lattice_registry,
)
}
pub fn with_registry_and_identity(
state: Value,
messages: Vec<Arc<Message>>,
run_policy: RunPolicy,
run_identity: RunIdentity,
lattice_registry: Arc<LatticeRegistry>,
) -> Self {
let doc = DocCell::new(state.clone());
Self {
thread_base: state,
messages: DeltaTracked::new(messages),
thread_patches: DeltaTracked::empty(),
serialized_state_actions: DeltaTracked::empty(),
run_policy,
run_identity,
doc,
version: None,
version_timestamp: None,
lattice_registry,
}
}
pub fn thread_id(&self) -> &str {
&self.run_identity.thread_id
}
pub fn run_policy(&self) -> &RunPolicy {
&self.run_policy
}
pub fn run_identity(&self) -> &RunIdentity {
&self.run_identity
}
pub fn set_run_identity(&mut self, run_identity: RunIdentity) {
self.run_identity = run_identity;
}
pub fn version(&self) -> u64 {
self.version.unwrap_or(0)
}
pub fn set_version(&mut self, version: u64, timestamp: Option<u64>) {
self.version = Some(version);
if let Some(ts) = timestamp {
self.version_timestamp = Some(ts);
}
}
pub fn version_timestamp(&self) -> Option<u64> {
self.version_timestamp
}
pub fn suspended_calls(&self) -> std::collections::HashMap<String, SuspendedCall> {
self.snapshot()
.map(|s| suspended_calls_from_state(&s))
.unwrap_or_default()
}
pub fn messages(&self) -> &[Arc<Message>] {
self.messages.as_slice()
}
pub fn initial_message_count(&self) -> usize {
self.messages.initial_count()
}
pub fn add_message(&mut self, msg: Arc<Message>) {
self.messages.push(msg);
}
pub fn add_messages(&mut self, msgs: Vec<Arc<Message>>) {
self.messages.extend(msgs);
}
pub fn thread_base(&self) -> &Value {
&self.thread_base
}
pub fn add_thread_patch(&mut self, patch: TrackedPatch) {
self.thread_patches.push(patch);
}
pub fn add_thread_patches(&mut self, patches: Vec<TrackedPatch>) {
self.thread_patches.extend(patches);
}
pub fn thread_patches(&self) -> &[TrackedPatch] {
self.thread_patches.as_slice()
}
pub fn add_serialized_state_actions(&mut self, state_actions: Vec<SerializedStateAction>) {
self.serialized_state_actions.extend(state_actions);
}
pub fn snapshot(&self) -> TireaResult<Value> {
let patches = self.thread_patches.as_slice();
if patches.is_empty() {
Ok(self.thread_base.clone())
} else {
apply_patches_with_registry(
&self.thread_base,
patches.iter().map(|p| p.patch()),
&self.lattice_registry,
)
}
}
pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
let val = self.snapshot()?;
let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
T::from_value(at)
}
pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
let val = self.snapshot()?;
let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
T::from_value(at)
}
pub fn take_delta(&mut self) -> RunDelta {
RunDelta {
messages: self.messages.take_delta(),
patches: self.thread_patches.take_delta(),
state_actions: self.serialized_state_actions.take_delta(),
}
}
pub fn has_delta(&self) -> bool {
self.messages.has_delta()
|| self.thread_patches.has_delta()
|| self.serialized_state_actions.has_delta()
}
pub fn tool_call_context<'ctx>(
&'ctx self,
ops: &'ctx Mutex<Vec<Op>>,
call_id: impl Into<String>,
source: impl Into<String>,
pending_messages: &'ctx Mutex<Vec<Arc<Message>>>,
activity_manager: Arc<dyn ActivityManager>,
) -> ToolCallContext<'ctx> {
let caller_context = CallerContext::new(
Some(self.thread_id().to_string()),
self.run_identity.run_id_opt().map(ToOwned::to_owned),
self.run_identity.agent_id_opt().map(ToOwned::to_owned),
self.messages().to_vec(),
);
ToolCallContext::new(
&self.doc,
ops,
call_id,
source,
&self.run_policy,
pending_messages,
activity_manager,
)
.with_run_identity(self.run_identity.clone())
.with_caller_context(caller_context)
}
}
impl RunContext {
pub fn from_thread(
thread: &crate::thread::Thread,
run_policy: RunPolicy,
) -> Result<Self, tirea_state::TireaError> {
Self::from_thread_with_registry_and_identity(
thread,
run_policy,
RunIdentity::for_thread(thread.id.clone()),
Arc::new(LatticeRegistry::new()),
)
}
pub fn from_thread_with_registry(
thread: &crate::thread::Thread,
run_policy: RunPolicy,
lattice_registry: Arc<LatticeRegistry>,
) -> Result<Self, tirea_state::TireaError> {
Self::from_thread_with_registry_and_identity(
thread,
run_policy,
RunIdentity::for_thread(thread.id.clone()),
lattice_registry,
)
}
pub fn from_thread_with_registry_and_identity(
thread: &crate::thread::Thread,
run_policy: RunPolicy,
mut run_identity: RunIdentity,
lattice_registry: Arc<LatticeRegistry>,
) -> Result<Self, tirea_state::TireaError> {
if run_identity.thread_id_opt().is_none() {
run_identity.thread_id = thread.id.clone();
}
if run_identity.parent_thread_id_opt().is_none() {
run_identity.parent_thread_id = thread.parent_thread_id.clone();
}
let state = thread.rebuild_state()?;
let messages: Vec<Arc<Message>> = thread.messages.clone();
let mut ctx = Self::with_registry_and_identity(
state,
messages,
run_policy,
run_identity,
lattice_registry,
);
if let Some(v) = thread.metadata.version {
ctx.set_version(v, thread.metadata.version_timestamp);
}
Ok(ctx)
}
pub fn lattice_registry(&self) -> &Arc<LatticeRegistry> {
&self.lattice_registry
}
}
impl std::fmt::Debug for RunContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RunContext")
.field("thread_id", &self.thread_id())
.field("messages", &self.messages.len())
.field("thread_patches", &self.thread_patches.len())
.field("has_delta", &self.has_delta())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use tirea_state::{path, Patch};
#[test]
fn new_context_has_no_delta() {
let msgs = vec![Arc::new(Message::user("hi"))];
let mut ctx = RunContext::new("t-1", json!({}), msgs, RunPolicy::default());
assert!(!ctx.has_delta());
let delta = ctx.take_delta();
assert!(delta.is_empty());
assert_eq!(ctx.messages().len(), 1);
}
#[test]
fn add_message_creates_delta() {
let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
ctx.add_message(Arc::new(Message::user("hello")));
ctx.add_message(Arc::new(Message::assistant("hi")));
assert!(ctx.has_delta());
let delta = ctx.take_delta();
assert_eq!(delta.messages.len(), 2);
assert!(delta.patches.is_empty());
assert!(!ctx.has_delta());
assert_eq!(ctx.messages().len(), 2);
}
#[test]
fn add_patch_creates_delta() {
let mut ctx = RunContext::new("t-1", json!({"a": 1}), vec![], RunPolicy::default());
let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(2))));
ctx.add_thread_patch(patch);
assert!(ctx.has_delta());
let delta = ctx.take_delta();
assert_eq!(delta.patches.len(), 1);
assert!(!ctx.has_delta());
}
#[test]
fn multiple_deltas() {
let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
ctx.add_message(Arc::new(Message::user("a")));
let d1 = ctx.take_delta();
assert_eq!(d1.messages.len(), 1);
ctx.add_message(Arc::new(Message::user("b")));
ctx.add_message(Arc::new(Message::user("c")));
let d2 = ctx.take_delta();
assert_eq!(d2.messages.len(), 2);
let d3 = ctx.take_delta();
assert!(d3.is_empty());
}
#[test]
fn initial_messages_excluded_from_delta() {
let initial = vec![
Arc::new(Message::user("pre-existing-1")),
Arc::new(Message::assistant("pre-existing-2")),
];
let mut ctx = RunContext::new("t-1", json!({}), initial, RunPolicy::default());
assert!(!ctx.has_delta());
let delta = ctx.take_delta();
assert!(delta.messages.is_empty());
assert_eq!(ctx.messages().len(), 2);
ctx.add_message(Arc::new(Message::user("run-added")));
let delta = ctx.take_delta();
assert_eq!(delta.messages.len(), 1);
assert_eq!(delta.messages[0].content, "run-added");
assert_eq!(ctx.messages().len(), 3);
}
#[test]
fn all_patches_are_delta() {
let mut ctx = RunContext::new("t-1", json!({"a": 0}), vec![], RunPolicy::default());
ctx.add_thread_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("a"), json!(1))),
));
ctx.add_thread_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("a"), json!(2))),
));
let delta = ctx.take_delta();
assert_eq!(delta.patches.len(), 2, "all run patches should be in delta");
}
#[test]
fn consecutive_take_delta_non_overlapping() {
let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
ctx.add_message(Arc::new(Message::user("m1")));
ctx.add_thread_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("x"), json!(1))),
));
let d1 = ctx.take_delta();
assert_eq!(d1.messages.len(), 1);
assert_eq!(d1.patches.len(), 1);
ctx.add_message(Arc::new(Message::user("m2")));
ctx.add_message(Arc::new(Message::user("m3")));
ctx.add_thread_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("y"), json!(2))),
));
let d2 = ctx.take_delta();
assert_eq!(d2.messages.len(), 2);
assert_eq!(d2.patches.len(), 1);
let d3 = ctx.take_delta();
assert!(d3.is_empty());
assert_eq!(ctx.messages().len(), 3);
assert_eq!(ctx.thread_patches().len(), 2);
}
#[test]
fn snapshot_of_deserializes_at_canonical_path() {
use crate::testing::TestFixtureState;
let ctx = RunContext::new(
"t-1",
json!({"__test_fixture": {"label": null}}),
vec![],
RunPolicy::default(),
);
let ctrl: TestFixtureState = ctx.snapshot_of().unwrap();
assert!(ctrl.label.is_none());
}
#[test]
fn snapshot_at_deserializes_at_explicit_path() {
use crate::testing::TestFixtureState;
let ctx = RunContext::new(
"t-1",
json!({"custom": {"label": null}}),
vec![],
RunPolicy::default(),
);
let ctrl: TestFixtureState = ctx.snapshot_at("custom").unwrap();
assert!(ctrl.label.is_none());
}
#[test]
fn snapshot_of_returns_error_for_missing_path() {
use crate::testing::TestFixtureState;
let ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
assert!(ctx.snapshot_of::<TestFixtureState>().is_err());
}
#[test]
fn from_thread_rebuilds_existing_patches() {
use crate::thread::Thread;
let mut thread = Thread::with_initial_state("t-1", json!({"counter": 0}));
thread.patches.push(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("counter"), json!(5))),
));
let ctx = RunContext::from_thread(&thread, RunPolicy::default()).unwrap();
assert_eq!(ctx.thread_base()["counter"], 5);
assert!(ctx.thread_patches().is_empty());
assert_eq!(ctx.snapshot().unwrap()["counter"], 5);
}
#[test]
fn from_thread_carries_version_metadata() {
use crate::thread::Thread;
let mut thread = Thread::new("t-1");
thread.metadata.version = Some(42);
thread.metadata.version_timestamp = Some(1700000000);
let ctx = RunContext::from_thread(&thread, RunPolicy::default()).unwrap();
assert_eq!(ctx.version(), 42);
assert_eq!(ctx.version_timestamp(), Some(1700000000));
}
#[test]
fn from_thread_broken_patch_returns_error() {
use crate::thread::Thread;
let mut thread = Thread::with_initial_state("t-1", json!({"x": 1}));
thread.patches.push(TrackedPatch::new(Patch::with_ops(vec![
tirea_state::Op::Append {
path: path!("x"),
value: json!(999),
},
])));
let result = RunContext::from_thread(&thread, RunPolicy::default());
assert!(
result.is_err(),
"broken patch should cause from_thread to fail"
);
}
#[test]
fn version_defaults_to_zero() {
let ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
assert_eq!(ctx.version(), 0);
assert_eq!(ctx.version_timestamp(), None);
}
#[test]
fn set_version_updates_correctly() {
let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
ctx.set_version(5, Some(1700000000));
assert_eq!(ctx.version(), 5);
assert_eq!(ctx.version_timestamp(), Some(1700000000));
ctx.set_version(6, None);
assert_eq!(ctx.version(), 6);
assert_eq!(ctx.version_timestamp(), Some(1700000000));
}
}