use crate::conversation::Conversation;
use crate::error::AgentError;
use crate::inference::InferenceEngine;
use crate::permission::{PermissionRequest, PermissionTracker};
use crate::tool::{parse_tool_calls, ToolRegistry};
use llama_cpp_v3::{LlamaBatch, LlamaContext, LlamaSampler};
use std::sync::Arc;
#[derive(Debug)]
pub enum AgentEvent {
IterationStart { iteration: usize, max_iterations: usize },
TextDelta(String),
ToolStart { name: String, arguments: String },
ToolResult {
name: String,
success: bool,
output: String,
},
PermissionResult { tool: String, allowed: bool },
ContextCompacted {
messages_before: usize,
messages_after: usize,
prompt_tokens: usize,
context_size: u32,
},
Completed { reason: CompletionReason },
Error(String),
}
#[derive(Debug, Clone)]
pub enum CompletionReason {
Done,
MaxIterations,
EndOfSequence,
}
pub struct AgentLoopConfig {
pub max_iterations: usize,
pub max_tokens_per_completion: usize,
pub temperature: f32,
pub top_k: i32,
pub min_p: f32,
pub repeat_penalty: f32,
pub auto_compact: bool,
pub compaction_threshold_pct: f32,
pub compaction_keep_recent: usize,
pub n_batch: usize,
pub stop_sequences: Vec<String>,
}
impl Default for AgentLoopConfig {
fn default() -> Self {
Self {
max_iterations: 50,
max_tokens_per_completion: 4096,
temperature: 0.7,
top_k: 40,
min_p: 0.01,
repeat_penalty: 1.0,
auto_compact: true,
compaction_threshold_pct: 0.75,
compaction_keep_recent: 4,
n_batch: 512,
stop_sequences: Vec::new(),
}
}
}
pub struct KvCacheState {
tokens: Vec<llama_cpp_sys_v3::llama_token>,
}
impl KvCacheState {
pub fn new() -> Self {
Self { tokens: Vec::new() }
}
pub fn invalidate(&mut self) {
self.tokens.clear();
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
}
impl Default for KvCacheState {
fn default() -> Self {
Self::new()
}
}
fn common_prefix_len(
a: &[llama_cpp_sys_v3::llama_token],
b: &[llama_cpp_sys_v3::llama_token],
) -> usize {
a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
}
fn decode_tokens_chunked(
lib: &Arc<llama_cpp_sys_v3::LlamaLib>,
ctx: &mut LlamaContext,
tokens: &[llama_cpp_sys_v3::llama_token],
pos_offset: usize,
n_batch: usize,
total_prompt_len: usize,
) -> Result<(), AgentError> {
if tokens.is_empty() {
return Ok(());
}
let n_batch = n_batch.max(1);
let n_tokens = tokens.len();
let mut i = 0;
while i < n_tokens {
let end = (i + n_batch).min(n_tokens);
let chunk = &tokens[i..end];
let is_last_chunk = end == n_tokens;
let mut batch = LlamaBatch::new(lib.clone(), chunk.len() as i32 + 1, 0, 1);
for (j, &token) in chunk.iter().enumerate() {
let pos = (pos_offset + i + j) as llama_cpp_sys_v3::llama_pos;
let logits = is_last_chunk && (j == chunk.len() - 1)
&& (pos_offset + i + j == total_prompt_len - 1);
batch.add(token, pos, &[0], logits);
}
ctx.decode(&batch)?;
i = end;
}
Ok(())
}
fn encode_prompt_incremental(
lib: &Arc<llama_cpp_sys_v3::LlamaLib>,
ctx: &mut LlamaContext,
tokens: &[llama_cpp_sys_v3::llama_token],
kv_cache: &mut KvCacheState,
n_batch: usize,
) -> Result<usize, AgentError> {
let prefix_len = common_prefix_len(&kv_cache.tokens, tokens);
if prefix_len > 0 && prefix_len == kv_cache.tokens.len() {
let delta = &tokens[prefix_len..];
decode_tokens_chunked(lib, ctx, delta, prefix_len, n_batch, tokens.len())?;
} else if prefix_len > 0 {
ctx.kv_cache_seq_rm(0, prefix_len as llama_cpp_sys_v3::llama_pos, -1);
let delta = &tokens[prefix_len..];
decode_tokens_chunked(lib, ctx, delta, prefix_len, n_batch, tokens.len())?;
} else {
ctx.kv_cache_clear();
decode_tokens_chunked(lib, ctx, tokens, 0, n_batch, tokens.len())?;
}
kv_cache.tokens.clear();
kv_cache.tokens.extend_from_slice(tokens);
Ok(tokens.len())
}
pub fn run_agent_loop(
engine: &InferenceEngine,
ctx: &mut LlamaContext,
conversation: &mut Conversation,
tools: &ToolRegistry,
permissions: &mut PermissionTracker,
config: &AgentLoopConfig,
kv_cache: &mut KvCacheState,
mut on_event: impl FnMut(AgentEvent),
) -> Result<(), AgentError> {
let lib = engine.lib();
let model = engine.model();
let n_ctx = engine.config.n_ctx;
let max_iters = if config.max_iterations == 0 {
usize::MAX
} else {
config.max_iterations
};
for iteration in 0..max_iters {
on_event(AgentEvent::IterationStart {
iteration: iteration + 1,
max_iterations: config.max_iterations,
});
let chat_messages = conversation.to_chat_messages();
let template = engine.config.chat_template.as_deref();
let prompt = model.apply_chat_template(template, &chat_messages, true)?;
let tokens = model.tokenize(&prompt, false, true)?;
let tokens = if config.auto_compact
&& tokens.len() as f32 > n_ctx as f32 * config.compaction_threshold_pct
&& conversation.compactable_count(config.compaction_keep_recent) > 0
{
let messages_before = conversation.len();
let prompt_tokens = tokens.len();
kv_cache.invalidate();
let summary = generate_compaction_summary(engine, ctx, conversation, config)?;
conversation.compact(&summary, config.compaction_keep_recent);
on_event(AgentEvent::ContextCompacted {
messages_before,
messages_after: conversation.len(),
prompt_tokens,
context_size: n_ctx,
});
let chat_messages = conversation.to_chat_messages();
let template = engine.config.chat_template.as_deref();
let prompt = model.apply_chat_template(template, &chat_messages, true)?;
model.tokenize(&prompt, false, true)?
} else {
tokens
};
let n_cur = encode_prompt_incremental(
&lib, ctx, &tokens, kv_cache, config.n_batch,
)?;
let sampler = build_sampler(lib.clone(), config);
let vocab = model.get_vocab();
let mut generated_text = String::new();
let mut n_cur = n_cur;
let mut generated_tokens: Vec<llama_cpp_sys_v3::llama_token> = Vec::new();
let mut batch = LlamaBatch::new(lib.clone(), 2, 0, 1);
for _ in 0..config.max_tokens_per_completion {
let token = sampler.sample(ctx, -1);
sampler.accept(token);
if vocab.is_eog(token) {
break;
}
let piece = model.token_to_piece(token);
on_event(AgentEvent::TextDelta(piece.clone()));
generated_text.push_str(&piece);
generated_tokens.push(token);
batch.clear();
batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
ctx.decode(&batch)?;
n_cur += 1;
if !config.stop_sequences.is_empty() {
let mut should_stop = false;
for stop in &config.stop_sequences {
if generated_text.ends_with(stop) {
should_stop = true;
break;
}
}
if should_stop {
break;
}
}
}
kv_cache.tokens.extend_from_slice(&generated_tokens);
if tools.is_empty() {
conversation.add_assistant(&generated_text, Vec::new());
on_event(AgentEvent::Completed {
reason: CompletionReason::Done,
});
return Ok(());
}
let (tool_calls, _text_parts) = parse_tool_calls(&generated_text);
conversation.add_assistant(&generated_text, tool_calls.clone());
if tool_calls.is_empty() {
on_event(AgentEvent::Completed {
reason: CompletionReason::Done,
});
return Ok(());
}
for call in &tool_calls {
let args_str =
serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".to_string());
on_event(AgentEvent::ToolStart {
name: call.name.clone(),
arguments: args_str.clone(),
});
let tool = tools.get(&call.name);
if let Some(tool_impl) = tool {
if tool_impl.requires_permission() {
let req = PermissionRequest {
tool_name: call.name.clone(),
description: format!("{}: {}", call.name, args_str),
dangerous: tool_impl.is_dangerous(&call.arguments),
arguments: call.arguments.clone(),
};
let allowed = permissions.check(&req);
on_event(AgentEvent::PermissionResult {
tool: call.name.clone(),
allowed,
});
if !allowed {
let result = crate::tool::ToolResult::err("Permission denied by user");
conversation.add_tool_result(call.clone(), result.clone());
on_event(AgentEvent::ToolResult {
name: call.name.clone(),
success: false,
output: result.output,
});
continue;
}
}
}
let result = tools.execute(call);
match result {
Ok(result) => {
on_event(AgentEvent::ToolResult {
name: call.name.clone(),
success: result.success,
output: result.output.clone(),
});
conversation.add_tool_result(call.clone(), result);
}
Err(e) => {
let result =
crate::tool::ToolResult::err(format!("Tool execution error: {}", e));
on_event(AgentEvent::ToolResult {
name: call.name.clone(),
success: false,
output: result.output.clone(),
});
conversation.add_tool_result(call.clone(), result);
}
}
}
}
on_event(AgentEvent::Completed {
reason: CompletionReason::MaxIterations,
});
Ok(())
}
const COMPACTION_PROMPT: &str = "\
Summarize the following conversation history concisely. Preserve:
- The user's goals and what they asked for
- Key decisions and outcomes
- Important file paths, variable names, or technical details mentioned
- Current progress and what still needs to be done
- Any errors encountered and how they were resolved
Be concise but complete. Use bullet points. Do NOT include pleasantries or filler.
Conversation to summarize:
";
fn generate_compaction_summary(
engine: &InferenceEngine,
ctx: &mut LlamaContext,
conversation: &Conversation,
config: &AgentLoopConfig,
) -> Result<String, AgentError> {
let model = engine.model();
let lib = engine.lib();
let start = if !conversation.messages().is_empty()
&& conversation.messages()[0].role == crate::conversation::Role::System
{
1
} else {
0
};
let total = conversation.messages().len();
let keep_from = if total > config.compaction_keep_recent {
total - config.compaction_keep_recent
} else {
start
};
let safe_cut = conversation.find_safe_cut_point(keep_from);
if safe_cut <= start {
return Ok(String::new());
}
let old_text = conversation.serialize_range(start, safe_cut);
let summary_prompt = format!("{}{}", COMPACTION_PROMPT, old_text);
let chat_messages = vec![
llama_cpp_v3::ChatMessage {
role: "system".to_string(),
content: "You are a precise summarizer. Output only the summary, nothing else."
.to_string(),
},
llama_cpp_v3::ChatMessage {
role: "user".to_string(),
content: summary_prompt,
},
];
let template = engine.config.chat_template.as_deref();
let prompt = model.apply_chat_template(template, &chat_messages, true)?;
let tokens = model.tokenize(&prompt, false, true)?;
ctx.kv_cache_clear();
decode_tokens_chunked(&lib, ctx, &tokens, 0, config.n_batch, tokens.len())?;
let mut sampler = LlamaSampler::new_chain(lib.clone(), false);
let greedy = LlamaSampler::new_greedy(lib.clone());
sampler.add(greedy);
let vocab = model.get_vocab();
let mut summary = String::new();
let mut n_cur = tokens.len();
let max_summary_tokens = 512;
let mut batch = LlamaBatch::new(lib.clone(), 2, 0, 1);
for _ in 0..max_summary_tokens {
let token = sampler.sample(ctx, -1);
sampler.accept(token);
if vocab.is_eog(token) {
break;
}
let piece = model.token_to_piece(token);
summary.push_str(&piece);
batch.clear();
batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
ctx.decode(&batch)?;
n_cur += 1;
}
Ok(summary.trim().to_string())
}
fn build_sampler(
lib: Arc<llama_cpp_sys_v3::LlamaLib>,
config: &AgentLoopConfig,
) -> LlamaSampler {
let mut chain = LlamaSampler::new_chain(lib.clone(), false);
if config.repeat_penalty != 1.0 {
let penalties =
LlamaSampler::new_penalties(lib.clone(), 64, config.repeat_penalty, 0.0, 0.0);
chain.add(penalties);
}
if config.top_k > 0 {
let top_k = LlamaSampler::new_top_k(lib.clone(), config.top_k);
chain.add(top_k);
}
if config.min_p > 0.0 {
let min_p = LlamaSampler::new_min_p(lib.clone(), config.min_p, 1);
chain.add(min_p);
}
if config.temperature > 0.0 {
let temp = LlamaSampler::new_temp(lib.clone(), config.temperature);
chain.add(temp);
let dist = LlamaSampler::new_dist(lib.clone(), 0);
chain.add(dist);
} else {
let greedy = LlamaSampler::new_greedy(lib.clone());
chain.add(greedy);
}
chain
}