Skip to main content

mermaid_cli/domain/
transition.rs

1//! Helpers that enforce invariants during turn-state transitions.
2//!
3//! The reducer calls these so the type system — not a comment or a
4//! convention — guarantees that you can't transition to the
5//! follow-up model call with missing tool outcomes, or commit a
6//! partial assistant message that's still streaming, or drop a
7//! thinking signature that the next request needs.
8//!
9//! Everything here is pure and sync.
10
11use std::time::SystemTime;
12
13use crate::models::tool_call::ToolCall as ModelToolCall;
14use crate::models::{ChatMessage, MessageRole};
15
16use super::action::{ActionDetails, ActionDisplay, ActionResult};
17use super::ids::{ToolCallId, TurnId};
18use super::runtime::ToolMetadata;
19use super::state::{GenPhase, PendingToolCall, ToolOutcome, TurnState};
20
21/// Flatten `Vec<Option<ToolOutcome>>` into `Vec<ToolOutcome>` iff
22/// every slot is populated. `None` means "still waiting on at least
23/// one tool" — the reducer stays in `ExecutingTools` and drops the
24/// event without state change.
25///
26/// This is the single gate between `ExecutingTools` and the follow-up
27/// `Generating`. It's impossible to bypass: there is no public
28/// constructor for `Vec<ToolOutcome>` elsewhere in the codebase, and
29/// the follow-up transition's builder function takes `Vec<ToolOutcome>`
30/// by value.
31pub fn try_complete_outcomes(outcomes: &[Option<ToolOutcome>]) -> Option<Vec<ToolOutcome>> {
32    let mut out = Vec::with_capacity(outcomes.len());
33    for slot in outcomes {
34        match slot {
35            Some(o) => out.push(o.clone()),
36            None => return None,
37        }
38    }
39    Some(out)
40}
41
42/// Write the outcome for a specific tool call ID into the slot
43/// carrying that call. Returns `true` if the slot was found and empty;
44/// `false` if the call isn't pending (stale event) or was already
45/// filled (duplicate event — first write wins).
46pub fn fill_outcome(
47    calls: &[PendingToolCall],
48    outcomes: &mut [Option<ToolOutcome>],
49    call_id: ToolCallId,
50    outcome: ToolOutcome,
51) -> bool {
52    debug_assert_eq!(
53        calls.len(),
54        outcomes.len(),
55        "calls and outcomes must be aligned"
56    );
57    let Some(idx) = calls.iter().position(|c| c.call_id == call_id) else {
58        return false;
59    };
60    if outcomes[idx].is_some() {
61        return false;
62    }
63    outcomes[idx] = Some(outcome);
64    true
65}
66
67/// Transition `Idle → Generating`. Always pure: the caller builds a
68/// `ChatRequest` separately and returns it to the reducer as a `Cmd`.
69pub fn start_generating(id: TurnId) -> TurnState {
70    TurnState::Generating {
71        id,
72        started: SystemTime::now(),
73        partial_text: String::new(),
74        partial_reasoning: String::new(),
75        tokens: 0,
76        phase: GenPhase::Sending,
77        thinking_signature: None,
78        pending_tool_calls: Vec::new(),
79    }
80}
81
82/// Transition `Generating → ExecutingTools`. Allocates `None` slots
83/// for every call so the invariant ("outcomes.len() == calls.len()")
84/// is upheld by construction.
85pub fn start_executing_tools(id: TurnId, calls: Vec<PendingToolCall>) -> TurnState {
86    let outcomes = vec![None; calls.len()];
87    TurnState::ExecutingTools {
88        id,
89        calls,
90        outcomes,
91    }
92}
93
94/// Build the committed assistant message from a `Generating` state's
95/// accumulated content. Safe to call with empty text (the model might
96/// have responded with only tool calls). Returns the message plus the
97/// thinking signature so the reducer can record it separately for
98/// Anthropic round-trip.
99pub fn commit_assistant_message(
100    partial_text: String,
101    partial_reasoning: String,
102    tool_calls: Vec<ModelToolCall>,
103    thinking_signature: Option<String>,
104) -> ChatMessage {
105    let thinking = if partial_reasoning.is_empty() {
106        None
107    } else {
108        Some(partial_reasoning)
109    };
110    let mut msg = ChatMessage {
111        role: MessageRole::Assistant,
112        content: partial_text,
113        timestamp: chrono::Local::now(),
114        kind: crate::models::ChatMessageKind::Normal,
115        metadata: None,
116        actions: Vec::new(),
117        thinking,
118        images: None,
119        tool_calls: if tool_calls.is_empty() {
120            None
121        } else {
122            Some(tool_calls)
123        },
124        tool_call_id: None,
125        tool_name: None,
126        thinking_signature: None,
127    };
128    if let Some(sig) = thinking_signature {
129        msg = msg.with_thinking_signature(sig);
130    }
131    msg
132}
133
134/// Build the follow-up `tool` role messages from completed outcomes.
135/// The OpenAI-compatible wire format requires (tool_call_id, tool_name,
136/// content) — we pull name from the original call.
137pub fn tool_result_messages(
138    calls: &[PendingToolCall],
139    outcomes: Vec<ToolOutcome>,
140) -> Vec<ChatMessage> {
141    debug_assert_eq!(calls.len(), outcomes.len());
142    calls
143        .iter()
144        .zip(outcomes)
145        .map(|(call, outcome)| {
146            let tool_call_id = call
147                .source
148                .id
149                .clone()
150                .unwrap_or_else(|| format!("call_{}", call.call_id.0));
151            ChatMessage::tool(
152                tool_call_id,
153                call.source.function.name.clone(),
154                outcome.as_tool_message_content(),
155            )
156        })
157        .collect()
158}
159
160/// Convert a completed tool outcome into an `ActionDisplay` entry
161/// attached to the assistant message that triggered the call. Used so
162/// the chat renderer can show "Read main.rs → 1,234 bytes" etc.
163pub fn action_display_for(call: &PendingToolCall, outcome: &ToolOutcome) -> ActionDisplay {
164    let (action_type, target) = display_info_for(call);
165    let duration = outcome.duration_secs;
166    let result = if outcome.is_success() {
167        ActionResult::Success {
168            output: outcome.output().to_string(),
169            images: outcome.images(),
170        }
171    } else {
172        ActionResult::Error {
173            error: outcome.error_message().unwrap_or("[cancelled]").to_string(),
174        }
175    };
176    let details = action_details_for(call, outcome, duration);
177    ActionDisplay {
178        action_type,
179        target,
180        result,
181        details,
182        duration_seconds: duration,
183        metadata: Some((*outcome.metadata).clone()),
184    }
185}
186
187fn action_details_for(
188    call: &PendingToolCall,
189    outcome: &ToolOutcome,
190    duration: Option<f64>,
191) -> ActionDetails {
192    if !outcome.is_success() {
193        return ActionDetails::Simple;
194    }
195
196    let name = call.source.function.name.as_str();
197    let args = &call.source.function.arguments;
198    match name {
199        "read_file" => {
200            let line_count = outcome
201                .metadata
202                .line_count
203                .or_else(|| metadata_line_count(&outcome.metadata.detail))
204                .unwrap_or_else(|| outcome.output().lines().count());
205            ActionDetails::Preview {
206                text: success_summary(
207                    format!("{} {} read", line_count, pluralize("line", line_count)),
208                    duration,
209                ),
210                line_count: Some(line_count),
211            }
212        },
213        "write_file" => {
214            let content = args
215                .get("content")
216                .and_then(|v| v.as_str())
217                .unwrap_or_default()
218                .to_string();
219            let line_count = outcome
220                .metadata
221                .line_count
222                .or_else(|| metadata_line_count(&outcome.metadata.detail))
223                .unwrap_or_else(|| content.lines().count());
224            ActionDetails::FileContent {
225                line_count,
226                content,
227            }
228        },
229        "web_search" => {
230            let result_count = outcome
231                .metadata
232                .result_count
233                .or_else(|| metadata_result_count(&outcome.metadata.detail))
234                .unwrap_or_else(|| count_search_results(outcome.output()));
235            ActionDetails::Preview {
236                text: success_summary(
237                    format!(
238                        "{} {} returned",
239                        result_count,
240                        pluralize("result", result_count)
241                    ),
242                    duration,
243                ),
244                line_count: None,
245            }
246        },
247        "web_fetch" => {
248            let line_count = outcome
249                .metadata
250                .line_count
251                .or_else(|| metadata_line_count(&outcome.metadata.detail))
252                .unwrap_or_else(|| outcome.output().lines().count());
253            ActionDetails::Preview {
254                text: success_summary(
255                    format!("{} {} fetched", line_count, pluralize("line", line_count)),
256                    duration,
257                ),
258                line_count: Some(line_count),
259            }
260        },
261        "execute_command" => ActionDetails::Preview {
262            text: command_success_summary(outcome, duration),
263            line_count: outcome
264                .metadata
265                .line_count
266                .or_else(|| metadata_line_count(&outcome.metadata.detail))
267                .or_else(|| Some(outcome.output().lines().count())),
268        },
269        "edit_file" => ActionDetails::Preview {
270            text: success_summary(outcome.summary.clone(), duration),
271            line_count: None,
272        },
273        _ => ActionDetails::Simple,
274    }
275}
276
277fn metadata_line_count(metadata: &ToolMetadata) -> Option<usize> {
278    match metadata {
279        ToolMetadata::ReadFile { line_count, .. }
280        | ToolMetadata::WriteFile { line_count, .. }
281        | ToolMetadata::WebFetch { line_count, .. } => Some(*line_count),
282        ToolMetadata::ExecuteCommand {
283            stdout_lines,
284            stderr_lines,
285            ..
286        } => Some(stdout_lines + stderr_lines),
287        _ => None,
288    }
289}
290
291fn metadata_result_count(metadata: &ToolMetadata) -> Option<usize> {
292    match metadata {
293        ToolMetadata::WebSearch { result_count, .. } => Some(*result_count),
294        _ => None,
295    }
296}
297
298fn success_summary(detail: String, duration: Option<f64>) -> String {
299    match duration {
300        Some(seconds) => format!("Success, {}, took {}", detail, format_duration(seconds)),
301        None => format!("Success, {}", detail),
302    }
303}
304
305fn command_success_summary(outcome: &ToolOutcome, duration: Option<f64>) -> String {
306    if outcome.metadata.process.is_none() {
307        return success_summary("command completed".to_string(), duration);
308    }
309
310    let mut lines = vec![success_summary(
311        "background process started".to_string(),
312        duration,
313    )];
314    for line in outcome.output().lines().skip(1) {
315        if line.starts_with("--- startup output ---") {
316            break;
317        }
318        if !line.trim().is_empty() {
319            lines.push(line.to_string());
320        }
321    }
322    lines.join("\n")
323}
324
325fn format_duration(seconds: f64) -> String {
326    if seconds < 1.0 {
327        format!("{}ms", (seconds * 1000.0).round().max(1.0) as u64)
328    } else if seconds < 10.0 {
329        format!("{:.1}s", seconds)
330    } else {
331        format!("{}s", seconds.round() as u64)
332    }
333}
334
335fn pluralize(word: &str, count: usize) -> String {
336    if count == 1 {
337        word.to_string()
338    } else {
339        format!("{}s", word)
340    }
341}
342
343fn count_search_results(output: &str) -> usize {
344    output
345        .lines()
346        .filter(|line| line.starts_with('[') && line.contains("] Title:"))
347        .count()
348}
349
350/// Best-effort name + target extraction from a tool call, for chat
351/// display ("Read src/main.rs", "Bash cargo test", etc). Matches on
352/// the wire-format tool name + arguments; unknown tools fall through
353/// to the raw function name.
354fn display_info_for(call: &PendingToolCall) -> (String, String) {
355    let name = call.source.function.name.as_str();
356    let args = &call.source.function.arguments;
357    let string_arg =
358        |k: &str| -> Option<String> { args.get(k).and_then(|v| v.as_str()).map(str::to_string) };
359    match name {
360        "read_file" => {
361            let target = string_arg("path")
362                .or_else(|| {
363                    args.get("paths")
364                        .and_then(|v| v.as_array())
365                        .map(|a| match a.len() {
366                            0 => "(no paths)".to_string(),
367                            1 => a[0].as_str().unwrap_or("").to_string(),
368                            n => format!("{} files", n),
369                        })
370                })
371                .unwrap_or_default();
372            ("Read".to_string(), target)
373        },
374        "write_file" => ("Write".to_string(), string_arg("path").unwrap_or_default()),
375        "edit_file" => ("Edit".to_string(), string_arg("path").unwrap_or_default()),
376        "delete_file" => ("Delete".to_string(), string_arg("path").unwrap_or_default()),
377        "create_directory" => (
378            "Bash".to_string(),
379            format!("mkdir -p {}", string_arg("path").unwrap_or_default()),
380        ),
381        "execute_command" => (
382            "Bash".to_string(),
383            string_arg("command").unwrap_or_default(),
384        ),
385        "web_search" => {
386            let target = string_arg("query")
387                .or_else(|| {
388                    args.get("queries")
389                        .and_then(|v| v.as_array())
390                        .map(|a| match a.len() {
391                            0 => "(no queries)".to_string(),
392                            1 => a[0]
393                                .get("query")
394                                .and_then(|q| q.as_str())
395                                .unwrap_or("")
396                                .to_string(),
397                            n => format!("{} queries", n),
398                        })
399                })
400                .unwrap_or_default();
401            ("Web Search".to_string(), target)
402        },
403        "web_fetch" => (
404            "Web Fetch".to_string(),
405            string_arg("url").unwrap_or_default(),
406        ),
407        "agent" => (
408            "Agent".to_string(),
409            string_arg("description").unwrap_or_default(),
410        ),
411        n if n.starts_with("mcp__") => {
412            let rest = &n[5..];
413            let target = rest.replacen("__", ":", 1);
414            ("MCP".to_string(), target)
415        },
416        _ => (name.to_string(), String::new()),
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::domain::{ManagedProcess, ManagedProcessStatus, ToolMetadata, ToolRunMetadata};
424    use crate::models::tool_call::{FunctionCall, ToolCall as ModelToolCall};
425
426    fn sample_call(id: u64, name: &str) -> PendingToolCall {
427        sample_call_args(id, name, serde_json::json!({}))
428    }
429
430    fn sample_call_args(id: u64, name: &str, arguments: serde_json::Value) -> PendingToolCall {
431        PendingToolCall {
432            call_id: ToolCallId(id),
433            source: ModelToolCall {
434                id: Some(format!("c{}", id)),
435                function: FunctionCall {
436                    name: name.to_string(),
437                    arguments,
438                },
439            },
440        }
441    }
442
443    #[test]
444    fn action_display_read_reports_line_count_and_duration() {
445        let call = sample_call_args(1, "read_file", serde_json::json!({"path": "src/main.rs"}));
446        let action = action_display_for(
447            &call,
448            &ToolOutcome::success("one\ntwo\nthree\n", "3 lines read", 1.25),
449        );
450
451        assert_eq!(action.action_type, "Read");
452        match action.details {
453            ActionDetails::Preview { text, line_count } => {
454                assert_eq!(line_count, Some(3));
455                assert!(text.contains("Success, 3 lines read"));
456                assert!(text.contains("took 1.2s"));
457            },
458            other => panic!("expected preview details, got {:?}", other),
459        }
460    }
461
462    #[test]
463    fn action_display_write_carries_file_content_preview_data() {
464        let call = sample_call_args(
465            1,
466            "write_file",
467            serde_json::json!({"path": "petal/index.html", "content": "a\nb\n"}),
468        );
469        let action = action_display_for(
470            &call,
471            &ToolOutcome::success("Wrote petal/index.html (2 lines)", "2 lines written", 0.05),
472        );
473
474        match action.details {
475            ActionDetails::FileContent {
476                line_count,
477                content,
478            } => {
479                assert_eq!(line_count, 2);
480                assert_eq!(content, "a\nb\n");
481            },
482            other => panic!("expected file content details, got {:?}", other),
483        }
484    }
485
486    #[test]
487    fn action_display_web_search_reports_result_count() {
488        let call = sample_call_args(1, "web_search", serde_json::json!({"query": "rust"}));
489        let output = "[SEARCH_RESULTS]\n[1] Title: A\nURL: https://a.test\nContent:\nA\n---\n[2] Title: B\nURL: https://b.test\nContent:\nB\n---\n";
490        let outcome = ToolOutcome::success(output, "2 results returned", 15.2).with_metadata(
491            ToolRunMetadata {
492                detail: ToolMetadata::WebSearch {
493                    queries: vec!["rust".to_string()],
494                    requested_count: 5,
495                    result_count: 2,
496                    sources: vec!["https://a.test".to_string(), "https://b.test".to_string()],
497                },
498                result_count: Some(2),
499                ..ToolRunMetadata::default()
500            },
501        );
502        let action = action_display_for(&call, &outcome);
503
504        match action.details {
505            ActionDetails::Preview { text, .. } => {
506                assert!(text.contains("Success, 2 results returned"));
507                assert!(text.contains("took 15s"));
508            },
509            other => panic!("expected preview details, got {:?}", other),
510        }
511        let metadata = action.metadata.expect("metadata");
512        assert_eq!(metadata.result_count, Some(2));
513        assert_eq!(metadata.duration_secs, Some(15.2));
514    }
515
516    #[test]
517    fn action_display_background_command_surfaces_pid_and_log() {
518        let call = sample_call_args(
519            1,
520            "execute_command",
521            serde_json::json!({"command": "npm run dev", "mode": "background"}),
522        );
523        let output = "Background command started.\nPID: 123\nLog: /tmp/mermaid-bg.log\nReady: matched pattern \"Local:\"\nDetected URL: http://127.0.0.1:5173\n\n--- startup output ---\nLocal: http://127.0.0.1:5173";
524        let outcome = ToolOutcome::success(output, "background process started", 0.8)
525            .with_metadata(ToolRunMetadata {
526                detail: ToolMetadata::ExecuteCommand {
527                    command: "npm run dev".to_string(),
528                    working_dir: None,
529                    exit_code: None,
530                    timed_out: false,
531                    background: true,
532                    stdout_lines: 1,
533                    stderr_lines: 0,
534                    detected_urls: vec!["http://127.0.0.1:5173".to_string()],
535                    pid: Some(123),
536                    log_path: Some("/tmp/mermaid-bg.log".to_string()),
537                },
538                process: Some(ManagedProcess {
539                    id: "bg-123".to_string(),
540                    pid: 123,
541                    command: "npm run dev".to_string(),
542                    cwd: None,
543                    log_path: "/tmp/mermaid-bg.log".to_string(),
544                    detected_url: Some("http://127.0.0.1:5173".to_string()),
545                    status: ManagedProcessStatus::Running,
546                }),
547                ..ToolRunMetadata::default()
548            });
549        let action = action_display_for(&call, &outcome);
550
551        match action.details {
552            ActionDetails::Preview { text, .. } => {
553                assert!(text.contains("Success, background process started"));
554                assert!(text.contains("PID: 123"));
555                assert!(text.contains("Log: /tmp/mermaid-bg.log"));
556                assert!(text.contains("Detected URL: http://127.0.0.1:5173"));
557                assert!(!text.contains("startup output"));
558            },
559            other => panic!("expected preview details, got {:?}", other),
560        }
561        let metadata = action.metadata.expect("metadata");
562        let process = metadata.process.expect("process metadata");
563        assert_eq!(process.id, "bg-123");
564        assert_eq!(process.pid, 123);
565        assert_eq!(process.command, "npm run dev");
566        assert_eq!(
567            process.detected_url.as_deref(),
568            Some("http://127.0.0.1:5173")
569        );
570    }
571
572    #[test]
573    fn try_complete_outcomes_returns_none_on_incomplete() {
574        let outcomes = vec![Some(ToolOutcome::success("a", "a", 0.1)), None];
575        assert!(try_complete_outcomes(&outcomes).is_none());
576    }
577
578    #[test]
579    fn try_complete_outcomes_returns_vec_on_complete() {
580        let outcomes = vec![
581            Some(ToolOutcome::success("a", "a", 0.1)),
582            Some(ToolOutcome::cancelled()),
583        ];
584        let result = try_complete_outcomes(&outcomes);
585        assert!(result.is_some());
586        assert_eq!(result.unwrap().len(), 2);
587    }
588
589    #[test]
590    fn fill_outcome_writes_to_correct_slot() {
591        let calls = vec![sample_call(1, "read_file"), sample_call(2, "write_file")];
592        let mut outcomes = vec![None, None];
593
594        let wrote = fill_outcome(
595            &calls,
596            &mut outcomes,
597            ToolCallId(2),
598            ToolOutcome::cancelled(),
599        );
600        assert!(wrote);
601        assert!(outcomes[0].is_none());
602        assert!(outcomes[1].is_some());
603    }
604
605    #[test]
606    fn fill_outcome_stale_call_id_returns_false() {
607        let calls = vec![sample_call(1, "read_file")];
608        let mut outcomes = vec![None];
609        let wrote = fill_outcome(
610            &calls,
611            &mut outcomes,
612            ToolCallId(999),
613            ToolOutcome::cancelled(),
614        );
615        assert!(!wrote);
616        assert!(outcomes[0].is_none());
617    }
618
619    #[test]
620    fn fill_outcome_duplicate_write_ignored() {
621        let calls = vec![sample_call(1, "read_file")];
622        let mut outcomes = vec![Some(ToolOutcome::success("first", "first", 0.0))];
623        let wrote = fill_outcome(
624            &calls,
625            &mut outcomes,
626            ToolCallId(1),
627            ToolOutcome::cancelled(),
628        );
629        assert!(!wrote);
630        match &outcomes[0] {
631            Some(outcome) if outcome.is_success() => assert_eq!(outcome.output(), "first"),
632            _ => panic!("original outcome was overwritten"),
633        }
634    }
635
636    #[test]
637    fn start_generating_produces_fresh_sending_phase() {
638        let s = start_generating(TurnId(1));
639        match s {
640            TurnState::Generating {
641                phase,
642                tokens,
643                partial_text,
644                ..
645            } => {
646                assert_eq!(phase, GenPhase::Sending);
647                assert_eq!(tokens, 0);
648                assert!(partial_text.is_empty());
649            },
650            _ => panic!("expected Generating"),
651        }
652    }
653
654    #[test]
655    fn start_executing_tools_allocates_outcome_slots() {
656        let calls = vec![
657            sample_call(1, "a"),
658            sample_call(2, "b"),
659            sample_call(3, "c"),
660        ];
661        let s = start_executing_tools(TurnId(1), calls);
662        match s {
663            TurnState::ExecutingTools {
664                outcomes, calls, ..
665            } => {
666                assert_eq!(outcomes.len(), 3);
667                assert_eq!(calls.len(), 3);
668                assert!(outcomes.iter().all(|o| o.is_none()));
669            },
670            _ => panic!("expected ExecutingTools"),
671        }
672    }
673
674    #[test]
675    fn commit_assistant_message_preserves_thinking_signature() {
676        let m = commit_assistant_message(
677            "hello".to_string(),
678            "reasoning".to_string(),
679            vec![],
680            Some("sig_abc".to_string()),
681        );
682        assert_eq!(m.content, "hello");
683        assert_eq!(m.thinking.as_deref(), Some("reasoning"));
684        assert_eq!(m.thinking_signature.as_deref(), Some("sig_abc"));
685    }
686
687    #[test]
688    fn commit_assistant_message_empty_reasoning_is_none() {
689        let m = commit_assistant_message("hi".to_string(), String::new(), vec![], None);
690        assert!(m.thinking.is_none());
691    }
692
693    #[test]
694    fn tool_result_messages_align_call_id_and_name() {
695        let calls = vec![sample_call(1, "read_file"), sample_call(2, "write_file")];
696        let outcomes = vec![
697            ToolOutcome::success("contents", "contents", 0.1),
698            ToolOutcome::cancelled(),
699        ];
700        let msgs = tool_result_messages(&calls, outcomes);
701        assert_eq!(msgs.len(), 2);
702        assert_eq!(msgs[0].role, MessageRole::Tool);
703        assert_eq!(msgs[0].tool_call_id.as_deref(), Some("c1"));
704        assert_eq!(msgs[0].tool_name.as_deref(), Some("read_file"));
705        assert_eq!(msgs[0].content, "contents");
706        assert!(msgs[1].content.contains("cancelled"));
707    }
708}