Skip to main content

limit_cli/tui/commands/
branch.rs

1//! Branching commands for tree-based sessions
2//!
3//! Provides functionality to branch from specific entries and list all branches.
4
5use super::CommandContext;
6use crate::error::CliError;
7
8/// Branch from a specific entry in the session
9///
10/// This moves the leaf pointer to the specified entry, creating a new branch.
11/// Subsequent appends will be children of this entry.
12pub fn branch_from(ctx: &mut CommandContext, entry_id: &str) -> Result<String, CliError> {
13    let session_manager = ctx
14        .session_manager
15        .lock()
16        .map_err(|e| CliError::ConfigError(format!("Failed to lock session manager: {}", e)))?;
17
18    let mut tree = session_manager.load_tree_session(&ctx.session_id)?;
19    let new_leaf = tree.branch_from(entry_id)?;
20
21    // Save updated tree
22    session_manager.save_tree_session(&ctx.session_id, &tree)?;
23
24    Ok(new_leaf)
25}
26
27/// List all branches (paths from root to leaves)
28///
29/// Returns information about all leaf nodes in the session tree,
30/// including their depth (message count) and whether they're the current branch.
31pub fn list_branches(ctx: &CommandContext) -> Result<Vec<BranchInfo>, CliError> {
32    let session_manager = ctx
33        .session_manager
34        .lock()
35        .map_err(|e| CliError::ConfigError(format!("Failed to lock session manager: {}", e)))?;
36
37    let tree = session_manager.load_tree_session(&ctx.session_id)?;
38
39    // Find all leaf nodes
40    let entries = tree.entries();
41    let parent_ids: std::collections::HashSet<_> = entries
42        .iter()
43        .filter_map(|e| e.parent_id.as_ref())
44        .cloned()
45        .collect();
46
47    let leaves: Vec<_> = entries
48        .iter()
49        .filter(|e| !parent_ids.contains(&e.id))
50        .collect();
51
52    let branches: Vec<BranchInfo> = leaves
53        .iter()
54        .map(|leaf| {
55            let depth = count_depth(&tree, &leaf.id);
56            BranchInfo {
57                leaf_id: leaf.id.clone(),
58                depth,
59                is_current: leaf.id == tree.leaf_id(),
60            }
61        })
62        .collect();
63
64    Ok(branches)
65}
66
67/// Count the depth of a leaf node (number of messages from root to leaf)
68fn count_depth(tree: &crate::session_tree::SessionTree, leaf_id: &str) -> usize {
69    let context = tree.build_context(leaf_id).unwrap_or_default();
70    context.len()
71}
72
73/// Information about a branch in the session tree
74#[derive(Debug, Clone)]
75pub struct BranchInfo {
76    /// ID of the leaf entry
77    pub leaf_id: String,
78    /// Number of messages from root to this leaf
79    pub depth: usize,
80    /// Whether this is the current active branch
81    pub is_current: bool,
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::session::SessionManager;
88    use crate::session_tree::{SerializableMessage, SessionEntry, SessionEntryType};
89    use limit_llm::{Message, Role};
90    use tempfile::tempdir;
91
92    /// Helper function to create a test entry
93    fn create_test_entry(id: &str, parent_id: Option<&str>, content: &str) -> SessionEntry {
94        SessionEntry {
95            id: id.to_string(),
96            parent_id: parent_id.map(|s| s.to_string()),
97            timestamp: "2024-01-01T00:00:00Z".to_string(),
98            entry_type: SessionEntryType::Message {
99                message: SerializableMessage::from(Message {
100                    role: Role::User,
101                    content: Some(limit_llm::MessageContent::text(content)),
102                    tool_calls: None,
103                    tool_call_id: None,
104                    cache_control: None,
105                }),
106            },
107        }
108    }
109
110    /// Helper function to create a test tree with branches
111    /// Tree structure:
112    ///   root -> A -> B
113    ///          \
114    ///           -> C
115    fn create_test_tree_with_branches(session_manager: &SessionManager, session_id: &str) {
116        let root = create_test_entry("root", None, "root content");
117        let a = create_test_entry("a", Some("root"), "a content");
118        let b = create_test_entry("b", Some("a"), "b content");
119
120        session_manager
121            .append_tree_entry(session_id, &root)
122            .unwrap();
123        session_manager.append_tree_entry(session_id, &a).unwrap();
124        session_manager.append_tree_entry(session_id, &b).unwrap();
125    }
126
127    #[test]
128    fn test_count_depth() {
129        let dir = tempdir().unwrap();
130        let db_path = dir.path().join("session.db");
131        let sessions_dir = dir.path().join("sessions");
132
133        let session_manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
134        let session_id = session_manager.create_new_session().unwrap();
135        session_manager
136            .create_tree_session(&session_id, "/test".to_string())
137            .unwrap();
138
139        create_test_tree_with_branches(&session_manager, &session_id);
140
141        let tree = session_manager.load_tree_session(&session_id).unwrap();
142
143        // root depth = 1
144        assert_eq!(count_depth(&tree, "root"), 1);
145        // a depth = 2 (root -> a)
146        assert_eq!(count_depth(&tree, "a"), 2);
147        // b depth = 3 (root -> a -> b)
148        assert_eq!(count_depth(&tree, "b"), 3);
149    }
150
151    #[test]
152    fn test_branch_from() {
153        let dir = tempdir().unwrap();
154        let db_path = dir.path().join("session.db");
155        let sessions_dir = dir.path().join("sessions");
156
157        let session_manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
158        let session_id = session_manager.create_new_session().unwrap();
159        session_manager
160            .create_tree_session(&session_id, "/test".to_string())
161            .unwrap();
162
163        create_test_tree_with_branches(&session_manager, &session_id);
164
165        let mut tree = session_manager.load_tree_session(&session_id).unwrap();
166        let branch_id = tree.branch_from("a").unwrap();
167        assert_eq!(branch_id, "a");
168
169        let new_entry = create_test_entry("new", Some("a"), "new content");
170        tree.append(new_entry).unwrap();
171
172        let context = tree.build_context("new").unwrap();
173        assert_eq!(context.len(), 3);
174    }
175
176    #[test]
177    fn test_list_branches() {
178        let dir = tempdir().unwrap();
179        let db_path = dir.path().join("session.db");
180        let sessions_dir = dir.path().join("sessions");
181
182        let session_manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
183        let session_id = session_manager.create_new_session().unwrap();
184        session_manager
185            .create_tree_session(&session_id, "/test".to_string())
186            .unwrap();
187
188        create_test_tree_with_branches(&session_manager, &session_id);
189
190        // Add another branch from root
191        let d = create_test_entry("d", Some("root"), "d content");
192        session_manager.append_tree_entry(&session_id, &d).unwrap();
193
194        let tree = session_manager.load_tree_session(&session_id).unwrap();
195
196        let entries = tree.entries();
197        let parent_ids: std::collections::HashSet<_> = entries
198            .iter()
199            .filter_map(|e| e.parent_id.as_ref())
200            .cloned()
201            .collect();
202
203        let leaves: Vec<_> = entries
204            .iter()
205            .filter(|e| !parent_ids.contains(&e.id))
206            .collect();
207
208        assert_eq!(leaves.len(), 2);
209
210        let leaf_ids: Vec<_> = leaves.iter().map(|l| l.id.as_str()).collect();
211        assert!(leaf_ids.contains(&"b"));
212        assert!(leaf_ids.contains(&"d"));
213    }
214
215    #[test]
216    fn test_branch_info() {
217        let info = BranchInfo {
218            leaf_id: "test-branch".to_string(),
219            depth: 5,
220            is_current: true,
221        };
222
223        assert_eq!(info.leaf_id, "test-branch");
224        assert_eq!(info.depth, 5);
225        assert!(info.is_current);
226    }
227}