matrixcode-core 0.4.5

MatrixCode Agent Core - Pure logic, no UI
Documentation
//! Agent run loop and public methods.

use anyhow::Result;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use tokio::sync::mpsc;

use crate::approval::ApproveMode;
use crate::cancel::CancellationToken;
use crate::compress::{
    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
};
use crate::event::{AgentEvent, EventData, EventType};
use crate::prompt;
use crate::providers::{ChatRequest, Message, MessageContent, Role};
use crate::tools::ToolDefinition;

use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};

impl Agent {
    pub(crate) fn new(builder: AgentBuilder) -> Self {
        let event_tx = builder.event_tx.unwrap_or_else(|| {
            let (tx, _) = mpsc::channel(100);
            tx
        });

        Self {
            provider: builder.provider,
            model_name: builder.model_name,
            tools: builder.tools,
            messages: Vec::new(),
            system_prompt: builder.system_prompt,
            max_tokens: builder.max_tokens,
            think: builder.think,
            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
            event_tx,
            skills: builder.skills,
            profile: builder.profile,
            project_overview: builder.project_overview,
            memory_summary: builder.memory_summary,
            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
            cancel_token: None,
            compression_config: crate::compress::CompressionConfig::default(),
            ask_rx: None,
        }
    }

    /// Get event sender for streaming
    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
        self.event_tx.clone()
    }

    /// Set ask response channel (for TUI mode)
    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
        self.ask_rx = Some(rx);
    }

    /// Set cancellation token
    pub fn set_cancel_token(&mut self, token: CancellationToken) {
        self.cancel_token = Some(token);
    }

    /// Set approve mode at runtime
    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
        log::info!("Agent approve mode changed: {} -> {}", old, mode);
        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
    }

    /// Get a shared reference to the approve mode atomic.
    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
        self.approve_mode.clone()
    }

    /// Replace the internal approve mode with an externally-created shared atomic.
    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
        self.approve_mode = shared;
    }

    /// Update memory summary and rebuild system prompt.
    pub fn update_memory_summary(&mut self, summary: Option<String>) {
        self.memory_summary = summary;
        self.system_prompt = prompt::build_system_prompt(
            &self.profile,
            &self.skills,
            self.project_overview.as_deref(),
            self.memory_summary.as_deref(),
        );
    }

    /// Run chat loop with tool execution (streaming version).
    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
        self.emit(AgentEvent::session_started())?;

        self.messages.push(Message {
            role: Role::User,
            content: MessageContent::Text(user_input.clone()),
        });

        let mut iterations = 0;
        let mut should_continue = true;

        while should_continue && iterations < MAX_ITERATIONS {
            iterations += 1;

            if let Some(token) = &self.cancel_token
                && token.is_cancelled()
            {
                self.emit(AgentEvent::error(
                    "Operation cancelled".to_string(),
                    None,
                    None,
                ))?;
                break;
            }

            let tool_defs: Vec<ToolDefinition> =
                self.tools.iter().map(|t| t.definition()).collect();
            let request = ChatRequest {
                system: Some(self.system_prompt.clone()),
                messages: self.messages.clone(),
                max_tokens: self.max_tokens,
                tools: tool_defs,
                think: self.think,
                enable_caching: true,
                server_tools: Vec::new(),
            };

            let response = self.call_streaming(&request).await?;

            self.track_usage(&response.usage);

            crate::debug::debug_log().api_call(
                &self.model_name,
                response.usage.input_tokens,
                response.usage.cache_read_input_tokens > 0,
            );

            should_continue = self.process_response(&response).await?;

            let context_size = self.provider.context_size();
            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
            let estimated_tokens = estimate_total_tokens(&self.messages);

            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
                api_tokens
            } else {
                estimated_tokens
            };

            crate::debug::debug_log().log(
                "compression",
                &format!(
                    "check: api={}, estimated={}, using={}, context={}, threshold={}",
                    api_tokens,
                    estimated_tokens,
                    current_tokens,
                    context_size.unwrap_or(0),
                    self.compression_config.threshold
                ),
            );

            if should_compress(current_tokens, context_size, &self.compression_config) {
                self.emit(AgentEvent::progress("Compressing context...", None))?;

                let original_tokens = current_tokens;

                match compress_messages(
                    &self.messages,
                    CompressionStrategy::SlidingWindow,
                    &self.compression_config,
                ) {
                    Ok(compressed) => {
                        let compressed_tokens = estimate_total_tokens(&compressed);
                        self.messages = compressed;
                        self.total_input_tokens
                            .store(compressed_tokens as u64, Ordering::Relaxed);
                        self.last_input_tokens
                            .store(compressed_tokens as u64, Ordering::Relaxed);

                        let ratio = compressed_tokens as f32 / original_tokens as f32;
                        crate::debug::debug_log().compression(
                            original_tokens,
                            compressed_tokens,
                            ratio,
                        );

                        self.emit(AgentEvent::with_data(
                            EventType::CompressionCompleted,
                            EventData::Compression {
                                original_tokens: original_tokens as u64,
                                compressed_tokens: compressed_tokens as u64,
                                ratio: compressed_tokens as f32 / original_tokens as f32,
                            },
                        ))?;
                    }
                    Err(e) => {
                        self.emit(AgentEvent::progress(
                            format!("Compression failed: {}", e),
                            None,
                        ))?;
                    }
                }
            }
        }

        self.emit(AgentEvent::usage_with_cache(
            self.total_input_tokens.load(Ordering::Relaxed),
            self.total_output_tokens.load(Ordering::Relaxed),
            0,
            0,
        ))?;

        self.emit(AgentEvent::session_ended())?;

        Ok(Vec::new())
    }

    /// Restore message history (for session continue/resume)
    pub fn set_messages(&mut self, messages: Vec<Message>) {
        self.messages = messages;
    }

    /// Get current messages (for session saving)
    pub fn get_messages(&self) -> &[Message] {
        &self.messages
    }

    /// Get current token counts
    pub fn get_token_counts(&self) -> (u64, u64) {
        (
            self.total_input_tokens.load(Ordering::Relaxed),
            self.total_output_tokens.load(Ordering::Relaxed),
        )
    }

    /// Clear message history
    pub fn clear_history(&mut self) {
        self.messages.clear();
        self.total_input_tokens.store(0, Ordering::Relaxed);
        self.total_output_tokens.store(0, Ordering::Relaxed);
        self.last_input_tokens.store(0, Ordering::Relaxed);
    }

    /// Get message count
    pub fn message_count(&self) -> usize {
        self.messages.len()
    }
}