Skip to main content

lash_core/
session_graph.rs

1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3use std::sync::{Arc, OnceLock};
4
5use chrono::Utc;
6
7use crate::session_model::{ConversationRecord, ProtocolEvent, SessionEventRecord};
8use crate::{BaseRenderCache, Message, MessageRole, PromptUsage, TokenUsage};
9
10#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
11pub struct SessionGraphData {
12    #[serde(default)]
13    pub nodes: Vec<SessionNodeRecord>,
14    #[serde(default, skip_serializing_if = "Option::is_none")]
15    pub leaf_node_id: Option<String>,
16}
17
18#[derive(Debug)]
19pub struct SessionGraph {
20    inner: Arc<SessionGraphData>,
21    cache: Arc<OnceLock<SessionGraphCache>>,
22}
23
24impl Default for SessionGraph {
25    fn default() -> Self {
26        Self {
27            inner: Arc::new(SessionGraphData::default()),
28            cache: Arc::new(OnceLock::new()),
29        }
30    }
31}
32
33impl Clone for SessionGraph {
34    fn clone(&self) -> Self {
35        Self {
36            inner: Arc::clone(&self.inner),
37            cache: Arc::clone(&self.cache),
38        }
39    }
40}
41
42impl serde::Serialize for SessionGraph {
43    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44    where
45        S: serde::Serializer,
46    {
47        self.inner.serialize(serializer)
48    }
49}
50
51impl<'de> serde::Deserialize<'de> for SessionGraph {
52    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
53    where
54        D: serde::Deserializer<'de>,
55    {
56        let inner = SessionGraphData::deserialize(deserializer)?;
57        Ok(Self {
58            inner: Arc::new(inner),
59            cache: Arc::new(OnceLock::new()),
60        })
61    }
62}
63
64impl Deref for SessionGraph {
65    type Target = SessionGraphData;
66
67    fn deref(&self) -> &Self::Target {
68        self.inner.as_ref()
69    }
70}
71
72#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
73pub struct SessionNodeRecord {
74    pub node_id: String,
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub parent_node_id: Option<String>,
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    pub caused_by: Option<crate::CausalRef>,
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub agent_frame_id: Option<crate::AgentFrameId>,
81    pub timestamp: String,
82    #[serde(flatten)]
83    pub payload: SessionNodePayload,
84}
85
86#[derive(Clone, Debug)]
87pub(crate) struct SessionNodeDraft {
88    payload: SessionNodeDraftPayload,
89    caused_by: Option<crate::CausalRef>,
90}
91
92#[derive(Clone, Debug)]
93enum SessionNodeDraftPayload {
94    Message(Message),
95    Plugin {
96        plugin_type: String,
97        body: serde_json::Value,
98    },
99    ProtocolEvent(ProtocolEvent),
100}
101
102impl SessionNodeDraft {
103    pub(crate) fn message(message: Message) -> Self {
104        Self {
105            payload: SessionNodeDraftPayload::Message(message),
106            caused_by: None,
107        }
108    }
109
110    pub(crate) fn plugin(plugin_type: impl Into<String>, body: serde_json::Value) -> Self {
111        Self {
112            payload: SessionNodeDraftPayload::Plugin {
113                plugin_type: plugin_type.into(),
114                body,
115            },
116            caused_by: None,
117        }
118    }
119
120    pub(crate) fn protocol_event(event: ProtocolEvent) -> Self {
121        Self {
122            payload: SessionNodeDraftPayload::ProtocolEvent(event),
123            caused_by: None,
124        }
125    }
126
127    pub(crate) fn with_caused_by(mut self, caused_by: Option<crate::CausalRef>) -> Self {
128        self.caused_by = caused_by;
129        self
130    }
131}
132
133#[derive(Clone, Debug, Default, PartialEq)]
134pub struct SharedJsonValue(pub Arc<serde_json::Value>);
135
136impl SharedJsonValue {
137    pub fn new(value: serde_json::Value) -> Self {
138        Self(Arc::new(value))
139    }
140
141    pub fn to_owned(&self) -> serde_json::Value {
142        self.0.as_ref().clone()
143    }
144}
145
146impl AsRef<serde_json::Value> for SharedJsonValue {
147    fn as_ref(&self) -> &serde_json::Value {
148        self.0.as_ref()
149    }
150}
151
152impl serde::Serialize for SharedJsonValue {
153    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
154    where
155        S: serde::Serializer,
156    {
157        self.0.serialize(serializer)
158    }
159}
160
161impl<'de> serde::Deserialize<'de> for SharedJsonValue {
162    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
163    where
164        D: serde::Deserializer<'de>,
165    {
166        let value = serde_json::Value::deserialize(deserializer)?;
167        Ok(Self::new(value))
168    }
169}
170
171#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
172#[serde(tag = "kind", rename_all = "snake_case")]
173#[allow(clippy::large_enum_variant)]
174pub enum SessionNodePayload {
175    Event {
176        event: SessionEventRecord,
177    },
178    Plugin {
179        plugin_type: String,
180        body: SharedJsonValue,
181    },
182}
183
184#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
185pub struct PersistedSessionConfig {
186    pub provider_id: String,
187    pub model: crate::ModelSpec,
188}
189
190#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
191pub struct PersistedTurnState {
192    pub turn_index: usize,
193    #[serde(default)]
194    pub token_usage: TokenUsage,
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    pub last_prompt_usage: Option<PromptUsage>,
197    #[serde(default)]
198    pub protocol_turn_options: crate::ProtocolTurnOptions,
199}
200
201#[derive(Clone, Debug)]
202pub struct SessionMessageTreeNode {
203    pub node_id: String,
204    pub parent_message_node_id: Option<String>,
205    pub message: Message,
206    pub timestamp: String,
207    pub children: Vec<SessionMessageTreeNode>,
208    pub active: bool,
209}
210
211#[derive(Clone, Debug)]
212pub(crate) struct ActiveReadReplacement {
213    pub(crate) leaf_node_id: Option<String>,
214    pub(crate) new_tail_nodes: Vec<SessionNodeRecord>,
215    pub(crate) active_events: Vec<SessionEventRecord>,
216    pub(crate) active_messages: Vec<Message>,
217}
218
219#[derive(Clone, Debug)]
220pub(crate) struct SessionReadModel {
221    pub(crate) active_events: Arc<Vec<SessionEventRecord>>,
222    pub(crate) messages: Arc<Vec<Message>>,
223    pub(crate) prompt_render_cache: Arc<BaseRenderCache>,
224}
225
226#[derive(Clone, Debug)]
227pub(crate) struct SessionGraphAppendBuilder {
228    existing_ids: HashSet<String>,
229    leaf_node_id: Option<String>,
230    agent_frame_id: Option<crate::AgentFrameId>,
231}
232
233impl SessionGraphAppendBuilder {
234    pub(crate) fn with_agent_frame_id(
235        mut self,
236        agent_frame_id: impl Into<crate::AgentFrameId>,
237    ) -> Self {
238        self.agent_frame_id = Some(agent_frame_id.into());
239        self
240    }
241
242    pub(crate) fn agent_frame_id(&self) -> Option<&str> {
243        self.agent_frame_id.as_deref()
244    }
245
246    pub(crate) fn leaf_node_id(&self) -> Option<&String> {
247        self.leaf_node_id.as_ref()
248    }
249
250    pub(crate) fn set_leaf_node_id(&mut self, leaf_node_id: Option<String>) {
251        self.leaf_node_id = leaf_node_id;
252    }
253
254    pub(crate) fn register_existing_node_ids<'a>(
255        &mut self,
256        node_ids: impl IntoIterator<Item = &'a str>,
257    ) {
258        self.existing_ids
259            .extend(node_ids.into_iter().map(ToOwned::to_owned));
260    }
261
262    pub(crate) fn existing_node_ids(&self) -> &HashSet<String> {
263        &self.existing_ids
264    }
265
266    pub(crate) fn append_messages<I>(&mut self, messages: I) -> Vec<SessionNodeRecord>
267    where
268        I: IntoIterator<Item = Message>,
269    {
270        self.append_drafts(messages.into_iter().map(SessionNodeDraft::message))
271    }
272
273    pub(crate) fn append_protocol_events<I>(&mut self, events: I) -> Vec<SessionNodeRecord>
274    where
275        I: IntoIterator<Item = ProtocolEvent>,
276    {
277        self.append_drafts(events.into_iter().map(SessionNodeDraft::protocol_event))
278    }
279
280    pub(crate) fn append_drafts<I>(&mut self, drafts: I) -> Vec<SessionNodeRecord>
281    where
282        I: IntoIterator<Item = SessionNodeDraft>,
283    {
284        let mut nodes = Vec::new();
285        for draft in drafts {
286            let parent_node_id = self.leaf_node_id.clone();
287            let (node_id, caused_by, payload) = match draft.payload {
288                SessionNodeDraftPayload::Message(mut message) => {
289                    if message.id.is_empty() {
290                        message.id = fresh_node_id("m");
291                    }
292                    let node_id = unique_message_node_id(&message.id, &self.existing_ids);
293                    let caused_by = draft
294                        .caused_by
295                        .or_else(|| causal_ref_from_message_origin(&message.origin));
296                    (
297                        node_id,
298                        caused_by,
299                        SessionNodePayload::Event {
300                            event: SessionEventRecord::Conversation(
301                                ConversationRecord::from_message(message),
302                            ),
303                        },
304                    )
305                }
306                SessionNodeDraftPayload::Plugin { plugin_type, body } => {
307                    let node_id = fresh_semantic_node_id("plugin", &self.existing_ids);
308                    (
309                        node_id,
310                        draft.caused_by,
311                        SessionNodePayload::Plugin {
312                            plugin_type,
313                            body: SharedJsonValue::new(body),
314                        },
315                    )
316                }
317                SessionNodeDraftPayload::ProtocolEvent(event) => {
318                    let node_id = fresh_semantic_node_id("protocol", &self.existing_ids);
319                    (
320                        node_id,
321                        draft.caused_by,
322                        SessionNodePayload::Event {
323                            event: SessionEventRecord::Protocol(event),
324                        },
325                    )
326                }
327            };
328            self.existing_ids.insert(node_id.clone());
329            self.leaf_node_id = Some(node_id.clone());
330            nodes.push(SessionNodeRecord {
331                node_id,
332                parent_node_id,
333                caused_by,
334                agent_frame_id: self.agent_frame_id.clone(),
335                timestamp: Utc::now().to_rfc3339(),
336                payload,
337            });
338        }
339        nodes
340    }
341}
342
343#[derive(Debug, Clone)]
344struct SessionGraphCache {
345    by_id: HashMap<String, usize>,
346    active_path_indices: Vec<usize>,
347    active_events: Arc<Vec<SessionEventRecord>>,
348    active_messages: Arc<Vec<Message>>,
349    /// Index from `Message::id` to its position in `active_messages`,
350    /// kept in sync with the vec so dedup on append is O(1) instead of an
351    /// O(n) linear scan (which made long sessions quadratic in message
352    /// count).
353    active_message_ids: HashMap<String, usize>,
354    /// Memoized render of `active_messages`. Shared with every
355    /// `MessageSequence` built off this read model so the chat projector's
356    /// per-iteration `render_prompt` walk only happens once per turn.
357    /// Replaced (not invalidated in-place) whenever `active_messages`
358    /// changes — the `Arc` identity tracks the cache's validity.
359    prompt_render_cache: Arc<BaseRenderCache>,
360}
361
362impl SessionGraphCache {
363    fn build(graph: &SessionGraph) -> Self {
364        let by_id = graph
365            .nodes
366            .iter()
367            .enumerate()
368            .map(|(idx, node)| (node.node_id.clone(), idx))
369            .collect::<HashMap<_, _>>();
370        let mut active_path_indices = Vec::new();
371        let mut current = graph
372            .leaf_node_id
373            .as_ref()
374            .and_then(|node_id| by_id.get(node_id).copied());
375        while let Some(idx) = current {
376            active_path_indices.push(idx);
377            current = graph.nodes[idx]
378                .parent_node_id
379                .as_ref()
380                .and_then(|node_id| by_id.get(node_id).copied());
381        }
382        active_path_indices.reverse();
383
384        let mut cache = Self {
385            by_id,
386            active_path_indices,
387            active_events: Arc::new(Vec::new()),
388            active_messages: Arc::new(Vec::new()),
389            active_message_ids: HashMap::new(),
390            prompt_render_cache: Arc::new(BaseRenderCache::new()),
391        };
392        cache.rebuild_read_model(graph);
393        cache
394    }
395
396    fn rebuild_read_model(&mut self, graph: &SessionGraph) {
397        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
398        let mut active_message_ids: HashMap<String, usize> =
399            HashMap::with_capacity(self.active_path_indices.len());
400        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
401        for idx in &self.active_path_indices {
402            let node = &graph.nodes[*idx];
403            if let Some(event) = node.event() {
404                active_events.push(event.clone());
405            }
406            if let Some(message) = node.message() {
407                if !message.is_transient() && !active_message_ids.contains_key(&message.id) {
408                    active_message_ids.insert(message.id.clone(), active_messages.len());
409                    active_messages.push(message);
410                }
411                continue;
412            }
413        }
414        self.active_messages = Arc::new(active_messages);
415        self.active_message_ids = active_message_ids;
416        self.active_events = Arc::new(active_events);
417        self.prompt_render_cache = Arc::new(BaseRenderCache::new());
418    }
419
420    fn read_model_for_agent_frame(
421        &self,
422        graph: &SessionGraph,
423        frame_id: &str,
424        include_unscoped: bool,
425    ) -> SessionReadModel {
426        let mut active_messages = Vec::with_capacity(self.active_path_indices.len());
427        let mut active_message_ids = HashSet::new();
428        let mut active_events = Vec::with_capacity(self.active_path_indices.len());
429        for idx in &self.active_path_indices {
430            let node = &graph.nodes[*idx];
431            if !node_belongs_to_agent_frame(node, frame_id, include_unscoped) {
432                continue;
433            }
434            if let Some(event) = node.event() {
435                active_events.push(event.clone());
436            }
437            if let Some(message) = node.message() {
438                if !message.is_transient() && active_message_ids.insert(message.id.clone()) {
439                    active_messages.push(message);
440                }
441                continue;
442            }
443        }
444        SessionReadModel {
445            active_events: Arc::new(active_events),
446            messages: Arc::new(active_messages),
447            prompt_render_cache: Arc::new(BaseRenderCache::new()),
448        }
449    }
450
451    fn append_node(
452        &mut self,
453        node_index: usize,
454        node: &SessionNodeRecord,
455        previous_leaf_node_id: Option<&str>,
456    ) {
457        self.by_id.insert(node.node_id.clone(), node_index);
458        let parent_matches_leaf = node.parent_node_id.as_deref() == previous_leaf_node_id;
459        if !parent_matches_leaf {
460            return;
461        }
462        self.active_path_indices.push(node_index);
463        if let Some(event) = node.event() {
464            Arc::make_mut(&mut self.active_events).push(event.clone());
465        }
466        if let Some(message) = node.message()
467            && !message.is_transient()
468            && !self.active_message_ids.contains_key(&message.id)
469        {
470            let messages = Arc::make_mut(&mut self.active_messages);
471            self.active_message_ids
472                .insert(message.id.clone(), messages.len());
473            messages.push(message);
474            self.prompt_render_cache = Arc::new(BaseRenderCache::new());
475        }
476    }
477
478    fn reserve_append_capacity(&mut self, additional_nodes: usize, additional_messages: usize) {
479        self.by_id.reserve(additional_nodes);
480        self.active_path_indices.reserve(additional_nodes);
481        if additional_messages > 0 {
482            Arc::make_mut(&mut self.active_messages).reserve(additional_messages);
483        }
484    }
485}
486
487impl SessionNodeRecord {
488    pub fn event(&self) -> Option<&SessionEventRecord> {
489        match &self.payload {
490            SessionNodePayload::Event { event } => Some(event),
491            SessionNodePayload::Plugin { .. } => None,
492        }
493    }
494
495    pub fn message(&self) -> Option<Message> {
496        match self.event()? {
497            SessionEventRecord::Conversation(record) => Some(record.to_message()),
498            _ => None,
499        }
500    }
501
502    pub fn plugin(&self) -> Option<(&str, &serde_json::Value)> {
503        match &self.payload {
504            SessionNodePayload::Event { .. } => None,
505            SessionNodePayload::Plugin { plugin_type, body } => {
506                Some((plugin_type.as_str(), body.as_ref()))
507            }
508        }
509    }
510
511    pub fn plugin_body<T>(&self) -> Option<T>
512    where
513        T: for<'de> serde::Deserialize<'de>,
514    {
515        let (_, body) = self.plugin()?;
516        T::deserialize(body).ok()
517    }
518}
519
520impl SessionGraph {
521    pub fn append_active_read_delta(&mut self, messages: &[Message]) {
522        self.append_active_read_delta_scoped(None, messages);
523    }
524
525    pub fn append_active_read_delta_for_agent_frame(
526        &mut self,
527        agent_frame_id: &str,
528        messages: &[Message],
529    ) {
530        self.append_active_read_delta_scoped(Some(agent_frame_id), messages);
531    }
532
533    fn append_active_read_delta_scoped(
534        &mut self,
535        agent_frame_id: Option<&str>,
536        messages: &[Message],
537    ) {
538        let appendable_messages = {
539            let read_model = agent_frame_id
540                .map(|frame_id| self.read_model_for_agent_frame(frame_id, false))
541                .unwrap_or_else(|| self.read_model());
542            let mut seen_message_ids = read_model
543                .messages
544                .iter()
545                .map(|message| message.id.as_str())
546                .collect::<HashSet<_>>();
547            messages
548                .iter()
549                .filter(|message| {
550                    !message.is_transient() && seen_message_ids.insert(message.id.as_str())
551                })
552                .cloned()
553                .collect::<Vec<_>>()
554        };
555
556        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len());
557        self.append_message_batch_scoped(agent_frame_id, appendable_messages);
558    }
559
560    pub(crate) fn append_active_conversation_messages_for_agent_frame(
561        &mut self,
562        agent_frame_id: &str,
563        messages: &[Message],
564    ) {
565        self.append_active_conversation_messages_scoped(Some(agent_frame_id), messages);
566    }
567
568    fn append_active_conversation_messages_scoped(
569        &mut self,
570        agent_frame_id: Option<&str>,
571        messages: &[Message],
572    ) {
573        let appendable_messages = messages
574            .iter()
575            .filter(|message| !message.is_transient())
576            .cloned()
577            .collect::<Vec<_>>();
578        self.reserve_append_capacity(appendable_messages.len(), appendable_messages.len());
579        self.append_message_batch_scoped(agent_frame_id, appendable_messages);
580    }
581
582    pub fn from_nodes(nodes: Vec<SessionNodeRecord>, leaf_node_id: Option<String>) -> Self {
583        Self {
584            inner: Arc::new(SessionGraphData {
585                nodes,
586                leaf_node_id,
587            }),
588            cache: Arc::new(OnceLock::new()),
589        }
590    }
591
592    pub(crate) fn append_builder(&self) -> SessionGraphAppendBuilder {
593        SessionGraphAppendBuilder {
594            existing_ids: self.nodes.iter().map(|node| node.node_id.clone()).collect(),
595            leaf_node_id: self.leaf_node_id.clone(),
596            agent_frame_id: None,
597        }
598    }
599
600    fn invalidate_cache(&mut self) {
601        self.cache = Arc::new(OnceLock::new());
602    }
603
604    fn data_mut(&mut self) -> &mut SessionGraphData {
605        self.invalidate_cache();
606        Arc::make_mut(&mut self.inner)
607    }
608
609    fn reserve_append_capacity(&mut self, additional_nodes: usize, additional_messages: usize) {
610        if additional_nodes == 0 {
611            return;
612        }
613        self.detach_initialized_cache_for_append();
614        Arc::make_mut(&mut self.inner)
615            .nodes
616            .reserve(additional_nodes);
617        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
618            && let Some(cache) = cache_lock.get_mut()
619        {
620            cache.reserve_append_capacity(additional_nodes, additional_messages);
621        }
622    }
623
624    fn detach_initialized_cache_for_append(&mut self) {
625        if Arc::get_mut(&mut self.cache).is_some() {
626            return;
627        }
628        let Some(cache) = self.cache.get().cloned() else {
629            self.invalidate_cache();
630            return;
631        };
632        let lock = OnceLock::new();
633        let _ = lock.set(cache);
634        self.cache = Arc::new(lock);
635    }
636
637    fn cache(&self) -> &SessionGraphCache {
638        self.cache.get_or_init(|| SessionGraphCache::build(self))
639    }
640
641    fn append_message_batch_scoped(
642        &mut self,
643        agent_frame_id: Option<&str>,
644        messages: Vec<Message>,
645    ) {
646        if messages.is_empty() {
647            return;
648        }
649        self.append_node_drafts_scoped(
650            agent_frame_id,
651            messages.into_iter().map(SessionNodeDraft::message),
652        );
653    }
654
655    fn append_prebuilt_nodes(&mut self, nodes: Vec<SessionNodeRecord>) {
656        if nodes.is_empty() {
657            return;
658        }
659
660        self.detach_initialized_cache_for_append();
661        if let Some(cache_lock) = Arc::get_mut(&mut self.cache)
662            && let Some(cache) = cache_lock.get_mut()
663        {
664            let data = Arc::make_mut(&mut self.inner);
665            for node in nodes {
666                let previous_leaf = data.leaf_node_id.clone();
667                let node_id = node.node_id.clone();
668                data.nodes.push(node);
669                cache.append_node(
670                    data.nodes.len() - 1,
671                    data.nodes.last().expect("just appended graph node"),
672                    previous_leaf.as_deref(),
673                );
674                data.leaf_node_id = Some(node_id);
675            }
676            return;
677        }
678
679        let data = self.data_mut();
680        for node in nodes {
681            data.leaf_node_id = Some(node.node_id.clone());
682            data.nodes.push(node);
683        }
684    }
685
686    pub fn append_message(&mut self, message: Message) -> String {
687        self.append_node_draft(SessionNodeDraft::message(message))
688    }
689
690    pub fn append_plugin(
691        &mut self,
692        plugin_type: impl Into<String>,
693        body: serde_json::Value,
694    ) -> String {
695        self.append_node_draft(SessionNodeDraft::plugin(plugin_type, body))
696    }
697
698    pub fn active_path_nodes(&self) -> Vec<&SessionNodeRecord> {
699        self.cache()
700            .active_path_indices
701            .iter()
702            .map(|idx| &self.nodes[*idx])
703            .collect()
704    }
705
706    pub(crate) fn read_model(&self) -> SessionReadModel {
707        let cache = self.cache();
708        SessionReadModel {
709            active_events: Arc::clone(&cache.active_events),
710            messages: Arc::clone(&cache.active_messages),
711            prompt_render_cache: Arc::clone(&cache.prompt_render_cache),
712        }
713    }
714
715    pub(crate) fn read_model_for_agent_frame(
716        &self,
717        frame_id: &str,
718        include_unscoped: bool,
719    ) -> SessionReadModel {
720        if frame_id.is_empty() {
721            return self.read_model();
722        }
723        self.cache()
724            .read_model_for_agent_frame(self, frame_id, include_unscoped)
725    }
726
727    pub fn append_protocol_event(&mut self, event: ProtocolEvent) -> String {
728        self.append_node_draft(SessionNodeDraft::protocol_event(event))
729    }
730
731    pub(crate) fn append_node_draft(&mut self, draft: SessionNodeDraft) -> String {
732        self.append_node_drafts([draft])
733            .into_iter()
734            .next()
735            .expect("single draft append must create one node")
736    }
737
738    pub(crate) fn append_node_drafts<I>(&mut self, drafts: I) -> Vec<String>
739    where
740        I: IntoIterator<Item = SessionNodeDraft>,
741    {
742        self.append_node_drafts_scoped(None, drafts)
743    }
744
745    pub(crate) fn append_node_drafts_for_agent_frame<I>(
746        &mut self,
747        agent_frame_id: &str,
748        drafts: I,
749    ) -> Vec<String>
750    where
751        I: IntoIterator<Item = SessionNodeDraft>,
752    {
753        self.append_node_drafts_scoped(Some(agent_frame_id), drafts)
754    }
755
756    fn append_node_drafts_scoped<I>(
757        &mut self,
758        agent_frame_id: Option<&str>,
759        drafts: I,
760    ) -> Vec<String>
761    where
762        I: IntoIterator<Item = SessionNodeDraft>,
763    {
764        let mut builder = self.append_builder();
765        if let Some(agent_frame_id) = agent_frame_id {
766            builder = builder.with_agent_frame_id(agent_frame_id.to_string());
767        }
768        let nodes = builder.append_drafts(drafts);
769        let node_ids = nodes
770            .iter()
771            .map(|node| node.node_id.clone())
772            .collect::<Vec<_>>();
773        self.append_prebuilt_nodes(nodes);
774        node_ids
775    }
776
777    pub fn user_message_count(&self) -> usize {
778        self.nodes
779            .iter()
780            .filter_map(SessionNodeRecord::message)
781            .filter(|message| matches!(message.role, MessageRole::User))
782            .count()
783    }
784
785    pub fn first_user_message(&self) -> String {
786        self.nodes
787            .iter()
788            .filter_map(SessionNodeRecord::message)
789            .find(|message| matches!(message.role, MessageRole::User))
790            .map(|message| first_message_search_text(&message))
791            .unwrap_or_default()
792    }
793
794    pub fn branch_to(&mut self, node_id: Option<String>) {
795        self.data_mut().leaf_node_id = node_id;
796    }
797
798    pub fn set_leaf_node_id(&mut self, node_id: Option<String>) {
799        self.data_mut().leaf_node_id = node_id;
800    }
801
802    pub fn push_node_record(&mut self, node: SessionNodeRecord) {
803        self.data_mut().nodes.push(node);
804    }
805
806    pub fn extend_node_records<I>(&mut self, nodes: I)
807    where
808        I: IntoIterator<Item = SessionNodeRecord>,
809    {
810        self.data_mut().nodes.extend(nodes);
811    }
812
813    /// Append nodes that extend the current active path, advancing the
814    /// leaf to the last node and updating the cache incrementally
815    /// instead of invalidating it. Use this when the appended nodes are
816    /// genuinely new descendants of the current leaf — e.g. the
817    /// turn-driver merging turn-local graph editor deltas into the base graph.
818    /// Use `extend_node_records` + `set_leaf_node_id` for store-side
819    /// replay paths that don't follow the active-path append shape.
820    pub fn extend_active_path(&mut self, nodes: Vec<SessionNodeRecord>) {
821        self.append_prebuilt_nodes(nodes);
822    }
823
824    pub fn active_path_contains(&self, node_id: &str) -> bool {
825        self.active_path_nodes()
826            .into_iter()
827            .any(|node| node.node_id == node_id)
828    }
829
830    /// If `leaf_node_id` points to a node that no longer exists in
831    /// `self.nodes` (e.g. after compaction rewrote the graph, or a
832    /// stored session referenced a node that was later purged), fall
833    /// back to the most recent message node. Returns `true` if the
834    /// leaf was repaired. Call this on load paths where an orphan
835    /// leaf would project to an empty transcript and silently drop
836    /// the user's history.
837    pub fn heal_orphaned_leaf(&mut self) -> bool {
838        if let Some(leaf) = self.leaf_node_id.as_ref()
839            && self.find_node(leaf).is_none()
840        {
841            let fallback = self
842                .nodes
843                .iter()
844                .rev()
845                .find(|node| node.message().is_some())
846                .map(|node| node.node_id.clone());
847            self.data_mut().leaf_node_id = fallback;
848            return true;
849        }
850        false
851    }
852
853    pub fn fork_current_path(&self) -> SessionGraph {
854        let path = self.active_path_nodes();
855        SessionGraph::from_nodes(
856            path.into_iter().cloned().collect(),
857            self.leaf_node_id.clone(),
858        )
859    }
860
861    pub fn find_node(&self, node_id: &str) -> Option<&SessionNodeRecord> {
862        self.cache()
863            .by_id
864            .get(node_id)
865            .and_then(|idx| self.nodes.get(*idx))
866    }
867
868    pub fn node_index(&self, node_id: &str) -> Option<usize> {
869        self.cache().by_id.get(node_id).copied()
870    }
871
872    pub fn replace_active_read_state(&mut self, messages: &[Message]) {
873        self.replace_active_read_state_scoped(None, messages);
874    }
875
876    pub fn replace_active_read_state_for_agent_frame(
877        &mut self,
878        agent_frame_id: &str,
879        messages: &[Message],
880    ) {
881        self.replace_active_read_state_scoped(Some(agent_frame_id), messages);
882    }
883
884    fn replace_active_read_state_scoped(
885        &mut self,
886        agent_frame_id: Option<&str>,
887        messages: &[Message],
888    ) {
889        let current_nodes = self.active_path_nodes();
890        let existing_ids = self
891            .nodes
892            .iter()
893            .map(|node| node.node_id.clone())
894            .collect::<HashSet<_>>();
895        let replacement =
896            build_active_read_replacement(current_nodes, &existing_ids, agent_frame_id, messages);
897        let data = self.data_mut();
898        data.leaf_node_id = replacement.leaf_node_id;
899        data.nodes.extend(replacement.new_tail_nodes);
900    }
901
902    pub fn from_active_read_state(messages: &[Message]) -> Self {
903        let mut graph = Self::default();
904        graph.replace_active_read_state(messages);
905        graph
906    }
907
908    pub fn message_tree(&self) -> Vec<SessionMessageTreeNode> {
909        let active_message_ids = self
910            .active_path_nodes()
911            .into_iter()
912            .filter_map(|node| node.message().map(|message| message.id.clone()))
913            .collect::<HashSet<_>>();
914
915        let message_nodes = self
916            .nodes
917            .iter()
918            .filter_map(|node| {
919                let message = node.message()?.clone();
920                let parent_message_node_id =
921                    self.nearest_message_ancestor(node.parent_node_id.as_deref());
922                Some(SessionMessageTreeNode {
923                    node_id: node.node_id.clone(),
924                    parent_message_node_id,
925                    message,
926                    timestamp: node.timestamp.clone(),
927                    children: Vec::new(),
928                    active: active_message_ids.contains(&node.node_id),
929                })
930            })
931            .collect::<Vec<_>>();
932
933        build_tree(message_nodes)
934    }
935
936    fn nearest_message_ancestor(&self, node_id: Option<&str>) -> Option<String> {
937        let by_id = self
938            .nodes
939            .iter()
940            .map(|node| (node.node_id.as_str(), node))
941            .collect::<HashMap<_, _>>();
942        let mut current = node_id.and_then(|id| by_id.get(id).copied());
943        while let Some(node) = current {
944            if node.message().is_some() {
945                return Some(node.node_id.clone());
946            }
947            current = node
948                .parent_node_id
949                .as_deref()
950                .and_then(|parent| by_id.get(parent).copied());
951        }
952        None
953    }
954}
955
956fn build_tree(mut nodes: Vec<SessionMessageTreeNode>) -> Vec<SessionMessageTreeNode> {
957    let mut children_by_parent = HashMap::<Option<String>, Vec<SessionMessageTreeNode>>::new();
958    for node in nodes.drain(..) {
959        children_by_parent
960            .entry(node.parent_message_node_id.clone())
961            .or_default()
962            .push(node);
963    }
964    let mut roots = build_tree_children(None, &mut children_by_parent);
965    sort_tree(&mut roots);
966    roots
967}
968
969fn sort_tree(nodes: &mut [SessionMessageTreeNode]) {
970    nodes.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
971    for node in nodes {
972        sort_tree(&mut node.children);
973    }
974}
975
976fn build_tree_children(
977    parent_id: Option<String>,
978    children_by_parent: &mut HashMap<Option<String>, Vec<SessionMessageTreeNode>>,
979) -> Vec<SessionMessageTreeNode> {
980    let mut children = children_by_parent.remove(&parent_id).unwrap_or_default();
981    for child in &mut children {
982        child.children = build_tree_children(Some(child.node_id.clone()), children_by_parent);
983    }
984    children
985}
986
987fn node_belongs_to_agent_frame(
988    node: &SessionNodeRecord,
989    frame_id: &str,
990    include_unscoped: bool,
991) -> bool {
992    match node.agent_frame_id.as_deref() {
993        Some(node_frame_id) => node_frame_id == frame_id,
994        None => include_unscoped,
995    }
996}
997
998pub(crate) fn build_active_read_replacement<'a>(
999    current_nodes: impl IntoIterator<Item = &'a SessionNodeRecord>,
1000    existing_node_ids: &HashSet<String>,
1001    agent_frame_id: Option<&str>,
1002    messages: &[Message],
1003) -> ActiveReadReplacement {
1004    let target = messages
1005        .iter()
1006        .filter(|message| !message.is_transient())
1007        .collect::<Vec<_>>();
1008
1009    let mut active_events = Vec::new();
1010    let mut active_messages = Vec::new();
1011    let mut active_message_ids = HashSet::new();
1012    let mut seen_active_read_keys = HashSet::new();
1013    let mut target_idx = 0usize;
1014    let mut leaf_node_id = None;
1015    for node in current_nodes {
1016        if node
1017            .message()
1018            .map(|message| message.is_transient())
1019            .unwrap_or(false)
1020        {
1021            continue;
1022        }
1023        if let Some(key) = recognized_active_read_key(node) {
1024            if !seen_active_read_keys.insert(key.clone()) {
1025                continue;
1026            }
1027            let Some(target_item) = target.get(target_idx) else {
1028                break;
1029            };
1030            if key != format!("message:{}", target_item.id) {
1031                break;
1032            }
1033            push_active_read_node(
1034                node,
1035                &mut active_events,
1036                &mut active_messages,
1037                &mut active_message_ids,
1038            );
1039            leaf_node_id = Some(node.node_id.clone());
1040            target_idx += 1;
1041        } else {
1042            push_active_read_node(
1043                node,
1044                &mut active_events,
1045                &mut active_messages,
1046                &mut active_message_ids,
1047            );
1048            leaf_node_id = Some(node.node_id.clone());
1049        }
1050    }
1051
1052    let mut new_node_ids = HashSet::new();
1053    let mut new_tail_nodes = Vec::new();
1054
1055    for message in target.into_iter().skip(target_idx) {
1056        let parent_node_id = leaf_node_id.clone();
1057        let node_id =
1058            unique_message_node_id_for_replacement(&message.id, existing_node_ids, &new_node_ids);
1059        let node = SessionNodeRecord {
1060            node_id,
1061            parent_node_id,
1062            caused_by: causal_ref_from_message_origin(&message.origin),
1063            agent_frame_id: agent_frame_id.map(ToOwned::to_owned),
1064            timestamp: Utc::now().to_rfc3339(),
1065            payload: SessionNodePayload::Event {
1066                event: SessionEventRecord::Conversation(ConversationRecord::from_message(
1067                    message.clone(),
1068                )),
1069            },
1070        };
1071        new_node_ids.insert(node.node_id.clone());
1072        leaf_node_id = Some(node.node_id.clone());
1073        push_active_read_node(
1074            &node,
1075            &mut active_events,
1076            &mut active_messages,
1077            &mut active_message_ids,
1078        );
1079        new_tail_nodes.push(node);
1080    }
1081
1082    ActiveReadReplacement {
1083        leaf_node_id,
1084        new_tail_nodes,
1085        active_events,
1086        active_messages,
1087    }
1088}
1089
1090fn push_active_read_node(
1091    node: &SessionNodeRecord,
1092    active_events: &mut Vec<SessionEventRecord>,
1093    active_messages: &mut Vec<Message>,
1094    active_message_ids: &mut HashSet<String>,
1095) {
1096    if let Some(event) = node.event() {
1097        active_events.push(event.clone());
1098    }
1099    if let Some(message) = node.message()
1100        && !message.is_transient()
1101        && active_message_ids.insert(message.id.clone())
1102    {
1103        active_messages.push(message);
1104    }
1105}
1106
1107fn recognized_active_read_key(node: &SessionNodeRecord) -> Option<String> {
1108    match &node.payload {
1109        SessionNodePayload::Event { event } => match event {
1110            SessionEventRecord::Conversation(record) => Some(format!("message:{}", record.id)),
1111            _ => None,
1112        },
1113        SessionNodePayload::Plugin { .. } => None,
1114    }
1115}
1116
1117fn causal_ref_from_message_origin(
1118    origin: &Option<crate::MessageOrigin>,
1119) -> Option<crate::CausalRef> {
1120    let Some(crate::MessageOrigin::Process {
1121        process_id,
1122        sequence,
1123        ..
1124    }) = origin
1125    else {
1126        return None;
1127    };
1128    Some(crate::CausalRef::ProcessEvent {
1129        process_id: process_id.clone(),
1130        sequence: *sequence,
1131    })
1132}
1133
1134fn fresh_semantic_node_id(prefix: &str, existing_ids: &HashSet<String>) -> String {
1135    loop {
1136        let candidate = format!("{prefix}:{}", uuid::Uuid::new_v4().simple());
1137        if !existing_ids.contains(&candidate) {
1138            return candidate;
1139        }
1140    }
1141}
1142
1143fn unique_message_node_id(message_id: &str, existing_ids: &HashSet<String>) -> String {
1144    if !existing_ids.contains(message_id) {
1145        return message_id.to_string();
1146    }
1147    let base = format!("message:{message_id}");
1148    if !existing_ids.contains(&base) {
1149        return base;
1150    }
1151    for suffix in 2.. {
1152        let candidate = format!("{base}:{suffix}");
1153        if !existing_ids.contains(&candidate) {
1154            return candidate;
1155        }
1156    }
1157    unreachable!("message node id space exhausted")
1158}
1159
1160fn unique_message_node_id_for_replacement(
1161    message_id: &str,
1162    existing_ids: &HashSet<String>,
1163    new_ids: &HashSet<String>,
1164) -> String {
1165    if !existing_ids.contains(message_id) && !new_ids.contains(message_id) {
1166        return message_id.to_string();
1167    }
1168    let base = format!("message:{message_id}");
1169    if !existing_ids.contains(&base) && !new_ids.contains(&base) {
1170        return base;
1171    }
1172    for suffix in 2.. {
1173        let candidate = format!("{base}:{suffix}");
1174        if !existing_ids.contains(&candidate) && !new_ids.contains(&candidate) {
1175            return candidate;
1176        }
1177    }
1178    unreachable!("message node id space exhausted")
1179}
1180
1181fn fresh_node_id(prefix: &str) -> String {
1182    format!("{prefix}{}", uuid::Uuid::new_v4().simple())
1183}
1184
1185fn first_message_search_text(message: &Message) -> String {
1186    message
1187        .parts
1188        .iter()
1189        .filter_map(|part| match part.kind {
1190            crate::PartKind::ToolCall | crate::PartKind::ToolResult => None,
1191            crate::PartKind::Image => Some("[Image attached]".to_string()),
1192            _ => (!part.content.trim().is_empty()).then(|| part.content.clone()),
1193        })
1194        .collect::<Vec<_>>()
1195        .join("\n\n")
1196        .trim()
1197        .to_string()
1198}
1199
1200#[cfg(test)]
1201mod tests {
1202    use super::*;
1203    use crate::{Part, PartKind, PruneState, shared_parts};
1204
1205    fn text_message(id: &str, role: MessageRole, content: &str) -> Message {
1206        Message {
1207            id: id.to_string(),
1208            role,
1209            parts: shared_parts(vec![Part {
1210                id: format!("{id}.p0"),
1211                kind: PartKind::Text,
1212                content: content.to_string(),
1213                attachment: None,
1214                tool_call_id: None,
1215                tool_name: None,
1216                tool_replay: None,
1217                prune_state: PruneState::Intact,
1218                reasoning_meta: None,
1219                response_meta: None,
1220            }]),
1221            origin: None,
1222        }
1223    }
1224
1225    fn protocol_event() -> ProtocolEvent {
1226        ProtocolEvent::typed("test_protocol", serde_json::json!({"step": "started"}))
1227            .expect("protocol event serializes")
1228    }
1229
1230    #[test]
1231    fn typed_append_node_ids_use_semantic_prefixes() {
1232        let mut graph = SessionGraph::default();
1233
1234        let message_id = graph.append_message(text_message("m1", MessageRole::User, "hello"));
1235        let protocol_id = graph.append_protocol_event(protocol_event());
1236        let plugin_id = graph.append_plugin("example", serde_json::json!({"ok": true}));
1237
1238        assert_eq!(message_id, "m1");
1239        assert!(protocol_id.starts_with("protocol:"));
1240        assert!(plugin_id.starts_with("plugin:"));
1241    }
1242
1243    #[test]
1244    fn active_read_replacement_persists_messages_only() {
1245        let message = text_message("m1", MessageRole::User, "hello");
1246        let graph = SessionGraph::from_active_read_state(&[message]);
1247
1248        assert_eq!(graph.nodes.len(), 1);
1249        assert!(matches!(
1250            graph.nodes[0].event(),
1251            Some(SessionEventRecord::Conversation(_))
1252        ));
1253    }
1254
1255    #[test]
1256    fn graph_writers_do_not_put_active_read_events_under_plugin_ids() {
1257        let mut graph = SessionGraph::default();
1258        graph.append_message(text_message("m1", MessageRole::User, "hello"));
1259        graph.append_protocol_event(protocol_event());
1260        graph.append_plugin("example", serde_json::json!({"ok": true}));
1261
1262        for node in &graph.nodes {
1263            match node.event() {
1264                Some(SessionEventRecord::Conversation(_)) => {
1265                    assert!(!node.node_id.starts_with("plugin:"), "{:?}", node);
1266                }
1267                Some(SessionEventRecord::Protocol(_)) => {
1268                    assert!(node.node_id.starts_with("protocol:"), "{:?}", node);
1269                }
1270                None => {
1271                    assert!(node.node_id.starts_with("plugin:"), "{:?}", node);
1272                }
1273            }
1274        }
1275    }
1276}