1use std::sync::Arc;
2
3use serde::Deserialize;
4use serde_json::Value;
5
6use crate::error::{AgnoError, Result};
7use crate::governance::{AccessController, Action, Principal, Role as GovernanceRole};
8use crate::hooks::{AgentHook, ConfirmationHandler};
9use crate::knowledge::Retriever;
10use crate::llm::{LanguageModel, ModelCompletion};
11use crate::memory::ConversationMemory;
12use crate::message::{Message, Role};
13use crate::metrics::{MetricsTracker, RunGuard};
14use crate::telemetry::{TelemetryCollector, TelemetryLabels};
15use crate::tool::ToolRegistry;
16
17#[derive(Debug, Deserialize, PartialEq)]
19#[serde(tag = "action", rename_all = "snake_case")]
20pub enum AgentDirective {
21 Respond { content: String },
22 CallTool { name: String, arguments: Value },
23}
24
25pub struct Agent<M: LanguageModel> {
27 system_prompt: String,
28 model: Arc<M>,
29 tools: ToolRegistry,
30 memory: ConversationMemory,
31 max_steps: usize,
32 input_schema: Option<serde_json::Value>,
33 output_schema: Option<serde_json::Value>,
34 hooks: Vec<Arc<dyn AgentHook>>,
35 retriever: Option<Arc<dyn Retriever>>,
36 require_tool_confirmation: bool,
37 confirmation_handler: Option<Arc<dyn ConfirmationHandler>>,
38 access_control: Option<Arc<AccessController>>,
39 principal: Principal,
40 metrics: Option<MetricsTracker>,
41 telemetry: Option<TelemetryCollector>,
42 streaming: bool,
43 workflow_label: Option<String>,
44}
45
46impl<M: LanguageModel> Agent<M> {
47 pub fn new(model: Arc<M>) -> Self {
48 Self {
49 system_prompt: "You are a helpful agent.".to_string(),
50 model,
51 tools: ToolRegistry::new(),
52 memory: ConversationMemory::default(),
53 max_steps: 6,
54 input_schema: None,
55 output_schema: None,
56 hooks: Vec::new(),
57 retriever: None,
58 require_tool_confirmation: false,
59 confirmation_handler: None,
60 access_control: None,
61 principal: Principal {
62 id: "anonymous".into(),
63 role: GovernanceRole::User,
64 tenant: None,
65 },
66 metrics: None,
67 telemetry: None,
68 streaming: false,
69 workflow_label: None,
70 }
71 }
72
73 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
74 self.system_prompt = prompt.into();
75 self
76 }
77
78 pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
79 self.tools = tools;
80 self
81 }
82
83 pub fn with_memory(mut self, memory: ConversationMemory) -> Self {
84 self.memory = memory;
85 self
86 }
87
88 pub fn with_access_control(mut self, controller: Arc<AccessController>) -> Self {
89 self.access_control = Some(controller);
90 self
91 }
92
93 pub fn with_principal(mut self, principal: Principal) -> Self {
94 self.principal = principal;
95 self
96 }
97
98 pub fn with_metrics(mut self, metrics: MetricsTracker) -> Self {
99 self.metrics = Some(metrics);
100 self
101 }
102
103 pub fn with_telemetry(mut self, telemetry: TelemetryCollector) -> Self {
104 self.telemetry = Some(telemetry);
105 self
106 }
107
108 pub fn with_workflow_label(mut self, workflow: impl Into<String>) -> Self {
109 self.workflow_label = Some(workflow.into());
110 self
111 }
112
113 pub fn with_input_schema(mut self, schema: serde_json::Value) -> Self {
114 self.input_schema = Some(schema);
115 self
116 }
117
118 pub fn with_output_schema(mut self, schema: serde_json::Value) -> Self {
119 self.output_schema = Some(schema);
120 self
121 }
122
123 pub fn with_hook(mut self, hook: Arc<dyn AgentHook>) -> Self {
124 self.hooks.push(hook);
125 self
126 }
127
128 pub fn with_retriever(mut self, retriever: Arc<dyn Retriever>) -> Self {
129 self.retriever = Some(retriever);
130 self
131 }
132
133 pub fn require_tool_confirmation(mut self, handler: Arc<dyn ConfirmationHandler>) -> Self {
134 self.require_tool_confirmation = true;
135 self.confirmation_handler = Some(handler);
136 self
137 }
138
139 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
140 self.max_steps = max_steps.max(1);
141 self
142 }
143
144 pub fn with_streaming(mut self, streaming: bool) -> Self {
145 self.streaming = streaming;
146 self
147 }
148
149 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
150 &mut self.tools
151 }
152
153 pub fn tool_names(&self) -> Vec<String> {
154 self.tools.names()
155 }
156
157 pub fn set_principal(&mut self, principal: Principal) {
158 self.principal = principal;
159 }
160
161 pub fn attach_access_control(&mut self, controller: Arc<AccessController>) {
162 self.access_control = Some(controller);
163 }
164
165 pub fn attach_metrics(&mut self, metrics: MetricsTracker) {
166 self.metrics = Some(metrics);
167 }
168
169 pub fn attach_telemetry(&mut self, telemetry: TelemetryCollector) {
170 self.telemetry = Some(telemetry);
171 }
172
173 pub fn memory(&self) -> &ConversationMemory {
174 &self.memory
175 }
176
177 pub fn sync_memory_from(&mut self, memory: &ConversationMemory) {
178 self.memory = memory.clone();
179 }
180
181 pub fn take_memory_snapshot(&self) -> ConversationMemory {
182 self.memory.clone()
183 }
184
185 pub async fn respond(&mut self, user_input: impl Into<String>) -> Result<String> {
187 let principal = self.principal.clone();
188 self.respond_for(principal, user_input).await
189 }
190
191 pub async fn respond_for(
192 &mut self,
193 principal: Principal,
194 user_input: impl Into<String>,
195 ) -> Result<String> {
196 if let Some(ctrl) = &self.access_control {
197 if !ctrl.authorize(&principal, &Action::SendMessage) {
198 return Err(AgnoError::Protocol(
199 "principal not authorized to send messages".into(),
200 ));
201 }
202 }
203
204 let base_labels = TelemetryLabels {
205 tenant: principal.tenant.clone(),
206 tool: None,
207 workflow: self.workflow_label.clone(),
208 };
209 if let Some(telemetry) = &self.telemetry {
210 telemetry.record(
211 "user_message",
212 serde_json::json!({"principal": principal.id.clone(), "tenant": principal.tenant}),
213 base_labels.clone(),
214 );
215 }
216
217 let mut run_guard: Option<RunGuard> = self
218 .metrics
219 .as_ref()
220 .map(|m| m.start_run(base_labels.clone()));
221 self.memory.push(Message::user(user_input));
222
223 for _ in 0..self.max_steps {
224 let contexts = self.retrieve_contexts().await?;
225 let system_prompt = self.build_system_message(&contexts)?;
226 let mut request_messages = vec![Message::system(system_prompt)];
227 request_messages.extend(self.memory.iter().cloned());
228 let snapshot: Vec<Message> = request_messages.clone();
229 for hook in &self.hooks {
230 hook.before_model(snapshot.as_slice()).await?;
231 }
232 let completion = self
233 .model
234 .complete_chat(&request_messages, &self.tools.describe(), self.streaming)
235 .await?;
236 for hook in &self.hooks {
237 let serialized = serde_json::to_string(&completion)
238 .unwrap_or_else(|_| "<unserializable>".into());
239 hook.after_model(&serialized).await?;
240 }
241
242 if !completion.tool_calls.is_empty() {
243 for mut call in completion.tool_calls {
244 if call.id.is_none() {
245 call.id = Some(format!("call-{}", self.memory.len()));
246 }
247 if let Some(ctrl) = &self.access_control {
248 if !ctrl.authorize(&principal, &Action::CallTool(call.name.clone())) {
249 if let Some(guard) = run_guard.as_mut() {
250 guard.record_failure(Some(call.name.clone()));
251 }
252 return Err(AgnoError::Protocol(format!(
253 "principal `{}` not allowed to call tool `{}`",
254 principal.id, call.name
255 )));
256 }
257 }
258 if self.require_tool_confirmation {
259 if let Some(handler) = &self.confirmation_handler {
260 let approved = handler.confirm_tool_call(&call).await?;
261 if !approved {
262 self.memory.push(Message::assistant(format!(
263 "Tool call `{}` rejected by guardrail",
264 call.name
265 )));
266 continue;
267 }
268 }
269 }
270 if let Some(guard) = run_guard.as_mut() {
271 guard.record_tool_call(call.name.clone());
272 }
273 let call_id = call.id.clone();
274 self.memory.push(Message {
275 role: Role::Assistant,
276 content: format!("Calling tool `{}`", call.name),
277 tool_call: Some(call.clone()),
278 tool_result: None,
279 attachments: Vec::new(),
280 });
281
282 for hook in &self.hooks {
283 hook.before_tool_call(
284 self.memory
285 .iter()
286 .last()
287 .unwrap()
288 .tool_call
289 .as_ref()
290 .unwrap(),
291 )
292 .await?;
293 }
294 let output = match self.tools.call(&call.name, call.arguments.clone()).await {
295 Ok(value) => value,
296 Err(err) => {
297 if let Some(guard) = run_guard.as_mut() {
298 guard.record_failure(Some(call.name.clone()));
299 }
300 if let Some(telemetry) = &self.telemetry {
301 telemetry.record_failure(
302 format!("tool::{}", call.name),
303 format!("{err}"),
304 0,
305 base_labels.clone().with_tool(call.name.clone()),
306 );
307 }
308 return Err(err);
309 }
310 };
311 let result_message =
312 Message::tool_with_call(&call.name, output, call_id.clone());
313 for hook in &self.hooks {
314 if let Some(result) = result_message.tool_result.as_ref() {
315 hook.after_tool_result(result).await?;
316 }
317 }
318 self.memory.push(result_message);
319 }
320 continue;
321 }
322
323 match completion {
324 ModelCompletion {
325 content: Some(content),
326 tool_calls,
327 } if tool_calls.is_empty() => {
328 self.memory.push(Message::assistant(&content));
329 if let Some(guard) = run_guard.take() {
330 guard.finish(true);
331 }
332 return Ok(content);
333 }
334 _ => {
335 if let Some(guard) = run_guard.as_mut() {
336 guard.record_failure(None);
337 }
338 return Err(AgnoError::Protocol(
339 "Model response missing content and tool calls".into(),
340 ));
341 }
342 }
343 }
344
345 if let Some(guard) = run_guard {
346 guard.finish(false);
347 }
348
349 Err(AgnoError::Protocol(
350 "Agent reached the step limit without returning a response".into(),
351 ))
352 }
353
354 async fn retrieve_contexts(&self) -> Result<Vec<String>> {
355 if let Some(retriever) = &self.retriever {
356 return Ok(retriever
357 .retrieve(
358 self.memory
359 .iter()
360 .rev()
361 .find(|m| m.role == Role::User)
362 .map(|m| m.content.as_str())
363 .unwrap_or_default(),
364 3,
365 )
366 .await
367 .unwrap_or_default());
368 }
369 Ok(Vec::new())
370 }
371
372 fn build_system_message(&self, contexts: &[String]) -> Result<String> {
373 let mut prompt = String::new();
374 prompt.push_str(&self.system_prompt);
375 prompt.push_str("\n\nWhen a tool is relevant, call it with appropriate JSON arguments. Return a direct response when no tool is needed.\n");
376 if let Some(schema) = &self.input_schema {
377 prompt.push_str(&format!(
378 "User input is expected to follow this JSON shape: {}\n\n",
379 schema
380 ));
381 }
382 if let Some(schema) = &self.output_schema {
383 prompt.push_str(&format!(
384 "When responding directly, conform to this output schema: {}\n",
385 schema
386 ));
387 }
388 if self.tools.names().is_empty() {
389 prompt.push_str("No tools are available.\n");
390 } else {
391 prompt.push_str("Available tools:\n");
392 for tool in self.tools.describe() {
393 prompt.push_str(&format!("- {}: {}", tool.name, tool.description));
394 if let Some(params) = &tool.parameters {
395 prompt.push_str(&format!(" (parameters: {})", params));
396 }
397 prompt.push('\n');
398 }
399 }
400 if !contexts.is_empty() {
401 prompt.push_str("\nContext snippets:\n");
402 for ctx in contexts {
403 prompt.push_str("- ");
404 prompt.push_str(ctx);
405 prompt.push('\n');
406 }
407 }
408
409 Ok(prompt)
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use async_trait::async_trait;
417
418 use crate::tool::Tool;
419 use crate::StubModel;
420
421 struct EchoTool;
422
423 #[async_trait]
424 impl Tool for EchoTool {
425 fn name(&self) -> &str {
426 "echo"
427 }
428
429 fn description(&self) -> &str {
430 "Echoes the `text` field back"
431 }
432
433 async fn call(&self, input: Value) -> Result<Value> {
434 Ok(input)
435 }
436 }
437
438 #[tokio::test]
439 async fn returns_llm_response_without_tools() {
440 let model = StubModel::new(vec![r#"{"action":"respond","content":"Hello!"}"#.into()]);
441 let mut agent = Agent::new(model);
442
443 let reply = agent.respond("hi").await.unwrap();
444
445 assert_eq!(reply, "Hello!");
446 assert_eq!(agent.memory().len(), 2);
447 }
448
449 #[tokio::test]
450 async fn executes_tool_then_replies() {
451 let model = StubModel::new(vec![
452 r#"{"action":"call_tool","name":"echo","arguments":{"text":"ping"}}"#.into(),
453 r#"{"action":"respond","content":"Echoed your request."}"#.into(),
454 ]);
455 let mut tools = ToolRegistry::new();
456 tools.register(EchoTool);
457
458 let mut agent = Agent::new(model).with_tools(tools);
459
460 let reply = agent.respond("say ping").await.unwrap();
461
462 assert_eq!(reply, "Echoed your request.");
463 assert_eq!(agent.memory().len(), 4);
464 }
465
466 #[tokio::test]
467 async fn includes_tool_metadata_in_prompt() {
468 struct DescribingTool;
469
470 #[async_trait]
471 impl Tool for DescribingTool {
472 fn name(&self) -> &str {
473 "describe"
474 }
475
476 fn description(&self) -> &str {
477 "Replies with metadata"
478 }
479
480 fn parameters(&self) -> Option<Value> {
481 Some(serde_json::json!({"type":"object","properties":{"id":{"type":"string"}}}))
482 }
483
484 async fn call(&self, _input: Value) -> Result<Value> {
485 Ok(serde_json::json!({"ok": true}))
486 }
487 }
488
489 let model = StubModel::new(vec![r#"{"action":"respond","content":"done"}"#.into()]);
490 let mut tools = ToolRegistry::new();
491 tools.register(DescribingTool);
492
493 let agent = Agent::new(model).with_tools(tools);
494 let prompt = agent.build_system_message(&[]).unwrap();
495
496 assert!(prompt.contains("Replies with metadata"));
497 assert!(prompt.contains("Available tools"));
498 }
499}