1use crate::conversation::Conversation;
2use crate::error::AgentError;
3use crate::inference::InferenceEngine;
4use crate::permission::{PermissionRequest, PermissionTracker};
5use crate::tool::{parse_tool_calls, ToolRegistry};
6use llama_cpp_v3::{LlamaBatch, LlamaContext, LlamaSampler};
7use std::sync::Arc;
8
9#[derive(Debug)]
11pub enum AgentEvent {
12 IterationStart { iteration: usize, max_iterations: usize },
14 TextDelta(String),
16 ToolStart { name: String, arguments: String },
18 ToolResult {
20 name: String,
21 success: bool,
22 output: String,
23 },
24 PermissionResult { tool: String, allowed: bool },
26 ContextCompacted {
28 messages_before: usize,
29 messages_after: usize,
30 prompt_tokens: usize,
31 context_size: u32,
32 },
33 Completed { reason: CompletionReason },
35 Error(String),
37}
38
39#[derive(Debug, Clone)]
40pub enum CompletionReason {
41 Done,
43 MaxIterations,
45 EndOfSequence,
47}
48
49pub struct AgentLoopConfig {
51 pub max_iterations: usize,
53 pub max_tokens_per_completion: usize,
55 pub temperature: f32,
57 pub top_k: i32,
59 pub min_p: f32,
61 pub repeat_penalty: f32,
63 pub auto_compact: bool,
65 pub compaction_threshold_pct: f32,
67 pub compaction_keep_recent: usize,
69 pub n_batch: usize,
75 pub stop_sequences: Vec<String>,
77}
78
79impl Default for AgentLoopConfig {
80 fn default() -> Self {
81 Self {
82 max_iterations: 50,
83 max_tokens_per_completion: 4096,
84 temperature: 0.7,
85 top_k: 40,
86 min_p: 0.01,
87 repeat_penalty: 1.0,
88 auto_compact: true,
89 compaction_threshold_pct: 0.75,
90 compaction_keep_recent: 4,
91 n_batch: 512,
92 stop_sequences: Vec::new(),
93 }
94 }
95}
96
97pub struct KvCacheState {
108 tokens: Vec<llama_cpp_sys_v3::llama_token>,
110}
111
112impl KvCacheState {
113 pub fn new() -> Self {
114 Self { tokens: Vec::new() }
115 }
116
117 pub fn invalidate(&mut self) {
119 self.tokens.clear();
120 }
121
122 pub fn len(&self) -> usize {
123 self.tokens.len()
124 }
125
126 pub fn is_empty(&self) -> bool {
127 self.tokens.is_empty()
128 }
129}
130
131impl Default for KvCacheState {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137fn common_prefix_len(
143 a: &[llama_cpp_sys_v3::llama_token],
144 b: &[llama_cpp_sys_v3::llama_token],
145) -> usize {
146 a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
147}
148
149fn decode_tokens_chunked(
157 lib: &Arc<llama_cpp_sys_v3::LlamaLib>,
158 ctx: &mut LlamaContext,
159 tokens: &[llama_cpp_sys_v3::llama_token],
160 pos_offset: usize,
161 n_batch: usize,
162 total_prompt_len: usize,
163) -> Result<(), AgentError> {
164 if tokens.is_empty() {
165 return Ok(());
166 }
167
168 let n_batch = n_batch.max(1);
169 let n_tokens = tokens.len();
170 let mut i = 0;
171
172 while i < n_tokens {
173 let end = (i + n_batch).min(n_tokens);
174 let chunk = &tokens[i..end];
175 let is_last_chunk = end == n_tokens;
176
177 let mut batch = LlamaBatch::new(lib.clone(), chunk.len() as i32 + 1, 0, 1);
178
179 for (j, &token) in chunk.iter().enumerate() {
180 let pos = (pos_offset + i + j) as llama_cpp_sys_v3::llama_pos;
181 let logits = is_last_chunk && (j == chunk.len() - 1)
183 && (pos_offset + i + j == total_prompt_len - 1);
184 batch.add(token, pos, &[0], logits);
185 }
186
187 ctx.decode(&batch)?;
188 i = end;
189 }
190
191 Ok(())
192}
193
194fn encode_prompt_incremental(
206 lib: &Arc<llama_cpp_sys_v3::LlamaLib>,
207 ctx: &mut LlamaContext,
208 tokens: &[llama_cpp_sys_v3::llama_token],
209 kv_cache: &mut KvCacheState,
210 n_batch: usize,
211) -> Result<usize, AgentError> {
212 let prefix_len = common_prefix_len(&kv_cache.tokens, tokens);
213
214 if prefix_len > 0 && prefix_len == kv_cache.tokens.len() {
215 let delta = &tokens[prefix_len..];
217 decode_tokens_chunked(lib, ctx, delta, prefix_len, n_batch, tokens.len())?;
218 } else if prefix_len > 0 {
219 ctx.kv_cache_seq_rm(0, prefix_len as llama_cpp_sys_v3::llama_pos, -1);
225
226 let delta = &tokens[prefix_len..];
227 decode_tokens_chunked(lib, ctx, delta, prefix_len, n_batch, tokens.len())?;
228 } else {
229 ctx.kv_cache_clear();
231 decode_tokens_chunked(lib, ctx, tokens, 0, n_batch, tokens.len())?;
232 }
233
234 kv_cache.tokens.clear();
236 kv_cache.tokens.extend_from_slice(tokens);
237
238 Ok(tokens.len())
239}
240
241pub fn run_agent_loop(
247 engine: &InferenceEngine,
248 ctx: &mut LlamaContext,
249 conversation: &mut Conversation,
250 tools: &ToolRegistry,
251 permissions: &mut PermissionTracker,
252 config: &AgentLoopConfig,
253 kv_cache: &mut KvCacheState,
254 mut on_event: impl FnMut(AgentEvent),
255) -> Result<(), AgentError> {
256 let lib = engine.lib();
257 let model = engine.model();
258 let n_ctx = engine.config.n_ctx;
259 let max_iters = if config.max_iterations == 0 {
260 usize::MAX
261 } else {
262 config.max_iterations
263 };
264
265 for iteration in 0..max_iters {
266 on_event(AgentEvent::IterationStart {
267 iteration: iteration + 1,
268 max_iterations: config.max_iterations,
269 });
270
271 let chat_messages = conversation.to_chat_messages();
273 let template = engine.config.chat_template.as_deref();
274 let prompt = model.apply_chat_template(template, &chat_messages, true)?;
275 let tokens = model.tokenize(&prompt, false, true)?;
276
277 let tokens = if config.auto_compact
279 && tokens.len() as f32 > n_ctx as f32 * config.compaction_threshold_pct
280 && conversation.compactable_count(config.compaction_keep_recent) > 0
281 {
282 let messages_before = conversation.len();
283 let prompt_tokens = tokens.len();
284
285 kv_cache.invalidate();
286 let summary = generate_compaction_summary(engine, ctx, conversation, config)?;
287 conversation.compact(&summary, config.compaction_keep_recent);
288
289 on_event(AgentEvent::ContextCompacted {
290 messages_before,
291 messages_after: conversation.len(),
292 prompt_tokens,
293 context_size: n_ctx,
294 });
295
296 let chat_messages = conversation.to_chat_messages();
297 let template = engine.config.chat_template.as_deref();
298 let prompt = model.apply_chat_template(template, &chat_messages, true)?;
299 model.tokenize(&prompt, false, true)?
300 } else {
301 tokens
302 };
303
304 let n_cur = encode_prompt_incremental(
306 &lib, ctx, &tokens, kv_cache, config.n_batch,
307 )?;
308
309 let sampler = build_sampler(lib.clone(), config);
311 let vocab = model.get_vocab();
312 let mut generated_text = String::new();
313 let mut n_cur = n_cur;
314 let mut generated_tokens: Vec<llama_cpp_sys_v3::llama_token> = Vec::new();
315
316 let mut batch = LlamaBatch::new(lib.clone(), 2, 0, 1);
317
318 for _ in 0..config.max_tokens_per_completion {
319 let token = sampler.sample(ctx, -1);
320 sampler.accept(token);
321
322 if vocab.is_eog(token) {
323 break;
324 }
325
326 let piece = model.token_to_piece(token);
327
328 on_event(AgentEvent::TextDelta(piece.clone()));
329 generated_text.push_str(&piece);
330 generated_tokens.push(token);
331
332 batch.clear();
333 batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
334 ctx.decode(&batch)?;
335 n_cur += 1;
336
337 if !config.stop_sequences.is_empty() {
339 let mut should_stop = false;
340 for stop in &config.stop_sequences {
341 if generated_text.ends_with(stop) {
342 should_stop = true;
343 break;
344 }
345 }
346 if should_stop {
347 break;
348 }
349 }
350 }
351
352 kv_cache.tokens.extend_from_slice(&generated_tokens);
354
355 if tools.is_empty() {
357 conversation.add_assistant(&generated_text, Vec::new());
358 on_event(AgentEvent::Completed {
359 reason: CompletionReason::Done,
360 });
361 return Ok(());
362 }
363
364 let (tool_calls, _text_parts) = parse_tool_calls(&generated_text);
365 conversation.add_assistant(&generated_text, tool_calls.clone());
366
367 if tool_calls.is_empty() {
369 on_event(AgentEvent::Completed {
370 reason: CompletionReason::Done,
371 });
372 return Ok(());
373 }
374
375 for call in &tool_calls {
377 let args_str =
378 serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".to_string());
379
380 on_event(AgentEvent::ToolStart {
381 name: call.name.clone(),
382 arguments: args_str.clone(),
383 });
384
385 let tool = tools.get(&call.name);
386 if let Some(tool_impl) = tool {
387 if tool_impl.requires_permission() {
388 let req = PermissionRequest {
389 tool_name: call.name.clone(),
390 description: format!("{}: {}", call.name, args_str),
391 dangerous: tool_impl.is_dangerous(&call.arguments),
392 arguments: call.arguments.clone(),
393 };
394
395 let allowed = permissions.check(&req);
396 on_event(AgentEvent::PermissionResult {
397 tool: call.name.clone(),
398 allowed,
399 });
400
401 if !allowed {
402 let result = crate::tool::ToolResult::err("Permission denied by user");
403 conversation.add_tool_result(call.clone(), result.clone());
404 on_event(AgentEvent::ToolResult {
405 name: call.name.clone(),
406 success: false,
407 output: result.output,
408 });
409 continue;
410 }
411 }
412 }
413
414 let result = tools.execute(call);
415 match result {
416 Ok(result) => {
417 on_event(AgentEvent::ToolResult {
418 name: call.name.clone(),
419 success: result.success,
420 output: result.output.clone(),
421 });
422 conversation.add_tool_result(call.clone(), result);
423 }
424 Err(e) => {
425 let result =
426 crate::tool::ToolResult::err(format!("Tool execution error: {}", e));
427 on_event(AgentEvent::ToolResult {
428 name: call.name.clone(),
429 success: false,
430 output: result.output.clone(),
431 });
432 conversation.add_tool_result(call.clone(), result);
433 }
434 }
435 }
436 }
437
438 on_event(AgentEvent::Completed {
439 reason: CompletionReason::MaxIterations,
440 });
441 Ok(())
442}
443
444const COMPACTION_PROMPT: &str = "\
449Summarize the following conversation history concisely. Preserve:
450- The user's goals and what they asked for
451- Key decisions and outcomes
452- Important file paths, variable names, or technical details mentioned
453- Current progress and what still needs to be done
454- Any errors encountered and how they were resolved
455
456Be concise but complete. Use bullet points. Do NOT include pleasantries or filler.
457
458Conversation to summarize:
459";
460
461fn generate_compaction_summary(
462 engine: &InferenceEngine,
463 ctx: &mut LlamaContext,
464 conversation: &Conversation,
465 config: &AgentLoopConfig,
466) -> Result<String, AgentError> {
467 let model = engine.model();
468 let lib = engine.lib();
469
470 let start = if !conversation.messages().is_empty()
471 && conversation.messages()[0].role == crate::conversation::Role::System
472 {
473 1
474 } else {
475 0
476 };
477
478 let total = conversation.messages().len();
479 let keep_from = if total > config.compaction_keep_recent {
480 total - config.compaction_keep_recent
481 } else {
482 start
483 };
484 let safe_cut = conversation.find_safe_cut_point(keep_from);
485
486 if safe_cut <= start {
487 return Ok(String::new());
488 }
489
490 let old_text = conversation.serialize_range(start, safe_cut);
491 let summary_prompt = format!("{}{}", COMPACTION_PROMPT, old_text);
492
493 let chat_messages = vec![
494 llama_cpp_v3::ChatMessage {
495 role: "system".to_string(),
496 content: "You are a precise summarizer. Output only the summary, nothing else."
497 .to_string(),
498 },
499 llama_cpp_v3::ChatMessage {
500 role: "user".to_string(),
501 content: summary_prompt,
502 },
503 ];
504
505 let template = engine.config.chat_template.as_deref();
506 let prompt = model.apply_chat_template(template, &chat_messages, true)?;
507 let tokens = model.tokenize(&prompt, false, true)?;
508
509 ctx.kv_cache_clear();
510
511 decode_tokens_chunked(&lib, ctx, &tokens, 0, config.n_batch, tokens.len())?;
513
514 let mut sampler = LlamaSampler::new_chain(lib.clone(), false);
515 let greedy = LlamaSampler::new_greedy(lib.clone());
516 sampler.add(greedy);
517
518 let vocab = model.get_vocab();
519 let mut summary = String::new();
520 let mut n_cur = tokens.len();
521 let max_summary_tokens = 512;
522
523 let mut batch = LlamaBatch::new(lib.clone(), 2, 0, 1);
524
525 for _ in 0..max_summary_tokens {
526 let token = sampler.sample(ctx, -1);
527 sampler.accept(token);
528
529 if vocab.is_eog(token) {
530 break;
531 }
532
533 let piece = model.token_to_piece(token);
534 summary.push_str(&piece);
535
536 batch.clear();
537 batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
538 ctx.decode(&batch)?;
539 n_cur += 1;
540 }
541
542 Ok(summary.trim().to_string())
543}
544
545fn build_sampler(
550 lib: Arc<llama_cpp_sys_v3::LlamaLib>,
551 config: &AgentLoopConfig,
552) -> LlamaSampler {
553 let mut chain = LlamaSampler::new_chain(lib.clone(), false);
554
555 if config.repeat_penalty != 1.0 {
556 let penalties =
557 LlamaSampler::new_penalties(lib.clone(), 64, config.repeat_penalty, 0.0, 0.0);
558 chain.add(penalties);
559 }
560
561 if config.top_k > 0 {
562 let top_k = LlamaSampler::new_top_k(lib.clone(), config.top_k);
563 chain.add(top_k);
564 }
565
566 if config.min_p > 0.0 {
567 let min_p = LlamaSampler::new_min_p(lib.clone(), config.min_p, 1);
568 chain.add(min_p);
569 }
570
571 if config.temperature > 0.0 {
572 let temp = LlamaSampler::new_temp(lib.clone(), config.temperature);
573 chain.add(temp);
574 let dist = LlamaSampler::new_dist(lib.clone(), 0);
575 chain.add(dist);
576 } else {
577 let greedy = LlamaSampler::new_greedy(lib.clone());
578 chain.add(greedy);
579 }
580
581 chain
582}