1use std::collections::HashMap;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub enum TrajectoryPhase {
42 Exploration,
44 Consolidation,
46 Synthesis,
48 Debugging,
50 Planning,
52}
53
54#[deprecated(since = "0.2.0", note = "Use TrajectoryPhase instead")]
56pub type ConversationPhase = TrajectoryPhase;
57
58impl TrajectoryPhase {
59 #[inline]
61 pub fn as_str(&self) -> &'static str {
62 match self {
63 Self::Exploration => "exploration",
64 Self::Consolidation => "consolidation",
65 Self::Synthesis => "synthesis",
66 Self::Debugging => "debugging",
67 Self::Planning => "planning",
68 }
69 }
70
71 pub fn from_str(s: &str) -> Option<Self> {
73 match s.to_lowercase().as_str() {
74 "exploration" => Some(Self::Exploration),
75 "consolidation" => Some(Self::Consolidation),
76 "synthesis" => Some(Self::Synthesis),
77 "debugging" => Some(Self::Debugging),
78 "planning" => Some(Self::Planning),
79 _ => None,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct TurnFeatures {
87 pub turn_id: u64,
89 pub role: String,
91 pub content_length: usize,
93 pub question_count: usize,
95 pub code_block_count: usize,
97 pub list_item_count: usize,
99 pub has_error_keywords: bool,
101 pub has_decision_keywords: bool,
103 pub has_planning_keywords: bool,
105 pub has_file_references: bool,
107 pub word_count: usize,
109 pub avg_sentence_length: f32,
111}
112
113impl TurnFeatures {
114 pub fn from_content(turn_id: u64, role: &str, content: &str) -> Self {
116 let content_lower = content.to_lowercase();
117
118 let question_count = content.chars().filter(|&c| c == '?').count();
120
121 let code_block_count = content.matches("```").count() / 2;
123
124 let list_item_count = content.lines()
126 .filter(|line| {
127 let trimmed = line.trim();
128 trimmed.starts_with("- ") ||
129 trimmed.starts_with("* ") ||
130 trimmed.chars().next().map_or(false, |c| c.is_ascii_digit()) &&
131 trimmed.chars().nth(1).map_or(false, |c| c == '.' || c == ')')
132 })
133 .count();
134
135 let error_keywords = [
137 "error", "exception", "traceback", "failed", "doesn't work",
138 "bug", "fix", "broken", "crash", "panic", "undefined",
139 ];
140 let has_error_keywords = error_keywords.iter().any(|k| content_lower.contains(k));
141
142 let decision_keywords = [
144 "decision:", "let's go with", "we'll use", "decided to",
145 "the approach", "final answer", "conclusion", "summary",
146 ];
147 let has_decision_keywords = decision_keywords.iter().any(|k| content_lower.contains(k));
148
149 let planning_keywords = [
151 "phase", "roadmap", "step", "plan", "milestone", "todo",
152 "first,", "second,", "third,", "then", "next step",
153 ];
154 let has_planning_keywords = planning_keywords.iter().any(|k| content_lower.contains(k));
155
156 let file_extensions = [".rs", ".py", ".ts", ".js", ".go", ".java", ".cpp", ".h"];
158 let has_file_references = file_extensions.iter().any(|ext| content.contains(ext));
159
160 let word_count = content.split_whitespace().count();
162
163 let sentences: Vec<&str> = content.split(&['.', '!', '?'][..])
165 .filter(|s| !s.trim().is_empty())
166 .collect();
167 let avg_sentence_length = if sentences.is_empty() {
168 0.0
169 } else {
170 sentences.iter().map(|s| s.split_whitespace().count()).sum::<usize>() as f32
171 / sentences.len() as f32
172 };
173
174 Self {
175 turn_id,
176 role: role.to_string(),
177 content_length: content.len(),
178 question_count,
179 code_block_count,
180 list_item_count,
181 has_error_keywords,
182 has_decision_keywords,
183 has_planning_keywords,
184 has_file_references,
185 word_count,
186 avg_sentence_length,
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
196pub struct PhaseTransition {
197 pub turn_id: u64,
199 pub from_phase: Option<TrajectoryPhase>,
201 pub to_phase: TrajectoryPhase,
203 pub confidence: f32,
205}
206
207#[derive(Debug, Clone)]
209pub struct PhaseConfig {
210 pub min_confidence: f32,
212 pub short_turn_threshold: usize,
214 pub long_turn_threshold: usize,
216 pub high_question_threshold: usize,
218 pub version: String,
220}
221
222impl Default for PhaseConfig {
223 fn default() -> Self {
224 Self {
225 min_confidence: 0.3,
226 short_turn_threshold: 200,
227 long_turn_threshold: 1000,
228 high_question_threshold: 2,
229 version: "v1.0".to_string(),
230 }
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct PhaseInferencer {
237 config: PhaseConfig,
238}
239
240impl PhaseInferencer {
241 pub fn new() -> Self {
243 Self {
244 config: PhaseConfig::default(),
245 }
246 }
247
248 pub fn with_config(config: PhaseConfig) -> Self {
250 Self { config }
251 }
252
253 pub fn infer_single(&self, features: &TurnFeatures) -> Option<(TrajectoryPhase, f32)> {
257 let mut scores: HashMap<TrajectoryPhase, f32> = HashMap::new();
258
259 if features.has_error_keywords {
261 *scores.entry(TrajectoryPhase::Debugging).or_default() += 0.6;
262 }
263
264 if features.has_planning_keywords {
266 *scores.entry(TrajectoryPhase::Planning).or_default() += 0.4;
267 }
268 if features.list_item_count > 3 {
269 *scores.entry(TrajectoryPhase::Planning).or_default() += 0.3;
270 }
271
272 if features.has_decision_keywords {
274 *scores.entry(TrajectoryPhase::Synthesis).or_default() += 0.5;
275 }
276
277 if features.content_length > self.config.long_turn_threshold {
279 *scores.entry(TrajectoryPhase::Consolidation).or_default() += 0.3;
280 }
281 if features.code_block_count > 0 {
282 *scores.entry(TrajectoryPhase::Consolidation).or_default() += 0.3;
283 }
284 if features.has_file_references {
285 *scores.entry(TrajectoryPhase::Consolidation).or_default() += 0.2;
286 }
287
288 if features.role == "user" {
290 if features.question_count >= self.config.high_question_threshold {
291 *scores.entry(TrajectoryPhase::Exploration).or_default() += 0.5;
292 }
293 if features.content_length < self.config.short_turn_threshold {
294 *scores.entry(TrajectoryPhase::Exploration).or_default() += 0.2;
295 }
296 }
297
298 scores.into_iter()
300 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
301 .filter(|(_, score)| *score >= self.config.min_confidence)
302 }
303
304 pub fn infer_sequence(&self, turns: &[TurnFeatures]) -> Vec<(u64, TrajectoryPhase, f32)> {
308 let mut results = Vec::with_capacity(turns.len());
309
310 for (i, features) in turns.iter().enumerate() {
311 let window_start = i.saturating_sub(2);
313 let window_end = (i + 3).min(turns.len());
314 let window = &turns[window_start..window_end];
315
316 let window_questions: usize = window.iter().map(|t| t.question_count).sum();
318 let window_errors = window.iter().filter(|t| t.has_error_keywords).count();
319 let window_code: usize = window.iter().map(|t| t.code_block_count).sum();
320
321 let (mut phase, mut confidence) = self.infer_single(features)
323 .unwrap_or((TrajectoryPhase::Exploration, 0.3));
324
325 if window_errors > 1 {
327 if phase != TrajectoryPhase::Debugging {
328 phase = TrajectoryPhase::Debugging;
329 confidence = 0.5;
330 } else {
331 confidence += 0.1;
332 }
333 }
334
335 if window_questions > 3 && phase != TrajectoryPhase::Debugging {
337 phase = TrajectoryPhase::Exploration;
338 confidence = 0.5;
339 }
340
341 if window_code > 2 && phase == TrajectoryPhase::Exploration {
342 phase = TrajectoryPhase::Consolidation;
343 confidence = 0.4;
344 }
345
346 results.push((features.turn_id, phase, confidence.min(1.0)));
347 }
348
349 results
350 }
351
352 pub fn detect_transitions(&self, turns: &[TurnFeatures]) -> Vec<PhaseTransition> {
354 let phases = self.infer_sequence(turns);
355 let mut transitions = Vec::new();
356
357 let mut prev_phase: Option<TrajectoryPhase> = None;
358 for (turn_id, phase, confidence) in phases {
359 if prev_phase != Some(phase) {
360 transitions.push(PhaseTransition {
361 turn_id,
362 from_phase: prev_phase,
363 to_phase: phase,
364 confidence,
365 });
366 }
367 prev_phase = Some(phase);
368 }
369
370 transitions
371 }
372
373 #[inline]
375 pub fn version(&self) -> &str {
376 &self.config.version
377 }
378}
379
380impl Default for PhaseInferencer {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_phase_from_str() {
392 assert_eq!(TrajectoryPhase::from_str("exploration"), Some(TrajectoryPhase::Exploration));
393 assert_eq!(TrajectoryPhase::from_str("DEBUGGING"), Some(TrajectoryPhase::Debugging));
394 assert_eq!(TrajectoryPhase::from_str("unknown"), None);
395 }
396
397 #[test]
398 fn test_feature_extraction() {
399 let content = "What is the error? Can you fix it?\n\n```python\nprint('hello')\n```";
400 let features = TurnFeatures::from_content(1, "user", content);
401
402 assert_eq!(features.question_count, 2);
403 assert_eq!(features.code_block_count, 1);
404 assert!(features.has_error_keywords);
405 }
406
407 #[test]
408 fn test_exploration_inference() {
409 let inferencer = PhaseInferencer::new();
410 let features = TurnFeatures::from_content(1, "user", "What is this? How does it work?");
411
412 let (phase, _) = inferencer.infer_single(&features).unwrap();
413 assert_eq!(phase, TrajectoryPhase::Exploration);
414 }
415
416 #[test]
417 fn test_debugging_inference() {
418 let inferencer = PhaseInferencer::new();
419 let features = TurnFeatures::from_content(1, "user", "I'm getting this error: Traceback... Exception raised");
420
421 let (phase, _) = inferencer.infer_single(&features).unwrap();
422 assert_eq!(phase, TrajectoryPhase::Debugging);
423 }
424
425 #[test]
426 fn test_planning_inference() {
427 let inferencer = PhaseInferencer::new();
428 let content = r#"
429Here's the roadmap:
4301. First, implement the parser
4312. Second, add validation
4323. Third, write tests
4334. Finally, deploy
434"#;
435 let features = TurnFeatures::from_content(1, "assistant", content);
436
437 let (phase, _) = inferencer.infer_single(&features).unwrap();
438 assert_eq!(phase, TrajectoryPhase::Planning);
439 }
440
441 #[test]
442 fn test_sequence_inference() {
443 let inferencer = PhaseInferencer::new();
444 let turns = vec![
445 TurnFeatures::from_content(1, "user", "What does this code do?"),
446 TurnFeatures::from_content(2, "assistant", "This code implements a parser..."),
447 TurnFeatures::from_content(3, "user", "I get this error: Exception"),
448 TurnFeatures::from_content(4, "assistant", "The error is caused by..."),
449 ];
450
451 let results = inferencer.infer_sequence(&turns);
452 assert_eq!(results.len(), 4);
453
454 let transitions = inferencer.detect_transitions(&turns);
456 assert!(!transitions.is_empty());
457 }
458}