1use crate::session::{Session, SessionManager};
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Default)]
11pub struct ForkOptions {
12 pub from_message_index: Option<usize>,
14 pub name: Option<String>,
16 pub include_future_messages: bool,
18}
19
20impl ForkOptions {
21 pub fn new() -> Self {
22 Self {
23 from_message_index: None,
24 name: None,
25 include_future_messages: true,
26 }
27 }
28
29 pub fn from_message_index(mut self, index: usize) -> Self {
30 self.from_message_index = Some(index);
31 self
32 }
33
34 pub fn name(mut self, name: impl Into<String>) -> Self {
35 self.name = Some(name.into());
36 self
37 }
38
39 pub fn include_future_messages(mut self, include: bool) -> Self {
40 self.include_future_messages = include;
41 self
42 }
43}
44
45#[derive(Debug, Clone, Default)]
47pub struct MergeOptions {
48 pub strategy: MergeStrategy,
50 pub keep_metadata: MetadataStrategy,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum MergeStrategy {
58 #[default]
60 Append,
61 Interleave,
63 Replace,
65}
66
67#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69#[serde(rename_all = "snake_case")]
70pub enum MetadataStrategy {
71 #[default]
73 Target,
74 Source,
76 Merge,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, Default)]
82pub struct ForkMetadata {
83 pub parent_id: Option<String>,
85 pub fork_point: Option<usize>,
87 pub branches: Vec<String>,
89 pub fork_name: Option<String>,
91 pub merged_from: Vec<String>,
93}
94
95impl ForkMetadata {
96 pub const EXTENSION_NAME: &'static str = "fork";
97 pub const VERSION: &'static str = "v0";
98
99 pub fn from_session(session: &Session) -> Option<Self> {
101 session
102 .extension_data
103 .get_extension_state(Self::EXTENSION_NAME, Self::VERSION)
104 .and_then(|v| serde_json::from_value(v.clone()).ok())
105 }
106
107 pub fn to_extension_data(
109 &self,
110 extension_data: &mut crate::session::ExtensionData,
111 ) -> Result<()> {
112 let value = serde_json::to_value(self)?;
113 extension_data.set_extension_state(Self::EXTENSION_NAME, Self::VERSION, value);
114 Ok(())
115 }
116}
117
118pub async fn fork_session(source_session_id: &str, options: ForkOptions) -> Result<Session> {
120 let source_session = SessionManager::get_session(source_session_id, true).await?;
121
122 let from_index = options.from_message_index.unwrap_or(0);
123 let new_name = options
124 .name
125 .unwrap_or_else(|| format!("{} (fork)", source_session.name));
126
127 let new_session = SessionManager::create_session(
129 source_session.working_dir.clone(),
130 new_name.clone(),
131 source_session.session_type,
132 )
133 .await?;
134
135 if let Some(conversation) = &source_session.conversation {
137 let messages = conversation.messages();
138 let messages_to_copy = if options.include_future_messages {
139 messages
140 .iter()
141 .skip(from_index)
142 .cloned()
143 .collect::<Vec<_>>()
144 } else {
145 messages
146 .iter()
147 .take(from_index)
148 .cloned()
149 .collect::<Vec<_>>()
150 };
151
152 if !messages_to_copy.is_empty() {
153 let new_conversation =
154 crate::conversation::Conversation::new_unvalidated(messages_to_copy);
155 SessionManager::replace_conversation(&new_session.id, &new_conversation).await?;
156 }
157 }
158
159 let fork_metadata = ForkMetadata {
161 parent_id: Some(source_session_id.to_string()),
162 fork_point: Some(from_index),
163 fork_name: Some(new_name),
164 ..Default::default()
165 };
166
167 let mut new_extension_data = new_session.extension_data.clone();
168 fork_metadata.to_extension_data(&mut new_extension_data)?;
169
170 SessionManager::update_session(&new_session.id)
171 .extension_data(new_extension_data)
172 .apply()
173 .await?;
174
175 let mut source_fork_metadata = ForkMetadata::from_session(&source_session).unwrap_or_default();
177 source_fork_metadata.branches.push(new_session.id.clone());
178
179 let mut source_extension_data = source_session.extension_data.clone();
180 source_fork_metadata.to_extension_data(&mut source_extension_data)?;
181
182 SessionManager::update_session(source_session_id)
183 .extension_data(source_extension_data)
184 .apply()
185 .await?;
186
187 SessionManager::get_session(&new_session.id, true).await
188}
189
190pub async fn merge_sessions(
192 target_session_id: &str,
193 source_session_id: &str,
194 options: MergeOptions,
195) -> Result<Session> {
196 let target_session = SessionManager::get_session(target_session_id, true).await?;
197 let source_session = SessionManager::get_session(source_session_id, true).await?;
198
199 let target_messages = target_session
200 .conversation
201 .as_ref()
202 .map(|c| c.messages().to_vec())
203 .unwrap_or_default();
204
205 let source_messages = source_session
206 .conversation
207 .as_ref()
208 .map(|c| c.messages().to_vec())
209 .unwrap_or_default();
210
211 let merged_messages = match options.strategy {
213 MergeStrategy::Append => {
214 let mut messages = target_messages;
215 messages.extend(source_messages);
216 messages
217 }
218 MergeStrategy::Interleave => {
219 let mut messages = target_messages;
220 messages.extend(source_messages);
221 messages.sort_by_key(|m| m.created);
222 messages
223 }
224 MergeStrategy::Replace => source_messages,
225 };
226
227 if !merged_messages.is_empty() {
229 let merged_conversation =
230 crate::conversation::Conversation::new_unvalidated(merged_messages);
231 SessionManager::replace_conversation(target_session_id, &merged_conversation).await?;
232 }
233
234 let mut update_builder = SessionManager::update_session(target_session_id);
236
237 match options.keep_metadata {
238 MetadataStrategy::Source => {
239 update_builder = update_builder
240 .total_tokens(source_session.total_tokens)
241 .input_tokens(source_session.input_tokens)
242 .output_tokens(source_session.output_tokens);
243 }
244 MetadataStrategy::Merge => {
245 let merged_total = target_session
246 .total_tokens
247 .unwrap_or(0)
248 .saturating_add(source_session.total_tokens.unwrap_or(0));
249 let merged_input = target_session
250 .input_tokens
251 .unwrap_or(0)
252 .saturating_add(source_session.input_tokens.unwrap_or(0));
253 let merged_output = target_session
254 .output_tokens
255 .unwrap_or(0)
256 .saturating_add(source_session.output_tokens.unwrap_or(0));
257
258 update_builder = update_builder
259 .total_tokens(Some(merged_total))
260 .input_tokens(Some(merged_input))
261 .output_tokens(Some(merged_output));
262 }
263 MetadataStrategy::Target => {
264 }
266 }
267
268 let mut fork_metadata = ForkMetadata::from_session(&target_session).unwrap_or_default();
270 fork_metadata
271 .merged_from
272 .push(source_session_id.to_string());
273
274 let mut extension_data = target_session.extension_data.clone();
275 fork_metadata.to_extension_data(&mut extension_data)?;
276
277 update_builder
278 .extension_data(extension_data)
279 .apply()
280 .await?;
281
282 SessionManager::get_session(target_session_id, true).await
283}
284
285pub async fn get_session_branch_tree(session_id: &str) -> Result<SessionBranchTree> {
287 let session = SessionManager::get_session(session_id, false).await?;
288 let fork_metadata = ForkMetadata::from_session(&session).unwrap_or_default();
289
290 let parent = if let Some(parent_id) = &fork_metadata.parent_id {
291 SessionManager::get_session(parent_id, false).await.ok()
292 } else {
293 None
294 };
295
296 let mut branches = Vec::new();
297 for branch_id in &fork_metadata.branches {
298 if let Ok(branch) = SessionManager::get_session(branch_id, false).await {
299 branches.push(branch);
300 }
301 }
302
303 Ok(SessionBranchTree {
304 session,
305 parent,
306 branches,
307 })
308}
309
310#[derive(Debug)]
312pub struct SessionBranchTree {
313 pub session: Session,
314 pub parent: Option<Session>,
315 pub branches: Vec<Session>,
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_fork_options_builder() {
324 let options = ForkOptions::new()
325 .from_message_index(5)
326 .name("Test Fork")
327 .include_future_messages(false);
328
329 assert_eq!(options.from_message_index, Some(5));
330 assert_eq!(options.name, Some("Test Fork".to_string()));
331 assert!(!options.include_future_messages);
332 }
333
334 #[test]
335 fn test_fork_metadata_serialization() {
336 let metadata = ForkMetadata {
337 parent_id: Some("parent_123".to_string()),
338 fork_point: Some(10),
339 branches: vec!["branch_1".to_string(), "branch_2".to_string()],
340 fork_name: Some("My Fork".to_string()),
341 merged_from: vec!["merged_1".to_string()],
342 };
343
344 let json = serde_json::to_string(&metadata).unwrap();
345 let deserialized: ForkMetadata = serde_json::from_str(&json).unwrap();
346
347 assert_eq!(deserialized.parent_id, metadata.parent_id);
348 assert_eq!(deserialized.fork_point, metadata.fork_point);
349 assert_eq!(deserialized.branches.len(), 2);
350 }
351}