Skip to main content

aonyx_agent/
runner.rs

1//! Main agent loop.
2//!
3//! Two entry points:
4//!
5//! - [`AgentRunner::run`] — drives a conversation forward and returns the
6//!   full transcript at the end. Convenient for tests and batch callers.
7//! - [`AgentRunner::run_streaming`] — same loop, but emits [`TurnEvent`]s on
8//!   a channel as they happen (delta text, tool start / end, iteration
9//!   boundaries, terminal `Done`). The interactive CLI uses this to render
10//!   tokens as they arrive and annotate tool activity inline.
11//!
12//! Inner steps:
13//! 1. Build a [`ChatRequest`] from the current message history and the
14//!    schemas of every registered tool.
15//! 2. Stream the response from the provider, accumulating text deltas and
16//!    collecting tool calls. Each delta becomes a
17//!    [`TurnEvent::AssistantDelta`].
18//! 3. Append the assistant text as a [`Role::Assistant`] message.
19//! 4. For each tool call, emit [`TurnEvent::ToolStart`], ask the
20//!    [`ApprovalPolicy`], invoke the handler, emit [`TurnEvent::ToolEnd`] or
21//!    [`TurnEvent::ToolRejected`], and append the result as a [`Role::Tool`]
22//!    message.
23//! 5. If the turn produced **no** tool calls, emit [`TurnEvent::Done`] and
24//!    return. Otherwise loop, bounded by `max_iterations`.
25
26use std::sync::Arc;
27
28use aonyx_core::{
29    AonyxError, ChatRequest, LlmProvider, Message, Result, Role, SafetyClass, ToolCall,
30    ToolHandler, ToolResult,
31};
32use aonyx_skills::{Skill, SkillEngine};
33use aonyx_tools::ToolRegistry;
34use futures::StreamExt;
35use serde_json::{json, Value};
36use tokio::sync::mpsc;
37
38use crate::approval::ApprovalPolicy;
39
40/// Streamed observation of what the runner is doing right now.
41#[derive(Debug, Clone)]
42pub enum TurnEvent {
43    /// Iteration index (1-based) about to start.
44    IterationStart(usize),
45    /// Incremental assistant text — render as soon as it arrives.
46    AssistantDelta(String),
47    /// The assistant emitted its final text for this iteration (no further
48    /// streaming until the next iteration or tool call).
49    AssistantMessageEnd,
50    /// The model requested a tool. The approval gate has not run yet.
51    ToolStart {
52        /// Tool name (matches `ToolHandler::name`).
53        name: String,
54        /// JSON arguments the model is sending.
55        args: Value,
56        /// Safety class of the tool.
57        class: SafetyClass,
58    },
59    /// A tool finished executing.
60    ToolEnd {
61        /// Tool name.
62        name: String,
63        /// `true` when the tool succeeded.
64        ok: bool,
65        /// Short one-line summary of the result (truncated).
66        summary: String,
67    },
68    /// A tool call was rejected by the approval policy and never executed.
69    ToolRejected {
70        /// Tool name.
71        name: String,
72        /// Safety class that caused the rejection.
73        class: SafetyClass,
74    },
75    /// The loop finished — model emitted no tool call, or the iteration cap
76    /// was hit.
77    Done {
78        /// Iterations consumed.
79        iterations: usize,
80        /// `true` when the loop bailed out at `max_iterations`.
81        max_iterations_hit: bool,
82    },
83}
84
85/// Outcome of a [`AgentRunner::run`] call.
86#[derive(Debug, Clone)]
87pub struct TurnResult {
88    /// The full message log at the end of the run (input + assistant + tool messages).
89    pub messages: Vec<Message>,
90    /// Number of provider turns consumed.
91    pub iterations: usize,
92    /// `true` when the loop bailed out at `max_iterations`.
93    pub max_iterations_hit: bool,
94}
95
96/// Drives a session forward, multi-turn, until the model emits no tool call
97/// or the iteration cap is reached.
98///
99/// `Clone` is intentional: `Arc<AgentRunner>` is awkward because every field
100/// is already cheaply cloneable, and the TUI spawns runner work onto its own
101/// task which needs an owned value.
102#[derive(Clone)]
103pub struct AgentRunner {
104    /// Active provider — shared behind an `Arc<Mutex<_>>` so the TUI
105    /// `/provider` command (Phase LL) can swap the whole backend live.
106    provider: Arc<std::sync::Mutex<Arc<dyn LlmProvider>>>,
107    tools: ToolRegistry,
108    skills: Vec<Skill>,
109    /// Skill ids the user has switched off for this session — shared
110    /// behind an `Arc<Mutex<_>>` so the TUI `/skills` panel (Phase X)
111    /// can flip them live and the next turn picks up the change.
112    disabled_skills: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
113    /// Pretty-printed (redacted) JSON of the most recent request sent
114    /// to the provider — surfaced by the TUI `/inspect` panel
115    /// (Phase Y). `None` until the first turn fires.
116    last_request: Arc<std::sync::Mutex<Option<String>>>,
117    project: Option<String>,
118    approval: ApprovalPolicy,
119    /// Active model id — shared behind an `Arc<Mutex<_>>` so the TUI
120    /// `/model` command (Phase EE) can swap it live and the next turn
121    /// (and `summarize`) picks up the change.
122    model: Arc<std::sync::Mutex<String>>,
123    max_iterations: usize,
124    /// Pre-fetch RAG context before each turn (retrieve-then-generate).
125    auto_retrieve: bool,
126    /// Chunks injected when [`Self::with_auto_retrieve`] is enabled.
127    auto_retrieve_top_k: usize,
128    /// Skip auto-retrieve for user messages shorter than this (chars).
129    auto_retrieve_min_len: usize,
130}
131
132impl AgentRunner {
133    /// Construct a runner with the V1 default policy ([`ApprovalPolicy::DenyDestructive`]).
134    pub fn new(
135        provider: Arc<dyn LlmProvider>,
136        tools: ToolRegistry,
137        model: impl Into<String>,
138    ) -> Self {
139        Self {
140            provider: Arc::new(std::sync::Mutex::new(provider)),
141            tools,
142            skills: Vec::new(),
143            disabled_skills: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
144            last_request: Arc::new(std::sync::Mutex::new(None)),
145            project: None,
146            approval: ApprovalPolicy::default(),
147            model: Arc::new(std::sync::Mutex::new(model.into())),
148            max_iterations: 10,
149            auto_retrieve: false,
150            auto_retrieve_top_k: 5,
151            auto_retrieve_min_len: 12,
152        }
153    }
154
155    /// Snapshot the active model id.
156    fn current_model(&self) -> String {
157        self.model.lock().map(|m| m.clone()).unwrap_or_default()
158    }
159
160    /// Share the live model handle so the TUI `/model` command can swap
161    /// the active model mid-session (Phase EE).
162    pub fn model_handle(&self) -> Arc<std::sync::Mutex<String>> {
163        Arc::clone(&self.model)
164    }
165
166    /// Share the live provider handle so the TUI `/provider` command
167    /// can swap the whole backend mid-session (Phase LL).
168    pub fn provider_handle(&self) -> Arc<std::sync::Mutex<Arc<dyn LlmProvider>>> {
169        Arc::clone(&self.provider)
170    }
171
172    /// Snapshot the active provider.
173    fn current_provider(&self) -> Arc<dyn LlmProvider> {
174        self.provider
175            .lock()
176            .map(|p| Arc::clone(&p))
177            .unwrap_or_else(|e| Arc::clone(&e.into_inner()))
178    }
179
180    /// Share a live skill-toggle set with the caller. Skill ids present
181    /// in the set are skipped during per-turn matching, letting the TUI
182    /// enable / disable skills mid-session (Phase X).
183    pub fn skill_toggle_handle(&self) -> Arc<std::sync::Mutex<std::collections::HashSet<String>>> {
184        Arc::clone(&self.disabled_skills)
185    }
186
187    /// Share a handle to the most-recent-request snapshot. The TUI
188    /// `/inspect` panel (Phase Y) reads the pretty-printed JSON written
189    /// here on every turn.
190    pub fn last_request_handle(&self) -> Arc<std::sync::Mutex<Option<String>>> {
191        Arc::clone(&self.last_request)
192    }
193
194    /// Override the approval policy.
195    pub fn with_approval(mut self, policy: ApprovalPolicy) -> Self {
196        self.approval = policy;
197        self
198    }
199
200    /// Override the per-turn iteration cap.
201    pub fn with_max_iterations(mut self, n: usize) -> Self {
202        self.max_iterations = n.max(1);
203        self
204    }
205
206    /// Register a skill catalogue. Active skills are matched per turn against
207    /// the latest user message + the (optional) project slug.
208    pub fn with_skills(mut self, skills: Vec<Skill>) -> Self {
209        self.skills = skills;
210        self
211    }
212
213    /// Set the project slug used for project-pattern skill triggers.
214    pub fn with_project(mut self, project: impl Into<String>) -> Self {
215        self.project = Some(project.into());
216        self
217    }
218
219    /// Enable retrieve-then-generate. Before each turn, pre-fetch RAG context
220    /// for the latest user message via the `rag_search` MCP tool and inject it
221    /// as a system block — helps weaker models that don't reliably call the
222    /// tool themselves on interrogative messages. The agent stays free to call
223    /// `read_document` / `find_related` to dig deeper; this only pre-loads.
224    /// No-op when no `…__rag_search` tool is registered, on slash-commands, or
225    /// on messages shorter than `min_len` chars. `top_k` is clamped to 1..=10.
226    pub fn with_auto_retrieve(mut self, enabled: bool, top_k: usize, min_len: usize) -> Self {
227        self.auto_retrieve = enabled;
228        self.auto_retrieve_top_k = top_k.clamp(1, 10);
229        self.auto_retrieve_min_len = min_len;
230        self
231    }
232
233    fn tools_schema(&self) -> Vec<Value> {
234        let mut names: Vec<&str> = self.tools.names().collect();
235        names.sort();
236        names
237            .into_iter()
238            .filter_map(|n| {
239                let h = self.tools.get(n)?;
240                let schema = h.schema();
241                let description = schema
242                    .get("description")
243                    .and_then(|v| v.as_str())
244                    .unwrap_or("")
245                    .to_string();
246                Some(json!({
247                    "name": n,
248                    "description": description,
249                    "input_schema": schema,
250                }))
251            })
252            .collect()
253    }
254
255    fn inject_active_skills(&self, messages: &mut Vec<Message>) {
256        if self.skills.is_empty() {
257            return;
258        }
259        let latest_user = messages
260            .iter()
261            .rev()
262            .find(|m| m.role == Role::User)
263            .map(|m| m.content.as_str())
264            .unwrap_or("");
265
266        // Phase X — drop skills the user toggled off before matching.
267        let disabled = self
268            .disabled_skills
269            .lock()
270            .map(|d| d.clone())
271            .unwrap_or_default();
272        let live_skills: Vec<Skill> = self
273            .skills
274            .iter()
275            .filter(|s| !disabled.contains(&s.id))
276            .cloned()
277            .collect();
278        if live_skills.is_empty() {
279            return;
280        }
281
282        let engine = SkillEngine::new(live_skills);
283        let active = engine.match_active(latest_user, self.project.as_deref());
284        if active.is_empty() {
285            return;
286        }
287        let block = active
288            .iter()
289            .map(|s| format!("# Skill: {}\n\n{}", s.name, s.body))
290            .collect::<Vec<_>>()
291            .join("\n\n");
292        messages.insert(0, Message::new(Role::System, block));
293    }
294
295    /// Retrieve-then-generate: pre-fetch RAG context for the latest user
296    /// message and insert it as a system block right before that message.
297    /// Best-effort — any failure (no tool, tool error, empty result) leaves
298    /// `messages` untouched so the turn proceeds normally.
299    async fn inject_auto_retrieve(&self, messages: &mut Vec<Message>) {
300        if !self.auto_retrieve {
301            return;
302        }
303        let Some(user_idx) = messages.iter().rposition(|m| m.role == Role::User) else {
304            return;
305        };
306        let query = messages[user_idx].content.trim().to_string();
307        // Skip slash-commands and very short messages ("ok", "merci", …).
308        if query.starts_with('/') || query.chars().count() < self.auto_retrieve_min_len {
309            return;
310        }
311        // MCP tools are named "<server>__<tool>" — match rag_search whatever
312        // the configured server is called.
313        let Some(tool_name) = self
314            .tools
315            .names()
316            .find(|n| *n == "rag_search" || n.ends_with("__rag_search"))
317            .map(|n| n.to_string())
318        else {
319            tracing::debug!("auto_retrieve: no rag_search tool registered; skipping");
320            return;
321        };
322        let Some(handler) = self.tools.get(&tool_name) else {
323            return;
324        };
325        let call = ToolCall {
326            id: "auto-retrieve".to_string(),
327            name: tool_name,
328            args: json!({ "query": query, "top_k": self.auto_retrieve_top_k }),
329        };
330        let output = match handler.invoke(call).await {
331            Ok(tr) if tr.error.is_none() => tr.output,
332            Ok(tr) => {
333                tracing::debug!(
334                    "auto_retrieve: rag_search returned an error: {:?}",
335                    tr.error
336                );
337                return;
338            }
339            Err(e) => {
340                tracing::debug!("auto_retrieve: rag_search invoke failed: {e}");
341                return;
342            }
343        };
344        let Some(body) = format_retrieved_context(&output, self.auto_retrieve_top_k) else {
345            return;
346        };
347        let block = format!(
348            "[Contexte RAG pré-chargé pour la question — cite la source (projet / fichier) \
349             si tu l'utilises ; tu peux approfondir avec read_document / find_related]\n\n{body}"
350        );
351        messages.insert(user_idx, Message::new(Role::System, block));
352    }
353
354    /// Run the loop and return the full transcript.
355    ///
356    /// Equivalent to [`AgentRunner::run_streaming`] with a discarded event
357    /// channel — convenient for tests and batch callers that don't need
358    /// progressive UI updates.
359    pub async fn run(&self, messages: Vec<Message>) -> Result<TurnResult> {
360        // Use a generous buffer so the synchronous sends inside the loop never
361        // block on a missing receiver. The receiver here drains and ignores
362        // every event.
363        let (tx, mut rx) = mpsc::channel::<TurnEvent>(256);
364        let drain = tokio::spawn(async move { while rx.recv().await.is_some() {} });
365        let result = self.run_streaming(messages, tx).await;
366        drain.await.ok();
367        result
368    }
369
370    /// Run the loop, emitting [`TurnEvent`]s on `events` as they happen.
371    ///
372    /// `events` is dropped when the function returns, which signals the
373    /// caller that the run is complete (a receive on the matching receiver
374    /// will then yield `None`).
375    pub async fn run_streaming(
376        &self,
377        mut messages: Vec<Message>,
378        events: mpsc::Sender<TurnEvent>,
379    ) -> Result<TurnResult> {
380        self.inject_active_skills(&mut messages);
381        self.inject_auto_retrieve(&mut messages).await;
382        let tools = self.tools_schema();
383        let mut iterations: usize = 0;
384
385        for i in 0..self.max_iterations {
386            iterations = i + 1;
387            let _ = events.send(TurnEvent::IterationStart(iterations)).await;
388
389            let req = ChatRequest {
390                model: self.current_model(),
391                messages: messages.clone(),
392                tools: tools.clone(),
393                temperature: None,
394                max_tokens: None,
395            };
396
397            // Phase Y — capture a redacted snapshot for `/inspect`
398            // before the request leaves. Best-effort: a serialization
399            // hiccup never blocks the turn.
400            if let Ok(mut slot) = self.last_request.lock() {
401                *slot = Some(redact_request_json(&req));
402            }
403
404            let (text, tool_calls) = self.consume_stream(req, &events).await?;
405
406            if tool_calls.is_empty() {
407                if !text.is_empty() {
408                    messages.push(Message::new(Role::Assistant, text));
409                }
410                let _ = events.send(TurnEvent::AssistantMessageEnd).await;
411                let _ = events
412                    .send(TurnEvent::Done {
413                        iterations,
414                        max_iterations_hit: false,
415                    })
416                    .await;
417                return Ok(TurnResult {
418                    messages,
419                    iterations,
420                    max_iterations_hit: false,
421                });
422            }
423
424            // The model requested tools — record the assistant turn carrying
425            // both its text and the tool calls, so the next iteration replays
426            // the request/response pair correctly to the provider.
427            messages.push(Message::assistant_tool_calls(text, tool_calls.clone()));
428            let _ = events.send(TurnEvent::AssistantMessageEnd).await;
429
430            for call in tool_calls {
431                let class = self
432                    .tools
433                    .get(&call.name)
434                    .map(|h| h.classify())
435                    .unwrap_or(SafetyClass::Safe);
436                let _ = events
437                    .send(TurnEvent::ToolStart {
438                        name: call.name.clone(),
439                        args: call.args.clone(),
440                        class,
441                    })
442                    .await;
443
444                let outcome = self.dispatch_tool(call.clone()).await;
445                let payload = match &outcome {
446                    Ok(tr) => {
447                        let _ = events
448                            .send(TurnEvent::ToolEnd {
449                                name: call.name.clone(),
450                                ok: true,
451                                summary: short_summary(&tr.output),
452                            })
453                            .await;
454                        format_tool_result(tr)
455                    }
456                    Err(AonyxError::ApprovalRejected(_)) => {
457                        let _ = events
458                            .send(TurnEvent::ToolRejected {
459                                name: call.name.clone(),
460                                class,
461                            })
462                            .await;
463                        format!("[approval rejected] {} ({:?})", call.name, class)
464                    }
465                    Err(e) => {
466                        let msg = format!("{e}");
467                        let _ = events
468                            .send(TurnEvent::ToolEnd {
469                                name: call.name.clone(),
470                                ok: false,
471                                summary: msg.clone(),
472                            })
473                            .await;
474                        format!("[tool error] {msg}")
475                    }
476                };
477                messages.push(Message::tool_result(call.id, payload));
478            }
479        }
480
481        let _ = events
482            .send(TurnEvent::Done {
483                iterations,
484                max_iterations_hit: true,
485            })
486            .await;
487        Ok(TurnResult {
488            messages,
489            iterations,
490            max_iterations_hit: true,
491        })
492    }
493
494    /// Summarize a slice of conversation into a single compact paragraph
495    /// (Phase BB). One-shot, tool-free, non-streaming — used by the TUI
496    /// auto-compaction to fold old turns into a system note.
497    pub async fn summarize(&self, history: &[Message]) -> Result<String> {
498        let transcript = history
499            .iter()
500            .map(|m| {
501                let who = match m.role {
502                    Role::System => "system",
503                    Role::User => "user",
504                    Role::Assistant => "assistant",
505                    Role::Tool => "tool",
506                };
507                format!("{who}: {}", m.content)
508            })
509            .collect::<Vec<_>>()
510            .join("\n\n");
511
512        let prompt = "You are compacting a conversation to save context. Summarize the \
513            exchange below concisely, preserving key facts, decisions, file paths, \
514            identifiers, and any open questions or TODOs. Omit pleasantries. Output \
515            only the summary prose — no preamble.";
516        let req = ChatRequest {
517            model: self.current_model(),
518            messages: vec![
519                Message::new(Role::System, prompt),
520                Message::new(Role::User, transcript),
521            ],
522            tools: Vec::new(),
523            temperature: Some(0.0),
524            max_tokens: Some(1024),
525        };
526
527        let provider = self.current_provider();
528        let mut stream = provider.chat_stream(req).await?;
529        let mut text = String::new();
530        while let Some(item) = stream.next().await {
531            let chunk = item?;
532            text.push_str(&chunk.delta_text);
533            if chunk.finished {
534                break;
535            }
536        }
537        Ok(text.trim().to_string())
538    }
539
540    async fn consume_stream(
541        &self,
542        req: ChatRequest,
543        events: &mpsc::Sender<TurnEvent>,
544    ) -> Result<(String, Vec<ToolCall>)> {
545        let provider = self.current_provider();
546        let mut stream = provider.chat_stream(req).await?;
547        let mut text = String::new();
548        let mut tool_calls: Vec<ToolCall> = Vec::new();
549
550        while let Some(item) = stream.next().await {
551            let chunk = item?;
552            if !chunk.delta_text.is_empty() {
553                let _ = events
554                    .send(TurnEvent::AssistantDelta(chunk.delta_text.clone()))
555                    .await;
556                text.push_str(&chunk.delta_text);
557            }
558            if let Some(tc) = chunk.tool_call {
559                tool_calls.push(tc);
560            }
561            if chunk.finished {
562                break;
563            }
564        }
565
566        Ok((text, tool_calls))
567    }
568
569    async fn dispatch_tool(&self, call: ToolCall) -> Result<ToolResult> {
570        let handler: Arc<dyn ToolHandler> = self
571            .tools
572            .get(&call.name)
573            .ok_or_else(|| AonyxError::Tool(format!("unknown tool: {}", call.name)))?;
574        let class = handler.classify();
575        if !self.approval.allow(&call, class).await {
576            return Err(AonyxError::ApprovalRejected(format!(
577                "{} ({:?})",
578                call.name, class
579            )));
580        }
581        handler.invoke(call).await
582    }
583}
584
585fn format_tool_result(tr: &ToolResult) -> String {
586    if let Some(err) = &tr.error {
587        return format!("[tool error] {err}");
588    }
589    match serde_json::to_string_pretty(&tr.output) {
590        Ok(s) => s,
591        Err(_) => tr.output.to_string(),
592    }
593}
594
595/// Serialize a [`ChatRequest`] to pretty JSON for the `/inspect` panel
596/// (Phase Y), eliding base64 image payloads so the snapshot stays
597/// readable (a single PNG can be hundreds of KB of base64).
598fn redact_request_json(req: &ChatRequest) -> String {
599    let mut value = match serde_json::to_value(req) {
600        Ok(v) => v,
601        Err(e) => return format!("(could not serialize request: {e})"),
602    };
603    if let Some(messages) = value.get_mut("messages").and_then(|m| m.as_array_mut()) {
604        for msg in messages.iter_mut() {
605            if let Some(atts) = msg.get_mut("attachments").and_then(|a| a.as_array_mut()) {
606                for att in atts.iter_mut() {
607                    if let Some(data) = att.get_mut("data") {
608                        if let Some(s) = data.as_str() {
609                            *data = Value::String(format!("<{} bytes base64 elided>", s.len()));
610                        }
611                    }
612                }
613            }
614        }
615    }
616    serde_json::to_string_pretty(&value).unwrap_or_else(|e| format!("(pretty-print failed: {e})"))
617}
618
619fn short_summary(value: &Value) -> String {
620    let raw = match value {
621        Value::String(s) => s.clone(),
622        other => serde_json::to_string(other).unwrap_or_default(),
623    };
624    let trimmed = raw.replace('\n', " ");
625    if trimmed.chars().count() > 120 {
626        let cut: String = trimmed.chars().take(120).collect();
627        format!("{cut}…")
628    } else {
629        trimmed
630    }
631}
632
633/// Format a `rag_search` tool result into a compact context block for
634/// [`AgentRunner::inject_auto_retrieve`]. Handles both a structured
635/// `{ "results": [{project, source, content}, …] }` payload and the MCP
636/// text-content case (where the server's JSON arrives as a single string —
637/// see `aonyx_mcp`'s `extract_call_result`). Returns `None` when there's
638/// nothing usable to inject.
639fn format_retrieved_context(output: &Value, top_k: usize) -> Option<String> {
640    if let Some(results) = output.get("results").and_then(|r| r.as_array()) {
641        return format_results_array(results, top_k);
642    }
643    if let Some(s) = output.as_str() {
644        // The MCP server's response often arrives as a JSON string.
645        if let Ok(parsed) = serde_json::from_str::<Value>(s) {
646            if let Some(results) = parsed.get("results").and_then(|r| r.as_array()) {
647                if let Some(block) = format_results_array(results, top_k) {
648                    return Some(block);
649                }
650            }
651        }
652        let trimmed = s.trim();
653        if trimmed.is_empty() {
654            return None;
655        }
656        return Some(cap(trimmed, 6000));
657    }
658    None
659}
660
661/// Render a `results` array (`{project, source, content}` objects) as a
662/// bulleted, source-attributed context block, capped per chunk.
663fn format_results_array(results: &[Value], top_k: usize) -> Option<String> {
664    let mut blocks = Vec::new();
665    for r in results.iter().take(top_k) {
666        let content = r
667            .get("content")
668            .and_then(|v| v.as_str())
669            .unwrap_or("")
670            .trim();
671        if content.is_empty() {
672            continue;
673        }
674        let project = r.get("project").and_then(|v| v.as_str()).unwrap_or("?");
675        let source = r.get("source").and_then(|v| v.as_str()).unwrap_or("?");
676        blocks.push(format!(
677            "- (projet {project} / {source})\n{}",
678            cap(content, 1200)
679        ));
680    }
681    if blocks.is_empty() {
682        None
683    } else {
684        Some(blocks.join("\n\n"))
685    }
686}
687
688/// Truncate `s` to at most `max_chars` characters, appending an ellipsis
689/// when it was cut.
690fn cap(s: &str, max_chars: usize) -> String {
691    if s.chars().count() <= max_chars {
692        s.to_string()
693    } else {
694        let head: String = s.chars().take(max_chars).collect();
695        format!("{head}…")
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702    use aonyx_core::{ChatChunk, ChatStream, Result as CoreResult};
703    use async_trait::async_trait;
704    use std::sync::Mutex;
705
706    /// Test double: each `chat_stream` call returns the next pre-canned chunk list.
707    struct FakeProvider {
708        queue: Mutex<Vec<Vec<ChatChunk>>>,
709    }
710
711    impl FakeProvider {
712        fn new(responses: Vec<Vec<ChatChunk>>) -> Self {
713            Self {
714                queue: Mutex::new(responses),
715            }
716        }
717    }
718
719    #[async_trait]
720    impl LlmProvider for FakeProvider {
721        fn name(&self) -> &str {
722            "fake"
723        }
724
725        async fn chat_stream(&self, _req: ChatRequest) -> CoreResult<ChatStream> {
726            let mut q = self.queue.lock().expect("queue poisoned");
727            let next = if q.is_empty() {
728                Vec::new()
729            } else {
730                q.remove(0)
731            };
732            let stream = futures::stream::iter(next.into_iter().map(Ok));
733            Ok(Box::pin(stream))
734        }
735    }
736
737    fn text_chunk(s: &str) -> ChatChunk {
738        ChatChunk {
739            delta_text: s.to_string(),
740            tool_call: None,
741            finished: false,
742        }
743    }
744
745    fn stop_chunk() -> ChatChunk {
746        ChatChunk {
747            delta_text: String::new(),
748            tool_call: None,
749            finished: true,
750        }
751    }
752
753    fn tool_chunk(name: &str, args: Value) -> ChatChunk {
754        ChatChunk {
755            delta_text: String::new(),
756            tool_call: Some(ToolCall {
757                id: format!("call-{name}"),
758                name: name.to_string(),
759                args,
760            }),
761            finished: false,
762        }
763    }
764
765    fn drain<T>(rx: &mut mpsc::Receiver<T>) -> Vec<T> {
766        let mut out = Vec::new();
767        while let Ok(ev) = rx.try_recv() {
768            out.push(ev);
769        }
770        out
771    }
772
773    fn always_on_skill(id: &str, body: &str) -> Skill {
774        let mut s = Skill {
775            id: id.to_string(),
776            name: id.to_string(),
777            enabled: true,
778            tools: Vec::new(),
779            trigger: Default::default(),
780            body: body.to_string(),
781        };
782        s.trigger.always_on = true;
783        s
784    }
785
786    #[tokio::test]
787    async fn summarize_collects_streamed_text() {
788        let provider = Arc::new(FakeProvider::new(vec![vec![
789            text_chunk("Summary: "),
790            text_chunk("user asked about X."),
791            stop_chunk(),
792        ]]));
793        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "any-model");
794        let history = vec![
795            Message::new(Role::User, "tell me about X"),
796            Message::new(Role::Assistant, "X is a thing"),
797        ];
798        let summary = runner.summarize(&history).await.unwrap();
799        assert_eq!(summary, "Summary: user asked about X.");
800    }
801
802    #[test]
803    fn redact_request_json_elides_image_payloads() {
804        use aonyx_core::Attachment;
805        let req = ChatRequest {
806            model: "claude-x".to_string(),
807            messages: vec![Message::with_attachments(
808                Role::User,
809                "look",
810                vec![Attachment::Image {
811                    media_type: "image/png".into(),
812                    data: "A".repeat(5000),
813                }],
814            )],
815            tools: vec![],
816            temperature: None,
817            max_tokens: None,
818        };
819        let json = redact_request_json(&req);
820        assert!(json.contains("claude-x"));
821        assert!(json.contains("image/png"));
822        // The 5000-char blob must be gone, replaced by the elision tag.
823        assert!(!json.contains(&"A".repeat(5000)));
824        assert!(json.contains("base64 elided"));
825    }
826
827    #[test]
828    fn redact_request_json_passes_text_only_requests_through() {
829        let req = ChatRequest {
830            model: "m".to_string(),
831            messages: vec![Message::new(Role::User, "plain text")],
832            tools: vec![],
833            temperature: None,
834            max_tokens: None,
835        };
836        let json = redact_request_json(&req);
837        assert!(json.contains("plain text"));
838    }
839
840    #[test]
841    fn inject_active_skills_adds_an_always_on_skill() {
842        let runner = AgentRunner::new(
843            Arc::new(FakeProvider::new(vec![])),
844            ToolRegistry::default_set(),
845            "any-model",
846        )
847        .with_skills(vec![always_on_skill("greeter", "ALWAYS GREET")]);
848        let mut messages = vec![Message::new(Role::User, "hi")];
849        runner.inject_active_skills(&mut messages);
850        assert_eq!(messages[0].role, Role::System);
851        assert!(messages[0].content.contains("ALWAYS GREET"));
852    }
853
854    #[test]
855    fn disabled_skill_is_not_injected() {
856        let runner = AgentRunner::new(
857            Arc::new(FakeProvider::new(vec![])),
858            ToolRegistry::default_set(),
859            "any-model",
860        )
861        .with_skills(vec![always_on_skill("greeter", "ALWAYS GREET")]);
862        // Toggle the skill off through the shared handle.
863        runner
864            .skill_toggle_handle()
865            .lock()
866            .unwrap()
867            .insert("greeter".to_string());
868        let mut messages = vec![Message::new(Role::User, "hi")];
869        runner.inject_active_skills(&mut messages);
870        // No system block injected; the lone user message stands.
871        assert_eq!(messages.len(), 1);
872        assert_eq!(messages[0].role, Role::User);
873    }
874
875    #[test]
876    fn re_enabling_a_skill_restores_injection() {
877        let runner = AgentRunner::new(
878            Arc::new(FakeProvider::new(vec![])),
879            ToolRegistry::default_set(),
880            "any-model",
881        )
882        .with_skills(vec![always_on_skill("greeter", "ALWAYS GREET")]);
883        let handle = runner.skill_toggle_handle();
884        handle.lock().unwrap().insert("greeter".to_string());
885        handle.lock().unwrap().remove("greeter");
886        let mut messages = vec![Message::new(Role::User, "hi")];
887        runner.inject_active_skills(&mut messages);
888        assert_eq!(messages[0].role, Role::System);
889        assert!(messages[0].content.contains("ALWAYS GREET"));
890    }
891
892    #[tokio::test]
893    async fn terminates_when_no_tool_calls() {
894        let provider = Arc::new(FakeProvider::new(vec![vec![
895            text_chunk("Hello, "),
896            text_chunk("world."),
897            stop_chunk(),
898        ]]));
899        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "any-model");
900        let res = runner
901            .run(vec![Message::new(Role::User, "hi")])
902            .await
903            .unwrap();
904        assert_eq!(res.iterations, 1);
905        assert!(!res.max_iterations_hit);
906        assert_eq!(res.messages.len(), 2);
907        assert_eq!(res.messages[1].role, Role::Assistant);
908        assert_eq!(res.messages[1].content, "Hello, world.");
909    }
910
911    #[tokio::test]
912    async fn loops_until_no_more_tool_calls() {
913        let dir = tempfile::tempdir().unwrap();
914        let path = dir.path().join("note.txt");
915        tokio::fs::write(&path, "hello").await.unwrap();
916
917        let provider = Arc::new(FakeProvider::new(vec![
918            // Turn 1: ask for fs_read, no text.
919            vec![
920                tool_chunk("fs_read", json!({ "path": path.to_string_lossy() })),
921                stop_chunk(),
922            ],
923            // Turn 2: produce final text, no tool call.
924            vec![text_chunk("read it."), stop_chunk()],
925        ]));
926        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "any-model");
927        let res = runner
928            .run(vec![Message::new(Role::User, "show me the file")])
929            .await
930            .unwrap();
931        assert_eq!(res.iterations, 2);
932        // User · Assistant(tool_use) · Tool result · Assistant(final).
933        // The assistant turn that requested the tool is now recorded with
934        // its tool_calls, and the result links back via tool_call_id.
935        let roles: Vec<_> = res.messages.iter().map(|m| m.role).collect();
936        assert_eq!(
937            roles,
938            vec![Role::User, Role::Assistant, Role::Tool, Role::Assistant]
939        );
940        assert_eq!(res.messages[1].tool_calls.len(), 1);
941        assert_eq!(res.messages[1].tool_calls[0].name, "fs_read");
942        assert!(res.messages[2].tool_call_id.is_some());
943        assert!(res.messages[2].content.contains("hello"));
944        assert_eq!(res.messages[3].content, "read it.");
945    }
946
947    #[tokio::test]
948    async fn respects_max_iterations() {
949        let provider = Arc::new(FakeProvider::new(vec![
950            vec![tool_chunk("git_status", json!({})), stop_chunk()],
951            vec![tool_chunk("git_status", json!({})), stop_chunk()],
952            vec![tool_chunk("git_status", json!({})), stop_chunk()],
953        ]));
954        let runner =
955            AgentRunner::new(provider, ToolRegistry::default_set(), "m").with_max_iterations(2);
956        let res = runner
957            .run(vec![Message::new(Role::User, "loop forever")])
958            .await
959            .unwrap();
960        assert_eq!(res.iterations, 2);
961        assert!(res.max_iterations_hit);
962    }
963
964    #[tokio::test]
965    async fn default_policy_blocks_destructive_writes() {
966        let dir = tempfile::tempdir().unwrap();
967        let path = dir.path().join("forbidden.txt");
968        let provider = Arc::new(FakeProvider::new(vec![vec![
969            tool_chunk(
970                "fs_write",
971                json!({ "path": path.to_string_lossy(), "content": "nope" }),
972            ),
973            stop_chunk(),
974        ]]));
975        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
976        let res = runner
977            .run(vec![Message::new(Role::User, "write to disk")])
978            .await
979            .unwrap();
980        let last = res.messages.last().unwrap();
981        assert_eq!(last.role, Role::Tool);
982        assert!(last.content.contains("approval rejected"));
983        assert!(!path.exists(), "file must not have been written");
984    }
985
986    #[tokio::test]
987    async fn auto_allow_lets_destructive_writes_through() {
988        let dir = tempfile::tempdir().unwrap();
989        let path = dir.path().join("ok.txt");
990        let provider = Arc::new(FakeProvider::new(vec![
991            vec![
992                tool_chunk(
993                    "fs_write",
994                    json!({ "path": path.to_string_lossy(), "content": "yes" }),
995                ),
996                stop_chunk(),
997            ],
998            vec![text_chunk("done"), stop_chunk()],
999        ]));
1000        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m")
1001            .with_approval(ApprovalPolicy::AutoAllow);
1002        let res = runner
1003            .run(vec![Message::new(Role::User, "write to disk")])
1004            .await
1005            .unwrap();
1006        assert_eq!(res.iterations, 2);
1007        assert_eq!(tokio::fs::read_to_string(&path).await.unwrap(), "yes");
1008    }
1009
1010    #[tokio::test]
1011    async fn run_streaming_emits_delta_events_in_order() {
1012        let provider = Arc::new(FakeProvider::new(vec![vec![
1013            text_chunk("Hello"),
1014            text_chunk(", "),
1015            text_chunk("world"),
1016            stop_chunk(),
1017        ]]));
1018        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
1019        let (tx, mut rx) = mpsc::channel::<TurnEvent>(64);
1020        runner
1021            .run_streaming(vec![Message::new(Role::User, "hi")], tx)
1022            .await
1023            .unwrap();
1024
1025        let events = drain(&mut rx);
1026        let deltas: Vec<_> = events
1027            .iter()
1028            .filter_map(|e| match e {
1029                TurnEvent::AssistantDelta(s) => Some(s.as_str()),
1030                _ => None,
1031            })
1032            .collect();
1033        assert_eq!(deltas, vec!["Hello", ", ", "world"]);
1034
1035        let has_done = events.iter().any(|e| {
1036            matches!(
1037                e,
1038                TurnEvent::Done {
1039                    max_iterations_hit: false,
1040                    ..
1041                }
1042            )
1043        });
1044        assert!(has_done);
1045    }
1046
1047    #[tokio::test]
1048    async fn run_streaming_announces_tool_start_and_end() {
1049        let dir = tempfile::tempdir().unwrap();
1050        let path = dir.path().join("hello.txt");
1051        tokio::fs::write(&path, "ok").await.unwrap();
1052
1053        let provider = Arc::new(FakeProvider::new(vec![
1054            vec![
1055                tool_chunk("fs_read", json!({ "path": path.to_string_lossy() })),
1056                stop_chunk(),
1057            ],
1058            vec![text_chunk("done"), stop_chunk()],
1059        ]));
1060        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
1061        let (tx, mut rx) = mpsc::channel::<TurnEvent>(64);
1062        runner
1063            .run_streaming(vec![Message::new(Role::User, "read it")], tx)
1064            .await
1065            .unwrap();
1066
1067        let events = drain(&mut rx);
1068        let start_seen = events
1069            .iter()
1070            .any(|e| matches!(e, TurnEvent::ToolStart { name, .. } if name == "fs_read"));
1071        let end_seen = events
1072            .iter()
1073            .any(|e| matches!(e, TurnEvent::ToolEnd { name, ok: true, .. } if name == "fs_read"));
1074        assert!(start_seen, "expected ToolStart for fs_read");
1075        assert!(end_seen, "expected successful ToolEnd for fs_read");
1076    }
1077
1078    #[tokio::test]
1079    async fn run_streaming_announces_tool_rejection() {
1080        let dir = tempfile::tempdir().unwrap();
1081        let path = dir.path().join("nope.txt");
1082        let provider = Arc::new(FakeProvider::new(vec![vec![
1083            tool_chunk(
1084                "fs_write",
1085                json!({ "path": path.to_string_lossy(), "content": "blocked" }),
1086            ),
1087            stop_chunk(),
1088        ]]));
1089        let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
1090        let (tx, mut rx) = mpsc::channel::<TurnEvent>(64);
1091        runner
1092            .run_streaming(vec![Message::new(Role::User, "write please")], tx)
1093            .await
1094            .unwrap();
1095
1096        let events = drain(&mut rx);
1097        let rejected = events
1098            .iter()
1099            .any(|e| matches!(e, TurnEvent::ToolRejected { name, .. } if name == "fs_write"));
1100        assert!(rejected, "expected ToolRejected for fs_write");
1101    }
1102
1103    #[test]
1104    fn auto_retrieve_formats_structured_results() {
1105        let output = json!({
1106            "results": [
1107                {"project": "infra", "source": "ref.md", "content": "alpha fact"},
1108                {"project": "ovelo", "source": "notes.md", "content": "beta fact"},
1109            ]
1110        });
1111        let block = format_retrieved_context(&output, 5).expect("context block");
1112        assert!(block.contains("projet infra / ref.md"));
1113        assert!(block.contains("alpha fact"));
1114        assert!(block.contains("beta fact"));
1115    }
1116
1117    #[test]
1118    fn auto_retrieve_parses_json_string_payload() {
1119        // MCP servers commonly return their JSON as a single text-content string.
1120        let payload = json!({
1121            "results": [{"project": "p", "source": "s", "content": "gamma"}]
1122        })
1123        .to_string();
1124        let block = format_retrieved_context(&Value::String(payload), 5).expect("context block");
1125        assert!(block.contains("gamma"));
1126        assert!(block.contains("projet p / s"));
1127    }
1128
1129    #[test]
1130    fn auto_retrieve_uses_plain_text_payload() {
1131        let output = Value::String("just some prose context".to_string());
1132        assert_eq!(
1133            format_retrieved_context(&output, 5).unwrap(),
1134            "just some prose context"
1135        );
1136    }
1137
1138    #[test]
1139    fn auto_retrieve_top_k_limits_chunks() {
1140        let output = json!({
1141            "results": [
1142                {"project":"p","source":"a","content":"one"},
1143                {"project":"p","source":"b","content":"two"},
1144                {"project":"p","source":"c","content":"three"},
1145            ]
1146        });
1147        let block = format_retrieved_context(&output, 2).unwrap();
1148        assert!(block.contains("one") && block.contains("two"));
1149        assert!(!block.contains("three"), "top_k=2 must drop the 3rd chunk");
1150    }
1151
1152    #[test]
1153    fn auto_retrieve_none_on_empty() {
1154        assert!(format_retrieved_context(&json!({"results": []}), 5).is_none());
1155        assert!(format_retrieved_context(&Value::String(String::new()), 5).is_none());
1156    }
1157}