Skip to main content

toolpath_claude/
query.rs

1use crate::types::{ContentPart, Conversation, ConversationEntry, HistoryEntry, MessageRole};
2use chrono::{DateTime, Utc};
3
4pub struct ConversationQuery<'a> {
5    conversation: &'a Conversation,
6}
7
8impl<'a> ConversationQuery<'a> {
9    pub fn new(conversation: &'a Conversation) -> Self {
10        Self { conversation }
11    }
12
13    pub fn by_role(&self, role: MessageRole) -> Vec<&'a ConversationEntry> {
14        self.conversation
15            .entries
16            .iter()
17            .filter(|e| e.message.as_ref().map(|m| m.role == role).unwrap_or(false))
18            .collect()
19    }
20
21    pub fn by_type(&self, entry_type: &str) -> Vec<&'a ConversationEntry> {
22        self.conversation
23            .entries
24            .iter()
25            .filter(|e| e.entry_type == entry_type)
26            .collect()
27    }
28
29    pub fn by_time_range(
30        &self,
31        start: DateTime<Utc>,
32        end: DateTime<Utc>,
33    ) -> Vec<&'a ConversationEntry> {
34        self.conversation
35            .entries
36            .iter()
37            .filter(|e| {
38                if let Ok(timestamp) = e.timestamp.parse::<DateTime<Utc>>() {
39                    timestamp >= start && timestamp <= end
40                } else {
41                    false
42                }
43            })
44            .collect()
45    }
46
47    pub fn tool_uses_by_name(&self, tool_name: &str) -> Vec<&'a ConversationEntry> {
48        self.conversation
49            .entries
50            .iter()
51            .filter(|e| {
52                if let Some(message) = &e.message
53                    && let Some(crate::types::MessageContent::Parts(parts)) = &message.content
54                {
55                    return parts.iter().any(|p| {
56                        if let ContentPart::ToolUse { name, .. } = p {
57                            name == tool_name
58                        } else {
59                            false
60                        }
61                    });
62                }
63                false
64            })
65            .collect()
66    }
67
68    pub fn contains_text(&self, search: &str) -> Vec<&'a ConversationEntry> {
69        let search_lower = search.to_lowercase();
70        self.conversation
71            .entries
72            .iter()
73            .filter(|e| {
74                if let Some(message) = &e.message {
75                    match &message.content {
76                        Some(crate::types::MessageContent::Text(text)) => {
77                            text.to_lowercase().contains(&search_lower)
78                        }
79                        Some(crate::types::MessageContent::Parts(parts)) => {
80                            parts.iter().any(|p| match p {
81                                ContentPart::Text { text } => {
82                                    text.to_lowercase().contains(&search_lower)
83                                }
84                                ContentPart::ToolResult { content, .. } => {
85                                    content.text().to_lowercase().contains(&search_lower)
86                                }
87                                _ => false,
88                            })
89                        }
90                        None => false,
91                    }
92                } else {
93                    false
94                }
95            })
96            .collect()
97    }
98
99    pub fn errors(&self) -> Vec<&'a ConversationEntry> {
100        self.conversation
101            .entries
102            .iter()
103            .filter(|e| {
104                if let Some(message) = &e.message
105                    && let Some(crate::types::MessageContent::Parts(parts)) = &message.content
106                {
107                    return parts.iter().any(|p| {
108                        if let ContentPart::ToolResult { is_error, .. } = p {
109                            *is_error
110                        } else {
111                            false
112                        }
113                    });
114                }
115                false
116            })
117            .collect()
118    }
119}
120
121pub struct HistoryQuery<'a> {
122    history: &'a [HistoryEntry],
123}
124
125impl<'a> HistoryQuery<'a> {
126    pub fn new(history: &'a [HistoryEntry]) -> Self {
127        Self { history }
128    }
129
130    pub fn by_project(&self, project: &str) -> Vec<&'a HistoryEntry> {
131        self.history
132            .iter()
133            .filter(|e| e.project.as_deref() == Some(project))
134            .collect()
135    }
136
137    pub fn by_session(&self, session_id: &str) -> Vec<&'a HistoryEntry> {
138        self.history
139            .iter()
140            .filter(|e| e.session_id.as_deref() == Some(session_id))
141            .collect()
142    }
143
144    pub fn by_time_range(&self, start: i64, end: i64) -> Vec<&'a HistoryEntry> {
145        self.history
146            .iter()
147            .filter(|e| e.timestamp >= start && e.timestamp <= end)
148            .collect()
149    }
150
151    pub fn contains_text(&self, search: &str) -> Vec<&'a HistoryEntry> {
152        let search_lower = search.to_lowercase();
153        self.history
154            .iter()
155            .filter(|e| e.display.to_lowercase().contains(&search_lower))
156            .collect()
157    }
158
159    pub fn recent(&self, count: usize) -> Vec<&'a HistoryEntry> {
160        let mut sorted: Vec<&'a HistoryEntry> = self.history.iter().collect();
161        sorted.sort_by_key(|e| std::cmp::Reverse(e.timestamp));
162        sorted.into_iter().take(count).collect()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::types::{Conversation, ConversationEntry, Message, MessageContent};
170
171    fn create_test_conversation() -> Conversation {
172        let mut conv = Conversation::new("test".to_string());
173
174        let user_entry = ConversationEntry {
175            parent_uuid: None,
176            is_sidechain: false,
177            entry_type: "user".to_string(),
178            uuid: "1".to_string(),
179            timestamp: "2024-01-01T00:00:00Z".to_string(),
180            session_id: Some("test".to_string()),
181            message: Some(Message {
182                role: MessageRole::User,
183                content: Some(MessageContent::Text("Hello world".to_string())),
184                model: None,
185                id: None,
186                message_type: None,
187                stop_reason: None,
188                stop_sequence: None,
189                usage: None,
190            }),
191            cwd: None,
192            git_branch: None,
193            version: None,
194            user_type: None,
195            request_id: None,
196            tool_use_result: None,
197            snapshot: None,
198            message_id: None,
199            extra: Default::default(),
200        };
201
202        let assistant_entry = ConversationEntry {
203            parent_uuid: Some("1".to_string()),
204            is_sidechain: false,
205            entry_type: "assistant".to_string(),
206            uuid: "2".to_string(),
207            timestamp: "2024-01-01T00:00:01Z".to_string(),
208            session_id: Some("test".to_string()),
209            message: Some(Message {
210                role: MessageRole::Assistant,
211                content: Some(MessageContent::Text("Hi there".to_string())),
212                model: None,
213                id: None,
214                message_type: None,
215                stop_reason: None,
216                stop_sequence: None,
217                usage: None,
218            }),
219            cwd: None,
220            git_branch: None,
221            version: None,
222            user_type: None,
223            request_id: None,
224            tool_use_result: None,
225            snapshot: None,
226            message_id: None,
227            extra: Default::default(),
228        };
229
230        conv.add_entry(user_entry);
231        conv.add_entry(assistant_entry);
232        conv
233    }
234
235    #[test]
236    fn test_query_by_role() {
237        let conv = create_test_conversation();
238        let query = ConversationQuery::new(&conv);
239
240        let user_msgs = query.by_role(MessageRole::User);
241        assert_eq!(user_msgs.len(), 1);
242
243        let assistant_msgs = query.by_role(MessageRole::Assistant);
244        assert_eq!(assistant_msgs.len(), 1);
245    }
246
247    #[test]
248    fn test_query_contains_text() {
249        let conv = create_test_conversation();
250        let query = ConversationQuery::new(&conv);
251
252        let results = query.contains_text("Hello");
253        assert_eq!(results.len(), 1);
254        assert_eq!(results[0].uuid, "1");
255
256        let results = query.contains_text("Hi");
257        assert_eq!(results.len(), 1);
258        assert_eq!(results[0].uuid, "2");
259    }
260
261    #[test]
262    fn test_query_by_type() {
263        let conv = create_test_conversation();
264        let query = ConversationQuery::new(&conv);
265
266        let users = query.by_type("user");
267        assert_eq!(users.len(), 1);
268        assert_eq!(users[0].uuid, "1");
269
270        let assistants = query.by_type("assistant");
271        assert_eq!(assistants.len(), 1);
272        assert_eq!(assistants[0].uuid, "2");
273    }
274
275    #[test]
276    fn test_query_by_time_range() {
277        let conv = create_test_conversation();
278        let query = ConversationQuery::new(&conv);
279
280        let start = "2024-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
281        let end = "2024-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
282        let results = query.by_time_range(start, end);
283        assert_eq!(results.len(), 1);
284        assert_eq!(results[0].uuid, "1");
285    }
286
287    #[test]
288    fn test_query_by_time_range_all() {
289        let conv = create_test_conversation();
290        let query = ConversationQuery::new(&conv);
291
292        let start = "2023-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
293        let end = "2025-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
294        let results = query.by_time_range(start, end);
295        assert_eq!(results.len(), 2);
296    }
297
298    #[test]
299    fn test_query_tool_uses_by_name() {
300        // Create a conversation with tool use
301        let mut conv = Conversation::new("test".to_string());
302        let entry = ConversationEntry {
303            parent_uuid: None,
304            is_sidechain: false,
305            entry_type: "assistant".to_string(),
306            uuid: "3".to_string(),
307            timestamp: "2024-01-01T00:00:02Z".to_string(),
308            session_id: Some("test".to_string()),
309            message: Some(Message {
310                role: MessageRole::Assistant,
311                content: Some(MessageContent::Parts(vec![ContentPart::ToolUse {
312                    id: "t1".to_string(),
313                    name: "Read".to_string(),
314                    input: serde_json::Value::Object(Default::default()),
315                }])),
316                model: None,
317                id: None,
318                message_type: None,
319                stop_reason: None,
320                stop_sequence: None,
321                usage: None,
322            }),
323            cwd: None,
324            git_branch: None,
325            version: None,
326            user_type: None,
327            request_id: None,
328            tool_use_result: None,
329            snapshot: None,
330            message_id: None,
331            extra: Default::default(),
332        };
333        conv.add_entry(entry);
334
335        let query = ConversationQuery::new(&conv);
336        let reads = query.tool_uses_by_name("Read");
337        assert_eq!(reads.len(), 1);
338
339        let writes = query.tool_uses_by_name("Write");
340        assert!(writes.is_empty());
341    }
342
343    #[test]
344    fn test_query_errors() {
345        let mut conv = Conversation::new("test".to_string());
346        let entry = ConversationEntry {
347            parent_uuid: None,
348            is_sidechain: false,
349            entry_type: "assistant".to_string(),
350            uuid: "e1".to_string(),
351            timestamp: "2024-01-01T00:00:00Z".to_string(),
352            session_id: Some("test".to_string()),
353            message: Some(Message {
354                role: MessageRole::Assistant,
355                content: Some(MessageContent::Parts(vec![ContentPart::ToolResult {
356                    tool_use_id: "t1".to_string(),
357                    content: crate::types::ToolResultContent::Text("failed!".to_string()),
358                    is_error: true,
359                }])),
360                model: None,
361                id: None,
362                message_type: None,
363                stop_reason: None,
364                stop_sequence: None,
365                usage: None,
366            }),
367            cwd: None,
368            git_branch: None,
369            version: None,
370            user_type: None,
371            request_id: None,
372            tool_use_result: None,
373            snapshot: None,
374            message_id: None,
375            extra: Default::default(),
376        };
377        conv.add_entry(entry);
378
379        let query = ConversationQuery::new(&conv);
380        let errors = query.errors();
381        assert_eq!(errors.len(), 1);
382    }
383
384    #[test]
385    fn test_query_errors_empty() {
386        let conv = create_test_conversation();
387        let query = ConversationQuery::new(&conv);
388        assert!(query.errors().is_empty());
389    }
390
391    #[test]
392    fn test_query_contains_text_case_insensitive() {
393        let conv = create_test_conversation();
394        let query = ConversationQuery::new(&conv);
395
396        let results = query.contains_text("hello");
397        assert_eq!(results.len(), 1);
398    }
399
400    // ── HistoryQuery ───────────────────────────────────────────────────
401
402    fn create_test_history() -> Vec<HistoryEntry> {
403        vec![
404            HistoryEntry {
405                display: "fix bug in auth".to_string(),
406                pasted_contents: Default::default(),
407                timestamp: 1000,
408                project: Some("/project/a".to_string()),
409                session_id: Some("session-1".to_string()),
410            },
411            HistoryEntry {
412                display: "add feature X".to_string(),
413                pasted_contents: Default::default(),
414                timestamp: 2000,
415                project: Some("/project/b".to_string()),
416                session_id: Some("session-2".to_string()),
417            },
418            HistoryEntry {
419                display: "refactor auth module".to_string(),
420                pasted_contents: Default::default(),
421                timestamp: 3000,
422                project: Some("/project/a".to_string()),
423                session_id: Some("session-1".to_string()),
424            },
425        ]
426    }
427
428    #[test]
429    fn test_history_by_project() {
430        let history = create_test_history();
431        let query = HistoryQuery::new(&history);
432
433        let results = query.by_project("/project/a");
434        assert_eq!(results.len(), 2);
435    }
436
437    #[test]
438    fn test_history_by_session() {
439        let history = create_test_history();
440        let query = HistoryQuery::new(&history);
441
442        let results = query.by_session("session-2");
443        assert_eq!(results.len(), 1);
444        assert_eq!(results[0].display, "add feature X");
445    }
446
447    #[test]
448    fn test_history_by_time_range() {
449        let history = create_test_history();
450        let query = HistoryQuery::new(&history);
451
452        let results = query.by_time_range(1500, 2500);
453        assert_eq!(results.len(), 1);
454        assert_eq!(results[0].timestamp, 2000);
455    }
456
457    #[test]
458    fn test_history_contains_text() {
459        let history = create_test_history();
460        let query = HistoryQuery::new(&history);
461
462        let results = query.contains_text("auth");
463        assert_eq!(results.len(), 2);
464    }
465
466    #[test]
467    fn test_history_contains_text_case_insensitive() {
468        let history = create_test_history();
469        let query = HistoryQuery::new(&history);
470
471        let results = query.contains_text("AUTH");
472        assert_eq!(results.len(), 2);
473    }
474
475    #[test]
476    fn test_history_recent() {
477        let history = create_test_history();
478        let query = HistoryQuery::new(&history);
479
480        let results = query.recent(2);
481        assert_eq!(results.len(), 2);
482        // Most recent first
483        assert_eq!(results[0].timestamp, 3000);
484        assert_eq!(results[1].timestamp, 2000);
485    }
486
487    #[test]
488    fn test_history_recent_more_than_available() {
489        let history = create_test_history();
490        let query = HistoryQuery::new(&history);
491
492        let results = query.recent(10);
493        assert_eq!(results.len(), 3);
494    }
495}