Skip to main content

albert_runtime/
conversation.rs

1use std::collections::BTreeMap;
2use std::fmt::{Display, Formatter};
3use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
4
5use serde::{Deserialize, Serialize};
6
7use crate::compact::{
8    compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
9};
10use crate::config::RuntimeFeatureConfig;
11use crate::hooks::{HookRunResult, HookRunner};
12use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
13use crate::session::{ContentBlock, ConversationMessage, Session};
14use crate::usage::{TokenUsage, UsageTracker};
15
16const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 200_000;
17const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct ApiRequest {
21    pub system_prompt: Vec<String>,
22    pub messages: Vec<ConversationMessage>,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum AssistantEvent {
27    TextDelta(String),
28    ToolUse {
29        id: String,
30        name: String,
31        input: String,
32    },
33    TaskStarted {
34        id: String,
35        label: String,
36    },
37    TaskCompleted {
38        id: String,
39        success: bool,
40    },
41    Usage(TokenUsage),
42    MessageStop,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
46pub struct ToolResult {
47    pub output: String,
48    pub state: i8, // Ternary Intelligence Stack: +1 Success, 0 Neutral/Halt, -1 Failure
49}
50
51pub trait ApiClient {
52    fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
53}
54
55pub trait ToolExecutor {
56    fn execute(&mut self, tool_name: &str, input: &str) -> Result<ToolResult, ToolError>;
57    fn query_memory(&mut self, query: &str) -> Result<String, ToolError>;
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct ToolError {
62    message: String,
63}
64
65impl ToolError {
66    #[must_use]
67    pub fn new(message: impl Into<String>) -> Self {
68        Self {
69            message: message.into(),
70        }
71    }
72}
73
74impl Display for ToolError {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        write!(f, "{}", self.message)
77    }
78}
79
80impl std::error::Error for ToolError {}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct RuntimeError {
84    message: String,
85}
86
87impl RuntimeError {
88    #[must_use]
89    pub fn new(message: impl Into<String>) -> Self {
90        Self {
91            message: message.into(),
92        }
93    }
94}
95
96impl Display for RuntimeError {
97    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{}", self.message)
99    }
100}
101
102impl std::error::Error for RuntimeError {}
103
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct TurnSummary {
106    pub assistant_messages: Vec<ConversationMessage>,
107    pub tool_results: Vec<ConversationMessage>,
108    pub iterations: usize,
109    pub usage: TokenUsage,
110    pub auto_compaction: Option<AutoCompactionEvent>,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub struct AutoCompactionEvent {
115    pub removed_message_count: usize,
116}
117
118pub struct ConversationRuntime<C, T> {
119    session: Session,
120    api_client: C,
121    tool_executor: T,
122    permission_policy: PermissionPolicy,
123    system_prompt: Vec<String>,
124    max_iterations: usize,
125    usage_tracker: UsageTracker,
126    hook_runner: HookRunner,
127    auto_compaction_input_tokens_threshold: u32,
128    cancel_token: Option<Arc<AtomicBool>>,
129}
130
131impl<C, T> ConversationRuntime<C, T>
132where
133    C: ApiClient,
134    T: ToolExecutor,
135{
136    #[must_use]
137    pub fn new(
138        session: Session,
139        api_client: C,
140        tool_executor: T,
141        permission_policy: PermissionPolicy,
142        system_prompt: Vec<String>,
143    ) -> Self {
144        Self::new_with_features(
145            session,
146            api_client,
147            tool_executor,
148            permission_policy,
149            system_prompt,
150            RuntimeFeatureConfig::default(),
151        )
152    }
153
154    #[must_use]
155    pub fn new_with_features(
156        session: Session,
157        api_client: C,
158        tool_executor: T,
159        permission_policy: PermissionPolicy,
160        system_prompt: Vec<String>,
161        feature_config: RuntimeFeatureConfig,
162    ) -> Self {
163        let usage_tracker = UsageTracker::from_session(&session);
164        Self {
165            session,
166            api_client,
167            tool_executor,
168            permission_policy,
169            system_prompt,
170            max_iterations: usize::MAX,
171            usage_tracker,
172            hook_runner: HookRunner::from_feature_config(&feature_config),
173            auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
174            cancel_token: None,
175        }
176    }
177
178    /// Wire up an external cancellation flag. When `true`, `run_turn` exits at
179    /// the next safe checkpoint (between API calls), so `handle.join()` returns
180    /// promptly and the main submit loop can continue.
181    pub fn set_cancel_token(&mut self, token: Arc<AtomicBool>) {
182        self.cancel_token = Some(token);
183    }
184
185    #[inline]
186    fn is_cancelled(&self) -> bool {
187        self.cancel_token.as_ref().map_or(false, |t| t.load(Ordering::Relaxed))
188    }
189
190    #[must_use]
191    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
192        self.max_iterations = max_iterations;
193        self
194    }
195
196    #[must_use]
197    pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
198        self.auto_compaction_input_tokens_threshold = threshold;
199        self
200    }
201
202    pub fn run_turn(
203    &mut self,
204    user_input: impl Into<String>,
205    mut prompter: Option<&mut dyn PermissionPrompter>,
206) -> Result<TurnSummary, RuntimeError> {
207    self.session
208        .messages
209        .push(ConversationMessage::user_text(user_input.into()));
210
211    let mut assistant_messages = Vec::new();
212    let mut tool_results = Vec::new();
213    let mut iterations = 0;
214
215    loop {
216        // Bail out at every safe checkpoint so handle.join() returns promptly on ESC.
217        if self.is_cancelled() {
218            return Err(RuntimeError::new("cancelled"));
219        }
220
221        iterations += 1;
222        if iterations > self.max_iterations {
223            return Err(RuntimeError::new(
224                "conversation loop exceeded the maximum number of iterations",
225            ));
226        }
227
228        let request = ApiRequest {
229            system_prompt: self.system_prompt.clone(),
230            messages: self.session.messages.clone(),
231        };
232        let events = self.api_client.stream(request)?;
233        // Check again immediately after the blocking HTTP call returns.
234        if self.is_cancelled() {
235            return Err(RuntimeError::new("cancelled"));
236        }
237        let (assistant_message, usage) = build_assistant_message(events)?;
238        if let Some(usage) = usage {
239            self.usage_tracker.record(usage);
240        }
241
242        let evaluation = get_consensus_evaluation(&assistant_message);
243
244        match evaluation {
245            1 => {
246                let pending_tool_uses = assistant_message
247                    .blocks
248                    .iter()
249                    .filter_map(|block| match block {
250                        ContentBlock::ToolUse { id, name, input } => {
251                            Some((id.clone(), name.clone(), input.clone()))
252                        }
253                        _ => None,
254                    })
255                    .collect::<Vec<_>>();
256
257                self.session.messages.push(assistant_message.clone());
258                assistant_messages.push(assistant_message);
259
260                if pending_tool_uses.is_empty() {
261                    break;
262                }
263
264                for (tool_use_id, tool_name, input) in pending_tool_uses {
265                    let permission_outcome = if let Some(prompt) = prompter.as_mut() {
266                        self.permission_policy
267                            .authorize(&tool_name, &input, Some(*prompt))
268                    } else {
269                        self.permission_policy.authorize(&tool_name, &input, None)
270                    };
271
272                    let result_message = match permission_outcome {
273                        PermissionOutcome::Allow | PermissionOutcome::AllowWithEdits { .. } => {
274                            let (effective_input, _is_human_edit) = match permission_outcome {
275                                PermissionOutcome::AllowWithEdits { new_input } => (new_input, true),
276                                _ => (input, false),
277                            };
278
279                            let pre_hook_result =
280                                self.hook_runner.run_pre_tool_use(&tool_name, &effective_input);
281                            if pre_hook_result.is_denied() {
282                                let deny_message =
283                                    format!("PreToolUse hook denied tool `{tool_name}`");
284                                ConversationMessage::tool_result(
285                                    tool_use_id,
286                                    tool_name,
287                                    format_hook_message(&pre_hook_result, &deny_message),
288                                    true,
289                                )
290                            } else {
291                                let (output, mut is_error, validation_state) =
292                                    match self.tool_executor.execute(&tool_name, &effective_input) {
293                                        Ok(res) => (res.output, res.state == -1, res.state),
294                                        Err(error) => {
295                                            let err_msg = error.to_string();
296                                            let reflection_prompt = format!(
297                                                "The tool '{}' failed with the following error: {}. Please analyze the error and provide a corrected tool call.",
298                                                tool_name, err_msg
299                                            );
300                                            self.session.messages.push(ConversationMessage::user_text(reflection_prompt));
301                                            return Ok(TurnSummary {
302                                                assistant_messages: assistant_messages.clone(),
303                                                tool_results: tool_results.clone(),
304                                                iterations,
305                                                usage: self.usage_tracker.cumulative_usage(),
306                                                auto_compaction: None,
307                                            });
308                                        },
309                                    };
310
311                                if validation_state == 0 {
312                                    // Neurosymbolic Gap Recovery: Try to resolve state 0 autonomously via local graph
313                                    let mut recovered = false;
314                                    let query_terms: Vec<&str> = effective_input.split(|c: char| !c.is_alphanumeric())
315                                        .filter(|s| s.len() > 3)
316                                        .collect();
317                                    
318                                    for term in query_terms {
319                                        // BET VM: @sparseskip - drop neutral paths and hit memory matrix
320                                        if let Ok(memory_context) = self.tool_executor.query_memory(term) {
321                                            if !memory_context.contains("[]") && memory_context.len() > 10 {
322                                                let recovery_prompt = format!(
323                                                    "AUTONOMOUS RECOVERY (State 0 -> +1):\n\
324                                                     Tool `{tool_name}` halted on ambiguous input. Found matching context in local knowledge graph for `{term}`:\n\
325                                                     {}\n\
326                                                     Please rewrite your tool call using this context to resolve the ambiguity.",
327                                                    memory_context
328                                                );
329                                                self.session.messages.push(ConversationMessage::user_text(recovery_prompt));
330                                                recovered = true;
331                                                break;
332                                            }
333                                        }
334                                    }
335
336                                    if recovered {
337                                        // Allow one more iteration to try and hit +1 state
338                                        continue; 
339                                    }
340
341                                    let halt_msg = format!("Tool `{tool_name}` requested manual authorization or clarification (State 0).");
342                                    let result_msg = ConversationMessage::tool_result(
343                                        tool_use_id,
344                                        tool_name,
345                                        halt_msg,
346                                        true,
347                                    );
348                                    self.session.messages.push(result_msg.clone());
349                                    tool_results.push(result_msg);
350                                    break; // Actually halt if recovery failed
351                                }
352
353                                let mut final_output = merge_hook_feedback(
354                                    pre_hook_result.messages(),
355                                    output,
356                                    false,
357                                );
358
359                                let post_hook_result = self.hook_runner.run_post_tool_use(
360                                    &tool_name,
361                                    &effective_input,
362                                    &final_output,
363                                    is_error,
364                                );
365                                if post_hook_result.is_denied() {
366                                    is_error = true;
367                                }
368                                final_output = merge_hook_feedback(
369                                    post_hook_result.messages(),
370                                    final_output,
371                                    post_hook_result.is_denied(),
372                                );
373
374                                ConversationMessage::tool_result(
375                                    tool_use_id,
376                                    tool_name,
377                                    final_output,
378                                    is_error,
379                                )
380                            }
381                        }
382                        PermissionOutcome::Deny { reason } => {
383                            ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
384                        }
385                    };
386                    self.session.messages.push(result_message.clone());
387                    tool_results.push(result_message);
388                }
389            }
390            0 => {
391                // Request disambiguation from the user
392                self.session.messages.push(ConversationMessage::user_text(
393                    "Could you please clarify your request?".to_string(),
394                ));
395                break;
396            }
397            -1 => {
398                // Generate an alternative plan
399                self.session.messages.push(ConversationMessage::user_text(
400                    "Let me try a different approach.".to_string(),
401                ));
402            }
403            _ => {
404                return Err(RuntimeError::new("invalid consensus evaluation"));
405            }
406        }
407    }
408
409    let auto_compaction = self.maybe_auto_compact();
410
411    Ok(TurnSummary {
412        assistant_messages,
413        tool_results,
414        iterations,
415        usage: self.usage_tracker.cumulative_usage(),
416        auto_compaction,
417    })
418}
419
420    #[must_use]
421    pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
422        compact_session(&self.session, config)
423    }
424
425    #[must_use]
426    pub fn estimated_tokens(&self) -> usize {
427        estimate_session_tokens(&self.session)
428    }
429
430    #[must_use]
431    pub fn usage(&self) -> &UsageTracker {
432        &self.usage_tracker
433    }
434
435    #[must_use]
436    pub fn session(&self) -> &Session {
437        &self.session
438    }
439
440    #[must_use]
441    pub fn into_session(self) -> Session {
442        self.session
443    }
444
445    fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
446        if self.usage_tracker.cumulative_usage().input_tokens
447            < self.auto_compaction_input_tokens_threshold
448        {
449            return None;
450        }
451
452        let result = compact_session(
453            &self.session,
454            CompactionConfig {
455                max_estimated_tokens: 0,
456                ..CompactionConfig::default()
457            },
458        );
459
460        if result.removed_message_count == 0 {
461            return None;
462        }
463
464        self.session = result.compacted_session;
465        Some(AutoCompactionEvent {
466            removed_message_count: result.removed_message_count,
467        })
468    }
469}
470
471#[must_use]
472pub fn auto_compaction_threshold_from_env() -> u32 {
473    parse_auto_compaction_threshold(
474        std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
475            .ok()
476            .as_deref(),
477    )
478}
479
480#[must_use]
481fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
482    value
483        .and_then(|raw| raw.trim().parse::<u32>().ok())
484        .filter(|threshold| *threshold > 0)
485        .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
486}
487
488fn get_consensus_evaluation(_reasoning: &ConversationMessage) -> i8 {
489    // For now, we'll just return +1 (Proceed) by default.
490    // This can be replaced with a more sophisticated evaluation logic later.
491    1
492}
493
494fn build_assistant_message(
495    events: Vec<AssistantEvent>,
496) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
497    let mut text = String::new();
498    let mut blocks = Vec::new();
499    let mut finished = false;
500    let mut usage = None;
501
502    for event in events {
503        match event {
504            AssistantEvent::TextDelta(delta) => text.push_str(&delta),
505            AssistantEvent::ToolUse { id, name, input } => {
506                flush_text_block(&mut text, &mut blocks);
507                blocks.push(ContentBlock::ToolUse { id, name, input });
508            }
509            AssistantEvent::TaskStarted { .. } | AssistantEvent::TaskCompleted { .. } => {
510                // Task events are handled by the TUI in real-time and don't affect 
511                // the static ConversationMessage structure.
512            }
513            AssistantEvent::Usage(u) => usage = Some(u),
514
515            AssistantEvent::MessageStop => {
516                finished = true;
517            }
518        }
519    }
520
521    flush_text_block(&mut text, &mut blocks);
522
523    if !finished {
524        return Err(RuntimeError::new(
525            "assistant stream ended without a message stop event",
526        ));
527    }
528
529    Ok((
530        ConversationMessage::assistant_with_usage(blocks, usage),
531        usage,
532    ))
533}
534
535fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
536    if !text.is_empty() {
537        blocks.push(ContentBlock::Text {
538            text: std::mem::take(text),
539        });
540    }
541}
542
543fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
544    if result.messages().is_empty() {
545        fallback.to_string()
546    } else {
547        result.messages().join("\n")
548    }
549}
550
551fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
552    if messages.is_empty() {
553        return output;
554    }
555
556    let mut sections = Vec::new();
557    if !output.trim().is_empty() {
558        sections.push(output);
559    }
560    let label = if denied {
561        "Hook feedback (denied)"
562    } else {
563        "Hook feedback"
564    };
565    sections.push(format!("{label}:\n{}", messages.join("\n")));
566    sections.join("\n\n")
567}
568
569type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
570
571#[derive(Default)]
572pub struct StaticToolExecutor {
573    handlers: BTreeMap<String, ToolHandler>,
574}
575
576impl StaticToolExecutor {
577    #[must_use]
578    pub fn new() -> Self {
579        Self::default()
580    }
581
582    #[must_use]
583    pub fn register(
584        mut self,
585        tool_name: impl Into<String>,
586        handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
587    ) -> Self {
588        self.handlers.insert(tool_name.into(), Box::new(handler));
589        self
590    }
591}
592
593impl ToolExecutor for StaticToolExecutor {
594    fn execute(&mut self, tool_name: &str, input: &str) -> Result<ToolResult, ToolError> {
595        self.handlers
596            .get_mut(tool_name)
597            .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
598            .map(|output| ToolResult { output, state: 1 })
599    }
600
601    fn query_memory(&mut self, _query: &str) -> Result<String, ToolError> {
602        Ok("[]".to_string())
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::{
609        parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent,
610        AutoCompactionEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
611        DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
612    };
613    use crate::compact::CompactionConfig;
614    use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
615    use crate::permissions::{
616        PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
617        PermissionRequest,
618    };
619    use crate::prompt::{ProjectContext, SystemPromptBuilder};
620    use crate::session::{ContentBlock, MessageRole, Session};
621    use crate::usage::TokenUsage;
622    use std::path::PathBuf;
623
624    struct ScriptedApiClient {
625        call_count: usize,
626    }
627
628    impl ApiClient for ScriptedApiClient {
629        fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
630            self.call_count += 1;
631            match self.call_count {
632                1 => {
633                    assert!(request
634                        .messages
635                        .iter()
636                        .any(|message| message.role == MessageRole::User));
637                    Ok(vec![
638                        AssistantEvent::TextDelta("Let me calculate that.".to_string()),
639                        AssistantEvent::ToolUse {
640                            id: "tool-1".to_string(),
641                            name: "add".to_string(),
642                            input: "2,2".to_string(),
643                        },
644                        AssistantEvent::Usage(TokenUsage {
645                            input_tokens: 20,
646                            output_tokens: 6,
647                            cache_creation_input_tokens: 1,
648                            cache_read_input_tokens: 2,
649                        }),
650                        AssistantEvent::MessageStop,
651                    ])
652                }
653                2 => {
654                    let last_message = request
655                        .messages
656                        .last()
657                        .expect("tool result should be present");
658                    assert_eq!(last_message.role, MessageRole::Tool);
659                    Ok(vec![
660                        AssistantEvent::TextDelta("The answer is 4.".to_string()),
661                        AssistantEvent::Usage(TokenUsage {
662                            input_tokens: 24,
663                            output_tokens: 4,
664                            cache_creation_input_tokens: 1,
665                            cache_read_input_tokens: 3,
666                        }),
667                        AssistantEvent::MessageStop,
668                    ])
669                }
670                _ => Err(RuntimeError::new("unexpected extra API call")),
671            }
672        }
673    }
674
675    struct PromptAllowOnce;
676
677    impl PermissionPrompter for PromptAllowOnce {
678        fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
679            assert_eq!(request.tool_name, "add");
680            PermissionPromptDecision::Allow
681        }
682    }
683
684    #[test]
685    fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
686        let api_client = ScriptedApiClient { call_count: 0 };
687        let tool_executor = StaticToolExecutor::new().register("add", |input| {
688            let total = input
689                .split(',')
690                .map(|part| part.parse::<i32>().expect("input must be valid integer"))
691                .sum::<i32>();
692            Ok(total.to_string())
693        });
694        let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
695        let system_prompt = SystemPromptBuilder::new()
696            .with_project_context(ProjectContext {
697                cwd: PathBuf::from("/tmp/project"),
698                current_date: "2026-03-31".to_string(),
699                git_status: None,
700                git_diff: None,
701                instruction_files: Vec::new(),
702            })
703            .with_os("linux", "6.8")
704            .build();
705        let mut runtime = ConversationRuntime::new(
706            Session::new(),
707            api_client,
708            tool_executor,
709            permission_policy,
710            system_prompt,
711        );
712
713        let summary = runtime
714            .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
715            .expect("conversation loop should succeed");
716
717        assert_eq!(summary.iterations, 2);
718        assert_eq!(summary.assistant_messages.len(), 2);
719        assert_eq!(summary.tool_results.len(), 1);
720        assert_eq!(runtime.session().messages.len(), 4);
721        assert_eq!(summary.usage.output_tokens, 10);
722        assert_eq!(summary.auto_compaction, None);
723        assert!(matches!(
724            runtime.session().messages[1].blocks[1],
725            ContentBlock::ToolUse { .. }
726        ));
727        assert!(matches!(
728            runtime.session().messages[2].blocks[0],
729            ContentBlock::ToolResult {
730                is_error: false,
731                ..
732            }
733        ));
734    }
735
736    #[test]
737    fn records_denied_tool_results_when_prompt_rejects() {
738        struct RejectPrompter;
739        impl PermissionPrompter for RejectPrompter {
740            fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
741                PermissionPromptDecision::Deny {
742                    reason: "not now".to_string(),
743                }
744            }
745        }
746
747        struct SingleCallApiClient;
748        impl ApiClient for SingleCallApiClient {
749            fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
750                if request
751                    .messages
752                    .iter()
753                    .any(|message| message.role == MessageRole::Tool)
754                {
755                    return Ok(vec![
756                        AssistantEvent::TextDelta("I could not use the tool.".to_string()),
757                        AssistantEvent::MessageStop,
758                    ]);
759                }
760                Ok(vec![
761                    AssistantEvent::ToolUse {
762                        id: "tool-1".to_string(),
763                        name: "blocked".to_string(),
764                        input: "secret".to_string(),
765                    },
766                    AssistantEvent::MessageStop,
767                ])
768            }
769        }
770
771        let mut runtime = ConversationRuntime::new(
772            Session::new(),
773            SingleCallApiClient,
774            StaticToolExecutor::new(),
775            PermissionPolicy::new(PermissionMode::WorkspaceWrite),
776            vec!["system".to_string()],
777        );
778
779        let summary = runtime
780            .run_turn("use the tool", Some(&mut RejectPrompter))
781            .expect("conversation should continue after denied tool");
782
783        assert_eq!(summary.tool_results.len(), 1);
784        assert!(matches!(
785            &summary.tool_results[0].blocks[0],
786            ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
787        ));
788    }
789
790    #[test]
791    fn denies_tool_use_when_pre_tool_hook_blocks() {
792        struct SingleCallApiClient;
793        impl ApiClient for SingleCallApiClient {
794            fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
795                if request
796                    .messages
797                    .iter()
798                    .any(|message| message.role == MessageRole::Tool)
799                {
800                    return Ok(vec![
801                        AssistantEvent::TextDelta("blocked".to_string()),
802                        AssistantEvent::MessageStop,
803                    ]);
804                }
805                Ok(vec![
806                    AssistantEvent::ToolUse {
807                        id: "tool-1".to_string(),
808                        name: "blocked".to_string(),
809                        input: r#"{"path":"secret.txt"}"#.to_string(),
810                    },
811                    AssistantEvent::MessageStop,
812                ])
813            }
814        }
815
816        let mut runtime = ConversationRuntime::new_with_features(
817            Session::new(),
818            SingleCallApiClient,
819            StaticToolExecutor::new().register("blocked", |_input| {
820                panic!("tool should not execute when hook denies")
821            }),
822            PermissionPolicy::new(PermissionMode::DangerFullAccess),
823            vec!["system".to_string()],
824            RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
825                vec![shell_snippet("printf 'blocked by hook'; exit 2")],
826                Vec::new(),
827            )),
828        );
829
830        let summary = runtime
831            .run_turn("use the tool", None)
832            .expect("conversation should continue after hook denial");
833
834        assert_eq!(summary.tool_results.len(), 1);
835        let ContentBlock::ToolResult {
836            is_error, output, ..
837        } = &summary.tool_results[0].blocks[0]
838        else {
839            panic!("expected tool result block");
840        };
841        assert!(
842            *is_error,
843            "hook denial should produce an error result: {output}"
844        );
845        assert!(
846            output.contains("denied tool") || output.contains("blocked by hook"),
847            "unexpected hook denial output: {output:?}"
848        );
849    }
850
851    #[test]
852    fn appends_post_tool_hook_feedback_to_tool_result() {
853        struct TwoCallApiClient {
854            calls: usize,
855        }
856
857        impl ApiClient for TwoCallApiClient {
858            fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
859                self.calls += 1;
860                match self.calls {
861                    1 => Ok(vec![
862                        AssistantEvent::ToolUse {
863                            id: "tool-1".to_string(),
864                            name: "add".to_string(),
865                            input: r#"{"lhs":2,"rhs":2}"#.to_string(),
866                        },
867                        AssistantEvent::MessageStop,
868                    ]),
869                    2 => {
870                        assert!(request
871                            .messages
872                            .iter()
873                            .any(|message| message.role == MessageRole::Tool));
874                        Ok(vec![
875                            AssistantEvent::TextDelta("done".to_string()),
876                            AssistantEvent::MessageStop,
877                        ])
878                    }
879                    _ => Err(RuntimeError::new("unexpected extra API call")),
880                }
881            }
882        }
883
884        let mut runtime = ConversationRuntime::new_with_features(
885            Session::new(),
886            TwoCallApiClient { calls: 0 },
887            StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
888            PermissionPolicy::new(PermissionMode::DangerFullAccess),
889            vec!["system".to_string()],
890            RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
891                vec![shell_snippet("printf 'pre hook ran'")],
892                vec![shell_snippet("printf 'post hook ran'")],
893            )),
894        );
895
896        let summary = runtime
897            .run_turn("use add", None)
898            .expect("tool loop succeeds");
899
900        assert_eq!(summary.tool_results.len(), 1);
901        let ContentBlock::ToolResult {
902            is_error, output, ..
903        } = &summary.tool_results[0].blocks[0]
904        else {
905            panic!("expected tool result block");
906        };
907        assert!(
908            !*is_error,
909            "post hook should preserve non-error result: {output:?}"
910        );
911        assert!(
912            output.contains("4"),
913            "tool output missing value: {output:?}"
914        );
915        assert!(
916            output.contains("pre hook ran"),
917            "tool output missing pre hook feedback: {output:?}"
918        );
919        assert!(
920            output.contains("post hook ran"),
921            "tool output missing post hook feedback: {output:?}"
922        );
923    }
924
925    #[test]
926    fn reconstructs_usage_tracker_from_restored_session() {
927        struct SimpleApi;
928        impl ApiClient for SimpleApi {
929            fn stream(
930                &mut self,
931                _request: ApiRequest,
932            ) -> Result<Vec<AssistantEvent>, RuntimeError> {
933                Ok(vec![
934                    AssistantEvent::TextDelta("done".to_string()),
935                    AssistantEvent::MessageStop,
936                ])
937            }
938        }
939
940        let mut session = Session::new();
941        session
942            .messages
943            .push(crate::session::ConversationMessage::assistant_with_usage(
944                vec![ContentBlock::Text {
945                    text: "earlier".to_string(),
946                }],
947                Some(TokenUsage {
948                    input_tokens: 11,
949                    output_tokens: 7,
950                    cache_creation_input_tokens: 2,
951                    cache_read_input_tokens: 1,
952                }),
953            ));
954
955        let runtime = ConversationRuntime::new(
956            session,
957            SimpleApi,
958            StaticToolExecutor::new(),
959            PermissionPolicy::new(PermissionMode::DangerFullAccess),
960            vec!["system".to_string()],
961        );
962
963        assert_eq!(runtime.usage().turns(), 1);
964        assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
965    }
966
967    #[test]
968    fn compacts_session_after_turns() {
969        struct SimpleApi;
970        impl ApiClient for SimpleApi {
971            fn stream(
972                &mut self,
973                _request: ApiRequest,
974            ) -> Result<Vec<AssistantEvent>, RuntimeError> {
975                Ok(vec![
976                    AssistantEvent::TextDelta("done".to_string()),
977                    AssistantEvent::MessageStop,
978                ])
979            }
980        }
981
982        let mut runtime = ConversationRuntime::new(
983            Session::new(),
984            SimpleApi,
985            StaticToolExecutor::new(),
986            PermissionPolicy::new(PermissionMode::DangerFullAccess),
987            vec!["system".to_string()],
988        );
989        runtime.run_turn("a", None).expect("turn a");
990        runtime.run_turn("b", None).expect("turn b");
991        runtime.run_turn("c", None).expect("turn c");
992
993        let result = runtime.compact(CompactionConfig {
994            preserve_recent_messages: 2,
995            max_estimated_tokens: 1,
996        });
997        assert!(result.summary.contains("Conversation summary"));
998        assert_eq!(
999            result.compacted_session.messages[0].role,
1000            MessageRole::System
1001        );
1002    }
1003
1004    #[cfg(windows)]
1005    fn shell_snippet(script: &str) -> String {
1006        script.replace('\'', "\"")
1007    }
1008
1009    #[cfg(not(windows))]
1010    fn shell_snippet(script: &str) -> String {
1011        script.to_string()
1012    }
1013
1014    #[test]
1015    fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
1016        struct SimpleApi;
1017        impl ApiClient for SimpleApi {
1018            fn stream(
1019                &mut self,
1020                _request: ApiRequest,
1021            ) -> Result<Vec<AssistantEvent>, RuntimeError> {
1022                Ok(vec![
1023                    AssistantEvent::TextDelta("done".to_string()),
1024                    AssistantEvent::Usage(TokenUsage {
1025                        input_tokens: 120_000,
1026                        output_tokens: 4,
1027                        cache_creation_input_tokens: 0,
1028                        cache_read_input_tokens: 0,
1029                    }),
1030                    AssistantEvent::MessageStop,
1031                ])
1032            }
1033        }
1034
1035        let session = Session {
1036            version: 1,
1037            messages: vec![
1038                crate::session::ConversationMessage::user_text("one"),
1039                crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
1040                    text: "two".to_string(),
1041                }]),
1042                crate::session::ConversationMessage::user_text("three"),
1043                crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
1044                    text: "four".to_string(),
1045                }]),
1046            ],
1047        };
1048
1049        let mut runtime = ConversationRuntime::new(
1050            session,
1051            SimpleApi,
1052            StaticToolExecutor::new(),
1053            PermissionPolicy::new(PermissionMode::DangerFullAccess),
1054            vec!["system".to_string()],
1055        )
1056        .with_auto_compaction_input_tokens_threshold(100_000);
1057
1058        let summary = runtime
1059            .run_turn("trigger", None)
1060            .expect("turn should succeed");
1061
1062        assert_eq!(
1063            summary.auto_compaction,
1064            Some(AutoCompactionEvent {
1065                removed_message_count: 2,
1066            })
1067        );
1068        assert_eq!(runtime.session().messages[0].role, MessageRole::System);
1069    }
1070
1071    #[test]
1072    fn skips_auto_compaction_below_threshold() {
1073        struct SimpleApi;
1074        impl ApiClient for SimpleApi {
1075            fn stream(
1076                &mut self,
1077                _request: ApiRequest,
1078            ) -> Result<Vec<AssistantEvent>, RuntimeError> {
1079                Ok(vec![
1080                    AssistantEvent::TextDelta("done".to_string()),
1081                    AssistantEvent::Usage(TokenUsage {
1082                        input_tokens: 99_999,
1083                        output_tokens: 4,
1084                        cache_creation_input_tokens: 0,
1085                        cache_read_input_tokens: 0,
1086                    }),
1087                    AssistantEvent::MessageStop,
1088                ])
1089            }
1090        }
1091
1092        let mut runtime = ConversationRuntime::new(
1093            Session::new(),
1094            SimpleApi,
1095            StaticToolExecutor::new(),
1096            PermissionPolicy::new(PermissionMode::DangerFullAccess),
1097            vec!["system".to_string()],
1098        )
1099        .with_auto_compaction_input_tokens_threshold(100_000);
1100
1101        let summary = runtime
1102            .run_turn("trigger", None)
1103            .expect("turn should succeed");
1104        assert_eq!(summary.auto_compaction, None);
1105        assert_eq!(runtime.session().messages.len(), 2);
1106    }
1107
1108    #[test]
1109    fn auto_compaction_threshold_defaults_and_parses_values() {
1110        assert_eq!(
1111            parse_auto_compaction_threshold(None),
1112            DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
1113        );
1114        assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
1115        assert_eq!(
1116            parse_auto_compaction_threshold(Some("not-a-number")),
1117            DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
1118        );
1119    }
1120}