Skip to main content

lash_core/
session_graph.rs

1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3use std::sync::{Arc, OnceLock};
4
5use crate::session_model::{ConversationRecord, ProtocolEvent, SessionEventRecord};
6use crate::{BaseRenderCache, Clock, Message, MessageRole, PromptUsage, TokenUsage};
7
8#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
9pub struct SessionGraphData {
10    #[serde(default)]
11    pub nodes: Vec<SessionNodeRecord>,
12    #[serde(default, skip_serializing_if = "Option::is_none")]
13    pub leaf_node_id: Option<String>,
14}
15
16#[derive(Debug)]
17pub struct SessionGraph {
18    inner: Arc<SessionGraphData>,
19    cache: Arc<OnceLock<SessionGraphCache>>,
20}
21
22impl Default for SessionGraph {
23    fn default() -> Self {
24        Self {
25            inner: Arc::new(SessionGraphData::default()),
26            cache: Arc::new(OnceLock::new()),
27        }
28    }
29}
30
31impl Clone for SessionGraph {
32    fn clone(&self) -> Self {
33        Self {
34            inner: Arc::clone(&self.inner),
35            cache: Arc::clone(&self.cache),
36        }
37    }
38}
39
40impl serde::Serialize for SessionGraph {
41    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
42    where
43        S: serde::Serializer,
44    {
45        self.inner.serialize(serializer)
46    }
47}
48
49impl<'de> serde::Deserialize<'de> for SessionGraph {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: serde::Deserializer<'de>,
53    {
54        let inner = SessionGraphData::deserialize(deserializer)?;
55        Ok(Self {
56            inner: Arc::new(inner),
57            cache: Arc::new(OnceLock::new()),
58        })
59    }
60}
61
62impl Deref for SessionGraph {
63    type Target = SessionGraphData;
64
65    fn deref(&self) -> &Self::Target {
66        self.inner.as_ref()
67    }
68}
69
70#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
71pub struct SessionNodeRecord {
72    pub node_id: String,
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub parent_node_id: Option<String>,
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub caused_by: Option<crate::CausalRef>,
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    pub agent_frame_id: Option<crate::AgentFrameId>,
79    pub timestamp: String,
80    #[serde(flatten)]
81    pub payload: SessionNodePayload,
82}
83
84#[derive(Clone, Debug)]
85pub(crate) struct SessionNodeDraft {
86    payload: SessionNodeDraftPayload,
87    caused_by: Option<crate::CausalRef>,
88}
89
90#[derive(Clone, Debug)]
91enum SessionNodeDraftPayload {
92    Message(Message),
93    Plugin {
94        plugin_type: String,
95        body: serde_json::Value,
96    },
97    ProtocolEvent(ProtocolEvent),
98}
99
100impl SessionNodeDraft {
101    pub(crate) fn message(message: Message) -> Self {
102        Self {
103            payload: SessionNodeDraftPayload::Message(message),
104            caused_by: None,
105        }
106    }
107
108    pub(crate) fn plugin(plugin_type: impl Into<String>, body: serde_json::Value) -> Self {
109        Self {
110            payload: SessionNodeDraftPayload::Plugin {
111                plugin_type: plugin_type.into(),
112                body,
113            },
114            caused_by: None,
115        }
116    }
117
118    pub(crate) fn protocol_event(event: ProtocolEvent) -> Self {
119        Self {
120            payload: SessionNodeDraftPayload::ProtocolEvent(event),
121            caused_by: None,
122        }
123    }
124
125    pub(crate) fn with_caused_by(mut self, caused_by: Option<crate::CausalRef>) -> Self {
126        self.caused_by = caused_by;
127        self
128    }
129}
130
131#[derive(Clone, Debug, Default, PartialEq)]
132pub struct SharedJsonValue(pub Arc<serde_json::Value>);
133
134impl SharedJsonValue {
135    pub fn new(value: serde_json::Value) -> Self {
136        Self(Arc::new(value))
137    }
138
139    pub fn to_owned(&self) -> serde_json::Value {
140        self.0.as_ref().clone()
141    }
142}
143
144impl AsRef<serde_json::Value> for SharedJsonValue {
145    fn as_ref(&self) -> &serde_json::Value {
146        self.0.as_ref()
147    }
148}
149
150impl serde::Serialize for SharedJsonValue {
151    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
152    where
153        S: serde::Serializer,
154    {
155        self.0.serialize(serializer)
156    }
157}
158
159impl<'de> serde::Deserialize<'de> for SharedJsonValue {
160    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
161    where
162        D: serde::Deserializer<'de>,
163    {
164        let value = serde_json::Value::deserialize(deserializer)?;
165        Ok(Self::new(value))
166    }
167}
168
169#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
170#[serde(tag = "kind", rename_all = "snake_case")]
171#[allow(clippy::large_enum_variant)]
172pub enum SessionNodePayload {
173    Event {
174        event: SessionEventRecord,
175    },
176    Plugin {
177        plugin_type: String,
178        body: SharedJsonValue,
179    },
180}
181
182#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
183pub struct PersistedSessionConfig {
184    pub provider_id: String,
185    pub model: crate::ModelSpec,
186}
187
188#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
189pub struct PersistedTurnState {
190    pub turn_index: usize,
191    #[serde(default)]
192    pub token_usage: TokenUsage,
193    #[serde(default, skip_serializing_if = "Option::is_none")]
194    pub last_prompt_usage: Option<PromptUsage>,
195    #[serde(default)]
196    pub protocol_turn_options: crate::ProtocolTurnOptions,
197}
198
199#[derive(Clone, Debug)]
200pub struct SessionMessageTreeNode {
201    pub node_id: String,
202    pub parent_message_node_id: Option<String>,
203    pub message: Message,
204    pub timestamp: String,
205    pub children: Vec<SessionMessageTreeNode>,
206    pub active: bool,
207}
208
209#[derive(Clone, Debug)]
210pub(crate) struct ActiveReadReplacement {
211    pub(crate) leaf_node_id: Option<String>,
212    pub(crate) new_tail_nodes: Vec<SessionNodeRecord>,
213    pub(crate) active_events: Vec<SessionEventRecord>,
214    pub(crate) active_messages: Vec<Message>,
215}
216
217#[derive(Clone, Debug)]
218pub(crate) struct SessionReadModel {
219    pub(crate) active_events: Arc<Vec<SessionEventRecord>>,
220    pub(crate) messages: Arc<Vec<Message>>,
221    pub(crate) prompt_render_cache: Arc<BaseRenderCache>,
222}
223
224#[derive(Clone, Debug)]
225pub(crate) struct SessionGraphAppendBuilder {
226    existing_ids: HashSet<String>,
227    leaf_node_id: Option<String>,
228    agent_frame_id: Option<crate::AgentFrameId>,
229}
230
231impl SessionGraphAppendBuilder {
232    pub(crate) fn with_agent_frame_id(
233        mut self,
234        agent_frame_id: impl Into<crate::AgentFrameId>,
235    ) -> Self {
236        self.agent_frame_id = Some(agent_frame_id.into());
237        self
238    }
239
240    pub(crate) fn agent_frame_id(&self) -> Option<&str> {
241        self.agent_frame_id.as_deref()
242    }
243
244    pub(crate) fn leaf_node_id(&self) -> Option<&String> {
245        self.leaf_node_id.as_ref()
246    }
247
248    pub(crate) fn set_leaf_node_id(&mut self, leaf_node_id: Option<String>) {
249        self.leaf_node_id = leaf_node_id;
250    }
251
252    pub(crate) fn register_existing_node_ids<'a>(
253        &mut self,
254        node_ids: impl IntoIterator<Item = &'a str>,
255    ) {
256        self.existing_ids
257            .extend(node_ids.into_iter().map(ToOwned::to_owned));
258    }
259
260    pub(crate) fn existing_node_ids(&self) -> &HashSet<String> {
261        &self.existing_ids
262    }
263
264    pub(crate) fn append_messages_at<I>(
265        &mut self,
266        messages: I,
267        timestamp: String,
268    ) -> Vec<SessionNodeRecord>
269    where
270        I: IntoIterator<Item = Message>,
271    {
272        self.append_drafts_at(
273            messages.into_iter().map(SessionNodeDraft::message),
274            timestamp,
275        )
276    }
277
278    pub(crate) fn append_protocol_events_at<I>(
279        &mut self,
280        events: I,
281        timestamp: String,
282    ) -> Vec<SessionNodeRecord>
283    where
284        I: IntoIterator<Item = ProtocolEvent>,
285    {
286        self.append_drafts_at(
287            events.into_iter().map(SessionNodeDraft::protocol_event),
288            timestamp,
289        )
290    }
291
292    pub(crate) fn append_drafts_at<I>(
293        &mut self,
294        drafts: I,
295        timestamp: String,
296    ) -> Vec<SessionNodeRecord>
297    where
298        I: IntoIterator<Item = SessionNodeDraft>,
299    {
300        let mut nodes = Vec::new();
301        for draft in drafts {
302            let parent_node_id = self.leaf_node_id.clone();
303            let (node_id, caused_by, payload) = match draft.payload {
304                SessionNodeDraftPayload::Message(mut message) => {
305                    if message.id.is_empty() {
306                        message.id = fresh_node_id("m");
307                    }
308                    let node_id = unique_message_node_id(&message.id, &self.existing_ids);
309                    let caused_by = draft
310                        .caused_by
311                        .or_else(|| causal_ref_from_message_origin(&message.origin));
312                    (
313                        node_id,
314                        caused_by,
315                        SessionNodePayload::Event {
316                            event: SessionEventRecord::Conversation(
317                                ConversationRecord::from_message(message),
318                            ),
319                        },
320                    )
321                }
322                SessionNodeDraftPayload::Plugin { plugin_type, body } => {
323                    let node_id = fresh_semantic_node_id("plugin", &self.existing_ids);
324                    (
325                        node_id,
326                        draft.caused_by,
327                        SessionNodePayload::Plugin {
328                            plugin_type,
329                            body: SharedJsonValue::new(body),
330                        },
331                    )
332                }
333                SessionNodeDraftPayload::ProtocolEvent(event) => {
334                    let node_id = fresh_semantic_node_id("protocol", &self.existing_ids);
335                    (
336                        node_id,
337                        draft.caused_by,
338                        SessionNodePayload::Event {
339                            event: SessionEventRecord::Protocol(event),
340                        },
341                    )
342                }
343            };
344            self.existing_ids.insert(node_id.clone());
345            self.leaf_node_id = Some(node_id.clone());
346            nodes.push(SessionNodeRecord {
347                node_id,
348                parent_node_id,
349                caused_by,
350                agent_frame_id: self.agent_frame_id.clone(),
351                timestamp: timestamp.clone(),
352                payload,
353            });
354        }
355        nodes
356    }
357}
358
359#[derive(Debug, Clone)]
360struct SessionGraphCache {
361    by_id: HashMap<String, usize>,
362    active_path_indices: Vec<usize>,
363    active_events: Arc<Vec<SessionEventRecord>>,
364    active_messages: Arc<Vec<Message>>,
365    /// Index from `Message::id` to its position in `active_messages`,
366    /// kept in sync with the vec so dedup on append is O(1) instead of an
367    /// O(n) linear scan (which made long sessions quadratic in message
368    /// count).
369    active_message_ids: HashMap<String, usize>,
370    /// Memoized render of `active_messages`. Shared with every
371    /// `MessageSequence` built off this read model so the chat projector's
372    /// per-iteration `render_prompt` walk only happens once per turn.
373    /// Replaced (not invalidated in-place) whenever `active_messages`
374    /// changes — the `Arc` identity tracks the cache's validity.
375    prompt_render_cache: Arc<BaseRenderCache>,
376}
377
378impl SessionGraphCache {
379    fn build(graph: &SessionGraph) -> Self {
380        let by_id = graph
381            .nodes
382            .iter()
383            .enumerate()
384            .map(|(idx, node)| (node.node_id.clone(), idx))
385            .collect::<HashMap<_, _>>();
386        let mut active_path_indices = Vec::new();
387        let mut current = graph
388            .leaf_node_id
389            .as_ref()
390            .and_then(|node_id| by_id.get(node_id).copied());
391        while let Some(idx) = current {
392            active_path_indices.push(idx);
393            current = graph.nodes[idx]
394                .parent_node_id
395                .as_ref()
396                .and_then(|node_id| by_id.get(node_id).copied());
397        }
398        active_path_indices.reverse();
399
400        let mut cache = Self {
401            by_id,
402            active_path_indices,
403            active_events: Arc::new(Vec::new()),
404            active_messages: Arc::new(Vec::new()),
405            active_message_ids: HashMap::new(),
406            prompt_render_cache: Arc::new(BaseRenderCache::new()),
407        };
408        cache.rebuild_read_model(graph);
409        cache
410    }
411
412    fn rebuild_read_model(&mut self, graph: &SessionGraph) {
413        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
414        let mut active_message_ids: HashMap<String, usize> =
415            HashMap::with_capacity(self.active_path_indices.len());
416        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
417        for idx in &self.active_path_indices {
418            let node = &graph.nodes[*idx];
419            if let Some(event) = node.event() {
420                active_events.push(event.clone());
421            }
422            if let Some(message) = node.message() {
423                if !message.is_transient() && !active_message_ids.contains_key(&message.id) {
424                    active_message_ids.insert(message.id.clone(), active_messages.len());
425                    active_messages.push(message);
426                }
427                continue;
428            }
429        }
430        self.active_messages = Arc::new(active_messages);
431        self.active_message_ids = active_message_ids;
432        self.active_events = Arc::new(active_events);
433        self.prompt_render_cache = Arc::new(BaseRenderCache::new());
434    }
435
436    fn read_model_for_agent_frame(
437        &self,
438        graph: &SessionGraph,
439        frame_id: &str,
440        include_unscoped: bool,
441    ) -> SessionReadModel {
442        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
443        let mut active_message_ids = HashSet::new();
444        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
445        for idx in &self.active_path_indices {
446            let node = &graph.nodes[*idx];
447            if !node_belongs_to_agent_frame(node, frame_id, include_unscoped) {
448                continue;
449            }
450            if let Some(event) = node.event() {
451                active_events.push(event.clone());
452            }
453            if let Some(message) = node.message() {
454                if !message.is_transient() && active_message_ids.insert(message.id.clone()) {
455                    active_messages.push(message);
456                }
457                continue;
458            }
459        }
460        SessionReadModel {
461            active_events: Arc::new(active_events),
462            messages: Arc::new(active_messages),
463            prompt_render_cache: Arc::new(BaseRenderCache::new()),
464        }
465    }
466
467    fn append_node(
468        &mut self,
469        node_index: usize,
470        node: &SessionNodeRecord,
471        previous_leaf_node_id: Option<&str>,
472    ) {
473        self.by_id.insert(node.node_id.clone(), node_index);
474        let parent_matches_leaf = node.parent_node_id.as_deref() == previous_leaf_node_id;
475        if !parent_matches_leaf {
476            return;
477        }
478        self.active_path_indices.push(node_index);
479        if let Some(event) = node.event() {
480            Arc::make_mut(&mut self.active_events).push(event.clone());
481        }
482        if let Some(message) = node.message()
483            && !message.is_transient()
484            && !self.active_message_ids.contains_key(&message.id)
485        {
486            let messages = Arc::make_mut(&mut self.active_messages);
487            self.active_message_ids
488                .insert(message.id.clone(), messages.len());
489            messages.push(message);
490            self.prompt_render_cache = Arc::new(BaseRenderCache::new());
491        }
492    }
493
494    fn reserve_append_capacity(&mut self, additional_nodes: usize, additional_messages: usize) {
495        self.by_id.reserve(additional_nodes);
496        self.active_path_indices.reserve(additional_nodes);
497        if additional_messages > 0 {
498            Arc::make_mut(&mut self.active_messages).reserve(additional_messages);
499        }
500    }
501}
502
503impl SessionNodeRecord {
504    pub fn event(&self) -> Option<&SessionEventRecord> {
505        match &self.payload {
506            SessionNodePayload::Event { event } => Some(event),
507            SessionNodePayload::Plugin { .. } => None,
508        }
509    }
510
511    pub fn message(&self) -> Option<Message> {
512        match self.event()? {
513            SessionEventRecord::Conversation(record) => Some(record.to_message()),
514            _ => None,
515        }
516    }
517
518    pub fn plugin(&self) -> Option<(&str, &serde_json::Value)> {
519        match &self.payload {
520            SessionNodePayload::Event { .. } => None,
521            SessionNodePayload::Plugin { plugin_type, body } => {
522                Some((plugin_type.as_str(), body.as_ref()))
523            }
524        }
525    }
526
527    pub fn plugin_body<T>(&self) -> Option<T>
528    where
529        T: for<'de> serde::Deserialize<'de>,
530    {
531        let (_, body) = self.plugin()?;
532        T::deserialize(body).ok()
533    }
534}
535
536impl SessionGraph {
537    pub fn append_active_read_delta(&mut self, messages: &[Message]) {
538        self.append_active_read_delta_scoped(None, messages);
539    }
540
541    pub fn append_active_read_delta_for_agent_frame(
542        &mut self,
543        agent_frame_id: &str,
544        messages: &[Message],
545    ) {
546        self.append_active_read_delta_scoped(Some(agent_frame_id), messages);
547    }
548
549    fn append_active_read_delta_scoped(
550        &mut self,
551        agent_frame_id: Option<&str>,
552        messages: &[Message],
553    ) {
554        let appendable_messages = {
555            let read_model = agent_frame_id
556                .map(|frame_id| self.read_model_for_agent_frame(frame_id, false))
557                .unwrap_or_else(|| self.read_model());
558            let mut seen_message_ids = read_model
559                .messages
560                .iter()
561                .map(|message| message.id.as_str())
562                .collect::<HashSet<_>>();
563            messages
564                .iter()
565                .filter(|message| {
566                    !message.is_transient() && seen_message_ids.insert(message.id.as_str())
567                })
568                .cloned()
569                .collect::<Vec<_>>()
570        };
571
572        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len());
573        self.append_message_batch_scoped(agent_frame_id, appendable_messages);
574    }
575
576    pub(crate) fn append_active_conversation_messages_for_agent_frame(
577        &mut self,
578        agent_frame_id: &str,
579        messages: &[Message],
580    ) {
581        self.append_active_conversation_messages_scoped(Some(agent_frame_id), messages);
582    }
583
584    fn append_active_conversation_messages_scoped(
585        &mut self,
586        agent_frame_id: Option<&str>,
587        messages: &[Message],
588    ) {
589        self.append_active_conversation_messages_scoped_at(
590            agent_frame_id,
591            messages,
592            crate::SystemClock.timestamp_rfc3339(),
593        );
594    }
595
596    pub(crate) fn append_active_conversation_messages_for_agent_frame_at(
597        &mut self,
598        agent_frame_id: &str,
599        messages: &[Message],
600        timestamp: String,
601    ) {
602        self.append_active_conversation_messages_scoped_at(
603            Some(agent_frame_id),
604            messages,
605            timestamp,
606        );
607    }
608
609    fn append_active_conversation_messages_scoped_at(
610        &mut self,
611        agent_frame_id: Option<&str>,
612        messages: &[Message],
613        timestamp: String,
614    ) {
615        let appendable_messages = messages
616            .iter()
617            .filter(|message| !message.is_transient())
618            .cloned()
619            .collect::<Vec<_>>();
620        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len());
621        self.append_message_batch_scoped_at(agent_frame_id, appendable_messages, timestamp);
622    }
623
624    pub fn from_nodes(nodes: Vec<SessionNodeRecord>, leaf_node_id: Option<String>) -> Self {
625        Self {
626            inner: Arc::new(SessionGraphData {
627                nodes,
628                leaf_node_id,
629            }),
630            cache: Arc::new(OnceLock::new()),
631        }
632    }
633
634    pub(crate) fn append_builder(&self) -> SessionGraphAppendBuilder {
635        SessionGraphAppendBuilder {
636            existing_ids: self.nodes.iter().map(|node| node.node_id.clone()).collect(),
637            leaf_node_id: self.leaf_node_id.clone(),
638            agent_frame_id: None,
639        }
640    }
641
642    fn invalidate_cache(&mut self) {
643        self.cache = Arc::new(OnceLock::new());
644    }
645
646    fn data_mut(&mut self) -> &mut SessionGraphData {
647        self.invalidate_cache();
648        Arc::make_mut(&mut self.inner)
649    }
650
651    fn reserve_append_capacity(&mut self, additional_nodes: usize, additional_messages: usize) {
652        if additional_nodes == 0 {
653            return;
654        }
655        self.detach_initialized_cache_for_append();
656        Arc::make_mut(&mut self.inner)
657            .nodes
658            .reserve(additional_nodes);
659        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
660            && let Some(cache) = cache_lock.get_mut()
661        {
662            cache.reserve_append_capacity(additional_nodes, additional_messages);
663        }
664    }
665
666    fn detach_initialized_cache_for_append(&mut self) {
667        if Arc::get_mut(&mut self.cache).is_some() {
668            return;
669        }
670        let Some(cache) = self.cache.get().cloned() else {
671            self.invalidate_cache();
672            return;
673        };
674        let lock = OnceLock::new();
675        let _ = lock.set(cache);
676        self.cache = Arc::new(lock);
677    }
678
679    fn cache(&self) -> &SessionGraphCache {
680        self.cache.get_or_init(|| SessionGraphCache::build(self))
681    }
682
683    fn append_message_batch_scoped(
684        &mut self,
685        agent_frame_id: Option<&str>,
686        messages: Vec<Message>,
687    ) {
688        self.append_message_batch_scoped_at(
689            agent_frame_id,
690            messages,
691            crate::SystemClock.timestamp_rfc3339(),
692        );
693    }
694
695    fn append_message_batch_scoped_at(
696        &mut self,
697        agent_frame_id: Option<&str>,
698        messages: Vec<Message>,
699        timestamp: String,
700    ) {
701        if messages.is_empty() {
702            return;
703        }
704        self.append_node_drafts_scoped_at(
705            agent_frame_id,
706            messages.into_iter().map(SessionNodeDraft::message),
707            timestamp,
708        );
709    }
710
711    fn append_prebuilt_nodes(&mut self, nodes: Vec<SessionNodeRecord>) {
712        if nodes.is_empty() {
713            return;
714        }
715
716        self.detach_initialized_cache_for_append();
717        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
718            && let Some(cache) = cache_lock.get_mut()
719        {
720            let data = Arc::make_mut(&mut self.inner);
721            for node in nodes {
722                let previous_leaf = data.leaf_node_id.clone();
723                let node_id = node.node_id.clone();
724                data.nodes.push(node);
725                cache.append_node(
726                    data.nodes.len() - 1,
727                    data.nodes.last().expect("just appended graph node"),
728                    previous_leaf.as_deref(),
729                );
730                data.leaf_node_id = Some(node_id);
731            }
732            return;
733        }
734
735        let data = self.data_mut();
736        for node in nodes {
737            data.leaf_node_id = Some(node.node_id.clone());
738            data.nodes.push(node);
739        }
740    }
741
742    pub fn append_message(&mut self, message: Message) -> String {
743        self.append_node_draft(SessionNodeDraft::message(message))
744    }
745
746    pub fn append_plugin(
747        &mut self,
748        plugin_type: impl Into<String>,
749        body: serde_json::Value,
750    ) -> String {
751        self.append_node_draft(SessionNodeDraft::plugin(plugin_type, body))
752    }
753
754    pub fn active_path_nodes(&self) -> Vec<&SessionNodeRecord> {
755        self.cache()
756            .active_path_indices
757            .iter()
758            .map(|idx| &self.nodes[*idx])
759            .collect()
760    }
761
762    pub(crate) fn read_model(&self) -> SessionReadModel {
763        let cache = self.cache();
764        SessionReadModel {
765            active_events: Arc::clone(&cache.active_events),
766            messages: Arc::clone(&cache.active_messages),
767            prompt_render_cache: Arc::clone(&cache.prompt_render_cache),
768        }
769    }
770
771    pub(crate) fn read_model_for_agent_frame(
772        &self,
773        frame_id: &str,
774        include_unscoped: bool,
775    ) -> SessionReadModel {
776        if frame_id.is_empty() {
777            return self.read_model();
778        }
779        self.cache()
780            .read_model_for_agent_frame(self, frame_id, include_unscoped)
781    }
782
783    pub fn append_protocol_event(&mut self, event: ProtocolEvent) -> String {
784        self.append_node_draft(SessionNodeDraft::protocol_event(event))
785    }
786
787    pub(crate) fn append_node_draft(&mut self, draft: SessionNodeDraft) -> String {
788        self.append_node_drafts([draft])
789            .into_iter()
790            .next()
791            .expect("single draft append must create one node")
792    }
793
794    pub(crate) fn append_node_drafts<I>(&mut self, drafts: I) -> Vec<String>
795    where
796        I: IntoIterator<Item = SessionNodeDraft>,
797    {
798        self.append_node_drafts_scoped_at(None, drafts, crate::SystemClock.timestamp_rfc3339())
799    }
800
801    pub(crate) fn append_node_drafts_for_agent_frame_at<I>(
802        &mut self,
803        agent_frame_id: &str,
804        drafts: I,
805        timestamp: String,
806    ) -> Vec<String>
807    where
808        I: IntoIterator<Item = SessionNodeDraft>,
809    {
810        self.append_node_drafts_scoped_at(Some(agent_frame_id), drafts, timestamp)
811    }
812
813    fn append_node_drafts_scoped_at<I>(
814        &mut self,
815        agent_frame_id: Option<&str>,
816        drafts: I,
817        timestamp: String,
818    ) -> Vec<String>
819    where
820        I: IntoIterator<Item = SessionNodeDraft>,
821    {
822        let mut builder = self.append_builder();
823        if let Some(agent_frame_id) = agent_frame_id {
824            builder = builder.with_agent_frame_id(agent_frame_id.to_string());
825        }
826        let nodes = builder.append_drafts_at(drafts, timestamp);
827        let node_ids = nodes
828            .iter()
829            .map(|node| node.node_id.clone())
830            .collect::<Vec<_>>();
831        self.append_prebuilt_nodes(nodes);
832        node_ids
833    }
834
835    pub fn user_message_count(&self) -> usize {
836        self.nodes
837            .iter()
838            .filter_map(SessionNodeRecord::message)
839            .filter(|message| matches!(message.role, MessageRole::User))
840            .count()
841    }
842
843    pub fn first_user_message(&self) -> String {
844        self.nodes
845            .iter()
846            .filter_map(SessionNodeRecord::message)
847            .find(|message| matches!(message.role, MessageRole::User))
848            .map(|message| first_message_search_text(&message))
849            .unwrap_or_default()
850    }
851
852    pub fn branch_to(&mut self, node_id: Option<String>) {
853        self.data_mut().leaf_node_id = node_id;
854    }
855
856    pub fn set_leaf_node_id(&mut self, node_id: Option<String>) {
857        self.data_mut().leaf_node_id = node_id;
858    }
859
860    pub fn push_node_record(&mut self, node: SessionNodeRecord) {
861        self.data_mut().nodes.push(node);
862    }
863
864    pub fn extend_node_records<I>(&mut self, nodes: I)
865    where
866        I: IntoIterator<Item = SessionNodeRecord>,
867    {
868        self.data_mut().nodes.extend(nodes);
869    }
870
871    /// Append nodes that extend the current active path, advancing the
872    /// leaf to the last node and updating the cache incrementally
873    /// instead of invalidating it. Use this when the appended nodes are
874    /// genuinely new descendants of the current leaf — e.g. the
875    /// turn-driver merging turn-local graph editor deltas into the base graph.
876    /// Use `extend_node_records` + `set_leaf_node_id` for store-side
877    /// replay paths that don't follow the active-path append shape.
878    pub fn extend_active_path(&mut self, nodes: Vec<SessionNodeRecord>) {
879        self.append_prebuilt_nodes(nodes);
880    }
881
882    pub fn active_path_contains(&self, node_id: &str) -> bool {
883        self.active_path_nodes()
884            .into_iter()
885            .any(|node| node.node_id == node_id)
886    }
887
888    /// If `leaf_node_id` points to a node that no longer exists in
889    /// `self.nodes` (e.g. after compaction rewrote the graph, or a
890    /// stored session referenced a node that was later purged), fall
891    /// back to the most recent message node. Returns `true` if the
892    /// leaf was repaired. Call this on load paths where an orphan
893    /// leaf would project to an empty transcript and silently drop
894    /// the user's history.
895    pub fn heal_orphaned_leaf(&mut self) -> bool {
896        if let Some(leaf) = self.leaf_node_id.as_ref()
897            && self.find_node(leaf).is_none()
898        {
899            let fallback = self
900                .nodes
901                .iter()
902                .rev()
903                .find(|node| node.message().is_some())
904                .map(|node| node.node_id.clone());
905            self.data_mut().leaf_node_id = fallback;
906            return true;
907        }
908        false
909    }
910
911    pub fn fork_current_path(&self) -> SessionGraph {
912        let path = self.active_path_nodes();
913        SessionGraph::from_nodes(
914            path.into_iter().cloned().collect(),
915            self.leaf_node_id.clone(),
916        )
917    }
918
919    pub fn find_node(&self, node_id: &str) -> Option<&SessionNodeRecord> {
920        self.cache()
921            .by_id
922            .get(node_id)
923            .and_then(|idx| self.nodes.get(*idx))
924    }
925
926    pub fn node_index(&self, node_id: &str) -> Option<usize> {
927        self.cache().by_id.get(node_id).copied()
928    }
929
930    pub fn replace_active_read_state(&mut self, messages: &[Message]) {
931        self.replace_active_read_state_scoped(None, messages);
932    }
933
934    pub fn replace_active_read_state_for_agent_frame(
935        &mut self,
936        agent_frame_id: &str,
937        messages: &[Message],
938    ) {
939        self.replace_active_read_state_scoped(Some(agent_frame_id), messages);
940    }
941
942    fn replace_active_read_state_scoped(
943        &mut self,
944        agent_frame_id: Option<&str>,
945        messages: &[Message],
946    ) {
947        let current_nodes = self.active_path_nodes();
948        let existing_ids = self
949            .nodes
950            .iter()
951            .map(|node| node.node_id.clone())
952            .collect::<HashSet<_>>();
953        let replacement = build_active_read_replacement(
954            current_nodes,
955            &existing_ids,
956            agent_frame_id,
957            messages,
958            crate::SystemClock.timestamp_rfc3339(),
959        );
960        let data = self.data_mut();
961        data.leaf_node_id = replacement.leaf_node_id;
962        data.nodes.extend(replacement.new_tail_nodes);
963    }
964
965    pub fn from_active_read_state(messages: &[Message]) -> Self {
966        let mut graph = Self::default();
967        graph.replace_active_read_state(messages);
968        graph
969    }
970
971    pub fn message_tree(&self) -> Vec<SessionMessageTreeNode> {
972        let active_message_ids = self
973            .active_path_nodes()
974            .into_iter()
975            .filter_map(|node| node.message().map(|message| message.id.clone()))
976            .collect::<HashSet<_>>();
977
978        let message_nodes = self
979            .nodes
980            .iter()
981            .filter_map(|node| {
982                let message = node.message()?.clone();
983                let parent_message_node_id =
984                    self.nearest_message_ancestor(node.parent_node_id.as_deref());
985                Some(SessionMessageTreeNode {
986                    node_id: node.node_id.clone(),
987                    parent_message_node_id,
988                    message,
989                    timestamp: node.timestamp.clone(),
990                    children: Vec::new(),
991                    active: active_message_ids.contains(&node.node_id),
992                })
993            })
994            .collect::<Vec<_>>();
995
996        build_tree(message_nodes)
997    }
998
999    fn nearest_message_ancestor(&self, node_id: Option<&str>) -> Option<String> {
1000        let by_id = self
1001            .nodes
1002            .iter()
1003            .map(|node| (node.node_id.as_str(), node))
1004            .collect::<HashMap<_, _>>();
1005        let mut current = node_id.and_then(|id| by_id.get(id).copied());
1006        while let Some(node) = current {
1007            if node.message().is_some() {
1008                return Some(node.node_id.clone());
1009            }
1010            current = node
1011                .parent_node_id
1012                .as_deref()
1013                .and_then(|parent| by_id.get(parent).copied());
1014        }
1015        None
1016    }
1017}
1018
1019fn build_tree(mut nodes: Vec<SessionMessageTreeNode>) -> Vec<SessionMessageTreeNode> {
1020    let mut children_by_parent = HashMap::<Option<String>, Vec<SessionMessageTreeNode>>::new();
1021    for node in nodes.drain(..) {
1022        children_by_parent
1023            .entry(node.parent_message_node_id.clone())
1024            .or_default()
1025            .push(node);
1026    }
1027    let mut roots = build_tree_children(None, &mut children_by_parent);
1028    sort_tree(&mut roots);
1029    roots
1030}
1031
1032fn sort_tree(nodes: &mut [SessionMessageTreeNode]) {
1033    nodes.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
1034    for node in nodes {
1035        sort_tree(&mut node.children);
1036    }
1037}
1038
1039fn build_tree_children(
1040    parent_id: Option<String>,
1041    children_by_parent: &mut HashMap<Option<String>, Vec<SessionMessageTreeNode>>,
1042) -> Vec<SessionMessageTreeNode> {
1043    let mut children = children_by_parent.remove(&parent_id).unwrap_or_default();
1044    for child in &mut children {
1045        child.children = build_tree_children(Some(child.node_id.clone()), children_by_parent);
1046    }
1047    children
1048}
1049
1050fn node_belongs_to_agent_frame(
1051    node: &SessionNodeRecord,
1052    frame_id: &str,
1053    include_unscoped: bool,
1054) -> bool {
1055    match node.agent_frame_id.as_deref() {
1056        Some(node_frame_id) => node_frame_id == frame_id,
1057        None => include_unscoped,
1058    }
1059}
1060
1061pub(crate) fn build_active_read_replacement<'a>(
1062    current_nodes: impl IntoIterator<Item = &'a SessionNodeRecord>,
1063    existing_node_ids: &HashSet<String>,
1064    agent_frame_id: Option<&str>,
1065    messages: &[Message],
1066    timestamp: String,
1067) -> ActiveReadReplacement {
1068    let target = messages
1069        .iter()
1070        .filter(|message| !message.is_transient())
1071        .collect::<Vec<_>>();
1072
1073    let mut active_events = Vec::new();
1074    let mut active_messages = Vec::new();
1075    let mut active_message_ids = HashSet::new();
1076    let mut seen_active_read_keys = HashSet::new();
1077    let mut target_idx = 0usize;
1078    let mut leaf_node_id = None;
1079    for node in current_nodes {
1080        if node
1081            .message()
1082            .map(|message| message.is_transient())
1083            .unwrap_or(false)
1084        {
1085            continue;
1086        }
1087        if let Some(key) = recognized_active_read_key(node) {
1088            if !seen_active_read_keys.insert(key.clone()) {
1089                continue;
1090            }
1091            let Some(target_item) = target.get(target_idx) else {
1092                break;
1093            };
1094            if key != format!("message:{}", target_item.id) {
1095                break;
1096            }
1097            push_active_read_node(
1098                node,
1099                &mut active_events,
1100                &mut active_messages,
1101                &mut active_message_ids,
1102            );
1103            leaf_node_id = Some(node.node_id.clone());
1104            target_idx += 1;
1105        } else {
1106            push_active_read_node(
1107                node,
1108                &mut active_events,
1109                &mut active_messages,
1110                &mut active_message_ids,
1111            );
1112            leaf_node_id = Some(node.node_id.clone());
1113        }
1114    }
1115
1116    let mut new_node_ids = HashSet::new();
1117    let mut new_tail_nodes = Vec::new();
1118
1119    for message in target.into_iter().skip(target_idx) {
1120        let parent_node_id = leaf_node_id.clone();
1121        let node_id =
1122            unique_message_node_id_for_replacement(&message.id, existing_node_ids, &new_node_ids);
1123        let node = SessionNodeRecord {
1124            node_id,
1125            parent_node_id,
1126            caused_by: causal_ref_from_message_origin(&message.origin),
1127            agent_frame_id: agent_frame_id.map(ToOwned::to_owned),
1128            timestamp: timestamp.clone(),
1129            payload: SessionNodePayload::Event {
1130                event: SessionEventRecord::Conversation(ConversationRecord::from_message(
1131                    message.clone(),
1132                )),
1133            },
1134        };
1135        new_node_ids.insert(node.node_id.clone());
1136        leaf_node_id = Some(node.node_id.clone());
1137        push_active_read_node(
1138            &node,
1139            &mut active_events,
1140            &mut active_messages,
1141            &mut active_message_ids,
1142        );
1143        new_tail_nodes.push(node);
1144    }
1145
1146    ActiveReadReplacement {
1147        leaf_node_id,
1148        new_tail_nodes,
1149        active_events,
1150        active_messages,
1151    }
1152}
1153
1154fn push_active_read_node(
1155    node: &SessionNodeRecord,
1156    active_events: &mut Vec<SessionEventRecord>,
1157    active_messages: &mut Vec<Message>,
1158    active_message_ids: &mut HashSet<String>,
1159) {
1160    if let Some(event) = node.event() {
1161        active_events.push(event.clone());
1162    }
1163    if let Some(message) = node.message()
1164        && !message.is_transient()
1165        && active_message_ids.insert(message.id.clone())
1166    {
1167        active_messages.push(message);
1168    }
1169}
1170
1171fn recognized_active_read_key(node: &SessionNodeRecord) -> Option<String> {
1172    match &node.payload {
1173        SessionNodePayload::Event { event } => match event {
1174            SessionEventRecord::Conversation(record) => Some(format!("message:{}", record.id)),
1175            _ => None,
1176        },
1177        SessionNodePayload::Plugin { .. } => None,
1178    }
1179}
1180
1181fn causal_ref_from_message_origin(
1182    origin: &Option<crate::MessageOrigin>,
1183) -> Option<crate::CausalRef> {
1184    let Some(crate::MessageOrigin::Process {
1185        process_id,
1186        sequence,
1187        ..
1188    }) = origin
1189    else {
1190        return None;
1191    };
1192    Some(crate::CausalRef::ProcessEvent {
1193        process_id: process_id.clone(),
1194        sequence: *sequence,
1195    })
1196}
1197
1198fn fresh_semantic_node_id(prefix: &str, existing_ids: &HashSet<String>) -> String {
1199    loop {
1200        let candidate = format!("{prefix}:{}", uuid::Uuid::new_v4().simple());
1201        if !existing_ids.contains(&candidate) {
1202            return candidate;
1203        }
1204    }
1205}
1206
1207fn unique_message_node_id(message_id: &str, existing_ids: &HashSet<String>) -> String {
1208    if !existing_ids.contains(message_id) {
1209        return message_id.to_string();
1210    }
1211    let base = format!("message:{message_id}");
1212    if !existing_ids.contains(&base) {
1213        return base;
1214    }
1215    for suffix in 2.. {
1216        let candidate = format!("{base}:{suffix}");
1217        if !existing_ids.contains(&candidate) {
1218            return candidate;
1219        }
1220    }
1221    unreachable!("message node id space exhausted")
1222}
1223
1224fn unique_message_node_id_for_replacement(
1225    message_id: &str,
1226    existing_ids: &HashSet<String>,
1227    new_ids: &HashSet<String>,
1228) -> String {
1229    if !existing_ids.contains(message_id) && !new_ids.contains(message_id) {
1230        return message_id.to_string();
1231    }
1232    let base = format!("message:{message_id}");
1233    if !existing_ids.contains(&base) && !new_ids.contains(&base) {
1234        return base;
1235    }
1236    for suffix in 2.. {
1237        let candidate = format!("{base}:{suffix}");
1238        if !existing_ids.contains(&candidate) && !new_ids.contains(&candidate) {
1239            return candidate;
1240        }
1241    }
1242    unreachable!("message node id space exhausted")
1243}
1244
1245fn fresh_node_id(prefix: &str) -> String {
1246    format!("{prefix}{}", uuid::Uuid::new_v4().simple())
1247}
1248
1249fn first_message_search_text(message: &Message) -> String {
1250    message
1251        .parts
1252        .iter()
1253        .filter_map(|part| match part.kind {
1254            crate::PartKind::ToolCall | crate::PartKind::ToolResult => None,
1255            crate::PartKind::Image => Some("[Image attached]".to_string()),
1256            _ => (!part.content.trim().is_empty()).then(|| part.content.clone()),
1257        })
1258        .collect::<Vec<_>>()
1259        .join("\n\n")
1260        .trim()
1261        .to_string()
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266    use super::*;
1267    use crate::{Part, PartKind, PruneState, shared_parts};
1268
1269    fn text_message(id: &str, role: MessageRole, content: &str) -> Message {
1270        Message {
1271            id: id.to_string(),
1272            role,
1273            parts: shared_parts(vec![Part {
1274                id: format!("{id}.p0"),
1275                kind: PartKind::Text,
1276                content: content.to_string(),
1277                attachment: None,
1278                tool_call_id: None,
1279                tool_name: None,
1280                tool_replay: None,
1281                prune_state: PruneState::Intact,
1282                reasoning_meta: None,
1283                response_meta: None,
1284            }]),
1285            origin: None,
1286        }
1287    }
1288
1289    fn protocol_event() -> ProtocolEvent {
1290        ProtocolEvent::typed("test_protocol", serde_json::json!({"step": "started"}))
1291            .expect("protocol event serializes")
1292    }
1293
1294    #[test]
1295    fn typed_append_node_ids_use_semantic_prefixes() {
1296        let mut graph = SessionGraph::default();
1297
1298        let message_id = graph.append_message(text_message("m1", MessageRole::User, "hello"));
1299        let protocol_id = graph.append_protocol_event(protocol_event());
1300        let plugin_id = graph.append_plugin("example", serde_json::json!({"ok": true}));
1301
1302        assert_eq!(message_id, "m1");
1303        assert!(protocol_id.starts_with("protocol:"));
1304        assert!(plugin_id.starts_with("plugin:"));
1305    }
1306
1307    #[test]
1308    fn active_read_replacement_persists_messages_only() {
1309        let message = text_message("m1", MessageRole::User, "hello");
1310        let graph = SessionGraph::from_active_read_state(&[message]);
1311
1312        assert_eq!(graph.nodes.len(), 1);
1313        assert!(matches!(
1314            graph.nodes[0].event(),
1315            Some(SessionEventRecord::Conversation(_))
1316        ));
1317    }
1318
1319    #[test]
1320    fn graph_writers_do_not_put_active_read_events_under_plugin_ids() {
1321        let mut graph = SessionGraph::default();
1322        graph.append_message(text_message("m1", MessageRole::User, "hello"));
1323        graph.append_protocol_event(protocol_event());
1324        graph.append_plugin("example", serde_json::json!({"ok": true}));
1325
1326        for node in &graph.nodes {
1327            match node.event() {
1328                Some(SessionEventRecord::Conversation(_)) => {
1329                    assert!(!node.node_id.starts_with("plugin:"), "{:?}", node);
1330                }
1331                Some(SessionEventRecord::Protocol(_)) => {
1332                    assert!(node.node_id.starts_with("protocol:"), "{:?}", node);
1333                }
1334                None => {
1335                    assert!(node.node_id.starts_with("plugin:"), "{:?}", node);
1336                }
1337            }
1338        }
1339    }
1340}