use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::session_graph::SessionReadModel;
use crate::session_graph::build_active_read_replacement;
use crate::session_model::{ProtocolEvent, SessionEventRecord};
use crate::store::GraphCommitDelta;
use crate::{BaseRenderCache, Message, MessageSequence, SessionGraph, SessionNodeRecord};
#[derive(Debug)]
pub(super) struct TurnGraphEditor {
base_graph: Arc<SessionGraph>,
active_events: Arc<Vec<SessionEventRecord>>,
active_messages: MessageSequence,
append_builder: crate::session_graph::SessionGraphAppendBuilder,
appended_nodes: Vec<SessionNodeRecord>,
appended_node_indices: HashMap<String, usize>,
committed_node_ids: HashSet<String>,
}
impl TurnGraphEditor {
pub(super) fn new(
base_graph: Arc<SessionGraph>,
base_read_model: SessionReadModel,
agent_frame_id: crate::AgentFrameId,
) -> Self {
let append_builder = base_graph
.append_builder()
.with_agent_frame_id(agent_frame_id);
let active_messages = MessageSequence::from_base(base_read_model.messages);
Self {
committed_node_ids: base_graph
.nodes
.iter()
.map(|node| node.node_id.clone())
.collect(),
base_graph,
active_events: base_read_model.active_events,
active_messages,
append_builder,
appended_nodes: Vec::new(),
appended_node_indices: HashMap::new(),
}
}
pub(super) fn base_graph(&self) -> Arc<SessionGraph> {
Arc::clone(&self.base_graph)
}
pub(super) fn message_sequence(&self) -> MessageSequence {
self.active_messages.clone()
}
pub(super) fn read_model(&self) -> SessionReadModel {
SessionReadModel {
active_events: Arc::clone(&self.active_events),
messages: self.active_messages.shared(),
prompt_render_cache: Arc::new(BaseRenderCache::new()),
}
}
pub(super) fn append_protocol_events<I>(&mut self, events: I)
where
I: IntoIterator<Item = ProtocolEvent>,
{
let events = events.into_iter().collect::<Vec<_>>();
if events.is_empty() {
return;
}
let nodes = self.append_builder.append_protocol_events(events);
Arc::make_mut(&mut self.active_events).extend(nodes.iter().filter_map(|node| {
if let Some(SessionEventRecord::Protocol(event)) = node.event() {
Some(SessionEventRecord::Protocol(event.clone()))
} else {
None
}
}));
self.append_appended_nodes(nodes);
}
pub(super) fn message_delta_if_current_preserved<'a>(
&self,
next: impl IntoIterator<Item = &'a Message>,
) -> Option<Vec<Message>> {
let mut current = self.active_messages.iter();
let mut appended = Vec::new();
for message in next.into_iter().filter(|message| !message.is_transient()) {
if let Some(current_message) = current.next() {
if current_message.id != message.id {
return None;
}
} else {
appended.push(message.clone());
}
}
current.next().is_none().then_some(appended)
}
pub(super) fn append_active_conversation_messages(&mut self, messages: &[Message]) {
let appendable_messages = messages
.iter()
.filter(|message| !message.is_transient())
.cloned()
.collect::<Vec<_>>();
if appendable_messages.is_empty() {
return;
}
let nodes = self
.append_builder
.append_messages(appendable_messages.clone());
Arc::make_mut(&mut self.active_events)
.extend(nodes.iter().filter_map(|node| node.event().cloned()));
self.append_appended_nodes(nodes);
self.active_messages.extend(appendable_messages);
}
pub(super) fn replace_active_read_state(&mut self, messages: &[Message]) {
let active_path = self.active_path_nodes();
let replacement = build_active_read_replacement(
active_path,
self.append_builder.existing_node_ids(),
self.append_builder.agent_frame_id(),
messages,
);
self.append_builder.register_existing_node_ids(
replacement
.new_tail_nodes
.iter()
.map(|node| node.node_id.as_str()),
);
self.append_appended_nodes(replacement.new_tail_nodes);
self.append_builder
.set_leaf_node_id(replacement.leaf_node_id.clone());
self.active_events = Arc::new(replacement.active_events);
self.active_messages = MessageSequence::from_owned(replacement.active_messages);
}
pub(super) fn graph_commit(&self, graph_replace_required: bool) -> GraphCommitDelta {
if graph_replace_required {
return GraphCommitDelta::ReplaceFull(self.materialized_graph());
}
let nodes = self
.appended_nodes
.iter()
.filter(|node| !self.committed_node_ids.contains(&node.node_id))
.cloned()
.collect::<Vec<_>>();
if nodes.is_empty() {
GraphCommitDelta::Unchanged {
leaf_node_id: self.leaf_node_id(),
}
} else {
GraphCommitDelta::Append {
nodes,
leaf_node_id: self.leaf_node_id(),
}
}
}
pub(super) fn mark_node_ids_persisted<I>(&mut self, node_ids: I)
where
I: IntoIterator<Item = String>,
{
self.committed_node_ids.extend(node_ids);
}
pub(super) fn replace_persisted_node_ids<I>(&mut self, node_ids: I)
where
I: IntoIterator<Item = String>,
{
self.committed_node_ids = node_ids.into_iter().collect();
}
pub(super) fn into_session_graph(self) -> SessionGraph {
let leaf_node_id = self.leaf_node_id();
if self.appended_nodes.is_empty() {
return Arc::try_unwrap(self.base_graph).unwrap_or_else(|graph| graph.as_ref().clone());
}
match Arc::try_unwrap(self.base_graph) {
Ok(mut graph) => {
graph.extend_node_records(self.appended_nodes);
graph.set_leaf_node_id(leaf_node_id);
graph
}
Err(base_graph) => {
let mut nodes =
Vec::with_capacity(base_graph.nodes.len() + self.appended_nodes.len());
nodes.extend(base_graph.nodes.iter().cloned());
nodes.extend(self.appended_nodes);
SessionGraph::from_nodes(nodes, leaf_node_id)
}
}
}
fn leaf_node_id(&self) -> Option<String> {
self.append_builder.leaf_node_id().cloned()
}
fn materialized_graph(&self) -> SessionGraph {
if self.appended_nodes.is_empty() {
return self.base_graph.as_ref().clone();
}
let mut nodes = Vec::with_capacity(self.base_graph.nodes.len() + self.appended_nodes.len());
nodes.extend(self.base_graph.nodes.iter().cloned());
nodes.extend(self.appended_nodes.iter().cloned());
SessionGraph::from_nodes(nodes, self.leaf_node_id())
}
fn active_path_nodes(&self) -> Vec<&SessionNodeRecord> {
let mut path = Vec::new();
let mut current = self.leaf_node_id();
while let Some(node_id) = current {
let Some(node) = self
.appended_node_indices
.get(node_id.as_str())
.and_then(|idx| self.appended_nodes.get(*idx))
.or_else(|| self.base_graph.find_node(node_id.as_str()))
else {
break;
};
path.push(node);
current = node.parent_node_id.clone();
}
path.reverse();
path
}
fn append_appended_nodes(&mut self, nodes: Vec<SessionNodeRecord>) {
self.appended_node_indices.reserve(nodes.len());
self.appended_nodes.reserve(nodes.len());
for node in nodes {
self.appended_node_indices
.insert(node.node_id.clone(), self.appended_nodes.len());
self.appended_nodes.push(node);
}
}
}