Skip to main content

aster/session/
fork.rs

1//! Session Fork/Branch Support
2//!
3//! Provides functionality for forking sessions and managing session branches,
4
5use crate::session::{Session, SessionManager};
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8
9/// Fork options for creating a new session branch
10#[derive(Debug, Clone, Default)]
11pub struct ForkOptions {
12    /// Message index to fork from (default: 0, meaning all messages)
13    pub from_message_index: Option<usize>,
14    /// Name for the new forked session
15    pub name: Option<String>,
16    /// Whether to include messages after the fork point (default: true)
17    pub include_future_messages: bool,
18}
19
20impl ForkOptions {
21    pub fn new() -> Self {
22        Self {
23            from_message_index: None,
24            name: None,
25            include_future_messages: true,
26        }
27    }
28
29    pub fn from_message_index(mut self, index: usize) -> Self {
30        self.from_message_index = Some(index);
31        self
32    }
33
34    pub fn name(mut self, name: impl Into<String>) -> Self {
35        self.name = Some(name.into());
36        self
37    }
38
39    pub fn include_future_messages(mut self, include: bool) -> Self {
40        self.include_future_messages = include;
41        self
42    }
43}
44
45/// Merge options for combining sessions
46#[derive(Debug, Clone, Default)]
47pub struct MergeOptions {
48    /// Merge strategy
49    pub strategy: MergeStrategy,
50    /// Metadata preservation strategy
51    pub keep_metadata: MetadataStrategy,
52}
53
54/// Strategy for merging session messages
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum MergeStrategy {
58    /// Append source messages to target
59    #[default]
60    Append,
61    /// Interleave messages by timestamp
62    Interleave,
63    /// Replace target messages with source
64    Replace,
65}
66
67/// Strategy for preserving metadata during merge
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69#[serde(rename_all = "snake_case")]
70pub enum MetadataStrategy {
71    /// Keep target session's metadata
72    #[default]
73    Target,
74    /// Use source session's metadata
75    Source,
76    /// Merge metadata from both sessions
77    Merge,
78}
79
80/// Fork metadata stored in extension_data
81#[derive(Debug, Clone, Serialize, Deserialize, Default)]
82pub struct ForkMetadata {
83    /// Parent session ID (if this is a fork)
84    pub parent_id: Option<String>,
85    /// Message index where fork occurred
86    pub fork_point: Option<usize>,
87    /// Child session IDs (branches)
88    pub branches: Vec<String>,
89    /// Fork name/description
90    pub fork_name: Option<String>,
91    /// Sessions merged into this one
92    pub merged_from: Vec<String>,
93}
94
95impl ForkMetadata {
96    pub const EXTENSION_NAME: &'static str = "fork";
97    pub const VERSION: &'static str = "v0";
98
99    /// Get fork metadata from session extension data
100    pub fn from_session(session: &Session) -> Option<Self> {
101        session
102            .extension_data
103            .get_extension_state(Self::EXTENSION_NAME, Self::VERSION)
104            .and_then(|v| serde_json::from_value(v.clone()).ok())
105    }
106
107    /// Save fork metadata to session extension data
108    pub fn to_extension_data(
109        &self,
110        extension_data: &mut crate::session::ExtensionData,
111    ) -> Result<()> {
112        let value = serde_json::to_value(self)?;
113        extension_data.set_extension_state(Self::EXTENSION_NAME, Self::VERSION, value);
114        Ok(())
115    }
116}
117
118/// Fork a session to create a new branch
119pub async fn fork_session(source_session_id: &str, options: ForkOptions) -> Result<Session> {
120    let source_session = SessionManager::get_session(source_session_id, true).await?;
121
122    let from_index = options.from_message_index.unwrap_or(0);
123    let new_name = options
124        .name
125        .unwrap_or_else(|| format!("{} (fork)", source_session.name));
126
127    // Create new session
128    let new_session = SessionManager::create_session(
129        source_session.working_dir.clone(),
130        new_name.clone(),
131        source_session.session_type,
132    )
133    .await?;
134
135    // Copy messages based on options
136    if let Some(conversation) = &source_session.conversation {
137        let messages = conversation.messages();
138        let messages_to_copy = if options.include_future_messages {
139            messages
140                .iter()
141                .skip(from_index)
142                .cloned()
143                .collect::<Vec<_>>()
144        } else {
145            messages
146                .iter()
147                .take(from_index)
148                .cloned()
149                .collect::<Vec<_>>()
150        };
151
152        if !messages_to_copy.is_empty() {
153            let new_conversation =
154                crate::conversation::Conversation::new_unvalidated(messages_to_copy);
155            SessionManager::replace_conversation(&new_session.id, &new_conversation).await?;
156        }
157    }
158
159    // Set fork metadata on new session
160    let fork_metadata = ForkMetadata {
161        parent_id: Some(source_session_id.to_string()),
162        fork_point: Some(from_index),
163        fork_name: Some(new_name),
164        ..Default::default()
165    };
166
167    let mut new_extension_data = new_session.extension_data.clone();
168    fork_metadata.to_extension_data(&mut new_extension_data)?;
169
170    SessionManager::update_session(&new_session.id)
171        .extension_data(new_extension_data)
172        .apply()
173        .await?;
174
175    // Update source session's branches list
176    let mut source_fork_metadata = ForkMetadata::from_session(&source_session).unwrap_or_default();
177    source_fork_metadata.branches.push(new_session.id.clone());
178
179    let mut source_extension_data = source_session.extension_data.clone();
180    source_fork_metadata.to_extension_data(&mut source_extension_data)?;
181
182    SessionManager::update_session(source_session_id)
183        .extension_data(source_extension_data)
184        .apply()
185        .await?;
186
187    SessionManager::get_session(&new_session.id, true).await
188}
189
190/// Merge one session into another
191pub async fn merge_sessions(
192    target_session_id: &str,
193    source_session_id: &str,
194    options: MergeOptions,
195) -> Result<Session> {
196    let target_session = SessionManager::get_session(target_session_id, true).await?;
197    let source_session = SessionManager::get_session(source_session_id, true).await?;
198
199    let target_messages = target_session
200        .conversation
201        .as_ref()
202        .map(|c| c.messages().to_vec())
203        .unwrap_or_default();
204
205    let source_messages = source_session
206        .conversation
207        .as_ref()
208        .map(|c| c.messages().to_vec())
209        .unwrap_or_default();
210
211    // Merge messages based on strategy
212    let merged_messages = match options.strategy {
213        MergeStrategy::Append => {
214            let mut messages = target_messages;
215            messages.extend(source_messages);
216            messages
217        }
218        MergeStrategy::Interleave => {
219            let mut messages = target_messages;
220            messages.extend(source_messages);
221            messages.sort_by_key(|m| m.created);
222            messages
223        }
224        MergeStrategy::Replace => source_messages,
225    };
226
227    // Update conversation
228    if !merged_messages.is_empty() {
229        let merged_conversation =
230            crate::conversation::Conversation::new_unvalidated(merged_messages);
231        SessionManager::replace_conversation(target_session_id, &merged_conversation).await?;
232    }
233
234    // Update metadata based on strategy
235    let mut update_builder = SessionManager::update_session(target_session_id);
236
237    match options.keep_metadata {
238        MetadataStrategy::Source => {
239            update_builder = update_builder
240                .total_tokens(source_session.total_tokens)
241                .input_tokens(source_session.input_tokens)
242                .output_tokens(source_session.output_tokens);
243        }
244        MetadataStrategy::Merge => {
245            let merged_total = target_session
246                .total_tokens
247                .unwrap_or(0)
248                .saturating_add(source_session.total_tokens.unwrap_or(0));
249            let merged_input = target_session
250                .input_tokens
251                .unwrap_or(0)
252                .saturating_add(source_session.input_tokens.unwrap_or(0));
253            let merged_output = target_session
254                .output_tokens
255                .unwrap_or(0)
256                .saturating_add(source_session.output_tokens.unwrap_or(0));
257
258            update_builder = update_builder
259                .total_tokens(Some(merged_total))
260                .input_tokens(Some(merged_input))
261                .output_tokens(Some(merged_output));
262        }
263        MetadataStrategy::Target => {
264            // Keep target metadata, no changes needed
265        }
266    }
267
268    // Record merge in fork metadata
269    let mut fork_metadata = ForkMetadata::from_session(&target_session).unwrap_or_default();
270    fork_metadata
271        .merged_from
272        .push(source_session_id.to_string());
273
274    let mut extension_data = target_session.extension_data.clone();
275    fork_metadata.to_extension_data(&mut extension_data)?;
276
277    update_builder
278        .extension_data(extension_data)
279        .apply()
280        .await?;
281
282    SessionManager::get_session(target_session_id, true).await
283}
284
285/// Get the branch tree for a session
286pub async fn get_session_branch_tree(session_id: &str) -> Result<SessionBranchTree> {
287    let session = SessionManager::get_session(session_id, false).await?;
288    let fork_metadata = ForkMetadata::from_session(&session).unwrap_or_default();
289
290    let parent = if let Some(parent_id) = &fork_metadata.parent_id {
291        SessionManager::get_session(parent_id, false).await.ok()
292    } else {
293        None
294    };
295
296    let mut branches = Vec::new();
297    for branch_id in &fork_metadata.branches {
298        if let Ok(branch) = SessionManager::get_session(branch_id, false).await {
299            branches.push(branch);
300        }
301    }
302
303    Ok(SessionBranchTree {
304        session,
305        parent,
306        branches,
307    })
308}
309
310/// Session branch tree structure
311#[derive(Debug)]
312pub struct SessionBranchTree {
313    pub session: Session,
314    pub parent: Option<Session>,
315    pub branches: Vec<Session>,
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_fork_options_builder() {
324        let options = ForkOptions::new()
325            .from_message_index(5)
326            .name("Test Fork")
327            .include_future_messages(false);
328
329        assert_eq!(options.from_message_index, Some(5));
330        assert_eq!(options.name, Some("Test Fork".to_string()));
331        assert!(!options.include_future_messages);
332    }
333
334    #[test]
335    fn test_fork_metadata_serialization() {
336        let metadata = ForkMetadata {
337            parent_id: Some("parent_123".to_string()),
338            fork_point: Some(10),
339            branches: vec!["branch_1".to_string(), "branch_2".to_string()],
340            fork_name: Some("My Fork".to_string()),
341            merged_from: vec!["merged_1".to_string()],
342        };
343
344        let json = serde_json::to_string(&metadata).unwrap();
345        let deserialized: ForkMetadata = serde_json::from_str(&json).unwrap();
346
347        assert_eq!(deserialized.parent_id, metadata.parent_id);
348        assert_eq!(deserialized.fork_point, metadata.fork_point);
349        assert_eq!(deserialized.branches.len(), 2);
350    }
351}