Skip to main content

claude_code_sdk_rust/sessions/
store_fork.rs

1use crate::error::{ClaudeSDKError, Result};
2use crate::session_store::{
3    project_key_for_directory, SessionKey, SessionStoreEntry, SessionStoreHandle,
4};
5use std::collections::HashMap;
6use std::path::Path;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ForkSessionResult {
10    pub session_id: String,
11}
12
13pub async fn fork_session_via_store(
14    session_store: &SessionStoreHandle,
15    session_id: &str,
16    directory: Option<&str>,
17    up_to_message_id: Option<&str>,
18    title: Option<&str>,
19) -> Result<ForkSessionResult> {
20    validate_uuid("session_id", session_id)?;
21    if let Some(up_to_message_id) = up_to_message_id {
22        validate_uuid("up_to_message_id", up_to_message_id)?;
23    }
24
25    let project_key = project_key_for_directory(directory.map(Path::new));
26    let source_key = SessionKey {
27        project_key: project_key.clone(),
28        session_id: session_id.to_string(),
29        subpath: None,
30    };
31    let Some(entries) = session_store.load(source_key).await? else {
32        return Err(ClaudeSDKError::Session(format!(
33            "Session {session_id} not found"
34        )));
35    };
36    let fork_entries = fork_entries(entries, session_id, up_to_message_id, title)?;
37    let forked_session_id = fork_entries
38        .first()
39        .and_then(|entry| entry.get("sessionId").or_else(|| entry.get("session_id")))
40        .and_then(|value| value.as_str())
41        .ok_or_else(|| ClaudeSDKError::Session("fork produced no session id".to_string()))?
42        .to_string();
43
44    session_store
45        .append(
46            SessionKey {
47                project_key,
48                session_id: forked_session_id.clone(),
49                subpath: None,
50            },
51            fork_entries,
52        )
53        .await?;
54
55    Ok(ForkSessionResult {
56        session_id: forked_session_id,
57    })
58}
59
60fn fork_entries(
61    entries: Vec<SessionStoreEntry>,
62    source_session_id: &str,
63    up_to_message_id: Option<&str>,
64    title: Option<&str>,
65) -> Result<Vec<SessionStoreEntry>> {
66    let selected = select_entries(entries, up_to_message_id)?;
67    if selected.is_empty() {
68        return Err(ClaudeSDKError::Session(
69            "session has no messages to fork".to_string(),
70        ));
71    }
72
73    let forked_session_id = uuid::Uuid::new_v4().to_string();
74    let uuid_map = selected
75        .iter()
76        .filter_map(|entry| entry.get("uuid").and_then(|value| value.as_str()))
77        .map(|old| (old.to_string(), uuid::Uuid::new_v4().to_string()))
78        .collect::<HashMap<_, _>>();
79
80    let mut forked = selected
81        .into_iter()
82        .map(|entry| remap_entry(entry, source_session_id, &forked_session_id, &uuid_map))
83        .collect::<Vec<_>>();
84    if let Some(title) = title.map(str::trim).filter(|title| !title.is_empty()) {
85        forked.push(custom_title_entry(&forked_session_id, title));
86    }
87    Ok(forked)
88}
89
90fn select_entries(
91    entries: Vec<SessionStoreEntry>,
92    up_to_message_id: Option<&str>,
93) -> Result<Vec<SessionStoreEntry>> {
94    let Some(up_to_message_id) = up_to_message_id else {
95        return Ok(entries);
96    };
97    let mut selected = Vec::new();
98    let mut found = false;
99    for entry in entries {
100        let matches_target =
101            entry.get("uuid").and_then(|value| value.as_str()) == Some(up_to_message_id);
102        selected.push(entry);
103        if matches_target {
104            found = true;
105            break;
106        }
107    }
108    if found {
109        Ok(selected)
110    } else {
111        Err(ClaudeSDKError::Session(format!(
112            "Message {up_to_message_id} not found"
113        )))
114    }
115}
116
117fn remap_entry(
118    mut entry: SessionStoreEntry,
119    source_session_id: &str,
120    forked_session_id: &str,
121    uuid_map: &HashMap<String, String>,
122) -> SessionStoreEntry {
123    replace_uuid_field(&mut entry, "uuid", uuid_map);
124    replace_uuid_field(&mut entry, "parentUuid", uuid_map);
125    replace_uuid_field(&mut entry, "parent_uuid", uuid_map);
126    replace_uuid_field(&mut entry, "parent_tool_use_id", uuid_map);
127    entry.insert(
128        "sessionId".to_string(),
129        serde_json::json!(forked_session_id),
130    );
131    entry.insert(
132        "session_id".to_string(),
133        serde_json::json!(forked_session_id),
134    );
135    entry.insert(
136        "forkedFrom".to_string(),
137        serde_json::json!(source_session_id),
138    );
139    entry
140}
141
142fn replace_uuid_field(
143    entry: &mut SessionStoreEntry,
144    field: &str,
145    uuid_map: &HashMap<String, String>,
146) {
147    let Some(old) = entry.get(field).and_then(|value| value.as_str()) else {
148        return;
149    };
150    if let Some(new) = uuid_map.get(old) {
151        entry.insert(field.to_string(), serde_json::json!(new));
152    }
153}
154
155fn custom_title_entry(session_id: &str, title: &str) -> SessionStoreEntry {
156    let mut entry = serde_json::Map::new();
157    entry.insert("type".to_string(), serde_json::json!("custom-title"));
158    entry.insert("customTitle".to_string(), serde_json::json!(title));
159    entry.insert("sessionId".to_string(), serde_json::json!(session_id));
160    entry.insert("session_id".to_string(), serde_json::json!(session_id));
161    entry.insert(
162        "uuid".to_string(),
163        serde_json::json!(uuid::Uuid::new_v4().to_string()),
164    );
165    entry.insert(
166        "timestamp".to_string(),
167        serde_json::json!(chrono::Utc::now().to_rfc3339()),
168    );
169    entry
170}
171
172fn validate_uuid(name: &str, value: &str) -> Result<()> {
173    uuid::Uuid::parse_str(value)
174        .map(|_| ())
175        .map_err(|_| ClaudeSDKError::Session(format!("Invalid {name}: {value}")))
176}