lash-core 0.1.0-alpha.1

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::session_graph::SessionReadModel;
use crate::session_graph::build_active_read_replacement;
use crate::session_graph::tool_call_record_active_read_key;
use crate::session_model::SessionEventRecord;
use crate::store::GraphCommitDelta;
use crate::{
    BaseRenderCache, Message, MessageSequence, SessionGraph, SessionNodeRecord, ToolCallRecord,
};

#[derive(Debug)]
pub(super) struct TurnGraphOverlay {
    base_graph: Arc<SessionGraph>,
    active_events: Arc<Vec<SessionEventRecord>>,
    active_messages: MessageSequence,
    graph_tool_calls: Vec<ToolCallRecord>,
    read_tool_calls: Arc<Vec<ToolCallRecord>>,
    append_builder: crate::session_graph::SessionGraphAppendBuilder,
    appended_nodes: Vec<SessionNodeRecord>,
    appended_node_indices: HashMap<String, usize>,
    committed_node_ids: HashSet<String>,
}

impl TurnGraphOverlay {
    pub(super) fn new(base_graph: Arc<SessionGraph>, base_read_model: SessionReadModel) -> Self {
        let append_builder = base_graph.append_builder();
        let active_messages = MessageSequence::from_base(base_read_model.messages);
        let graph_tool_calls = base_read_model.tool_calls.as_ref().clone();
        Self {
            committed_node_ids: base_graph
                .nodes
                .iter()
                .map(|node| node.node_id.clone())
                .collect(),
            base_graph,
            active_events: base_read_model.active_events,
            active_messages,
            graph_tool_calls,
            read_tool_calls: base_read_model.tool_calls,
            append_builder,
            appended_nodes: Vec::new(),
            appended_node_indices: HashMap::new(),
        }
    }

    pub(super) fn base_graph(&self) -> Arc<SessionGraph> {
        Arc::clone(&self.base_graph)
    }

    pub(super) fn message_sequence(&self) -> MessageSequence {
        self.active_messages.clone()
    }

    pub(super) fn tool_calls_arc(&self) -> Arc<Vec<ToolCallRecord>> {
        Arc::clone(&self.read_tool_calls)
    }

    pub(super) fn graph_tool_calls(&self) -> &[ToolCallRecord] {
        self.graph_tool_calls.as_slice()
    }

    #[allow(dead_code)]
    pub(super) fn read_model(&self) -> SessionReadModel {
        SessionReadModel {
            active_events: Arc::clone(&self.active_events),
            messages: self.active_messages.shared(),
            tool_calls: Arc::clone(&self.read_tool_calls),
            prompt_render_cache: Arc::new(BaseRenderCache::new()),
        }
    }

    pub(super) fn record_tool_calls<I>(&mut self, records: I)
    where
        I: IntoIterator<Item = ToolCallRecord>,
    {
        let mut seen_tool_call_keys = self
            .graph_tool_calls
            .iter()
            .map(tool_call_record_active_read_key)
            .collect::<HashSet<_>>();
        self.append_events(records.into_iter().filter_map(|record| {
            let stable_key = tool_call_record_active_read_key(&record);
            if seen_tool_call_keys.insert(stable_key.clone()) {
                Some(SessionEventRecord::Tool(
                    crate::session_model::ToolEvent::Invocation { stable_key, record },
                ))
            } else {
                tracing::debug!(
                    stable_key,
                    call_id = record.call_id.as_deref(),
                    tool = record.tool.as_str(),
                    "skipping duplicate tool call record"
                );
                None
            }
        }));
    }

    pub(super) fn append_events<I>(&mut self, events: I)
    where
        I: IntoIterator<Item = SessionEventRecord>,
    {
        for event in events {
            let node = self.append_builder.append_event_record(event.clone());
            self.append_appended_nodes(node);
            Arc::make_mut(&mut self.active_events).push(event.clone());
            match event {
                SessionEventRecord::Conversation(record) => {
                    self.active_messages.push(record.to_message());
                }
                SessionEventRecord::Tool(crate::session_model::ToolEvent::Invocation {
                    record,
                    ..
                }) => {
                    self.graph_tool_calls.push(record.clone());
                    Arc::make_mut(&mut self.read_tool_calls).push(record);
                }
                _ => {}
            }
        }
    }

    pub(super) fn message_delta_if_current_preserved<'a>(
        &self,
        next: impl IntoIterator<Item = &'a Message>,
    ) -> Option<Vec<Message>> {
        let mut current = self.active_messages.iter();
        let mut appended = Vec::new();
        for message in next.into_iter().filter(|message| !message.is_transient()) {
            if let Some(current_message) = current.next() {
                if current_message.id != message.id {
                    return None;
                }
            } else {
                appended.push(message.clone());
            }
        }
        current.next().is_none().then_some(appended)
    }

    pub(super) fn append_active_conversation_messages(&mut self, messages: &[Message]) {
        let appendable_messages = messages
            .iter()
            .filter(|message| !message.is_transient())
            .cloned()
            .collect::<Vec<_>>();
        if appendable_messages.is_empty() {
            return;
        }

        let nodes = self
            .append_builder
            .append_messages(appendable_messages.clone());
        Arc::make_mut(&mut self.active_events)
            .extend(nodes.iter().filter_map(|node| node.event().cloned()));
        self.append_appended_nodes(nodes);
        self.active_messages.extend(appendable_messages);
    }

    pub(super) fn append_active_read_delta(
        &mut self,
        messages: &[Message],
        tool_calls: &[ToolCallRecord],
    ) {
        let appendable_messages = {
            let mut seen_message_ids = self
                .active_messages
                .iter()
                .map(|message| message.id.as_str())
                .collect::<HashSet<_>>();
            messages
                .iter()
                .filter(|message| {
                    !message.is_transient() && seen_message_ids.insert(message.id.as_str())
                })
                .cloned()
                .collect::<Vec<_>>()
        };
        let mut seen_tool_call_keys = self
            .graph_tool_calls
            .iter()
            .map(tool_call_record_active_read_key)
            .collect::<HashSet<_>>();
        let appendable_tool_calls = tool_calls
            .iter()
            .filter(|record| seen_tool_call_keys.insert(tool_call_record_active_read_key(record)))
            .cloned()
            .collect::<Vec<_>>();

        if !appendable_messages.is_empty() {
            let nodes = self
                .append_builder
                .append_messages(appendable_messages.clone());
            Arc::make_mut(&mut self.active_events)
                .extend(nodes.iter().filter_map(|node| node.event().cloned()));
            self.append_appended_nodes(nodes);
            self.active_messages.extend(appendable_messages);
        }
        if !appendable_tool_calls.is_empty() {
            let nodes = self
                .append_builder
                .append_tool_call_records(appendable_tool_calls.clone());
            Arc::make_mut(&mut self.active_events)
                .extend(nodes.iter().filter_map(|node| node.event().cloned()));
            self.append_appended_nodes(nodes);
            self.graph_tool_calls.extend(appendable_tool_calls.clone());
            Arc::make_mut(&mut self.read_tool_calls).extend(appendable_tool_calls);
        }
    }

    pub(super) fn replace_active_read_state(
        &mut self,
        messages: &[Message],
        tool_calls: &[ToolCallRecord],
    ) {
        let active_path = self.active_path_nodes();
        let replacement = build_active_read_replacement(
            active_path,
            self.append_builder.existing_node_ids(),
            messages,
            tool_calls,
        );
        self.append_builder.register_existing_node_ids(
            replacement
                .new_tail_nodes
                .iter()
                .map(|node| node.node_id.as_str()),
        );
        self.append_appended_nodes(replacement.new_tail_nodes);
        self.append_builder
            .set_leaf_node_id(replacement.leaf_node_id.clone());
        self.active_events = Arc::new(replacement.active_events);
        self.active_messages = MessageSequence::from_owned(replacement.active_messages);
        self.graph_tool_calls = replacement.active_tool_calls.clone();
        self.read_tool_calls = Arc::new(replacement.active_tool_calls);
    }

    pub(super) fn graph_commit(&self, graph_replace_required: bool) -> GraphCommitDelta {
        if graph_replace_required {
            return GraphCommitDelta::ReplaceFull(self.materialized_graph());
        }

        let nodes = self
            .appended_nodes
            .iter()
            .filter(|node| !self.committed_node_ids.contains(&node.node_id))
            .cloned()
            .collect::<Vec<_>>();
        if nodes.is_empty() {
            GraphCommitDelta::Unchanged {
                leaf_node_id: self.leaf_node_id(),
            }
        } else {
            GraphCommitDelta::Append {
                nodes,
                leaf_node_id: self.leaf_node_id(),
            }
        }
    }

    pub(super) fn mark_node_ids_persisted<I>(&mut self, node_ids: I)
    where
        I: IntoIterator<Item = String>,
    {
        self.committed_node_ids.extend(node_ids);
    }

    pub(super) fn replace_persisted_node_ids<I>(&mut self, node_ids: I)
    where
        I: IntoIterator<Item = String>,
    {
        self.committed_node_ids = node_ids.into_iter().collect();
    }

    pub(super) fn into_session_graph(self) -> SessionGraph {
        let leaf_node_id = self.leaf_node_id();
        if self.appended_nodes.is_empty() {
            return Arc::try_unwrap(self.base_graph).unwrap_or_else(|graph| graph.as_ref().clone());
        }
        match Arc::try_unwrap(self.base_graph) {
            Ok(mut graph) => {
                graph.extend_node_records(self.appended_nodes);
                graph.set_leaf_node_id(leaf_node_id);
                graph
            }
            Err(base_graph) => {
                let mut nodes =
                    Vec::with_capacity(base_graph.nodes.len() + self.appended_nodes.len());
                nodes.extend(base_graph.nodes.iter().cloned());
                nodes.extend(self.appended_nodes);
                SessionGraph::from_nodes(nodes, leaf_node_id)
            }
        }
    }

    fn leaf_node_id(&self) -> Option<String> {
        self.append_builder.leaf_node_id().cloned()
    }

    fn materialized_graph(&self) -> SessionGraph {
        if self.appended_nodes.is_empty() {
            return self.base_graph.as_ref().clone();
        }
        let mut nodes = Vec::with_capacity(self.base_graph.nodes.len() + self.appended_nodes.len());
        nodes.extend(self.base_graph.nodes.iter().cloned());
        nodes.extend(self.appended_nodes.iter().cloned());
        SessionGraph::from_nodes(nodes, self.leaf_node_id())
    }

    fn active_path_nodes(&self) -> Vec<&SessionNodeRecord> {
        let mut path = Vec::new();
        let mut current = self.leaf_node_id();
        while let Some(node_id) = current {
            let Some(node) = self
                .appended_node_indices
                .get(node_id.as_str())
                .and_then(|idx| self.appended_nodes.get(*idx))
                .or_else(|| self.base_graph.find_node(node_id.as_str()))
            else {
                break;
            };
            path.push(node);
            current = node.parent_node_id.clone();
        }
        path.reverse();
        path
    }

    fn append_appended_nodes(&mut self, nodes: Vec<SessionNodeRecord>) {
        self.appended_node_indices.reserve(nodes.len());
        self.appended_nodes.reserve(nodes.len());
        for node in nodes {
            self.appended_node_indices
                .insert(node.node_id.clone(), self.appended_nodes.len());
            self.appended_nodes.push(node);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{ToolCallOutput, session_model::ToolEvent};

    fn tool_record(call_id: &str) -> ToolCallRecord {
        ToolCallRecord {
            call_id: Some(call_id.to_string()),
            tool: "lookup".to_string(),
            args: serde_json::json!({"q": "x"}),
            output: ToolCallOutput::success(serde_json::json!({"answer": "y"})),
            duration_ms: 3,
        }
    }

    #[test]
    fn record_tool_calls_skips_duplicate_stable_keys() {
        let base_graph = Arc::new(SessionGraph::default());
        let base_read_model = base_graph.read_model();
        let mut overlay = TurnGraphOverlay::new(base_graph, base_read_model);
        let record = tool_record("call-1");

        overlay.record_tool_calls([record.clone(), record]);

        let read_model = overlay.read_model();
        assert_eq!(read_model.tool_calls.len(), 1);
        let graph = overlay.into_session_graph();
        let graph_tool_records = graph
            .active_path_nodes()
            .into_iter()
            .filter_map(|node| match node.event()? {
                SessionEventRecord::Tool(ToolEvent::Invocation { record, .. }) => Some(record),
                _ => None,
            })
            .collect::<Vec<_>>();
        assert_eq!(graph_tool_records.len(), 1);
        assert_eq!(graph_tool_records[0].call_id.as_deref(), Some("call-1"));
    }
}