1use std::sync::Arc;
27
28use aonyx_core::{
29 AonyxError, ChatRequest, LlmProvider, Message, Result, Role, SafetyClass, ToolCall,
30 ToolHandler, ToolResult,
31};
32use aonyx_skills::{Skill, SkillEngine};
33use aonyx_tools::ToolRegistry;
34use futures::StreamExt;
35use serde_json::{json, Value};
36use tokio::sync::mpsc;
37
38use crate::approval::ApprovalPolicy;
39
40#[derive(Debug, Clone)]
42pub enum TurnEvent {
43 IterationStart(usize),
45 AssistantDelta(String),
47 AssistantMessageEnd,
50 ToolStart {
52 name: String,
54 args: Value,
56 class: SafetyClass,
58 },
59 ToolEnd {
61 name: String,
63 ok: bool,
65 summary: String,
67 },
68 ToolRejected {
70 name: String,
72 class: SafetyClass,
74 },
75 Done {
78 iterations: usize,
80 max_iterations_hit: bool,
82 },
83}
84
85#[derive(Debug, Clone)]
87pub struct TurnResult {
88 pub messages: Vec<Message>,
90 pub iterations: usize,
92 pub max_iterations_hit: bool,
94}
95
96#[derive(Clone)]
103pub struct AgentRunner {
104 provider: Arc<std::sync::Mutex<Arc<dyn LlmProvider>>>,
107 tools: ToolRegistry,
108 skills: Vec<Skill>,
109 disabled_skills: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
113 last_request: Arc<std::sync::Mutex<Option<String>>>,
117 project: Option<String>,
118 approval: ApprovalPolicy,
119 model: Arc<std::sync::Mutex<String>>,
123 max_iterations: usize,
124 auto_retrieve: bool,
126 auto_retrieve_top_k: usize,
128 auto_retrieve_min_len: usize,
130}
131
132impl AgentRunner {
133 pub fn new(
135 provider: Arc<dyn LlmProvider>,
136 tools: ToolRegistry,
137 model: impl Into<String>,
138 ) -> Self {
139 Self {
140 provider: Arc::new(std::sync::Mutex::new(provider)),
141 tools,
142 skills: Vec::new(),
143 disabled_skills: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
144 last_request: Arc::new(std::sync::Mutex::new(None)),
145 project: None,
146 approval: ApprovalPolicy::default(),
147 model: Arc::new(std::sync::Mutex::new(model.into())),
148 max_iterations: 10,
149 auto_retrieve: false,
150 auto_retrieve_top_k: 5,
151 auto_retrieve_min_len: 12,
152 }
153 }
154
155 fn current_model(&self) -> String {
157 self.model.lock().map(|m| m.clone()).unwrap_or_default()
158 }
159
160 pub fn model_handle(&self) -> Arc<std::sync::Mutex<String>> {
163 Arc::clone(&self.model)
164 }
165
166 pub fn provider_handle(&self) -> Arc<std::sync::Mutex<Arc<dyn LlmProvider>>> {
169 Arc::clone(&self.provider)
170 }
171
172 fn current_provider(&self) -> Arc<dyn LlmProvider> {
174 self.provider
175 .lock()
176 .map(|p| Arc::clone(&p))
177 .unwrap_or_else(|e| Arc::clone(&e.into_inner()))
178 }
179
180 pub fn skill_toggle_handle(&self) -> Arc<std::sync::Mutex<std::collections::HashSet<String>>> {
184 Arc::clone(&self.disabled_skills)
185 }
186
187 pub fn last_request_handle(&self) -> Arc<std::sync::Mutex<Option<String>>> {
191 Arc::clone(&self.last_request)
192 }
193
194 pub fn with_approval(mut self, policy: ApprovalPolicy) -> Self {
196 self.approval = policy;
197 self
198 }
199
200 pub fn with_max_iterations(mut self, n: usize) -> Self {
202 self.max_iterations = n.max(1);
203 self
204 }
205
206 pub fn with_skills(mut self, skills: Vec<Skill>) -> Self {
209 self.skills = skills;
210 self
211 }
212
213 pub fn with_project(mut self, project: impl Into<String>) -> Self {
215 self.project = Some(project.into());
216 self
217 }
218
219 pub fn with_auto_retrieve(mut self, enabled: bool, top_k: usize, min_len: usize) -> Self {
227 self.auto_retrieve = enabled;
228 self.auto_retrieve_top_k = top_k.clamp(1, 10);
229 self.auto_retrieve_min_len = min_len;
230 self
231 }
232
233 fn tools_schema(&self) -> Vec<Value> {
234 let mut names: Vec<&str> = self.tools.names().collect();
235 names.sort();
236 names
237 .into_iter()
238 .filter_map(|n| {
239 let h = self.tools.get(n)?;
240 let schema = h.schema();
241 let description = schema
242 .get("description")
243 .and_then(|v| v.as_str())
244 .unwrap_or("")
245 .to_string();
246 Some(json!({
247 "name": n,
248 "description": description,
249 "input_schema": schema,
250 }))
251 })
252 .collect()
253 }
254
255 fn inject_active_skills(&self, messages: &mut Vec<Message>) {
256 if self.skills.is_empty() {
257 return;
258 }
259 let latest_user = messages
260 .iter()
261 .rev()
262 .find(|m| m.role == Role::User)
263 .map(|m| m.content.as_str())
264 .unwrap_or("");
265
266 let disabled = self
268 .disabled_skills
269 .lock()
270 .map(|d| d.clone())
271 .unwrap_or_default();
272 let live_skills: Vec<Skill> = self
273 .skills
274 .iter()
275 .filter(|s| !disabled.contains(&s.id))
276 .cloned()
277 .collect();
278 if live_skills.is_empty() {
279 return;
280 }
281
282 let engine = SkillEngine::new(live_skills);
283 let active = engine.match_active(latest_user, self.project.as_deref());
284 if active.is_empty() {
285 return;
286 }
287 let block = active
288 .iter()
289 .map(|s| format!("# Skill: {}\n\n{}", s.name, s.body))
290 .collect::<Vec<_>>()
291 .join("\n\n");
292 messages.insert(0, Message::new(Role::System, block));
293 }
294
295 async fn inject_auto_retrieve(&self, messages: &mut Vec<Message>) {
300 if !self.auto_retrieve {
301 return;
302 }
303 let Some(user_idx) = messages.iter().rposition(|m| m.role == Role::User) else {
304 return;
305 };
306 let query = messages[user_idx].content.trim().to_string();
307 if query.starts_with('/') || query.chars().count() < self.auto_retrieve_min_len {
309 return;
310 }
311 let Some(tool_name) = self
314 .tools
315 .names()
316 .find(|n| *n == "rag_search" || n.ends_with("__rag_search"))
317 .map(|n| n.to_string())
318 else {
319 tracing::debug!("auto_retrieve: no rag_search tool registered; skipping");
320 return;
321 };
322 let Some(handler) = self.tools.get(&tool_name) else {
323 return;
324 };
325 let call = ToolCall {
326 id: "auto-retrieve".to_string(),
327 name: tool_name,
328 args: json!({ "query": query, "top_k": self.auto_retrieve_top_k }),
329 };
330 let output = match handler.invoke(call).await {
331 Ok(tr) if tr.error.is_none() => tr.output,
332 Ok(tr) => {
333 tracing::debug!(
334 "auto_retrieve: rag_search returned an error: {:?}",
335 tr.error
336 );
337 return;
338 }
339 Err(e) => {
340 tracing::debug!("auto_retrieve: rag_search invoke failed: {e}");
341 return;
342 }
343 };
344 let Some(body) = format_retrieved_context(&output, self.auto_retrieve_top_k) else {
345 return;
346 };
347 let block = format!(
348 "[Contexte RAG pré-chargé pour la question — cite la source (projet / fichier) \
349 si tu l'utilises ; tu peux approfondir avec read_document / find_related]\n\n{body}"
350 );
351 messages.insert(user_idx, Message::new(Role::System, block));
352 }
353
354 pub async fn run(&self, messages: Vec<Message>) -> Result<TurnResult> {
360 let (tx, mut rx) = mpsc::channel::<TurnEvent>(256);
364 let drain = tokio::spawn(async move { while rx.recv().await.is_some() {} });
365 let result = self.run_streaming(messages, tx).await;
366 drain.await.ok();
367 result
368 }
369
370 pub async fn run_streaming(
376 &self,
377 mut messages: Vec<Message>,
378 events: mpsc::Sender<TurnEvent>,
379 ) -> Result<TurnResult> {
380 self.inject_active_skills(&mut messages);
381 self.inject_auto_retrieve(&mut messages).await;
382 let tools = self.tools_schema();
383 let mut iterations: usize = 0;
384
385 for i in 0..self.max_iterations {
386 iterations = i + 1;
387 let _ = events.send(TurnEvent::IterationStart(iterations)).await;
388
389 let req = ChatRequest {
390 model: self.current_model(),
391 messages: messages.clone(),
392 tools: tools.clone(),
393 temperature: None,
394 max_tokens: None,
395 };
396
397 if let Ok(mut slot) = self.last_request.lock() {
401 *slot = Some(redact_request_json(&req));
402 }
403
404 let (text, tool_calls) = self.consume_stream(req, &events).await?;
405
406 if tool_calls.is_empty() {
407 if !text.is_empty() {
408 messages.push(Message::new(Role::Assistant, text));
409 }
410 let _ = events.send(TurnEvent::AssistantMessageEnd).await;
411 let _ = events
412 .send(TurnEvent::Done {
413 iterations,
414 max_iterations_hit: false,
415 })
416 .await;
417 return Ok(TurnResult {
418 messages,
419 iterations,
420 max_iterations_hit: false,
421 });
422 }
423
424 messages.push(Message::assistant_tool_calls(text, tool_calls.clone()));
428 let _ = events.send(TurnEvent::AssistantMessageEnd).await;
429
430 for call in tool_calls {
431 let class = self
432 .tools
433 .get(&call.name)
434 .map(|h| h.classify())
435 .unwrap_or(SafetyClass::Safe);
436 let _ = events
437 .send(TurnEvent::ToolStart {
438 name: call.name.clone(),
439 args: call.args.clone(),
440 class,
441 })
442 .await;
443
444 let outcome = self.dispatch_tool(call.clone()).await;
445 let payload = match &outcome {
446 Ok(tr) => {
447 let _ = events
448 .send(TurnEvent::ToolEnd {
449 name: call.name.clone(),
450 ok: true,
451 summary: short_summary(&tr.output),
452 })
453 .await;
454 format_tool_result(tr)
455 }
456 Err(AonyxError::ApprovalRejected(_)) => {
457 let _ = events
458 .send(TurnEvent::ToolRejected {
459 name: call.name.clone(),
460 class,
461 })
462 .await;
463 format!("[approval rejected] {} ({:?})", call.name, class)
464 }
465 Err(e) => {
466 let msg = format!("{e}");
467 let _ = events
468 .send(TurnEvent::ToolEnd {
469 name: call.name.clone(),
470 ok: false,
471 summary: msg.clone(),
472 })
473 .await;
474 format!("[tool error] {msg}")
475 }
476 };
477 messages.push(Message::tool_result(call.id, payload));
478 }
479 }
480
481 let _ = events
482 .send(TurnEvent::Done {
483 iterations,
484 max_iterations_hit: true,
485 })
486 .await;
487 Ok(TurnResult {
488 messages,
489 iterations,
490 max_iterations_hit: true,
491 })
492 }
493
494 pub async fn summarize(&self, history: &[Message]) -> Result<String> {
498 let transcript = history
499 .iter()
500 .map(|m| {
501 let who = match m.role {
502 Role::System => "system",
503 Role::User => "user",
504 Role::Assistant => "assistant",
505 Role::Tool => "tool",
506 };
507 format!("{who}: {}", m.content)
508 })
509 .collect::<Vec<_>>()
510 .join("\n\n");
511
512 let prompt = "You are compacting a conversation to save context. Summarize the \
513 exchange below concisely, preserving key facts, decisions, file paths, \
514 identifiers, and any open questions or TODOs. Omit pleasantries. Output \
515 only the summary prose — no preamble.";
516 let req = ChatRequest {
517 model: self.current_model(),
518 messages: vec![
519 Message::new(Role::System, prompt),
520 Message::new(Role::User, transcript),
521 ],
522 tools: Vec::new(),
523 temperature: Some(0.0),
524 max_tokens: Some(1024),
525 };
526
527 let provider = self.current_provider();
528 let mut stream = provider.chat_stream(req).await?;
529 let mut text = String::new();
530 while let Some(item) = stream.next().await {
531 let chunk = item?;
532 text.push_str(&chunk.delta_text);
533 if chunk.finished {
534 break;
535 }
536 }
537 Ok(text.trim().to_string())
538 }
539
540 async fn consume_stream(
541 &self,
542 req: ChatRequest,
543 events: &mpsc::Sender<TurnEvent>,
544 ) -> Result<(String, Vec<ToolCall>)> {
545 let provider = self.current_provider();
546 let mut stream = provider.chat_stream(req).await?;
547 let mut text = String::new();
548 let mut tool_calls: Vec<ToolCall> = Vec::new();
549
550 while let Some(item) = stream.next().await {
551 let chunk = item?;
552 if !chunk.delta_text.is_empty() {
553 let _ = events
554 .send(TurnEvent::AssistantDelta(chunk.delta_text.clone()))
555 .await;
556 text.push_str(&chunk.delta_text);
557 }
558 if let Some(tc) = chunk.tool_call {
559 tool_calls.push(tc);
560 }
561 if chunk.finished {
562 break;
563 }
564 }
565
566 Ok((text, tool_calls))
567 }
568
569 async fn dispatch_tool(&self, call: ToolCall) -> Result<ToolResult> {
570 let handler: Arc<dyn ToolHandler> = self
571 .tools
572 .get(&call.name)
573 .ok_or_else(|| AonyxError::Tool(format!("unknown tool: {}", call.name)))?;
574 let class = handler.classify();
575 if !self.approval.allow(&call, class).await {
576 return Err(AonyxError::ApprovalRejected(format!(
577 "{} ({:?})",
578 call.name, class
579 )));
580 }
581 handler.invoke(call).await
582 }
583}
584
585fn format_tool_result(tr: &ToolResult) -> String {
586 if let Some(err) = &tr.error {
587 return format!("[tool error] {err}");
588 }
589 match serde_json::to_string_pretty(&tr.output) {
590 Ok(s) => s,
591 Err(_) => tr.output.to_string(),
592 }
593}
594
595fn redact_request_json(req: &ChatRequest) -> String {
599 let mut value = match serde_json::to_value(req) {
600 Ok(v) => v,
601 Err(e) => return format!("(could not serialize request: {e})"),
602 };
603 if let Some(messages) = value.get_mut("messages").and_then(|m| m.as_array_mut()) {
604 for msg in messages.iter_mut() {
605 if let Some(atts) = msg.get_mut("attachments").and_then(|a| a.as_array_mut()) {
606 for att in atts.iter_mut() {
607 if let Some(data) = att.get_mut("data") {
608 if let Some(s) = data.as_str() {
609 *data = Value::String(format!("<{} bytes base64 elided>", s.len()));
610 }
611 }
612 }
613 }
614 }
615 }
616 serde_json::to_string_pretty(&value).unwrap_or_else(|e| format!("(pretty-print failed: {e})"))
617}
618
619fn short_summary(value: &Value) -> String {
620 let raw = match value {
621 Value::String(s) => s.clone(),
622 other => serde_json::to_string(other).unwrap_or_default(),
623 };
624 let trimmed = raw.replace('\n', " ");
625 if trimmed.chars().count() > 120 {
626 let cut: String = trimmed.chars().take(120).collect();
627 format!("{cut}…")
628 } else {
629 trimmed
630 }
631}
632
633fn format_retrieved_context(output: &Value, top_k: usize) -> Option<String> {
640 if let Some(results) = output.get("results").and_then(|r| r.as_array()) {
641 return format_results_array(results, top_k);
642 }
643 if let Some(s) = output.as_str() {
644 if let Ok(parsed) = serde_json::from_str::<Value>(s) {
646 if let Some(results) = parsed.get("results").and_then(|r| r.as_array()) {
647 if let Some(block) = format_results_array(results, top_k) {
648 return Some(block);
649 }
650 }
651 }
652 let trimmed = s.trim();
653 if trimmed.is_empty() {
654 return None;
655 }
656 return Some(cap(trimmed, 6000));
657 }
658 None
659}
660
661fn format_results_array(results: &[Value], top_k: usize) -> Option<String> {
664 let mut blocks = Vec::new();
665 for r in results.iter().take(top_k) {
666 let content = r
667 .get("content")
668 .and_then(|v| v.as_str())
669 .unwrap_or("")
670 .trim();
671 if content.is_empty() {
672 continue;
673 }
674 let project = r.get("project").and_then(|v| v.as_str()).unwrap_or("?");
675 let source = r.get("source").and_then(|v| v.as_str()).unwrap_or("?");
676 blocks.push(format!(
677 "- (projet {project} / {source})\n{}",
678 cap(content, 1200)
679 ));
680 }
681 if blocks.is_empty() {
682 None
683 } else {
684 Some(blocks.join("\n\n"))
685 }
686}
687
688fn cap(s: &str, max_chars: usize) -> String {
691 if s.chars().count() <= max_chars {
692 s.to_string()
693 } else {
694 let head: String = s.chars().take(max_chars).collect();
695 format!("{head}…")
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702 use aonyx_core::{ChatChunk, ChatStream, Result as CoreResult};
703 use async_trait::async_trait;
704 use std::sync::Mutex;
705
706 struct FakeProvider {
708 queue: Mutex<Vec<Vec<ChatChunk>>>,
709 }
710
711 impl FakeProvider {
712 fn new(responses: Vec<Vec<ChatChunk>>) -> Self {
713 Self {
714 queue: Mutex::new(responses),
715 }
716 }
717 }
718
719 #[async_trait]
720 impl LlmProvider for FakeProvider {
721 fn name(&self) -> &str {
722 "fake"
723 }
724
725 async fn chat_stream(&self, _req: ChatRequest) -> CoreResult<ChatStream> {
726 let mut q = self.queue.lock().expect("queue poisoned");
727 let next = if q.is_empty() {
728 Vec::new()
729 } else {
730 q.remove(0)
731 };
732 let stream = futures::stream::iter(next.into_iter().map(Ok));
733 Ok(Box::pin(stream))
734 }
735 }
736
737 fn text_chunk(s: &str) -> ChatChunk {
738 ChatChunk {
739 delta_text: s.to_string(),
740 tool_call: None,
741 finished: false,
742 }
743 }
744
745 fn stop_chunk() -> ChatChunk {
746 ChatChunk {
747 delta_text: String::new(),
748 tool_call: None,
749 finished: true,
750 }
751 }
752
753 fn tool_chunk(name: &str, args: Value) -> ChatChunk {
754 ChatChunk {
755 delta_text: String::new(),
756 tool_call: Some(ToolCall {
757 id: format!("call-{name}"),
758 name: name.to_string(),
759 args,
760 }),
761 finished: false,
762 }
763 }
764
765 fn drain<T>(rx: &mut mpsc::Receiver<T>) -> Vec<T> {
766 let mut out = Vec::new();
767 while let Ok(ev) = rx.try_recv() {
768 out.push(ev);
769 }
770 out
771 }
772
773 fn always_on_skill(id: &str, body: &str) -> Skill {
774 let mut s = Skill {
775 id: id.to_string(),
776 name: id.to_string(),
777 enabled: true,
778 tools: Vec::new(),
779 trigger: Default::default(),
780 body: body.to_string(),
781 };
782 s.trigger.always_on = true;
783 s
784 }
785
786 #[tokio::test]
787 async fn summarize_collects_streamed_text() {
788 let provider = Arc::new(FakeProvider::new(vec![vec![
789 text_chunk("Summary: "),
790 text_chunk("user asked about X."),
791 stop_chunk(),
792 ]]));
793 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "any-model");
794 let history = vec![
795 Message::new(Role::User, "tell me about X"),
796 Message::new(Role::Assistant, "X is a thing"),
797 ];
798 let summary = runner.summarize(&history).await.unwrap();
799 assert_eq!(summary, "Summary: user asked about X.");
800 }
801
802 #[test]
803 fn redact_request_json_elides_image_payloads() {
804 use aonyx_core::Attachment;
805 let req = ChatRequest {
806 model: "claude-x".to_string(),
807 messages: vec![Message::with_attachments(
808 Role::User,
809 "look",
810 vec![Attachment::Image {
811 media_type: "image/png".into(),
812 data: "A".repeat(5000),
813 }],
814 )],
815 tools: vec![],
816 temperature: None,
817 max_tokens: None,
818 };
819 let json = redact_request_json(&req);
820 assert!(json.contains("claude-x"));
821 assert!(json.contains("image/png"));
822 assert!(!json.contains(&"A".repeat(5000)));
824 assert!(json.contains("base64 elided"));
825 }
826
827 #[test]
828 fn redact_request_json_passes_text_only_requests_through() {
829 let req = ChatRequest {
830 model: "m".to_string(),
831 messages: vec![Message::new(Role::User, "plain text")],
832 tools: vec![],
833 temperature: None,
834 max_tokens: None,
835 };
836 let json = redact_request_json(&req);
837 assert!(json.contains("plain text"));
838 }
839
840 #[test]
841 fn inject_active_skills_adds_an_always_on_skill() {
842 let runner = AgentRunner::new(
843 Arc::new(FakeProvider::new(vec![])),
844 ToolRegistry::default_set(),
845 "any-model",
846 )
847 .with_skills(vec![always_on_skill("greeter", "ALWAYS GREET")]);
848 let mut messages = vec![Message::new(Role::User, "hi")];
849 runner.inject_active_skills(&mut messages);
850 assert_eq!(messages[0].role, Role::System);
851 assert!(messages[0].content.contains("ALWAYS GREET"));
852 }
853
854 #[test]
855 fn disabled_skill_is_not_injected() {
856 let runner = AgentRunner::new(
857 Arc::new(FakeProvider::new(vec![])),
858 ToolRegistry::default_set(),
859 "any-model",
860 )
861 .with_skills(vec![always_on_skill("greeter", "ALWAYS GREET")]);
862 runner
864 .skill_toggle_handle()
865 .lock()
866 .unwrap()
867 .insert("greeter".to_string());
868 let mut messages = vec![Message::new(Role::User, "hi")];
869 runner.inject_active_skills(&mut messages);
870 assert_eq!(messages.len(), 1);
872 assert_eq!(messages[0].role, Role::User);
873 }
874
875 #[test]
876 fn re_enabling_a_skill_restores_injection() {
877 let runner = AgentRunner::new(
878 Arc::new(FakeProvider::new(vec![])),
879 ToolRegistry::default_set(),
880 "any-model",
881 )
882 .with_skills(vec![always_on_skill("greeter", "ALWAYS GREET")]);
883 let handle = runner.skill_toggle_handle();
884 handle.lock().unwrap().insert("greeter".to_string());
885 handle.lock().unwrap().remove("greeter");
886 let mut messages = vec![Message::new(Role::User, "hi")];
887 runner.inject_active_skills(&mut messages);
888 assert_eq!(messages[0].role, Role::System);
889 assert!(messages[0].content.contains("ALWAYS GREET"));
890 }
891
892 #[tokio::test]
893 async fn terminates_when_no_tool_calls() {
894 let provider = Arc::new(FakeProvider::new(vec![vec![
895 text_chunk("Hello, "),
896 text_chunk("world."),
897 stop_chunk(),
898 ]]));
899 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "any-model");
900 let res = runner
901 .run(vec![Message::new(Role::User, "hi")])
902 .await
903 .unwrap();
904 assert_eq!(res.iterations, 1);
905 assert!(!res.max_iterations_hit);
906 assert_eq!(res.messages.len(), 2);
907 assert_eq!(res.messages[1].role, Role::Assistant);
908 assert_eq!(res.messages[1].content, "Hello, world.");
909 }
910
911 #[tokio::test]
912 async fn loops_until_no_more_tool_calls() {
913 let dir = tempfile::tempdir().unwrap();
914 let path = dir.path().join("note.txt");
915 tokio::fs::write(&path, "hello").await.unwrap();
916
917 let provider = Arc::new(FakeProvider::new(vec![
918 vec![
920 tool_chunk("fs_read", json!({ "path": path.to_string_lossy() })),
921 stop_chunk(),
922 ],
923 vec![text_chunk("read it."), stop_chunk()],
925 ]));
926 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "any-model");
927 let res = runner
928 .run(vec![Message::new(Role::User, "show me the file")])
929 .await
930 .unwrap();
931 assert_eq!(res.iterations, 2);
932 let roles: Vec<_> = res.messages.iter().map(|m| m.role).collect();
936 assert_eq!(
937 roles,
938 vec![Role::User, Role::Assistant, Role::Tool, Role::Assistant]
939 );
940 assert_eq!(res.messages[1].tool_calls.len(), 1);
941 assert_eq!(res.messages[1].tool_calls[0].name, "fs_read");
942 assert!(res.messages[2].tool_call_id.is_some());
943 assert!(res.messages[2].content.contains("hello"));
944 assert_eq!(res.messages[3].content, "read it.");
945 }
946
947 #[tokio::test]
948 async fn respects_max_iterations() {
949 let provider = Arc::new(FakeProvider::new(vec![
950 vec![tool_chunk("git_status", json!({})), stop_chunk()],
951 vec![tool_chunk("git_status", json!({})), stop_chunk()],
952 vec![tool_chunk("git_status", json!({})), stop_chunk()],
953 ]));
954 let runner =
955 AgentRunner::new(provider, ToolRegistry::default_set(), "m").with_max_iterations(2);
956 let res = runner
957 .run(vec![Message::new(Role::User, "loop forever")])
958 .await
959 .unwrap();
960 assert_eq!(res.iterations, 2);
961 assert!(res.max_iterations_hit);
962 }
963
964 #[tokio::test]
965 async fn default_policy_blocks_destructive_writes() {
966 let dir = tempfile::tempdir().unwrap();
967 let path = dir.path().join("forbidden.txt");
968 let provider = Arc::new(FakeProvider::new(vec![vec![
969 tool_chunk(
970 "fs_write",
971 json!({ "path": path.to_string_lossy(), "content": "nope" }),
972 ),
973 stop_chunk(),
974 ]]));
975 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
976 let res = runner
977 .run(vec![Message::new(Role::User, "write to disk")])
978 .await
979 .unwrap();
980 let last = res.messages.last().unwrap();
981 assert_eq!(last.role, Role::Tool);
982 assert!(last.content.contains("approval rejected"));
983 assert!(!path.exists(), "file must not have been written");
984 }
985
986 #[tokio::test]
987 async fn auto_allow_lets_destructive_writes_through() {
988 let dir = tempfile::tempdir().unwrap();
989 let path = dir.path().join("ok.txt");
990 let provider = Arc::new(FakeProvider::new(vec![
991 vec![
992 tool_chunk(
993 "fs_write",
994 json!({ "path": path.to_string_lossy(), "content": "yes" }),
995 ),
996 stop_chunk(),
997 ],
998 vec![text_chunk("done"), stop_chunk()],
999 ]));
1000 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m")
1001 .with_approval(ApprovalPolicy::AutoAllow);
1002 let res = runner
1003 .run(vec![Message::new(Role::User, "write to disk")])
1004 .await
1005 .unwrap();
1006 assert_eq!(res.iterations, 2);
1007 assert_eq!(tokio::fs::read_to_string(&path).await.unwrap(), "yes");
1008 }
1009
1010 #[tokio::test]
1011 async fn run_streaming_emits_delta_events_in_order() {
1012 let provider = Arc::new(FakeProvider::new(vec![vec![
1013 text_chunk("Hello"),
1014 text_chunk(", "),
1015 text_chunk("world"),
1016 stop_chunk(),
1017 ]]));
1018 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
1019 let (tx, mut rx) = mpsc::channel::<TurnEvent>(64);
1020 runner
1021 .run_streaming(vec![Message::new(Role::User, "hi")], tx)
1022 .await
1023 .unwrap();
1024
1025 let events = drain(&mut rx);
1026 let deltas: Vec<_> = events
1027 .iter()
1028 .filter_map(|e| match e {
1029 TurnEvent::AssistantDelta(s) => Some(s.as_str()),
1030 _ => None,
1031 })
1032 .collect();
1033 assert_eq!(deltas, vec!["Hello", ", ", "world"]);
1034
1035 let has_done = events.iter().any(|e| {
1036 matches!(
1037 e,
1038 TurnEvent::Done {
1039 max_iterations_hit: false,
1040 ..
1041 }
1042 )
1043 });
1044 assert!(has_done);
1045 }
1046
1047 #[tokio::test]
1048 async fn run_streaming_announces_tool_start_and_end() {
1049 let dir = tempfile::tempdir().unwrap();
1050 let path = dir.path().join("hello.txt");
1051 tokio::fs::write(&path, "ok").await.unwrap();
1052
1053 let provider = Arc::new(FakeProvider::new(vec![
1054 vec![
1055 tool_chunk("fs_read", json!({ "path": path.to_string_lossy() })),
1056 stop_chunk(),
1057 ],
1058 vec![text_chunk("done"), stop_chunk()],
1059 ]));
1060 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
1061 let (tx, mut rx) = mpsc::channel::<TurnEvent>(64);
1062 runner
1063 .run_streaming(vec![Message::new(Role::User, "read it")], tx)
1064 .await
1065 .unwrap();
1066
1067 let events = drain(&mut rx);
1068 let start_seen = events
1069 .iter()
1070 .any(|e| matches!(e, TurnEvent::ToolStart { name, .. } if name == "fs_read"));
1071 let end_seen = events
1072 .iter()
1073 .any(|e| matches!(e, TurnEvent::ToolEnd { name, ok: true, .. } if name == "fs_read"));
1074 assert!(start_seen, "expected ToolStart for fs_read");
1075 assert!(end_seen, "expected successful ToolEnd for fs_read");
1076 }
1077
1078 #[tokio::test]
1079 async fn run_streaming_announces_tool_rejection() {
1080 let dir = tempfile::tempdir().unwrap();
1081 let path = dir.path().join("nope.txt");
1082 let provider = Arc::new(FakeProvider::new(vec![vec![
1083 tool_chunk(
1084 "fs_write",
1085 json!({ "path": path.to_string_lossy(), "content": "blocked" }),
1086 ),
1087 stop_chunk(),
1088 ]]));
1089 let runner = AgentRunner::new(provider, ToolRegistry::default_set(), "m");
1090 let (tx, mut rx) = mpsc::channel::<TurnEvent>(64);
1091 runner
1092 .run_streaming(vec![Message::new(Role::User, "write please")], tx)
1093 .await
1094 .unwrap();
1095
1096 let events = drain(&mut rx);
1097 let rejected = events
1098 .iter()
1099 .any(|e| matches!(e, TurnEvent::ToolRejected { name, .. } if name == "fs_write"));
1100 assert!(rejected, "expected ToolRejected for fs_write");
1101 }
1102
1103 #[test]
1104 fn auto_retrieve_formats_structured_results() {
1105 let output = json!({
1106 "results": [
1107 {"project": "infra", "source": "ref.md", "content": "alpha fact"},
1108 {"project": "ovelo", "source": "notes.md", "content": "beta fact"},
1109 ]
1110 });
1111 let block = format_retrieved_context(&output, 5).expect("context block");
1112 assert!(block.contains("projet infra / ref.md"));
1113 assert!(block.contains("alpha fact"));
1114 assert!(block.contains("beta fact"));
1115 }
1116
1117 #[test]
1118 fn auto_retrieve_parses_json_string_payload() {
1119 let payload = json!({
1121 "results": [{"project": "p", "source": "s", "content": "gamma"}]
1122 })
1123 .to_string();
1124 let block = format_retrieved_context(&Value::String(payload), 5).expect("context block");
1125 assert!(block.contains("gamma"));
1126 assert!(block.contains("projet p / s"));
1127 }
1128
1129 #[test]
1130 fn auto_retrieve_uses_plain_text_payload() {
1131 let output = Value::String("just some prose context".to_string());
1132 assert_eq!(
1133 format_retrieved_context(&output, 5).unwrap(),
1134 "just some prose context"
1135 );
1136 }
1137
1138 #[test]
1139 fn auto_retrieve_top_k_limits_chunks() {
1140 let output = json!({
1141 "results": [
1142 {"project":"p","source":"a","content":"one"},
1143 {"project":"p","source":"b","content":"two"},
1144 {"project":"p","source":"c","content":"three"},
1145 ]
1146 });
1147 let block = format_retrieved_context(&output, 2).unwrap();
1148 assert!(block.contains("one") && block.contains("two"));
1149 assert!(!block.contains("three"), "top_k=2 must drop the 3rd chunk");
1150 }
1151
1152 #[test]
1153 fn auto_retrieve_none_on_empty() {
1154 assert!(format_retrieved_context(&json!({"results": []}), 5).is_none());
1155 assert!(format_retrieved_context(&Value::String(String::new()), 5).is_none());
1156 }
1157}