Skip to main content

lash_core/
session_graph.rs

1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3use std::sync::{Arc, OnceLock};
4
5use crate::session_model::{ConversationRecord, ProtocolEvent, SessionEventRecord};
6use crate::{BaseRenderCache, Clock, Message, MessageRole, PromptUsage, TokenUsage};
7
8#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
9pub struct SessionGraphData {
10    #[serde(default)]
11    pub nodes: Vec<SessionNodeRecord>,
12    #[serde(default, skip_serializing_if = "Option::is_none")]
13    pub leaf_node_id: Option<String>,
14}
15
16#[derive(Debug)]
17pub struct SessionGraph {
18    inner: Arc<SessionGraphData>,
19    cache: Arc<OnceLock<SessionGraphCache>>,
20}
21
22impl Default for SessionGraph {
23    fn default() -> Self {
24        Self {
25            inner: Arc::new(SessionGraphData::default()),
26            cache: Arc::new(OnceLock::new()),
27        }
28    }
29}
30
31impl Clone for SessionGraph {
32    fn clone(&self) -> Self {
33        Self {
34            inner: Arc::clone(&self.inner),
35            cache: Arc::clone(&self.cache),
36        }
37    }
38}
39
40impl serde::Serialize for SessionGraph {
41    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
42    where
43        S: serde::Serializer,
44    {
45        self.inner.serialize(serializer)
46    }
47}
48
49impl<'de> serde::Deserialize<'de> for SessionGraph {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: serde::Deserializer<'de>,
53    {
54        let inner = SessionGraphData::deserialize(deserializer)?;
55        Ok(Self {
56            inner: Arc::new(inner),
57            cache: Arc::new(OnceLock::new()),
58        })
59    }
60}
61
62impl Deref for SessionGraph {
63    type Target = SessionGraphData;
64
65    fn deref(&self) -> &Self::Target {
66        self.inner.as_ref()
67    }
68}
69
70#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
71pub struct SessionNodeRecord {
72    pub node_id: String,
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub parent_node_id: Option<String>,
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub caused_by: Option<crate::CausalRef>,
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    pub agent_frame_id: Option<crate::AgentFrameId>,
79    pub timestamp: String,
80    #[serde(flatten)]
81    pub payload: SessionNodePayload,
82}
83
84#[derive(Clone, Debug)]
85pub(crate) struct SessionNodeDraft {
86    payload: SessionNodeDraftPayload,
87    caused_by: Option<crate::CausalRef>,
88}
89
90#[derive(Clone, Debug)]
91enum SessionNodeDraftPayload {
92    Message(Message),
93    Plugin {
94        plugin_type: String,
95        body: serde_json::Value,
96    },
97    ProtocolEvent(ProtocolEvent),
98}
99
100impl SessionNodeDraft {
101    pub(crate) fn message(message: Message) -> Self {
102        Self {
103            payload: SessionNodeDraftPayload::Message(message),
104            caused_by: None,
105        }
106    }
107
108    pub(crate) fn plugin(plugin_type: impl Into<String>, body: serde_json::Value) -> Self {
109        Self {
110            payload: SessionNodeDraftPayload::Plugin {
111                plugin_type: plugin_type.into(),
112                body,
113            },
114            caused_by: None,
115        }
116    }
117
118    pub(crate) fn protocol_event(event: ProtocolEvent) -> Self {
119        Self {
120            payload: SessionNodeDraftPayload::ProtocolEvent(event),
121            caused_by: None,
122        }
123    }
124
125    pub(crate) fn event(event: SessionEventRecord) -> Self {
126        match event {
127            SessionEventRecord::Conversation(record) => Self::message(record.to_message()),
128            SessionEventRecord::Protocol(event) => Self::protocol_event(event),
129        }
130    }
131
132    pub(crate) fn with_caused_by(mut self, caused_by: Option<crate::CausalRef>) -> Self {
133        self.caused_by = caused_by;
134        self
135    }
136}
137
138#[derive(Clone, Debug, Default, PartialEq)]
139pub struct SharedJsonValue(pub Arc<serde_json::Value>);
140
141impl SharedJsonValue {
142    pub fn new(value: serde_json::Value) -> Self {
143        Self(Arc::new(value))
144    }
145
146    pub fn to_owned(&self) -> serde_json::Value {
147        self.0.as_ref().clone()
148    }
149}
150
151impl AsRef<serde_json::Value> for SharedJsonValue {
152    fn as_ref(&self) -> &serde_json::Value {
153        self.0.as_ref()
154    }
155}
156
157impl serde::Serialize for SharedJsonValue {
158    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
159    where
160        S: serde::Serializer,
161    {
162        self.0.serialize(serializer)
163    }
164}
165
166impl<'de> serde::Deserialize<'de> for SharedJsonValue {
167    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
168    where
169        D: serde::Deserializer<'de>,
170    {
171        let value = serde_json::Value::deserialize(deserializer)?;
172        Ok(Self::new(value))
173    }
174}
175
176#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
177#[serde(tag = "kind", rename_all = "snake_case")]
178#[allow(clippy::large_enum_variant)]
179pub enum SessionNodePayload {
180    Event {
181        event: SessionEventRecord,
182    },
183    Plugin {
184        plugin_type: String,
185        body: SharedJsonValue,
186    },
187}
188
189#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
190pub struct PersistedSessionConfig {
191    pub provider_id: String,
192    pub model: crate::ModelSpec,
193}
194
195#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
196pub struct PersistedTurnState {
197    pub turn_index: usize,
198    #[serde(default)]
199    pub token_usage: TokenUsage,
200    #[serde(default, skip_serializing_if = "Option::is_none")]
201    pub last_prompt_usage: Option<PromptUsage>,
202    #[serde(default)]
203    pub protocol_turn_options: crate::ProtocolTurnOptions,
204}
205
206#[derive(Clone, Debug)]
207pub struct SessionMessageTreeNode {
208    pub node_id: String,
209    pub parent_message_node_id: Option<String>,
210    pub message: Message,
211    pub timestamp: String,
212    pub children: Vec<SessionMessageTreeNode>,
213    pub active: bool,
214}
215
216#[derive(Clone, Debug)]
217pub(crate) struct ActiveReadReplacement {
218    pub(crate) leaf_node_id: Option<String>,
219    pub(crate) new_tail_nodes: Vec<SessionNodeRecord>,
220    pub(crate) active_events: Vec<SessionEventRecord>,
221    pub(crate) active_messages: Vec<Message>,
222}
223
224#[derive(Clone, Debug)]
225pub(crate) struct SessionReadModel {
226    pub(crate) active_events: Arc<Vec<SessionEventRecord>>,
227    pub(crate) messages: Arc<Vec<Message>>,
228    pub(crate) prompt_render_cache: Arc<BaseRenderCache>,
229}
230
231#[derive(Clone, Debug)]
232pub(crate) struct SessionGraphAppendBuilder {
233    existing_ids: HashSet<String>,
234    leaf_node_id: Option<String>,
235    agent_frame_id: Option<crate::AgentFrameId>,
236}
237
238impl SessionGraphAppendBuilder {
239    pub(crate) fn with_agent_frame_id(
240        mut self,
241        agent_frame_id: impl Into<crate::AgentFrameId>,
242    ) -> Self {
243        self.agent_frame_id = Some(agent_frame_id.into());
244        self
245    }
246
247    pub(crate) fn agent_frame_id(&self) -> Option<&str> {
248        self.agent_frame_id.as_deref()
249    }
250
251    pub(crate) fn leaf_node_id(&self) -> Option<&String> {
252        self.leaf_node_id.as_ref()
253    }
254
255    pub(crate) fn set_leaf_node_id(&mut self, leaf_node_id: Option<String>) {
256        self.leaf_node_id = leaf_node_id;
257    }
258
259    pub(crate) fn register_existing_node_ids<'a>(
260        &mut self,
261        node_ids: impl IntoIterator<Item = &'a str>,
262    ) {
263        self.existing_ids
264            .extend(node_ids.into_iter().map(ToOwned::to_owned));
265    }
266
267    pub(crate) fn existing_node_ids(&self) -> &HashSet<String> {
268        &self.existing_ids
269    }
270
271    pub(crate) fn append_messages_at<I>(
272        &mut self,
273        messages: I,
274        timestamp: String,
275    ) -> Vec<SessionNodeRecord>
276    where
277        I: IntoIterator<Item = Message>,
278    {
279        self.append_drafts_at(
280            messages.into_iter().map(SessionNodeDraft::message),
281            timestamp,
282        )
283    }
284
285    pub(crate) fn append_events_at<I>(
286        &mut self,
287        events: I,
288        timestamp: String,
289    ) -> Vec<SessionNodeRecord>
290    where
291        I: IntoIterator<Item = SessionEventRecord>,
292    {
293        self.append_drafts_at(events.into_iter().map(SessionNodeDraft::event), timestamp)
294    }
295
296    pub(crate) fn append_drafts_at<I>(
297        &mut self,
298        drafts: I,
299        timestamp: String,
300    ) -> Vec<SessionNodeRecord>
301    where
302        I: IntoIterator<Item = SessionNodeDraft>,
303    {
304        let mut nodes = Vec::new();
305        for draft in drafts {
306            let parent_node_id = self.leaf_node_id.clone();
307            let (node_id, caused_by, payload) = match draft.payload {
308                SessionNodeDraftPayload::Message(mut message) => {
309                    if message.id.is_empty() {
310                        message.id = fresh_node_id("m");
311                    }
312                    let node_id = unique_message_node_id(&message.id, &self.existing_ids);
313                    let caused_by = draft
314                        .caused_by
315                        .or_else(|| causal_ref_from_message_origin(&message.origin));
316                    (
317                        node_id,
318                        caused_by,
319                        SessionNodePayload::Event {
320                            event: SessionEventRecord::Conversation(
321                                ConversationRecord::from_message(message),
322                            ),
323                        },
324                    )
325                }
326                SessionNodeDraftPayload::Plugin { plugin_type, body } => {
327                    let node_id = fresh_semantic_node_id("plugin", &self.existing_ids);
328                    (
329                        node_id,
330                        draft.caused_by,
331                        SessionNodePayload::Plugin {
332                            plugin_type,
333                            body: SharedJsonValue::new(body),
334                        },
335                    )
336                }
337                SessionNodeDraftPayload::ProtocolEvent(event) => {
338                    let node_id = fresh_semantic_node_id("protocol", &self.existing_ids);
339                    (
340                        node_id,
341                        draft.caused_by,
342                        SessionNodePayload::Event {
343                            event: SessionEventRecord::Protocol(event),
344                        },
345                    )
346                }
347            };
348            self.existing_ids.insert(node_id.clone());
349            self.leaf_node_id = Some(node_id.clone());
350            nodes.push(SessionNodeRecord {
351                node_id,
352                parent_node_id,
353                caused_by,
354                agent_frame_id: self.agent_frame_id.clone(),
355                timestamp: timestamp.clone(),
356                payload,
357            });
358        }
359        nodes
360    }
361}
362
363#[derive(Debug, Clone)]
364struct SessionGraphCache {
365    by_id: HashMap<String, usize>,
366    active_path_indices: Vec<usize>,
367    active_events: Arc<Vec<SessionEventRecord>>,
368    active_messages: Arc<Vec<Message>>,
369    /// Index from `Message::id` to its position in `active_messages`,
370    /// kept in sync with the vec so dedup on append is O(1) instead of an
371    /// O(n) linear scan (which made long sessions quadratic in message
372    /// count).
373    active_message_ids: HashMap<String, usize>,
374    /// Memoized render of `active_messages`. Shared with every
375    /// `MessageSequence` built off this read model so the chat projector's
376    /// per-iteration `render_prompt` walk only happens once per turn.
377    /// Replaced (not invalidated in-place) whenever `active_messages`
378    /// changes — the `Arc` identity tracks the cache's validity.
379    prompt_render_cache: Arc<BaseRenderCache>,
380}
381
382impl SessionGraphCache {
383    fn build(graph: &SessionGraph) -> Self {
384        let by_id = graph
385            .nodes
386            .iter()
387            .enumerate()
388            .map(|(idx, node)| (node.node_id.clone(), idx))
389            .collect::<HashMap<_, _>>();
390        let mut active_path_indices = Vec::new();
391        let mut current = graph
392            .leaf_node_id
393            .as_ref()
394            .and_then(|node_id| by_id.get(node_id).copied());
395        while let Some(idx) = current {
396            active_path_indices.push(idx);
397            current = graph.nodes[idx]
398                .parent_node_id
399                .as_ref()
400                .and_then(|node_id| by_id.get(node_id).copied());
401        }
402        active_path_indices.reverse();
403
404        let mut cache = Self {
405            by_id,
406            active_path_indices,
407            active_events: Arc::new(Vec::new()),
408            active_messages: Arc::new(Vec::new()),
409            active_message_ids: HashMap::new(),
410            prompt_render_cache: Arc::new(BaseRenderCache::new()),
411        };
412        cache.rebuild_read_model(graph);
413        cache
414    }
415
416    fn rebuild_read_model(&mut self, graph: &SessionGraph) {
417        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
418        let mut active_message_ids: HashMap<String, usize> =
419            HashMap::with_capacity(self.active_path_indices.len());
420        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
421        for idx in &self.active_path_indices {
422            let node = &graph.nodes[*idx];
423            if let Some(event) = node.event() {
424                active_events.push(event.clone());
425            }
426            if let Some(message) = node.message() {
427                if !message.is_transient() && !active_message_ids.contains_key(&message.id) {
428                    active_message_ids.insert(message.id.clone(), active_messages.len());
429                    active_messages.push(message);
430                }
431                continue;
432            }
433        }
434        self.active_messages = Arc::new(active_messages);
435        self.active_message_ids = active_message_ids;
436        self.active_events = Arc::new(active_events);
437        self.prompt_render_cache = Arc::new(BaseRenderCache::new());
438    }
439
440    fn read_model_for_agent_frame(
441        &self,
442        graph: &SessionGraph,
443        frame_id: &str,
444        include_unscoped: bool,
445    ) -> SessionReadModel {
446        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
447        let mut active_message_ids = HashSet::new();
448        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
449        for idx in &self.active_path_indices {
450            let node = &graph.nodes[*idx];
451            if !node_belongs_to_agent_frame(node, frame_id, include_unscoped) {
452                continue;
453            }
454            if let Some(event) = node.event() {
455                active_events.push(event.clone());
456            }
457            if let Some(message) = node.message() {
458                if !message.is_transient() && active_message_ids.insert(message.id.clone()) {
459                    active_messages.push(message);
460                }
461                continue;
462            }
463        }
464        SessionReadModel {
465            active_events: Arc::new(active_events),
466            messages: Arc::new(active_messages),
467            prompt_render_cache: Arc::new(BaseRenderCache::new()),
468        }
469    }
470
471    fn append_node(
472        &mut self,
473        node_index: usize,
474        node: &SessionNodeRecord,
475        previous_leaf_node_id: Option<&str>,
476    ) {
477        self.by_id.insert(node.node_id.clone(), node_index);
478        let parent_matches_leaf = node.parent_node_id.as_deref() == previous_leaf_node_id;
479        if !parent_matches_leaf {
480            return;
481        }
482        self.active_path_indices.push(node_index);
483        if let Some(event) = node.event() {
484            Arc::make_mut(&mut self.active_events).push(event.clone());
485        }
486        if let Some(message) = node.message()
487            && !message.is_transient()
488            && !self.active_message_ids.contains_key(&message.id)
489        {
490            let messages = Arc::make_mut(&mut self.active_messages);
491            self.active_message_ids
492                .insert(message.id.clone(), messages.len());
493            messages.push(message);
494            self.prompt_render_cache = Arc::new(BaseRenderCache::new());
495        }
496    }
497
498    fn reserve_append_capacity(&mut self, additional_nodes: usize, additional_messages: usize) {
499        self.by_id.reserve(additional_nodes);
500        self.active_path_indices.reserve(additional_nodes);
501        if additional_messages > 0 {
502            Arc::make_mut(&mut self.active_messages).reserve(additional_messages);
503        }
504    }
505}
506
507impl SessionNodeRecord {
508    pub fn event(&self) -> Option<&SessionEventRecord> {
509        match &self.payload {
510            SessionNodePayload::Event { event } => Some(event),
511            SessionNodePayload::Plugin { .. } => None,
512        }
513    }
514
515    pub fn message(&self) -> Option<Message> {
516        match self.event()? {
517            SessionEventRecord::Conversation(record) => Some(record.to_message()),
518            _ => None,
519        }
520    }
521
522    pub fn plugin(&self) -> Option<(&str, &serde_json::Value)> {
523        match &self.payload {
524            SessionNodePayload::Event { .. } => None,
525            SessionNodePayload::Plugin { plugin_type, body } => {
526                Some((plugin_type.as_str(), body.as_ref()))
527            }
528        }
529    }
530
531    pub fn plugin_body<T>(&self) -> Option<T>
532    where
533        T: for<'de> serde::Deserialize<'de>,
534    {
535        let (_, body) = self.plugin()?;
536        T::deserialize(body).ok()
537    }
538}
539
540impl SessionGraph {
541    pub fn append_active_read_delta(&mut self, messages: &[Message]) {
542        self.append_active_read_delta_scoped(None, messages);
543    }
544
545    pub fn append_active_read_delta_for_agent_frame(
546        &mut self,
547        agent_frame_id: &str,
548        messages: &[Message],
549    ) {
550        self.append_active_read_delta_scoped(Some(agent_frame_id), messages);
551    }
552
553    fn append_active_read_delta_scoped(
554        &mut self,
555        agent_frame_id: Option<&str>,
556        messages: &[Message],
557    ) {
558        let appendable_messages = {
559            let read_model = agent_frame_id
560                .map(|frame_id| self.read_model_for_agent_frame(frame_id, false))
561                .unwrap_or_else(|| self.read_model());
562            let mut seen_message_ids = read_model
563                .messages
564                .iter()
565                .map(|message| message.id.as_str())
566                .collect::<HashSet<_>>();
567            messages
568                .iter()
569                .filter(|message| {
570                    !message.is_transient() && seen_message_ids.insert(message.id.as_str())
571                })
572                .cloned()
573                .collect::<Vec<_>>()
574        };
575
576        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len());
577        self.append_message_batch_scoped(agent_frame_id, appendable_messages);
578    }
579
580    pub(crate) fn append_active_conversation_messages_for_agent_frame(
581        &mut self,
582        agent_frame_id: &str,
583        messages: &[Message],
584    ) {
585        self.append_active_conversation_messages_scoped(Some(agent_frame_id), messages);
586    }
587
588    fn append_active_conversation_messages_scoped(
589        &mut self,
590        agent_frame_id: Option<&str>,
591        messages: &[Message],
592    ) {
593        self.append_active_conversation_messages_scoped_at(
594            agent_frame_id,
595            messages,
596            crate::SystemClock.timestamp_rfc3339(),
597        );
598    }
599
600    pub(crate) fn append_active_conversation_messages_for_agent_frame_at(
601        &mut self,
602        agent_frame_id: &str,
603        messages: &[Message],
604        timestamp: String,
605    ) {
606        self.append_active_conversation_messages_scoped_at(
607            Some(agent_frame_id),
608            messages,
609            timestamp,
610        );
611    }
612
613    fn append_active_conversation_messages_scoped_at(
614        &mut self,
615        agent_frame_id: Option<&str>,
616        messages: &[Message],
617        timestamp: String,
618    ) {
619        let appendable_messages = messages
620            .iter()
621            .filter(|message| !message.is_transient())
622            .cloned()
623            .collect::<Vec<_>>();
624        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len());
625        self.append_message_batch_scoped_at(agent_frame_id, appendable_messages, timestamp);
626    }
627
628    pub fn from_nodes(nodes: Vec<SessionNodeRecord>, leaf_node_id: Option<String>) -> Self {
629        Self {
630            inner: Arc::new(SessionGraphData {
631                nodes,
632                leaf_node_id,
633            }),
634            cache: Arc::new(OnceLock::new()),
635        }
636    }
637
638    pub(crate) fn append_builder(&self) -> SessionGraphAppendBuilder {
639        SessionGraphAppendBuilder {
640            existing_ids: self.nodes.iter().map(|node| node.node_id.clone()).collect(),
641            leaf_node_id: self.leaf_node_id.clone(),
642            agent_frame_id: None,
643        }
644    }
645
646    fn invalidate_cache(&mut self) {
647        self.cache = Arc::new(OnceLock::new());
648    }
649
650    fn data_mut(&mut self) -> &mut SessionGraphData {
651        self.invalidate_cache();
652        Arc::make_mut(&mut self.inner)
653    }
654
655    fn reserve_append_capacity(&mut self, additional_nodes: usize, additional_messages: usize) {
656        if additional_nodes == 0 {
657            return;
658        }
659        self.detach_initialized_cache_for_append();
660        Arc::make_mut(&mut self.inner)
661            .nodes
662            .reserve(additional_nodes);
663        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
664            && let Some(cache) = cache_lock.get_mut()
665        {
666            cache.reserve_append_capacity(additional_nodes, additional_messages);
667        }
668    }
669
670    fn detach_initialized_cache_for_append(&mut self) {
671        if Arc::get_mut(&mut self.cache).is_some() {
672            return;
673        }
674        let Some(cache) = self.cache.get().cloned() else {
675            self.invalidate_cache();
676            return;
677        };
678        let lock = OnceLock::new();
679        let _ = lock.set(cache);
680        self.cache = Arc::new(lock);
681    }
682
683    fn cache(&self) -> &SessionGraphCache {
684        self.cache.get_or_init(|| SessionGraphCache::build(self))
685    }
686
687    fn append_message_batch_scoped(
688        &mut self,
689        agent_frame_id: Option<&str>,
690        messages: Vec<Message>,
691    ) {
692        self.append_message_batch_scoped_at(
693            agent_frame_id,
694            messages,
695            crate::SystemClock.timestamp_rfc3339(),
696        );
697    }
698
699    fn append_message_batch_scoped_at(
700        &mut self,
701        agent_frame_id: Option<&str>,
702        messages: Vec<Message>,
703        timestamp: String,
704    ) {
705        if messages.is_empty() {
706            return;
707        }
708        self.append_node_drafts_scoped_at(
709            agent_frame_id,
710            messages.into_iter().map(SessionNodeDraft::message),
711            timestamp,
712        );
713    }
714
715    fn append_prebuilt_nodes(&mut self, nodes: Vec<SessionNodeRecord>) {
716        if nodes.is_empty() {
717            return;
718        }
719
720        self.detach_initialized_cache_for_append();
721        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
722            && let Some(cache) = cache_lock.get_mut()
723        {
724            let data = Arc::make_mut(&mut self.inner);
725            for node in nodes {
726                let previous_leaf = data.leaf_node_id.clone();
727                let node_id = node.node_id.clone();
728                data.nodes.push(node);
729                cache.append_node(
730                    data.nodes.len() - 1,
731                    data.nodes.last().expect("just appended graph node"),
732                    previous_leaf.as_deref(),
733                );
734                data.leaf_node_id = Some(node_id);
735            }
736            return;
737        }
738
739        let data = self.data_mut();
740        for node in nodes {
741            data.leaf_node_id = Some(node.node_id.clone());
742            data.nodes.push(node);
743        }
744    }
745
746    pub fn append_message(&mut self, message: Message) -> String {
747        self.append_node_draft(SessionNodeDraft::message(message))
748    }
749
750    pub fn append_plugin(
751        &mut self,
752        plugin_type: impl Into<String>,
753        body: serde_json::Value,
754    ) -> String {
755        self.append_node_draft(SessionNodeDraft::plugin(plugin_type, body))
756    }
757
758    pub fn active_path_nodes(&self) -> Vec<&SessionNodeRecord> {
759        self.cache()
760            .active_path_indices
761            .iter()
762            .map(|idx| &self.nodes[*idx])
763            .collect()
764    }
765
766    pub(crate) fn read_model(&self) -> SessionReadModel {
767        let cache = self.cache();
768        SessionReadModel {
769            active_events: Arc::clone(&cache.active_events),
770            messages: Arc::clone(&cache.active_messages),
771            prompt_render_cache: Arc::clone(&cache.prompt_render_cache),
772        }
773    }
774
775    pub(crate) fn read_model_for_agent_frame(
776        &self,
777        frame_id: &str,
778        include_unscoped: bool,
779    ) -> SessionReadModel {
780        if frame_id.is_empty() {
781            return self.read_model();
782        }
783        self.cache()
784            .read_model_for_agent_frame(self, frame_id, include_unscoped)
785    }
786
787    pub fn append_protocol_event(&mut self, event: ProtocolEvent) -> String {
788        self.append_node_draft(SessionNodeDraft::protocol_event(event))
789    }
790
791    pub(crate) fn append_node_draft(&mut self, draft: SessionNodeDraft) -> String {
792        self.append_node_drafts([draft])
793            .into_iter()
794            .next()
795            .expect("single draft append must create one node")
796    }
797
798    pub(crate) fn append_node_drafts<I>(&mut self, drafts: I) -> Vec<String>
799    where
800        I: IntoIterator<Item = SessionNodeDraft>,
801    {
802        self.append_node_drafts_scoped_at(None, drafts, crate::SystemClock.timestamp_rfc3339())
803    }
804
805    pub(crate) fn append_node_drafts_for_agent_frame_at<I>(
806        &mut self,
807        agent_frame_id: &str,
808        drafts: I,
809        timestamp: String,
810    ) -> Vec<String>
811    where
812        I: IntoIterator<Item = SessionNodeDraft>,
813    {
814        self.append_node_drafts_scoped_at(Some(agent_frame_id), drafts, timestamp)
815    }
816
817    fn append_node_drafts_scoped_at<I>(
818        &mut self,
819        agent_frame_id: Option<&str>,
820        drafts: I,
821        timestamp: String,
822    ) -> Vec<String>
823    where
824        I: IntoIterator<Item = SessionNodeDraft>,
825    {
826        let mut builder = self.append_builder();
827        if let Some(agent_frame_id) = agent_frame_id {
828            builder = builder.with_agent_frame_id(agent_frame_id.to_string());
829        }
830        let nodes = builder.append_drafts_at(drafts, timestamp);
831        let node_ids = nodes
832            .iter()
833            .map(|node| node.node_id.clone())
834            .collect::<Vec<_>>();
835        self.append_prebuilt_nodes(nodes);
836        node_ids
837    }
838
839    pub fn user_message_count(&self) -> usize {
840        self.nodes
841            .iter()
842            .filter_map(SessionNodeRecord::message)
843            .filter(|message| matches!(message.role, MessageRole::User))
844            .count()
845    }
846
847    pub fn first_user_message(&self) -> String {
848        self.nodes
849            .iter()
850            .filter_map(SessionNodeRecord::message)
851            .find(|message| matches!(message.role, MessageRole::User))
852            .map(|message| first_message_search_text(&message))
853            .unwrap_or_default()
854    }
855
856    pub fn branch_to(&mut self, node_id: Option<String>) {
857        self.data_mut().leaf_node_id = node_id;
858    }
859
860    pub fn set_leaf_node_id(&mut self, node_id: Option<String>) {
861        self.data_mut().leaf_node_id = node_id;
862    }
863
864    pub fn push_node_record(&mut self, node: SessionNodeRecord) {
865        self.data_mut().nodes.push(node);
866    }
867
868    pub fn extend_node_records<I>(&mut self, nodes: I)
869    where
870        I: IntoIterator<Item = SessionNodeRecord>,
871    {
872        self.data_mut().nodes.extend(nodes);
873    }
874
875    /// Append nodes that extend the current active path, advancing the
876    /// leaf to the last node and updating the cache incrementally
877    /// instead of invalidating it. Use this when the appended nodes are
878    /// genuinely new descendants of the current leaf — e.g. the
879    /// turn-driver merging turn-local graph editor deltas into the base graph.
880    /// Use `extend_node_records` + `set_leaf_node_id` for store-side
881    /// replay paths that don't follow the active-path append shape.
882    pub fn extend_active_path(&mut self, nodes: Vec<SessionNodeRecord>) {
883        self.append_prebuilt_nodes(nodes);
884    }
885
886    pub fn active_path_contains(&self, node_id: &str) -> bool {
887        self.active_path_nodes()
888            .into_iter()
889            .any(|node| node.node_id == node_id)
890    }
891
892    /// If `leaf_node_id` points to a node that no longer exists in
893    /// `self.nodes` (e.g. after compaction rewrote the graph, or a
894    /// stored session referenced a node that was later purged), fall
895    /// back to the most recent message node. Returns `true` if the
896    /// leaf was repaired. Call this on load paths where an orphan
897    /// leaf would project to an empty transcript and silently drop
898    /// the user's history.
899    pub fn heal_orphaned_leaf(&mut self) -> bool {
900        if let Some(leaf) = self.leaf_node_id.as_ref()
901            && self.find_node(leaf).is_none()
902        {
903            let fallback = self
904                .nodes
905                .iter()
906                .rev()
907                .find(|node| node.message().is_some())
908                .map(|node| node.node_id.clone());
909            self.data_mut().leaf_node_id = fallback;
910            return true;
911        }
912        false
913    }
914
915    pub fn fork_current_path(&self) -> SessionGraph {
916        let path = self.active_path_nodes();
917        SessionGraph::from_nodes(
918            path.into_iter().cloned().collect(),
919            self.leaf_node_id.clone(),
920        )
921    }
922
923    pub fn find_node(&self, node_id: &str) -> Option<&SessionNodeRecord> {
924        self.cache()
925            .by_id
926            .get(node_id)
927            .and_then(|idx| self.nodes.get(*idx))
928    }
929
930    pub fn node_index(&self, node_id: &str) -> Option<usize> {
931        self.cache().by_id.get(node_id).copied()
932    }
933
934    pub fn replace_active_read_state(&mut self, messages: &[Message]) {
935        self.replace_active_read_state_scoped(None, messages);
936    }
937
938    pub fn replace_active_read_state_for_agent_frame(
939        &mut self,
940        agent_frame_id: &str,
941        messages: &[Message],
942    ) {
943        self.replace_active_read_state_scoped(Some(agent_frame_id), messages);
944    }
945
946    fn replace_active_read_state_scoped(
947        &mut self,
948        agent_frame_id: Option<&str>,
949        messages: &[Message],
950    ) {
951        let current_nodes = self.active_path_nodes();
952        let existing_ids = self
953            .nodes
954            .iter()
955            .map(|node| node.node_id.clone())
956            .collect::<HashSet<_>>();
957        let replacement = build_active_read_replacement(
958            current_nodes,
959            &existing_ids,
960            agent_frame_id,
961            messages,
962            crate::SystemClock.timestamp_rfc3339(),
963        );
964        let data = self.data_mut();
965        data.leaf_node_id = replacement.leaf_node_id;
966        data.nodes.extend(replacement.new_tail_nodes);
967    }
968
969    pub fn from_active_read_state(messages: &[Message]) -> Self {
970        let mut graph = Self::default();
971        graph.replace_active_read_state(messages);
972        graph
973    }
974
975    pub fn message_tree(&self) -> Vec<SessionMessageTreeNode> {
976        let active_message_ids = self
977            .active_path_nodes()
978            .into_iter()
979            .filter_map(|node| node.message().map(|message| message.id.clone()))
980            .collect::<HashSet<_>>();
981
982        let message_nodes = self
983            .nodes
984            .iter()
985            .filter_map(|node| {
986                let message = node.message()?.clone();
987                let parent_message_node_id =
988                    self.nearest_message_ancestor(node.parent_node_id.as_deref());
989                Some(SessionMessageTreeNode {
990                    node_id: node.node_id.clone(),
991                    parent_message_node_id,
992                    message,
993                    timestamp: node.timestamp.clone(),
994                    children: Vec::new(),
995                    active: active_message_ids.contains(&node.node_id),
996                })
997            })
998            .collect::<Vec<_>>();
999
1000        build_tree(message_nodes)
1001    }
1002
1003    fn nearest_message_ancestor(&self, node_id: Option<&str>) -> Option<String> {
1004        let by_id = self
1005            .nodes
1006            .iter()
1007            .map(|node| (node.node_id.as_str(), node))
1008            .collect::<HashMap<_, _>>();
1009        let mut current = node_id.and_then(|id| by_id.get(id).copied());
1010        while let Some(node) = current {
1011            if node.message().is_some() {
1012                return Some(node.node_id.clone());
1013            }
1014            current = node
1015                .parent_node_id
1016                .as_deref()
1017                .and_then(|parent| by_id.get(parent).copied());
1018        }
1019        None
1020    }
1021}
1022
1023fn build_tree(mut nodes: Vec<SessionMessageTreeNode>) -> Vec<SessionMessageTreeNode> {
1024    let mut children_by_parent = HashMap::<Option<String>, Vec<SessionMessageTreeNode>>::new();
1025    for node in nodes.drain(..) {
1026        children_by_parent
1027            .entry(node.parent_message_node_id.clone())
1028            .or_default()
1029            .push(node);
1030    }
1031    let mut roots = build_tree_children(None, &mut children_by_parent);
1032    sort_tree(&mut roots);
1033    roots
1034}
1035
1036fn sort_tree(nodes: &mut [SessionMessageTreeNode]) {
1037    nodes.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
1038    for node in nodes {
1039        sort_tree(&mut node.children);
1040    }
1041}
1042
1043fn build_tree_children(
1044    parent_id: Option<String>,
1045    children_by_parent: &mut HashMap<Option<String>, Vec<SessionMessageTreeNode>>,
1046) -> Vec<SessionMessageTreeNode> {
1047    let mut children = children_by_parent.remove(&parent_id).unwrap_or_default();
1048    for child in &mut children {
1049        child.children = build_tree_children(Some(child.node_id.clone()), children_by_parent);
1050    }
1051    children
1052}
1053
1054fn node_belongs_to_agent_frame(
1055    node: &SessionNodeRecord,
1056    frame_id: &str,
1057    include_unscoped: bool,
1058) -> bool {
1059    match node.agent_frame_id.as_deref() {
1060        Some(node_frame_id) => node_frame_id == frame_id,
1061        None => include_unscoped,
1062    }
1063}
1064
1065pub(crate) fn build_active_read_replacement<'a>(
1066    current_nodes: impl IntoIterator<Item = &'a SessionNodeRecord>,
1067    existing_node_ids: &HashSet<String>,
1068    agent_frame_id: Option<&str>,
1069    messages: &[Message],
1070    timestamp: String,
1071) -> ActiveReadReplacement {
1072    let target = messages
1073        .iter()
1074        .filter(|message| !message.is_transient())
1075        .collect::<Vec<_>>();
1076
1077    let mut active_events = Vec::new();
1078    let mut active_messages = Vec::new();
1079    let mut active_message_ids = HashSet::new();
1080    let mut seen_active_read_keys = HashSet::new();
1081    let mut target_idx = 0usize;
1082    let mut leaf_node_id = None;
1083    for node in current_nodes {
1084        if node
1085            .message()
1086            .map(|message| message.is_transient())
1087            .unwrap_or(false)
1088        {
1089            continue;
1090        }
1091        if let Some(key) = recognized_active_read_key(node) {
1092            if !seen_active_read_keys.insert(key.clone()) {
1093                continue;
1094            }
1095            let Some(target_item) = target.get(target_idx) else {
1096                break;
1097            };
1098            if key != format!("message:{}", target_item.id) {
1099                break;
1100            }
1101            push_active_read_node(
1102                node,
1103                &mut active_events,
1104                &mut active_messages,
1105                &mut active_message_ids,
1106            );
1107            leaf_node_id = Some(node.node_id.clone());
1108            target_idx += 1;
1109        } else {
1110            push_active_read_node(
1111                node,
1112                &mut active_events,
1113                &mut active_messages,
1114                &mut active_message_ids,
1115            );
1116            leaf_node_id = Some(node.node_id.clone());
1117        }
1118    }
1119
1120    let mut new_node_ids = HashSet::new();
1121    let mut new_tail_nodes = Vec::new();
1122
1123    for message in target.into_iter().skip(target_idx) {
1124        let parent_node_id = leaf_node_id.clone();
1125        let node_id =
1126            unique_message_node_id_for_replacement(&message.id, existing_node_ids, &new_node_ids);
1127        let node = SessionNodeRecord {
1128            node_id,
1129            parent_node_id,
1130            caused_by: causal_ref_from_message_origin(&message.origin),
1131            agent_frame_id: agent_frame_id.map(ToOwned::to_owned),
1132            timestamp: timestamp.clone(),
1133            payload: SessionNodePayload::Event {
1134                event: SessionEventRecord::Conversation(ConversationRecord::from_message(
1135                    message.clone(),
1136                )),
1137            },
1138        };
1139        new_node_ids.insert(node.node_id.clone());
1140        leaf_node_id = Some(node.node_id.clone());
1141        push_active_read_node(
1142            &node,
1143            &mut active_events,
1144            &mut active_messages,
1145            &mut active_message_ids,
1146        );
1147        new_tail_nodes.push(node);
1148    }
1149
1150    ActiveReadReplacement {
1151        leaf_node_id,
1152        new_tail_nodes,
1153        active_events,
1154        active_messages,
1155    }
1156}
1157
1158fn push_active_read_node(
1159    node: &SessionNodeRecord,
1160    active_events: &mut Vec<SessionEventRecord>,
1161    active_messages: &mut Vec<Message>,
1162    active_message_ids: &mut HashSet<String>,
1163) {
1164    if let Some(event) = node.event() {
1165        active_events.push(event.clone());
1166    }
1167    if let Some(message) = node.message()
1168        && !message.is_transient()
1169        && active_message_ids.insert(message.id.clone())
1170    {
1171        active_messages.push(message);
1172    }
1173}
1174
1175fn recognized_active_read_key(node: &SessionNodeRecord) -> Option<String> {
1176    match &node.payload {
1177        SessionNodePayload::Event { event } => match event {
1178            SessionEventRecord::Conversation(record) => Some(format!("message:{}", record.id)),
1179            _ => None,
1180        },
1181        SessionNodePayload::Plugin { .. } => None,
1182    }
1183}
1184
1185fn causal_ref_from_message_origin(
1186    origin: &Option<crate::MessageOrigin>,
1187) -> Option<crate::CausalRef> {
1188    let Some(crate::MessageOrigin::Process {
1189        process_id,
1190        sequence,
1191        ..
1192    }) = origin
1193    else {
1194        return None;
1195    };
1196    Some(crate::CausalRef::ProcessEvent {
1197        process_id: process_id.clone(),
1198        sequence: *sequence,
1199    })
1200}
1201
1202fn fresh_semantic_node_id(prefix: &str, existing_ids: &HashSet<String>) -> String {
1203    loop {
1204        let candidate = format!("{prefix}:{}", uuid::Uuid::new_v4().simple());
1205        if !existing_ids.contains(&candidate) {
1206            return candidate;
1207        }
1208    }
1209}
1210
1211fn unique_message_node_id(message_id: &str, existing_ids: &HashSet<String>) -> String {
1212    if !existing_ids.contains(message_id) {
1213        return message_id.to_string();
1214    }
1215    let base = format!("message:{message_id}");
1216    if !existing_ids.contains(&base) {
1217        return base;
1218    }
1219    for suffix in 2.. {
1220        let candidate = format!("{base}:{suffix}");
1221        if !existing_ids.contains(&candidate) {
1222            return candidate;
1223        }
1224    }
1225    unreachable!("message node id space exhausted")
1226}
1227
1228fn unique_message_node_id_for_replacement(
1229    message_id: &str,
1230    existing_ids: &HashSet<String>,
1231    new_ids: &HashSet<String>,
1232) -> String {
1233    if !existing_ids.contains(message_id) && !new_ids.contains(message_id) {
1234        return message_id.to_string();
1235    }
1236    let base = format!("message:{message_id}");
1237    if !existing_ids.contains(&base) && !new_ids.contains(&base) {
1238        return base;
1239    }
1240    for suffix in 2.. {
1241        let candidate = format!("{base}:{suffix}");
1242        if !existing_ids.contains(&candidate) && !new_ids.contains(&candidate) {
1243            return candidate;
1244        }
1245    }
1246    unreachable!("message node id space exhausted")
1247}
1248
1249fn fresh_node_id(prefix: &str) -> String {
1250    format!("{prefix}{}", uuid::Uuid::new_v4().simple())
1251}
1252
1253fn first_message_search_text(message: &Message) -> String {
1254    message
1255        .parts
1256        .iter()
1257        .filter_map(|part| match part.kind {
1258            crate::PartKind::ToolCall | crate::PartKind::ToolResult => None,
1259            crate::PartKind::Image => Some("[Image attached]".to_string()),
1260            _ => (!part.content.trim().is_empty()).then(|| part.content.clone()),
1261        })
1262        .collect::<Vec<_>>()
1263        .join("\n\n")
1264        .trim()
1265        .to_string()
1266}
1267
1268#[cfg(test)]
1269mod tests {
1270    use super::*;
1271    use crate::{Part, PartKind, PruneState, shared_parts};
1272
1273    fn text_message(id: &str, role: MessageRole, content: &str) -> Message {
1274        Message {
1275            id: id.to_string(),
1276            role,
1277            parts: shared_parts(vec![Part {
1278                id: format!("{id}.p0"),
1279                kind: PartKind::Text,
1280                content: content.to_string(),
1281                attachment: None,
1282                tool_call_id: None,
1283                tool_name: None,
1284                tool_replay: None,
1285                prune_state: PruneState::Intact,
1286                reasoning_meta: None,
1287                response_meta: None,
1288            }]),
1289            origin: None,
1290        }
1291    }
1292
1293    fn protocol_event() -> ProtocolEvent {
1294        ProtocolEvent::typed("test_protocol", serde_json::json!({"step": "started"}))
1295            .expect("protocol event serializes")
1296    }
1297
1298    #[test]
1299    fn typed_append_node_ids_use_semantic_prefixes() {
1300        let mut graph = SessionGraph::default();
1301
1302        let message_id = graph.append_message(text_message("m1", MessageRole::User, "hello"));
1303        let protocol_id = graph.append_protocol_event(protocol_event());
1304        let plugin_id = graph.append_plugin("example", serde_json::json!({"ok": true}));
1305
1306        assert_eq!(message_id, "m1");
1307        assert!(protocol_id.starts_with("protocol:"));
1308        assert!(plugin_id.starts_with("plugin:"));
1309    }
1310
1311    #[test]
1312    fn active_read_replacement_persists_messages_only() {
1313        let message = text_message("m1", MessageRole::User, "hello");
1314        let graph = SessionGraph::from_active_read_state(&[message]);
1315
1316        assert_eq!(graph.nodes.len(), 1);
1317        assert!(matches!(
1318            graph.nodes[0].event(),
1319            Some(SessionEventRecord::Conversation(_))
1320        ));
1321    }
1322
1323    #[test]
1324    fn graph_writers_do_not_put_active_read_events_under_plugin_ids() {
1325        let mut graph = SessionGraph::default();
1326        graph.append_message(text_message("m1", MessageRole::User, "hello"));
1327        graph.append_protocol_event(protocol_event());
1328        graph.append_plugin("example", serde_json::json!({"ok": true}));
1329
1330        for node in &graph.nodes {
1331            match node.event() {
1332                Some(SessionEventRecord::Conversation(_)) => {
1333                    assert!(!node.node_id.starts_with("plugin:"), "{:?}", node);
1334                }
1335                Some(SessionEventRecord::Protocol(_)) => {
1336                    assert!(node.node_id.starts_with("protocol:"), "{:?}", node);
1337                }
1338                None => {
1339                    assert!(node.node_id.starts_with("plugin:"), "{:?}", node);
1340                }
1341            }
1342        }
1343    }
1344}