Skip to main content

llama_cpp_v3_agent_sdk/
agent_loop.rs

1use crate::conversation::Conversation;
2use crate::error::AgentError;
3use crate::inference::InferenceEngine;
4use crate::permission::{PermissionRequest, PermissionTracker};
5use crate::tool::{parse_tool_calls, ToolRegistry};
6use llama_cpp_v3::{LlamaBatch, LlamaContext, LlamaSampler};
7use std::sync::Arc;
8
9/// Events emitted by the agent loop.
10#[derive(Debug)]
11pub enum AgentEvent {
12    /// The agent is starting a new iteration.
13    IterationStart { iteration: usize, max_iterations: usize },
14    /// A text chunk was generated by the model.
15    TextDelta(String),
16    /// The agent is about to call a tool.
17    ToolStart { name: String, arguments: String },
18    /// A tool returned a result.
19    ToolResult {
20        name: String,
21        success: bool,
22        output: String,
23    },
24    /// Permission was requested and granted/denied.
25    PermissionResult { tool: String, allowed: bool },
26    /// Context was automatically compacted.
27    ContextCompacted {
28        messages_before: usize,
29        messages_after: usize,
30        prompt_tokens: usize,
31        context_size: u32,
32    },
33    /// The agent completed its response.
34    Completed { reason: CompletionReason },
35    /// An error occurred during the loop.
36    Error(String),
37}
38
39#[derive(Debug, Clone)]
40pub enum CompletionReason {
41    /// Model finished generating (no more tool calls).
42    Done,
43    /// Maximum iterations reached.
44    MaxIterations,
45    /// Model generated EOS token.
46    EndOfSequence,
47}
48
49/// Configuration for the agent loop.
50pub struct AgentLoopConfig {
51    /// Maximum number of iterations (tool-use rounds). 0 = unlimited.
52    pub max_iterations: usize,
53    /// Maximum tokens to generate per completion.
54    pub max_tokens_per_completion: usize,
55    /// Temperature for sampling.
56    pub temperature: f32,
57    /// Top-K sampling parameter.
58    pub top_k: i32,
59    /// Min-P sampling parameter.
60    pub min_p: f32,
61    /// Repetition penalty.
62    pub repeat_penalty: f32,
63    /// Enable automatic context compaction (default: true).
64    pub auto_compact: bool,
65    /// Percentage of context window that triggers compaction (0.0–1.0).
66    pub compaction_threshold_pct: f32,
67    /// Number of recent messages to keep when compacting.
68    pub compaction_keep_recent: usize,
69    /// Maximum tokens per batch during prompt encoding (default: 512).
70    ///
71    /// Large prompts are split into chunks of this size for decoding.
72    /// Smaller values use less peak GPU memory; larger values may be
73    /// faster on GPUs with plenty of VRAM.
74    pub n_batch: usize,
75    /// List of stop sequences. If the model generates one of these, it stops.
76    pub stop_sequences: Vec<String>,
77}
78
79impl Default for AgentLoopConfig {
80    fn default() -> Self {
81        Self {
82            max_iterations: 50,
83            max_tokens_per_completion: 4096,
84            temperature: 0.7,
85            top_k: 40,
86            min_p: 0.01,
87            repeat_penalty: 1.0,
88            auto_compact: true,
89            compaction_threshold_pct: 0.75,
90            compaction_keep_recent: 4,
91            n_batch: 512,
92            stop_sequences: Vec::new(),
93        }
94    }
95}
96
97// ─────────────────────────────────────────────────────────────────────────────
98// KV Cache State — incremental prompt encoding
99// ─────────────────────────────────────────────────────────────────────────────
100
101/// Tracks the token sequence currently held in the KV cache.
102///
103/// By comparing a new prompt's token sequence against the cached one, the
104/// agent loop can skip re-encoding the common prefix and only decode the
105/// delta. This is especially beneficial for multi-agent conversations
106/// where ~95% of the prompt is unchanged between turns.
107pub struct KvCacheState {
108    /// The token sequence currently in the KV cache.
109    tokens: Vec<llama_cpp_sys_v3::llama_token>,
110}
111
112impl KvCacheState {
113    pub fn new() -> Self {
114        Self { tokens: Vec::new() }
115    }
116
117    /// Invalidate the cache (e.g. after compaction or history clear).
118    pub fn invalidate(&mut self) {
119        self.tokens.clear();
120    }
121
122    pub fn len(&self) -> usize {
123        self.tokens.len()
124    }
125
126    pub fn is_empty(&self) -> bool {
127        self.tokens.is_empty()
128    }
129}
130
131impl Default for KvCacheState {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137// ─────────────────────────────────────────────────────────────────────────────
138// Chunked prompt encoding
139// ─────────────────────────────────────────────────────────────────────────────
140
141/// Find the length of the common prefix between two token sequences.
142fn common_prefix_len(
143    a: &[llama_cpp_sys_v3::llama_token],
144    b: &[llama_cpp_sys_v3::llama_token],
145) -> usize {
146    a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
147}
148
149/// Decode a slice of tokens into the KV cache in chunks of `n_batch`.
150///
151/// `pos_offset` is the starting position in the KV cache (so that
152/// incremental decodes append at the right position).
153///
154/// Chunking avoids OOM on smaller GPUs and improves throughput on most
155/// hardware by matching the GPU's preferred batch size.
156fn decode_tokens_chunked(
157    lib: &Arc<llama_cpp_sys_v3::LlamaLib>,
158    ctx: &mut LlamaContext,
159    tokens: &[llama_cpp_sys_v3::llama_token],
160    pos_offset: usize,
161    n_batch: usize,
162    total_prompt_len: usize,
163) -> Result<(), AgentError> {
164    if tokens.is_empty() {
165        return Ok(());
166    }
167
168    let n_batch = n_batch.max(1);
169    let n_tokens = tokens.len();
170    let mut i = 0;
171
172    while i < n_tokens {
173        let end = (i + n_batch).min(n_tokens);
174        let chunk = &tokens[i..end];
175        let is_last_chunk = end == n_tokens;
176
177        let mut batch = LlamaBatch::new(lib.clone(), chunk.len() as i32 + 1, 0, 1);
178
179        for (j, &token) in chunk.iter().enumerate() {
180            let pos = (pos_offset + i + j) as llama_cpp_sys_v3::llama_pos;
181            // Only the very last token of the very last chunk needs logits
182            let logits = is_last_chunk && (j == chunk.len() - 1)
183                && (pos_offset + i + j == total_prompt_len - 1);
184            batch.add(token, pos, &[0], logits);
185        }
186
187        ctx.decode(&batch)?;
188        i = end;
189    }
190
191    Ok(())
192}
193
194/// Encode a prompt into the KV cache, reusing the existing prefix.
195///
196/// Three cases are handled:
197/// 1. **Exact prefix match** — KV cache is a prefix of the new prompt.
198///    Only the delta (new tokens) is decoded.
199/// 2. **Partial prefix match** — KV cache shares a common prefix but
200///    diverges. Uses `kv_cache_seq_rm` to remove positions from the
201///    divergence point onward, then decodes only the new suffix.
202/// 3. **No match** — Full KV cache clear and re-encode.
203///
204/// Returns the total number of tokens now in the KV cache (= prompt length).
205fn encode_prompt_incremental(
206    lib: &Arc<llama_cpp_sys_v3::LlamaLib>,
207    ctx: &mut LlamaContext,
208    tokens: &[llama_cpp_sys_v3::llama_token],
209    kv_cache: &mut KvCacheState,
210    n_batch: usize,
211) -> Result<usize, AgentError> {
212    let prefix_len = common_prefix_len(&kv_cache.tokens, tokens);
213
214    if prefix_len > 0 && prefix_len == kv_cache.tokens.len() {
215        // Case 1: KV cache is an exact prefix — decode only the delta
216        let delta = &tokens[prefix_len..];
217        decode_tokens_chunked(lib, ctx, delta, prefix_len, n_batch, tokens.len())?;
218    } else if prefix_len > 0 {
219        // Case 2: Partial match — remove the divergent suffix from KV cache
220        // and decode from the divergence point onward.
221        //
222        // kv_cache_seq_rm(seq_id=0, p0=prefix_len, p1=-1) removes all
223        // positions from prefix_len to the end for sequence 0.
224        ctx.kv_cache_seq_rm(0, prefix_len as llama_cpp_sys_v3::llama_pos, -1);
225
226        let delta = &tokens[prefix_len..];
227        decode_tokens_chunked(lib, ctx, delta, prefix_len, n_batch, tokens.len())?;
228    } else {
229        // Case 3: No common prefix — full re-encode
230        ctx.kv_cache_clear();
231        decode_tokens_chunked(lib, ctx, tokens, 0, n_batch, tokens.len())?;
232    }
233
234    // Update cache state to reflect the prompt tokens now in the KV cache
235    kv_cache.tokens.clear();
236    kv_cache.tokens.extend_from_slice(tokens);
237
238    Ok(tokens.len())
239}
240
241// ─────────────────────────────────────────────────────────────────────────────
242// Agent loop
243// ─────────────────────────────────────────────────────────────────────────────
244
245/// Run the core agentic loop with incremental, chunked prompt encoding.
246pub fn run_agent_loop(
247    engine: &InferenceEngine,
248    ctx: &mut LlamaContext,
249    conversation: &mut Conversation,
250    tools: &ToolRegistry,
251    permissions: &mut PermissionTracker,
252    config: &AgentLoopConfig,
253    kv_cache: &mut KvCacheState,
254    mut on_event: impl FnMut(AgentEvent),
255) -> Result<(), AgentError> {
256    let lib = engine.lib();
257    let model = engine.model();
258    let n_ctx = engine.config.n_ctx;
259    let max_iters = if config.max_iterations == 0 {
260        usize::MAX
261    } else {
262        config.max_iterations
263    };
264
265    for iteration in 0..max_iters {
266        on_event(AgentEvent::IterationStart {
267            iteration: iteration + 1,
268            max_iterations: config.max_iterations,
269        });
270
271        // 1. Format conversation → prompt → tokens
272        let chat_messages = conversation.to_chat_messages();
273        let template = engine.config.chat_template.as_deref();
274        let prompt = model.apply_chat_template(template, &chat_messages, true)?;
275        let tokens = model.tokenize(&prompt, false, true)?;
276
277        // 2. Auto-compact if needed
278        let tokens = if config.auto_compact
279            && tokens.len() as f32 > n_ctx as f32 * config.compaction_threshold_pct
280            && conversation.compactable_count(config.compaction_keep_recent) > 0
281        {
282            let messages_before = conversation.len();
283            let prompt_tokens = tokens.len();
284
285            kv_cache.invalidate();
286            let summary = generate_compaction_summary(engine, ctx, conversation, config)?;
287            conversation.compact(&summary, config.compaction_keep_recent);
288
289            on_event(AgentEvent::ContextCompacted {
290                messages_before,
291                messages_after: conversation.len(),
292                prompt_tokens,
293                context_size: n_ctx,
294            });
295
296            let chat_messages = conversation.to_chat_messages();
297            let template = engine.config.chat_template.as_deref();
298            let prompt = model.apply_chat_template(template, &chat_messages, true)?;
299            model.tokenize(&prompt, false, true)?
300        } else {
301            tokens
302        };
303
304        // 3. Encode prompt into KV cache (incremental + chunked)
305        let n_cur = encode_prompt_incremental(
306            &lib, ctx, &tokens, kv_cache, config.n_batch,
307        )?;
308
309        // 4. Generate completion token by token
310        let sampler = build_sampler(lib.clone(), config);
311        let vocab = model.get_vocab();
312        let mut generated_text = String::new();
313        let mut n_cur = n_cur;
314        let mut generated_tokens: Vec<llama_cpp_sys_v3::llama_token> = Vec::new();
315
316        let mut batch = LlamaBatch::new(lib.clone(), 2, 0, 1);
317
318        for _ in 0..config.max_tokens_per_completion {
319            let token = sampler.sample(ctx, -1);
320            sampler.accept(token);
321
322            if vocab.is_eog(token) {
323                break;
324            }
325
326            let piece = model.token_to_piece(token);
327
328            on_event(AgentEvent::TextDelta(piece.clone()));
329            generated_text.push_str(&piece);
330            generated_tokens.push(token);
331
332            batch.clear();
333            batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
334            ctx.decode(&batch)?;
335            n_cur += 1;
336
337            // Check stop sequences
338            if !config.stop_sequences.is_empty() {
339                let mut should_stop = false;
340                for stop in &config.stop_sequences {
341                    if generated_text.ends_with(stop) {
342                        should_stop = true;
343                        break;
344                    }
345                }
346                if should_stop {
347                    break;
348                }
349            }
350        }
351
352        // 5. Update KV cache state: prompt + generated tokens
353        kv_cache.tokens.extend_from_slice(&generated_tokens);
354
355        // 6. Parse tool calls
356        if tools.is_empty() {
357            conversation.add_assistant(&generated_text, Vec::new());
358            on_event(AgentEvent::Completed {
359                reason: CompletionReason::Done,
360            });
361            return Ok(());
362        }
363
364        let (tool_calls, _text_parts) = parse_tool_calls(&generated_text);
365        conversation.add_assistant(&generated_text, tool_calls.clone());
366
367        // 7. If no tool calls, done
368        if tool_calls.is_empty() {
369            on_event(AgentEvent::Completed {
370                reason: CompletionReason::Done,
371            });
372            return Ok(());
373        }
374
375        // 8. Execute tool calls
376        for call in &tool_calls {
377            let args_str =
378                serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".to_string());
379
380            on_event(AgentEvent::ToolStart {
381                name: call.name.clone(),
382                arguments: args_str.clone(),
383            });
384
385            let tool = tools.get(&call.name);
386            if let Some(tool_impl) = tool {
387                if tool_impl.requires_permission() {
388                    let req = PermissionRequest {
389                        tool_name: call.name.clone(),
390                        description: format!("{}: {}", call.name, args_str),
391                        dangerous: tool_impl.is_dangerous(&call.arguments),
392                        arguments: call.arguments.clone(),
393                    };
394
395                    let allowed = permissions.check(&req);
396                    on_event(AgentEvent::PermissionResult {
397                        tool: call.name.clone(),
398                        allowed,
399                    });
400
401                    if !allowed {
402                        let result = crate::tool::ToolResult::err("Permission denied by user");
403                        conversation.add_tool_result(call.clone(), result.clone());
404                        on_event(AgentEvent::ToolResult {
405                            name: call.name.clone(),
406                            success: false,
407                            output: result.output,
408                        });
409                        continue;
410                    }
411                }
412            }
413
414            let result = tools.execute(call);
415            match result {
416                Ok(result) => {
417                    on_event(AgentEvent::ToolResult {
418                        name: call.name.clone(),
419                        success: result.success,
420                        output: result.output.clone(),
421                    });
422                    conversation.add_tool_result(call.clone(), result);
423                }
424                Err(e) => {
425                    let result =
426                        crate::tool::ToolResult::err(format!("Tool execution error: {}", e));
427                    on_event(AgentEvent::ToolResult {
428                        name: call.name.clone(),
429                        success: false,
430                        output: result.output.clone(),
431                    });
432                    conversation.add_tool_result(call.clone(), result);
433                }
434            }
435        }
436    }
437
438    on_event(AgentEvent::Completed {
439        reason: CompletionReason::MaxIterations,
440    });
441    Ok(())
442}
443
444// ─────────────────────────────────────────────────────────────────────────────
445// Context Compaction
446// ─────────────────────────────────────────────────────────────────────────────
447
448const COMPACTION_PROMPT: &str = "\
449Summarize the following conversation history concisely. Preserve:
450- The user's goals and what they asked for
451- Key decisions and outcomes
452- Important file paths, variable names, or technical details mentioned
453- Current progress and what still needs to be done
454- Any errors encountered and how they were resolved
455
456Be concise but complete. Use bullet points. Do NOT include pleasantries or filler.
457
458Conversation to summarize:
459";
460
461fn generate_compaction_summary(
462    engine: &InferenceEngine,
463    ctx: &mut LlamaContext,
464    conversation: &Conversation,
465    config: &AgentLoopConfig,
466) -> Result<String, AgentError> {
467    let model = engine.model();
468    let lib = engine.lib();
469
470    let start = if !conversation.messages().is_empty()
471        && conversation.messages()[0].role == crate::conversation::Role::System
472    {
473        1
474    } else {
475        0
476    };
477
478    let total = conversation.messages().len();
479    let keep_from = if total > config.compaction_keep_recent {
480        total - config.compaction_keep_recent
481    } else {
482        start
483    };
484    let safe_cut = conversation.find_safe_cut_point(keep_from);
485
486    if safe_cut <= start {
487        return Ok(String::new());
488    }
489
490    let old_text = conversation.serialize_range(start, safe_cut);
491    let summary_prompt = format!("{}{}", COMPACTION_PROMPT, old_text);
492
493    let chat_messages = vec![
494        llama_cpp_v3::ChatMessage {
495            role: "system".to_string(),
496            content: "You are a precise summarizer. Output only the summary, nothing else."
497                .to_string(),
498        },
499        llama_cpp_v3::ChatMessage {
500            role: "user".to_string(),
501            content: summary_prompt,
502        },
503    ];
504
505    let template = engine.config.chat_template.as_deref();
506    let prompt = model.apply_chat_template(template, &chat_messages, true)?;
507    let tokens = model.tokenize(&prompt, false, true)?;
508
509    ctx.kv_cache_clear();
510
511    // Use chunked encoding for the summarization prompt too
512    decode_tokens_chunked(&lib, ctx, &tokens, 0, config.n_batch, tokens.len())?;
513
514    let mut sampler = LlamaSampler::new_chain(lib.clone(), false);
515    let greedy = LlamaSampler::new_greedy(lib.clone());
516    sampler.add(greedy);
517
518    let vocab = model.get_vocab();
519    let mut summary = String::new();
520    let mut n_cur = tokens.len();
521    let max_summary_tokens = 512;
522
523    let mut batch = LlamaBatch::new(lib.clone(), 2, 0, 1);
524
525    for _ in 0..max_summary_tokens {
526        let token = sampler.sample(ctx, -1);
527        sampler.accept(token);
528
529        if vocab.is_eog(token) {
530            break;
531        }
532
533        let piece = model.token_to_piece(token);
534        summary.push_str(&piece);
535
536        batch.clear();
537        batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
538        ctx.decode(&batch)?;
539        n_cur += 1;
540    }
541
542    Ok(summary.trim().to_string())
543}
544
545// ─────────────────────────────────────────────────────────────────────────────
546// Sampler
547// ─────────────────────────────────────────────────────────────────────────────
548
549fn build_sampler(
550    lib: Arc<llama_cpp_sys_v3::LlamaLib>,
551    config: &AgentLoopConfig,
552) -> LlamaSampler {
553    let mut chain = LlamaSampler::new_chain(lib.clone(), false);
554
555    if config.repeat_penalty != 1.0 {
556        let penalties =
557            LlamaSampler::new_penalties(lib.clone(), 64, config.repeat_penalty, 0.0, 0.0);
558        chain.add(penalties);
559    }
560
561    if config.top_k > 0 {
562        let top_k = LlamaSampler::new_top_k(lib.clone(), config.top_k);
563        chain.add(top_k);
564    }
565
566    if config.min_p > 0.0 {
567        let min_p = LlamaSampler::new_min_p(lib.clone(), config.min_p, 1);
568        chain.add(min_p);
569    }
570
571    if config.temperature > 0.0 {
572        let temp = LlamaSampler::new_temp(lib.clone(), config.temperature);
573        chain.add(temp);
574        let dist = LlamaSampler::new_dist(lib.clone(), 0);
575        chain.add(dist);
576    } else {
577        let greedy = LlamaSampler::new_greedy(lib.clone());
578        chain.add(greedy);
579    }
580
581    chain
582}