Skip to main content

lash_core/
session_graph.rs

1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3use std::sync::{Arc, OnceLock};
4
5use chrono::Utc;
6use sha2::Digest;
7
8use crate::session_model::{ConversationRecord, SessionEventRecord, ToolEvent};
9use crate::{
10    BaseRenderCache, ExecutionMode, Message, MessageRole, PromptUsage, StandardContextApproach,
11    TokenUsage, ToolCallRecord,
12};
13
14#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
15pub struct SessionGraphData {
16    #[serde(default)]
17    pub nodes: Vec<SessionNodeRecord>,
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub leaf_node_id: Option<String>,
20}
21
22#[derive(Debug)]
23pub struct SessionGraph {
24    inner: Arc<SessionGraphData>,
25    cache: Arc<OnceLock<SessionGraphCache>>,
26}
27
28impl Default for SessionGraph {
29    fn default() -> Self {
30        Self {
31            inner: Arc::new(SessionGraphData::default()),
32            cache: Arc::new(OnceLock::new()),
33        }
34    }
35}
36
37impl Clone for SessionGraph {
38    fn clone(&self) -> Self {
39        Self {
40            inner: Arc::clone(&self.inner),
41            cache: Arc::clone(&self.cache),
42        }
43    }
44}
45
46impl serde::Serialize for SessionGraph {
47    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
48    where
49        S: serde::Serializer,
50    {
51        self.inner.serialize(serializer)
52    }
53}
54
55impl<'de> serde::Deserialize<'de> for SessionGraph {
56    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
57    where
58        D: serde::Deserializer<'de>,
59    {
60        let inner = SessionGraphData::deserialize(deserializer)?;
61        Ok(Self {
62            inner: Arc::new(inner),
63            cache: Arc::new(OnceLock::new()),
64        })
65    }
66}
67
68impl Deref for SessionGraph {
69    type Target = SessionGraphData;
70
71    fn deref(&self) -> &Self::Target {
72        self.inner.as_ref()
73    }
74}
75
76#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
77pub struct SessionNodeRecord {
78    pub node_id: String,
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub parent_node_id: Option<String>,
81    pub timestamp: String,
82    #[serde(flatten)]
83    pub payload: SessionNodePayload,
84}
85
86#[derive(Clone, Debug, Default, PartialEq)]
87pub struct SharedJsonValue(pub Arc<serde_json::Value>);
88
89impl SharedJsonValue {
90    pub fn new(value: serde_json::Value) -> Self {
91        Self(Arc::new(value))
92    }
93
94    pub fn to_owned(&self) -> serde_json::Value {
95        self.0.as_ref().clone()
96    }
97}
98
99impl AsRef<serde_json::Value> for SharedJsonValue {
100    fn as_ref(&self) -> &serde_json::Value {
101        self.0.as_ref()
102    }
103}
104
105impl serde::Serialize for SharedJsonValue {
106    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107    where
108        S: serde::Serializer,
109    {
110        self.0.serialize(serializer)
111    }
112}
113
114impl<'de> serde::Deserialize<'de> for SharedJsonValue {
115    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116    where
117        D: serde::Deserializer<'de>,
118    {
119        let value = serde_json::Value::deserialize(deserializer)?;
120        Ok(Self::new(value))
121    }
122}
123
124#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
125#[serde(tag = "kind", rename_all = "snake_case")]
126#[allow(clippy::large_enum_variant)]
127pub enum SessionNodePayload {
128    Event {
129        event: SessionEventRecord,
130    },
131    Plugin {
132        plugin_type: String,
133        body: SharedJsonValue,
134    },
135}
136
137#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
138pub struct PersistedSessionConfig {
139    pub provider_id: String,
140    pub configured_model: String,
141    pub context_window: u64,
142    pub execution_mode: ExecutionMode,
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub standard_context_approach: Option<StandardContextApproach>,
145    #[serde(default, skip_serializing_if = "Option::is_none")]
146    pub model_variant: Option<String>,
147}
148
149#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
150pub struct PersistedTurnState {
151    pub turn_index: usize,
152    #[serde(default)]
153    pub token_usage: TokenUsage,
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub last_prompt_usage: Option<PromptUsage>,
156    #[serde(default)]
157    pub mode_turn_options: crate::ModeTurnOptions,
158}
159
160#[derive(Clone, Debug)]
161pub struct SessionMessageTreeNode {
162    pub node_id: String,
163    pub parent_message_node_id: Option<String>,
164    pub message: Message,
165    pub timestamp: String,
166    pub children: Vec<SessionMessageTreeNode>,
167    pub active: bool,
168}
169
170#[derive(Clone)]
171enum ActiveReadItem<'a> {
172    Message(&'a Message),
173    ToolCall {
174        stable_key: String,
175        record: &'a ToolCallRecord,
176    },
177}
178
179#[derive(Debug)]
180pub(crate) struct ActiveReadReplacement {
181    pub(crate) leaf_node_id: Option<String>,
182    pub(crate) new_tail_nodes: Vec<SessionNodeRecord>,
183    pub(crate) active_events: Vec<SessionEventRecord>,
184    pub(crate) active_messages: Vec<Message>,
185    pub(crate) active_tool_calls: Vec<ToolCallRecord>,
186}
187
188#[derive(Clone, Debug)]
189pub(crate) struct SessionReadModel {
190    pub(crate) active_events: Arc<Vec<SessionEventRecord>>,
191    pub(crate) messages: Arc<Vec<Message>>,
192    pub(crate) tool_calls: Arc<Vec<ToolCallRecord>>,
193    pub(crate) prompt_render_cache: Arc<BaseRenderCache>,
194}
195
196#[derive(Clone, Debug)]
197pub(crate) struct SessionGraphAppendBuilder {
198    existing_ids: HashSet<String>,
199    leaf_node_id: Option<String>,
200}
201
202impl SessionGraphAppendBuilder {
203    pub(crate) fn leaf_node_id(&self) -> Option<&String> {
204        self.leaf_node_id.as_ref()
205    }
206
207    pub(crate) fn set_leaf_node_id(&mut self, leaf_node_id: Option<String>) {
208        self.leaf_node_id = leaf_node_id;
209    }
210
211    pub(crate) fn register_existing_node_ids<'a>(
212        &mut self,
213        node_ids: impl IntoIterator<Item = &'a str>,
214    ) {
215        self.existing_ids
216            .extend(node_ids.into_iter().map(ToOwned::to_owned));
217    }
218
219    pub(crate) fn existing_node_ids(&self) -> &HashSet<String> {
220        &self.existing_ids
221    }
222
223    pub(crate) fn append_messages<I>(&mut self, messages: I) -> Vec<SessionNodeRecord>
224    where
225        I: IntoIterator<Item = Message>,
226    {
227        let mut nodes = Vec::new();
228        for mut message in messages {
229            if message.id.is_empty() {
230                message.id = fresh_node_id("m");
231            }
232            let node_id = unique_message_node_id(&message.id, &self.existing_ids);
233            self.existing_ids.insert(node_id.clone());
234            let parent_node_id = self.leaf_node_id.clone();
235            self.leaf_node_id = Some(node_id.clone());
236            nodes.push(SessionNodeRecord {
237                node_id,
238                parent_node_id,
239                timestamp: Utc::now().to_rfc3339(),
240                payload: SessionNodePayload::Event {
241                    event: SessionEventRecord::Conversation(ConversationRecord::from_message(
242                        message,
243                    )),
244                },
245            });
246        }
247        nodes
248    }
249
250    pub(crate) fn append_tool_call_records<I>(&mut self, records: I) -> Vec<SessionNodeRecord>
251    where
252        I: IntoIterator<Item = ToolCallRecord>,
253    {
254        let mut nodes = Vec::new();
255        for record in records {
256            let stable_key = stable_tool_call_key(&record);
257            let node_id = unique_plugin_node_id(&stable_key, &self.existing_ids);
258            self.existing_ids.insert(node_id.clone());
259            let parent_node_id = self.leaf_node_id.clone();
260            self.leaf_node_id = Some(node_id.clone());
261            nodes.push(SessionNodeRecord {
262                node_id,
263                parent_node_id,
264                timestamp: Utc::now().to_rfc3339(),
265                payload: SessionNodePayload::Event {
266                    event: SessionEventRecord::Tool(ToolEvent::Invocation { stable_key, record }),
267                },
268            });
269        }
270        nodes
271    }
272
273    pub(crate) fn append_event_record(
274        &mut self,
275        event: SessionEventRecord,
276    ) -> Vec<SessionNodeRecord> {
277        let node_id = unique_plugin_node_id(&fresh_node_id("e"), &self.existing_ids);
278        self.existing_ids.insert(node_id.clone());
279        let parent_node_id = self.leaf_node_id.clone();
280        self.leaf_node_id = Some(node_id.clone());
281        vec![SessionNodeRecord {
282            node_id,
283            parent_node_id,
284            timestamp: Utc::now().to_rfc3339(),
285            payload: SessionNodePayload::Event { event },
286        }]
287    }
288}
289
290#[derive(Debug, Clone)]
291struct SessionGraphCache {
292    by_id: HashMap<String, usize>,
293    active_path_indices: Vec<usize>,
294    active_events: Arc<Vec<SessionEventRecord>>,
295    active_messages: Arc<Vec<Message>>,
296    /// Index from `Message::id` to its position in `active_messages`,
297    /// kept in sync with the vec so dedup on append is O(1) instead of an
298    /// O(n) linear scan (which made long sessions quadratic in message
299    /// count).
300    active_message_ids: HashMap<String, usize>,
301    active_tool_calls: Arc<Vec<ToolCallRecord>>,
302    /// Memoized render of `active_messages`. Shared with every
303    /// `MessageSequence` built off this read model so the chat projector's
304    /// per-iteration `render_prompt` walk only happens once per turn.
305    /// Replaced (not invalidated in-place) whenever `active_messages`
306    /// changes — the `Arc` identity tracks the cache's validity.
307    prompt_render_cache: Arc<BaseRenderCache>,
308}
309
310impl SessionGraphCache {
311    fn build(graph: &SessionGraph) -> Self {
312        let by_id = graph
313            .nodes
314            .iter()
315            .enumerate()
316            .map(|(idx, node)| (node.node_id.clone(), idx))
317            .collect::<HashMap<_, _>>();
318        let mut active_path_indices = Vec::new();
319        let mut current = graph
320            .leaf_node_id
321            .as_ref()
322            .and_then(|node_id| by_id.get(node_id).copied());
323        while let Some(idx) = current {
324            active_path_indices.push(idx);
325            current = graph.nodes[idx]
326                .parent_node_id
327                .as_ref()
328                .and_then(|node_id| by_id.get(node_id).copied());
329        }
330        active_path_indices.reverse();
331
332        let mut cache = Self {
333            by_id,
334            active_path_indices,
335            active_events: Arc::new(Vec::new()),
336            active_messages: Arc::new(Vec::new()),
337            active_message_ids: HashMap::new(),
338            active_tool_calls: Arc::new(Vec::new()),
339            prompt_render_cache: Arc::new(BaseRenderCache::new()),
340        };
341        cache.rebuild_read_model(graph);
342        cache
343    }
344
345    fn rebuild_read_model(&mut self, graph: &SessionGraph) {
346        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
347        let mut active_message_ids: HashMap<String, usize> =
348            HashMap::with_capacity(self.active_path_indices.len());
349        let mut active_tool_calls = Vec::with_capacity(self.active_path_indices.len());
350        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
351        for idx in &self.active_path_indices {
352            let node = &graph.nodes[*idx];
353            if let Some(event) = node.event() {
354                active_events.push(event.clone());
355            }
356            if let Some(message) = node.message() {
357                if !message.is_transient() && !active_message_ids.contains_key(&message.id) {
358                    active_message_ids.insert(message.id.clone(), active_messages.len());
359                    active_messages.push(message);
360                }
361                continue;
362            }
363            if let Some(event) = node.event()
364                && let SessionEventRecord::Tool(ToolEvent::Invocation { record, .. }) = event
365            {
366                active_tool_calls.push(record.clone());
367                continue;
368            }
369        }
370        self.active_messages = Arc::new(active_messages);
371        self.active_message_ids = active_message_ids;
372        self.active_events = Arc::new(active_events);
373        self.active_tool_calls = Arc::new(active_tool_calls);
374        self.prompt_render_cache = Arc::new(BaseRenderCache::new());
375    }
376
377    fn append_node(
378        &mut self,
379        node_index: usize,
380        node: &SessionNodeRecord,
381        previous_leaf_node_id: Option<&str>,
382    ) {
383        self.by_id.insert(node.node_id.clone(), node_index);
384        let parent_matches_leaf = node.parent_node_id.as_deref() == previous_leaf_node_id;
385        if !parent_matches_leaf {
386            return;
387        }
388        self.active_path_indices.push(node_index);
389        if let Some(event) = node.event() {
390            Arc::make_mut(&mut self.active_events).push(event.clone());
391        }
392        if let Some(message) = node.message() {
393            if !message.is_transient() && !self.active_message_ids.contains_key(&message.id) {
394                let messages = Arc::make_mut(&mut self.active_messages);
395                self.active_message_ids
396                    .insert(message.id.clone(), messages.len());
397                messages.push(message);
398                self.prompt_render_cache = Arc::new(BaseRenderCache::new());
399            }
400            return;
401        }
402        if let Some(event) = node.event()
403            && let SessionEventRecord::Tool(ToolEvent::Invocation { record, .. }) = event
404        {
405            Arc::make_mut(&mut self.active_tool_calls).push(record.clone());
406        }
407    }
408
409    fn reserve_append_capacity(
410        &mut self,
411        additional_nodes: usize,
412        additional_messages: usize,
413        additional_tool_calls: usize,
414    ) {
415        self.by_id.reserve(additional_nodes);
416        self.active_path_indices.reserve(additional_nodes);
417        if additional_messages > 0 {
418            Arc::make_mut(&mut self.active_messages).reserve(additional_messages);
419        }
420        if additional_tool_calls > 0 {
421            Arc::make_mut(&mut self.active_tool_calls).reserve(additional_tool_calls);
422        }
423    }
424}
425
426impl SessionNodeRecord {
427    pub fn event(&self) -> Option<&SessionEventRecord> {
428        match &self.payload {
429            SessionNodePayload::Event { event } => Some(event),
430            SessionNodePayload::Plugin { .. } => None,
431        }
432    }
433
434    pub fn message(&self) -> Option<Message> {
435        match self.event()? {
436            SessionEventRecord::Conversation(record) => Some(record.to_message()),
437            _ => None,
438        }
439    }
440
441    pub fn plugin(&self) -> Option<(&str, &serde_json::Value)> {
442        match &self.payload {
443            SessionNodePayload::Event { .. } => None,
444            SessionNodePayload::Plugin { plugin_type, body } => {
445                Some((plugin_type.as_str(), body.as_ref()))
446            }
447        }
448    }
449
450    pub fn plugin_body<T>(&self) -> Option<T>
451    where
452        T: for<'de> serde::Deserialize<'de>,
453    {
454        let (_, body) = self.plugin()?;
455        T::deserialize(body).ok()
456    }
457}
458
459impl SessionGraph {
460    pub fn append_active_read_delta(
461        &mut self,
462        messages: &[Message],
463        tool_calls: &[ToolCallRecord],
464    ) {
465        let appendable_messages = {
466            let read_model = self.read_model();
467            let mut seen_message_ids = read_model
468                .messages
469                .iter()
470                .map(|message| message.id.as_str())
471                .collect::<HashSet<_>>();
472            messages
473                .iter()
474                .filter(|message| {
475                    !message.is_transient() && seen_message_ids.insert(message.id.as_str())
476                })
477                .cloned()
478                .collect::<Vec<_>>()
479        };
480        let read_model = self.read_model();
481        let mut seen_tool_call_keys = read_model
482            .tool_calls
483            .iter()
484            .map(|record| tool_call_active_read_key(&stable_tool_call_key(record), record))
485            .collect::<HashSet<_>>();
486        let appendable_tool_calls = tool_calls
487            .iter()
488            .filter_map(|record| {
489                let stable_key = stable_tool_call_key(record);
490                let active_read_key = tool_call_active_read_key(&stable_key, record);
491                seen_tool_call_keys
492                    .insert(active_read_key)
493                    .then_some((stable_key, record))
494            })
495            .collect::<Vec<_>>();
496
497        self.reserve_append_capacity(
498            appendable_messages.len() + appendable_tool_calls.len(),
499            appendable_messages.len(),
500            appendable_tool_calls.len(),
501        );
502        self.append_message_batch(appendable_messages);
503
504        for (stable_key, record) in appendable_tool_calls {
505            self.append_event(SessionEventRecord::Tool(ToolEvent::Invocation {
506                stable_key,
507                record: record.clone(),
508            }));
509        }
510    }
511
512    pub(crate) fn append_active_conversation_messages(&mut self, messages: &[Message]) {
513        let appendable_messages = messages
514            .iter()
515            .filter(|message| !message.is_transient())
516            .cloned()
517            .collect::<Vec<_>>();
518        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len(), 0);
519        self.append_message_batch(appendable_messages);
520    }
521
522    pub fn from_nodes(nodes: Vec<SessionNodeRecord>, leaf_node_id: Option<String>) -> Self {
523        Self {
524            inner: Arc::new(SessionGraphData {
525                nodes,
526                leaf_node_id,
527            }),
528            cache: Arc::new(OnceLock::new()),
529        }
530    }
531
532    pub(crate) fn append_builder(&self) -> SessionGraphAppendBuilder {
533        SessionGraphAppendBuilder {
534            existing_ids: self.nodes.iter().map(|node| node.node_id.clone()).collect(),
535            leaf_node_id: self.leaf_node_id.clone(),
536        }
537    }
538
539    fn invalidate_cache(&mut self) {
540        self.cache = Arc::new(OnceLock::new());
541    }
542
543    fn data_mut(&mut self) -> &mut SessionGraphData {
544        self.invalidate_cache();
545        Arc::make_mut(&mut self.inner)
546    }
547
548    fn reserve_append_capacity(
549        &mut self,
550        additional_nodes: usize,
551        additional_messages: usize,
552        additional_tool_calls: usize,
553    ) {
554        if additional_nodes == 0 {
555            return;
556        }
557        self.detach_initialized_cache_for_append();
558        Arc::make_mut(&mut self.inner)
559            .nodes
560            .reserve(additional_nodes);
561        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
562            && let Some(cache) = cache_lock.get_mut()
563        {
564            cache.reserve_append_capacity(
565                additional_nodes,
566                additional_messages,
567                additional_tool_calls,
568            );
569        }
570    }
571
572    fn detach_initialized_cache_for_append(&mut self) {
573        if Arc::get_mut(&mut self.cache).is_some() {
574            return;
575        }
576        let Some(cache) = self.cache.get().cloned() else {
577            self.invalidate_cache();
578            return;
579        };
580        let lock = OnceLock::new();
581        let _ = lock.set(cache);
582        self.cache = Arc::new(lock);
583    }
584
585    fn cache(&self) -> &SessionGraphCache {
586        self.cache.get_or_init(|| SessionGraphCache::build(self))
587    }
588
589    fn append_message_batch(&mut self, messages: Vec<Message>) {
590        if messages.is_empty() {
591            return;
592        }
593
594        let messages = messages.into_iter();
595        let mut existing_ids = self
596            .nodes
597            .iter()
598            .map(|node| node.node_id.clone())
599            .collect::<HashSet<_>>();
600        let mut parent_node_id = self.leaf_node_id.clone();
601        let mut nodes = Vec::with_capacity(messages.len());
602        for mut message in messages {
603            if message.id.is_empty() {
604                message.id = fresh_node_id("m");
605            }
606            let node_id = unique_message_node_id(&message.id, &existing_ids);
607            existing_ids.insert(node_id.clone());
608            nodes.push(SessionNodeRecord {
609                node_id: node_id.clone(),
610                parent_node_id,
611                timestamp: Utc::now().to_rfc3339(),
612                payload: SessionNodePayload::Event {
613                    event: SessionEventRecord::Conversation(ConversationRecord::from_message(
614                        message,
615                    )),
616                },
617            });
618            parent_node_id = Some(node_id);
619        }
620
621        self.append_prebuilt_nodes(nodes);
622    }
623
624    fn append_prebuilt_nodes(&mut self, nodes: Vec<SessionNodeRecord>) {
625        if nodes.is_empty() {
626            return;
627        }
628
629        self.detach_initialized_cache_for_append();
630        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
631            && let Some(cache) = cache_lock.get_mut()
632        {
633            let data = Arc::make_mut(&mut self.inner);
634            for node in nodes {
635                let previous_leaf = data.leaf_node_id.clone();
636                let node_id = node.node_id.clone();
637                data.nodes.push(node);
638                cache.append_node(
639                    data.nodes.len() - 1,
640                    data.nodes.last().expect("just appended graph node"),
641                    previous_leaf.as_deref(),
642                );
643                data.leaf_node_id = Some(node_id);
644            }
645            return;
646        }
647
648        let data = self.data_mut();
649        for node in nodes {
650            data.leaf_node_id = Some(node.node_id.clone());
651            data.nodes.push(node);
652        }
653    }
654
655    pub fn append_message(&mut self, mut message: Message) -> String {
656        if message.id.is_empty() {
657            message.id = fresh_node_id("m");
658        }
659        let existing_ids = self
660            .nodes
661            .iter()
662            .map(|node| node.node_id.clone())
663            .collect::<HashSet<_>>();
664        let node_id = unique_message_node_id(&message.id, &existing_ids);
665        let previous_leaf = self.leaf_node_id.clone();
666        let parent_node_id = previous_leaf.clone();
667        let node = SessionNodeRecord {
668            node_id: node_id.clone(),
669            parent_node_id,
670            timestamp: Utc::now().to_rfc3339(),
671            payload: SessionNodePayload::Event {
672                event: SessionEventRecord::Conversation(ConversationRecord::from_message(message)),
673            },
674        };
675        self.detach_initialized_cache_for_append();
676        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
677            && let Some(cache) = cache_lock.get_mut()
678        {
679            let data = Arc::make_mut(&mut self.inner);
680            data.nodes.push(node);
681            cache.append_node(
682                data.nodes.len() - 1,
683                data.nodes.last().expect("just appended graph node"),
684                previous_leaf.as_deref(),
685            );
686            data.leaf_node_id = Some(node_id.clone());
687            return node_id;
688        }
689        let data = self.data_mut();
690        data.nodes.push(node);
691        data.leaf_node_id = Some(node_id.clone());
692        node_id
693    }
694
695    pub fn append_plugin(
696        &mut self,
697        plugin_type: impl Into<String>,
698        body: serde_json::Value,
699    ) -> String {
700        let node_id = fresh_node_id("x");
701        let previous_leaf = self.leaf_node_id.clone();
702        let parent_node_id = previous_leaf.clone();
703        let node = SessionNodeRecord {
704            node_id: node_id.clone(),
705            parent_node_id,
706            timestamp: Utc::now().to_rfc3339(),
707            payload: SessionNodePayload::Plugin {
708                plugin_type: plugin_type.into(),
709                body: SharedJsonValue::new(body),
710            },
711        };
712        self.detach_initialized_cache_for_append();
713        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
714            && let Some(cache) = cache_lock.get_mut()
715        {
716            let data = Arc::make_mut(&mut self.inner);
717            data.nodes.push(node);
718            cache.append_node(
719                data.nodes.len() - 1,
720                data.nodes.last().expect("just appended graph node"),
721                previous_leaf.as_deref(),
722            );
723            data.leaf_node_id = Some(node_id.clone());
724            return node_id;
725        }
726        let data = self.data_mut();
727        data.nodes.push(node);
728        data.leaf_node_id = Some(node_id.clone());
729        node_id
730    }
731
732    pub fn active_path_nodes(&self) -> Vec<&SessionNodeRecord> {
733        self.cache()
734            .active_path_indices
735            .iter()
736            .map(|idx| &self.nodes[*idx])
737            .collect()
738    }
739
740    pub(crate) fn read_model(&self) -> SessionReadModel {
741        let cache = self.cache();
742        SessionReadModel {
743            active_events: Arc::clone(&cache.active_events),
744            messages: Arc::clone(&cache.active_messages),
745            tool_calls: Arc::clone(&cache.active_tool_calls),
746            prompt_render_cache: Arc::clone(&cache.prompt_render_cache),
747        }
748    }
749
750    pub fn replace_active_tool_calls(&mut self, tool_calls: &[ToolCallRecord]) {
751        let messages = Arc::clone(&self.cache().active_messages);
752        self.replace_active_read_state(messages.as_slice(), tool_calls);
753    }
754
755    pub fn append_event(&mut self, event: SessionEventRecord) -> String {
756        let node_id = fresh_node_id("e");
757        let previous_leaf = self.leaf_node_id.clone();
758        let node = SessionNodeRecord {
759            node_id: node_id.clone(),
760            parent_node_id: previous_leaf.clone(),
761            timestamp: Utc::now().to_rfc3339(),
762            payload: SessionNodePayload::Event { event },
763        };
764        self.detach_initialized_cache_for_append();
765        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
766            && let Some(cache) = cache_lock.get_mut()
767        {
768            let data = Arc::make_mut(&mut self.inner);
769            data.nodes.push(node);
770            cache.append_node(
771                data.nodes.len() - 1,
772                data.nodes.last().expect("just appended graph node"),
773                previous_leaf.as_deref(),
774            );
775            data.leaf_node_id = Some(node_id.clone());
776            return node_id;
777        }
778        let data = self.data_mut();
779        data.nodes.push(node);
780        data.leaf_node_id = Some(node_id.clone());
781        node_id
782    }
783
784    pub fn append_events<I>(&mut self, events: I) -> Vec<String>
785    where
786        I: IntoIterator<Item = SessionEventRecord>,
787    {
788        events
789            .into_iter()
790            .map(|event| self.append_event(event))
791            .collect()
792    }
793
794    pub fn user_message_count(&self) -> usize {
795        self.nodes
796            .iter()
797            .filter_map(SessionNodeRecord::message)
798            .filter(|message| matches!(message.role, MessageRole::User))
799            .count()
800    }
801
802    pub fn first_user_message(&self) -> String {
803        self.nodes
804            .iter()
805            .filter_map(SessionNodeRecord::message)
806            .find(|message| matches!(message.role, MessageRole::User))
807            .map(|message| first_message_search_text(&message))
808            .unwrap_or_default()
809    }
810
811    pub fn branch_to(&mut self, node_id: Option<String>) {
812        self.data_mut().leaf_node_id = node_id;
813    }
814
815    pub fn set_leaf_node_id(&mut self, node_id: Option<String>) {
816        self.data_mut().leaf_node_id = node_id;
817    }
818
819    pub fn push_node_record(&mut self, node: SessionNodeRecord) {
820        self.data_mut().nodes.push(node);
821    }
822
823    pub fn extend_node_records<I>(&mut self, nodes: I)
824    where
825        I: IntoIterator<Item = SessionNodeRecord>,
826    {
827        self.data_mut().nodes.extend(nodes);
828    }
829
830    /// Append nodes that extend the current active path, advancing the
831    /// leaf to the last node and updating the cache incrementally
832    /// instead of invalidating it. Use this when the appended nodes are
833    /// genuinely new descendants of the current leaf — e.g. the
834    /// turn-driver merging `TurnGraphOverlay` deltas into the base graph.
835    /// Use `extend_node_records` + `set_leaf_node_id` for store-side
836    /// replay paths that don't follow the active-path append shape.
837    pub fn extend_active_path(&mut self, nodes: Vec<SessionNodeRecord>) {
838        self.append_prebuilt_nodes(nodes);
839    }
840
841    pub fn active_path_contains(&self, node_id: &str) -> bool {
842        self.active_path_nodes()
843            .into_iter()
844            .any(|node| node.node_id == node_id)
845    }
846
847    /// If `leaf_node_id` points to a node that no longer exists in
848    /// `self.nodes` (e.g. after compaction rewrote the graph, or a
849    /// stored session referenced a node that was later purged), fall
850    /// back to the most recent message node. Returns `true` if the
851    /// leaf was repaired. Call this on load paths where an orphan
852    /// leaf would project to an empty transcript and silently drop
853    /// the user's history.
854    pub fn heal_orphaned_leaf(&mut self) -> bool {
855        if let Some(leaf) = self.leaf_node_id.as_ref()
856            && self.find_node(leaf).is_none()
857        {
858            let fallback = self
859                .nodes
860                .iter()
861                .rev()
862                .find(|node| node.message().is_some())
863                .map(|node| node.node_id.clone());
864            self.data_mut().leaf_node_id = fallback;
865            return true;
866        }
867        false
868    }
869
870    pub fn fork_current_path(&self) -> SessionGraph {
871        let path = self.active_path_nodes();
872        SessionGraph::from_nodes(
873            path.into_iter().cloned().collect(),
874            self.leaf_node_id.clone(),
875        )
876    }
877
878    pub fn find_node(&self, node_id: &str) -> Option<&SessionNodeRecord> {
879        self.cache()
880            .by_id
881            .get(node_id)
882            .and_then(|idx| self.nodes.get(*idx))
883    }
884
885    pub fn node_index(&self, node_id: &str) -> Option<usize> {
886        self.cache().by_id.get(node_id).copied()
887    }
888
889    pub fn replace_active_read_state(
890        &mut self,
891        messages: &[Message],
892        tool_calls: &[ToolCallRecord],
893    ) {
894        let current_nodes = self.active_path_nodes();
895        let existing_ids = self
896            .nodes
897            .iter()
898            .map(|node| node.node_id.clone())
899            .collect::<HashSet<_>>();
900        let replacement =
901            build_active_read_replacement(current_nodes, &existing_ids, messages, tool_calls);
902        let data = self.data_mut();
903        data.leaf_node_id = replacement.leaf_node_id;
904        data.nodes.extend(replacement.new_tail_nodes);
905    }
906
907    pub fn from_active_read_state(messages: &[Message], tool_calls: &[ToolCallRecord]) -> Self {
908        let mut graph = Self::default();
909        graph.replace_active_read_state(messages, tool_calls);
910        graph
911    }
912
913    pub fn message_tree(&self) -> Vec<SessionMessageTreeNode> {
914        let active_message_ids = self
915            .active_path_nodes()
916            .into_iter()
917            .filter_map(|node| node.message().map(|message| message.id.clone()))
918            .collect::<HashSet<_>>();
919
920        let message_nodes = self
921            .nodes
922            .iter()
923            .filter_map(|node| {
924                let message = node.message()?.clone();
925                let parent_message_node_id =
926                    self.nearest_message_ancestor(node.parent_node_id.as_deref());
927                Some(SessionMessageTreeNode {
928                    node_id: node.node_id.clone(),
929                    parent_message_node_id,
930                    message,
931                    timestamp: node.timestamp.clone(),
932                    children: Vec::new(),
933                    active: active_message_ids.contains(&node.node_id),
934                })
935            })
936            .collect::<Vec<_>>();
937
938        build_tree(message_nodes)
939    }
940
941    fn nearest_message_ancestor(&self, node_id: Option<&str>) -> Option<String> {
942        let by_id = self
943            .nodes
944            .iter()
945            .map(|node| (node.node_id.as_str(), node))
946            .collect::<HashMap<_, _>>();
947        let mut current = node_id.and_then(|id| by_id.get(id).copied());
948        while let Some(node) = current {
949            if node.message().is_some() {
950                return Some(node.node_id.clone());
951            }
952            current = node
953                .parent_node_id
954                .as_deref()
955                .and_then(|parent| by_id.get(parent).copied());
956        }
957        None
958    }
959}
960
961fn build_tree(mut nodes: Vec<SessionMessageTreeNode>) -> Vec<SessionMessageTreeNode> {
962    let mut children_by_parent = HashMap::<Option<String>, Vec<SessionMessageTreeNode>>::new();
963    for node in nodes.drain(..) {
964        children_by_parent
965            .entry(node.parent_message_node_id.clone())
966            .or_default()
967            .push(node);
968    }
969    let mut roots = build_tree_children(None, &mut children_by_parent);
970    sort_tree(&mut roots);
971    roots
972}
973
974fn sort_tree(nodes: &mut [SessionMessageTreeNode]) {
975    nodes.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
976    for node in nodes {
977        sort_tree(&mut node.children);
978    }
979}
980
981fn build_tree_children(
982    parent_id: Option<String>,
983    children_by_parent: &mut HashMap<Option<String>, Vec<SessionMessageTreeNode>>,
984) -> Vec<SessionMessageTreeNode> {
985    let mut children = children_by_parent.remove(&parent_id).unwrap_or_default();
986    for child in &mut children {
987        child.children = build_tree_children(Some(child.node_id.clone()), children_by_parent);
988    }
989    children
990}
991
992fn build_active_read_items<'a>(
993    messages: &'a [Message],
994    tool_calls: &'a [ToolCallRecord],
995) -> Vec<ActiveReadItem<'a>> {
996    let mut first_message_for_call = HashMap::<String, usize>::new();
997    let active_messages = messages
998        .iter()
999        .filter(|message| !message.is_transient())
1000        .collect::<Vec<_>>();
1001    for (idx, message) in active_messages.iter().enumerate() {
1002        for part in message.parts.iter() {
1003            if let Some(call_id) = &part.tool_call_id {
1004                first_message_for_call.entry(call_id.clone()).or_insert(idx);
1005            }
1006        }
1007    }
1008
1009    let mut anchored = HashMap::<usize, Vec<ActiveReadItem<'a>>>::new();
1010    for record in tool_calls {
1011        let stable_key = stable_tool_call_key(record);
1012        let anchor = record
1013            .call_id
1014            .as_ref()
1015            .and_then(|call_id| first_message_for_call.get(call_id).copied())
1016            .unwrap_or_else(|| active_messages.len().saturating_sub(1));
1017        anchored
1018            .entry(anchor)
1019            .or_default()
1020            .push(ActiveReadItem::ToolCall { stable_key, record });
1021    }
1022
1023    let mut out = Vec::new();
1024    for (idx, message) in active_messages.iter().enumerate() {
1025        out.push(ActiveReadItem::Message(message));
1026        if let Some(items) = anchored.remove(&idx) {
1027            out.extend(items);
1028        }
1029    }
1030    out
1031}
1032
1033pub(crate) fn build_active_read_replacement<'a>(
1034    current_nodes: impl IntoIterator<Item = &'a SessionNodeRecord>,
1035    existing_node_ids: &HashSet<String>,
1036    messages: &[Message],
1037    tool_calls: &[ToolCallRecord],
1038) -> ActiveReadReplacement {
1039    let target = build_active_read_items(messages, tool_calls);
1040
1041    let mut active_events = Vec::new();
1042    let mut active_messages = Vec::new();
1043    let mut active_message_ids = HashSet::new();
1044    let mut active_tool_calls = Vec::new();
1045    let mut seen_active_read_keys = HashSet::new();
1046    let mut target_idx = 0usize;
1047    let mut leaf_node_id = None;
1048    for node in current_nodes {
1049        if node
1050            .message()
1051            .map(|message| message.is_transient())
1052            .unwrap_or(false)
1053        {
1054            continue;
1055        }
1056        if let Some(key) = recognized_active_read_key(node) {
1057            if !seen_active_read_keys.insert(key.clone()) {
1058                continue;
1059            }
1060            let Some(target_item) = target.get(target_idx) else {
1061                break;
1062            };
1063            if key != active_read_item_key(target_item) {
1064                break;
1065            }
1066            push_active_read_node(
1067                node,
1068                &mut active_events,
1069                &mut active_messages,
1070                &mut active_message_ids,
1071                &mut active_tool_calls,
1072            );
1073            leaf_node_id = Some(node.node_id.clone());
1074            target_idx += 1;
1075        } else {
1076            push_active_read_node(
1077                node,
1078                &mut active_events,
1079                &mut active_messages,
1080                &mut active_message_ids,
1081                &mut active_tool_calls,
1082            );
1083            leaf_node_id = Some(node.node_id.clone());
1084        }
1085    }
1086
1087    let mut new_node_ids = HashSet::new();
1088    let mut new_tail_nodes = Vec::new();
1089
1090    for item in target.into_iter().skip(target_idx) {
1091        let parent_node_id = leaf_node_id.clone();
1092        let node = match item {
1093            ActiveReadItem::Message(message) => {
1094                let node_id = unique_message_node_id_for_replacement(
1095                    &message.id,
1096                    existing_node_ids,
1097                    &new_node_ids,
1098                );
1099                SessionNodeRecord {
1100                    node_id,
1101                    parent_node_id,
1102                    timestamp: Utc::now().to_rfc3339(),
1103                    payload: SessionNodePayload::Event {
1104                        event: SessionEventRecord::Conversation(ConversationRecord::from_message(
1105                            message.clone(),
1106                        )),
1107                    },
1108                }
1109            }
1110            ActiveReadItem::ToolCall { stable_key, record } => {
1111                let node_id = unique_plugin_node_id_for_replacement(
1112                    &stable_key,
1113                    existing_node_ids,
1114                    &new_node_ids,
1115                );
1116                SessionNodeRecord {
1117                    node_id,
1118                    parent_node_id,
1119                    timestamp: Utc::now().to_rfc3339(),
1120                    payload: SessionNodePayload::Event {
1121                        event: SessionEventRecord::Tool(ToolEvent::Invocation {
1122                            stable_key,
1123                            record: record.clone(),
1124                        }),
1125                    },
1126                }
1127            }
1128        };
1129        new_node_ids.insert(node.node_id.clone());
1130        leaf_node_id = Some(node.node_id.clone());
1131        push_active_read_node(
1132            &node,
1133            &mut active_events,
1134            &mut active_messages,
1135            &mut active_message_ids,
1136            &mut active_tool_calls,
1137        );
1138        new_tail_nodes.push(node);
1139    }
1140
1141    ActiveReadReplacement {
1142        leaf_node_id,
1143        new_tail_nodes,
1144        active_events,
1145        active_messages,
1146        active_tool_calls,
1147    }
1148}
1149
1150fn push_active_read_node(
1151    node: &SessionNodeRecord,
1152    active_events: &mut Vec<SessionEventRecord>,
1153    active_messages: &mut Vec<Message>,
1154    active_message_ids: &mut HashSet<String>,
1155    active_tool_calls: &mut Vec<ToolCallRecord>,
1156) {
1157    if let Some(event) = node.event() {
1158        active_events.push(event.clone());
1159    }
1160    if let Some(message) = node.message() {
1161        if !message.is_transient() && active_message_ids.insert(message.id.clone()) {
1162            active_messages.push(message);
1163        }
1164        return;
1165    }
1166    if let Some(SessionEventRecord::Tool(ToolEvent::Invocation { record, .. })) = node.event() {
1167        active_tool_calls.push(record.clone());
1168    }
1169}
1170
1171fn recognized_active_read_key(node: &SessionNodeRecord) -> Option<String> {
1172    match &node.payload {
1173        SessionNodePayload::Event { event } => match event {
1174            SessionEventRecord::Conversation(record) => Some(format!("message:{}", record.id)),
1175            SessionEventRecord::Tool(ToolEvent::Invocation { stable_key, record }) => {
1176                Some(tool_call_active_read_key(stable_key, record))
1177            }
1178            _ => None,
1179        },
1180        SessionNodePayload::Plugin { .. } => None,
1181    }
1182}
1183
1184fn active_read_item_key(item: &ActiveReadItem<'_>) -> String {
1185    match item {
1186        ActiveReadItem::Message(message) => format!("message:{}", message.id),
1187        ActiveReadItem::ToolCall { stable_key, record } => {
1188            tool_call_active_read_key(stable_key, record)
1189        }
1190    }
1191}
1192
1193fn tool_call_active_read_key(stable_key: &str, record: &ToolCallRecord) -> String {
1194    let fingerprint = serde_json::to_string(record).unwrap_or_default();
1195    format!("tool_call:{stable_key}:{fingerprint}")
1196}
1197
1198pub(crate) fn tool_call_record_active_read_key(record: &ToolCallRecord) -> String {
1199    let stable_key = stable_tool_call_key(record);
1200    tool_call_active_read_key(&stable_key, record)
1201}
1202
1203fn unique_plugin_node_id(stable_key: &str, existing_ids: &HashSet<String>) -> String {
1204    let base = format!("plugin:{}", stable_key);
1205    if !existing_ids.contains(&base) {
1206        return base;
1207    }
1208    loop {
1209        let candidate = format!("plugin:{}:{}", stable_key, uuid::Uuid::new_v4());
1210        if !existing_ids.contains(&candidate) {
1211            return candidate;
1212        }
1213    }
1214}
1215
1216fn unique_plugin_node_id_for_replacement(
1217    stable_key: &str,
1218    existing_ids: &HashSet<String>,
1219    new_ids: &HashSet<String>,
1220) -> String {
1221    let base = format!("plugin:{}", stable_key);
1222    if !existing_ids.contains(&base) && !new_ids.contains(&base) {
1223        return base;
1224    }
1225    loop {
1226        let candidate = format!("plugin:{}:{}", stable_key, uuid::Uuid::new_v4());
1227        if !existing_ids.contains(&candidate) && !new_ids.contains(&candidate) {
1228            return candidate;
1229        }
1230    }
1231}
1232
1233fn unique_message_node_id(message_id: &str, existing_ids: &HashSet<String>) -> String {
1234    if !existing_ids.contains(message_id) {
1235        return message_id.to_string();
1236    }
1237    let base = format!("message:{message_id}");
1238    if !existing_ids.contains(&base) {
1239        return base;
1240    }
1241    for suffix in 2.. {
1242        let candidate = format!("{base}:{suffix}");
1243        if !existing_ids.contains(&candidate) {
1244            return candidate;
1245        }
1246    }
1247    unreachable!("message node id space exhausted")
1248}
1249
1250fn unique_message_node_id_for_replacement(
1251    message_id: &str,
1252    existing_ids: &HashSet<String>,
1253    new_ids: &HashSet<String>,
1254) -> String {
1255    if !existing_ids.contains(message_id) && !new_ids.contains(message_id) {
1256        return message_id.to_string();
1257    }
1258    let base = format!("message:{message_id}");
1259    if !existing_ids.contains(&base) && !new_ids.contains(&base) {
1260        return base;
1261    }
1262    for suffix in 2.. {
1263        let candidate = format!("{base}:{suffix}");
1264        if !existing_ids.contains(&candidate) && !new_ids.contains(&candidate) {
1265            return candidate;
1266        }
1267    }
1268    unreachable!("message node id space exhausted")
1269}
1270
1271fn fresh_node_id(prefix: &str) -> String {
1272    format!("{prefix}{}", uuid::Uuid::new_v4().simple())
1273}
1274
1275fn stable_tool_call_key(record: &ToolCallRecord) -> String {
1276    if let Some(call_id) = record
1277        .call_id
1278        .as_ref()
1279        .filter(|call_id| !call_id.is_empty())
1280    {
1281        return call_id.clone();
1282    }
1283    let raw = serde_json::to_vec(&(record.tool.clone(), &record.args, &record.output))
1284        .unwrap_or_else(|_| b"tool-call".to_vec());
1285    let digest = sha2::Sha256::digest(raw);
1286    format!("anon-{}", &format!("{digest:x}")[..12])
1287}
1288
1289fn first_message_search_text(message: &Message) -> String {
1290    message
1291        .parts
1292        .iter()
1293        .filter_map(|part| match part.kind {
1294            crate::PartKind::ToolCall | crate::PartKind::ToolResult => None,
1295            crate::PartKind::Image => Some("[Image attached]".to_string()),
1296            _ => (!part.content.trim().is_empty()).then(|| part.content.clone()),
1297        })
1298        .collect::<Vec<_>>()
1299        .join("\n\n")
1300        .trim()
1301        .to_string()
1302}